mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Add Qwen3-Next to layer-wise benchmarks (#9065)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
This commit is contained in:
parent
fdb0787e85
commit
cc4c980e03
@ -83,7 +83,6 @@ common-files: &common_files |
|
||||
examples/infinitebench/compute_scores.py |
|
||||
examples/infinitebench/construct_synthetic_dataset.py |
|
||||
examples/infinitebench/eval_utils.py |
|
||||
examples/layer_wise_benchmarks/run_single.py |
|
||||
examples/llm-api/_tensorrt_engine/llm_eagle_decoding.py |
|
||||
examples/llm-api/_tensorrt_engine/llm_eagle2_decoding.py |
|
||||
examples/llm-api/_tensorrt_engine/llm_inference_customize.py |
|
||||
@ -811,7 +810,6 @@ common-files: &common_files |
|
||||
tensorrt_llm/serve/tool_parser/utils.py |
|
||||
tensorrt_llm/tools/__init__.py |
|
||||
tensorrt_llm/tools/importlib_utils.py |
|
||||
tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py |
|
||||
tensorrt_llm/tools/multimodal_builder.py |
|
||||
tensorrt_llm/tools/onnx_utils.py |
|
||||
tensorrt_llm/tools/plugin_gen/__init__.py |
|
||||
@ -1188,7 +1186,6 @@ common-files: &common_files |
|
||||
tests/unittest/tools/plugin_gen/test_core.py |
|
||||
tests/unittest/tools/plugin_gen/test_plugin_gen.py |
|
||||
tests/unittest/tools/plugin_gen/test_shape_infer.py |
|
||||
tests/unittest/tools/test_layer_wise_benchmarks.py |
|
||||
tests/unittest/tools/test_prepare_dataset.py |
|
||||
tests/unittest/tools/test_test_to_stage_mapping.py |
|
||||
tests/unittest/trt/__init__.py |
|
||||
|
||||
@ -15,7 +15,7 @@ pip install -e ../..
|
||||
**Step 3:** In the container, run benchmarks and generate profiles:
|
||||
|
||||
```bash
|
||||
# Run DeepSeek-R1
|
||||
# Run DeepSeek-R1 NVFP4
|
||||
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml
|
||||
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml
|
||||
|
||||
@ -24,7 +24,7 @@ NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSee
|
||||
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
|
||||
|
||||
# Run DeepSeek-V3.2-Exp with 32k context length
|
||||
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --max-num-tokens $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
|
||||
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
|
||||
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769
|
||||
|
||||
# Run with attention TP
|
||||
@ -48,6 +48,10 @@ NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --moe-back
|
||||
# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
|
||||
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp
|
||||
|
||||
# Run Qwen3-Next (balanced routing is not implemented)
|
||||
NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified
|
||||
NP=2 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified
|
||||
|
||||
# Run with DeepEP A2A
|
||||
NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_ctx.yaml --moe-backend WIDEEP
|
||||
NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP
|
||||
@ -76,7 +80,7 @@ It uses the image recorded in `../../jenkins/current_image_tags.properties`. The
|
||||
**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes:
|
||||
|
||||
```bash
|
||||
# Run DeepSeek-R1 with wide ep: uses MNNVL A2A if applicable
|
||||
# Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable
|
||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP
|
||||
|
||||
# Run with attention TP and TRTLLMGen
|
||||
@ -93,3 +97,9 @@ SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run_single.sh config
|
||||
## Parse profiles
|
||||
|
||||
Coming soon.
|
||||
|
||||
## Trouble shooting
|
||||
|
||||
1. Error `fp8 blockscale gemm only support Hopper` on Blackwell.
|
||||
|
||||
The default MoE backend "CUTLASS" does not support FP8 weights. Please choose the same MoE backend as your end-to-end config. A typical choice is adding `--moe-backend DEEPGEMM`, `--moe-backend TRTLLM`, or `--moe-backend WIDEEP` option.
|
||||
|
||||
@ -9,7 +9,6 @@ max_seq_len: 9220 # 8192 + 1024 + 4
|
||||
enable_attention_dp: true
|
||||
|
||||
# Model init args
|
||||
max_num_tokens: 20480
|
||||
moe_backend: CUTLASS
|
||||
use_cuda_graph: false
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ max_seq_len: 9220 # 8192 + 1024 + 4
|
||||
enable_attention_dp: true
|
||||
|
||||
# Model init args
|
||||
max_num_tokens: 4096 # MTP3 as max
|
||||
moe_backend: CUTLASS
|
||||
use_cuda_graph: true
|
||||
|
||||
|
||||
@ -8,8 +8,7 @@ import yaml
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
|
||||
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks.deepseekv3_runner import (
|
||||
BalanceMethod, DeepSeekV3Runner)
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls
|
||||
|
||||
|
||||
def comma_separated_ints(s):
|
||||
@ -23,30 +22,25 @@ parser.add_argument("--model", type=str, help="Pretrained model name or path")
|
||||
parser.add_argument(
|
||||
"--layer-indices",
|
||||
type=comma_separated_ints,
|
||||
help="Comma separated indices of layers, should be a contiguous range")
|
||||
help="Comma separated indices of layers, should be a contiguous range",
|
||||
)
|
||||
parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"])
|
||||
parser.add_argument("--scaled-from", type=int)
|
||||
# KV cache related args
|
||||
parser.add_argument("--max-batch-size", type=int)
|
||||
parser.add_argument("--tokens-per-block", type=int)
|
||||
parser.add_argument("--max-seq-len", type=int)
|
||||
group = parser.add_mutually_exclusive_group(required=False)
|
||||
group.add_argument("--enable-attention-dp",
|
||||
action="store_true",
|
||||
dest="enable_attention_dp")
|
||||
group.add_argument("--no-enable-attention-dp",
|
||||
action="store_false",
|
||||
dest="enable_attention_dp")
|
||||
group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp")
|
||||
group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp")
|
||||
parser.set_defaults(enable_attention_dp=None)
|
||||
# Model init args
|
||||
parser.add_argument("--max-num-tokens", type=int)
|
||||
parser.add_argument("--moe-backend", type=str)
|
||||
parser.add_argument("--moe-max-num-tokens", type=int)
|
||||
group = parser.add_mutually_exclusive_group(required=False)
|
||||
group.add_argument("--use-cuda-graph",
|
||||
action="store_true",
|
||||
dest="use_cuda_graph")
|
||||
group.add_argument("--no-use-cuda-graph",
|
||||
action="store_false",
|
||||
dest="use_cuda_graph")
|
||||
group.add_argument("--use-cuda-graph", action="store_true", dest="use_cuda_graph")
|
||||
group.add_argument("--no-use-cuda-graph", action="store_false", dest="use_cuda_graph")
|
||||
parser.set_defaults(use_cuda_graph=None)
|
||||
# Per iteration args
|
||||
parser.add_argument("--batch-size", type=int)
|
||||
@ -59,8 +53,12 @@ with open(args.config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
del args.config_path
|
||||
for k, v in vars(args).items():
|
||||
if v is None:
|
||||
if v is None and k in config:
|
||||
setattr(args, k, config[k])
|
||||
if args.max_batch_size is None:
|
||||
args.max_batch_size = args.batch_size
|
||||
if args.max_num_tokens is None:
|
||||
args.max_num_tokens = args.max_batch_size * args.seq_len_q
|
||||
print(args)
|
||||
|
||||
# MPI args
|
||||
@ -70,43 +68,49 @@ local_rank = local_mpi_rank()
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# Create KV cache manager
|
||||
mapping = DeepSeekV3Runner.create_mapping(
|
||||
enable_attention_dp=args.enable_attention_dp)
|
||||
max_batch_size = 2048
|
||||
kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager(
|
||||
Runner = get_runner_cls(args.model)
|
||||
mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp)
|
||||
kv_cache_manager = Runner.create_kv_cache_manager(
|
||||
args.model,
|
||||
mapping,
|
||||
tokens_per_block=args.tokens_per_block,
|
||||
max_batch_size=max_batch_size,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_seq_len=args.max_seq_len,
|
||||
layer_indices=args.layer_indices)
|
||||
attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8)
|
||||
layer_indices=args.layer_indices,
|
||||
)
|
||||
attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
|
||||
|
||||
# Create other global objects
|
||||
AutoTuner.get().clear_cache()
|
||||
capture_stream = torch.cuda.Stream()
|
||||
|
||||
# Create Runner
|
||||
runner = DeepSeekV3Runner(args.model,
|
||||
mapping,
|
||||
moe_backend=args.moe_backend,
|
||||
layer_indices=args.layer_indices,
|
||||
scaled_from=args.scaled_from,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
use_cuda_graph=args.use_cuda_graph)
|
||||
runner = Runner(
|
||||
args.model,
|
||||
mapping,
|
||||
moe_backend=args.moe_backend,
|
||||
layer_indices=args.layer_indices,
|
||||
scaled_from=args.scaled_from,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
moe_max_num_tokens=args.moe_max_num_tokens,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
)
|
||||
|
||||
# Warm up
|
||||
assert args.batch_size <= max_batch_size
|
||||
assert args.batch_size <= args.max_batch_size
|
||||
assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len
|
||||
run_pack = runner.create_run_pack(args.run_type,
|
||||
batch_size=args.batch_size,
|
||||
seq_len_q=args.seq_len_q,
|
||||
seq_len_kv_cache=args.seq_len_kv_cache,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
attn_workspace=attn_workspace)
|
||||
runner.replace_routing_method(balance_method=BalanceMethod[args.balance_method],
|
||||
balance_ratio=args.balance_ratio)
|
||||
run_pack = runner.create_run_pack(
|
||||
args.run_type,
|
||||
batch_size=args.batch_size,
|
||||
seq_len_q=args.seq_len_q,
|
||||
seq_len_kv_cache=args.seq_len_kv_cache,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
attn_workspace=attn_workspace,
|
||||
)
|
||||
runner.replace_routing_method(
|
||||
balance_method=BalanceMethod[args.balance_method], balance_ratio=args.balance_ratio
|
||||
)
|
||||
capture_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(capture_stream):
|
||||
run_pack()
|
||||
@ -120,21 +124,15 @@ torch.cuda.cudart().cudaProfilerStart()
|
||||
if args.use_cuda_graph:
|
||||
with with_multi_stream(True):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g,
|
||||
stream=capture_stream,
|
||||
capture_error_mode="global"):
|
||||
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
|
||||
run_pack()
|
||||
|
||||
warmup_times = 20
|
||||
run_times = 100
|
||||
events = [
|
||||
torch.cuda.Event(enable_timing=True)
|
||||
for _ in range(warmup_times + run_times + 1)
|
||||
]
|
||||
events = [torch.cuda.Event(enable_timing=True) for _ in range(warmup_times + run_times + 1)]
|
||||
for i in range(warmup_times + run_times):
|
||||
events[i].record()
|
||||
with nvtx.annotate(
|
||||
f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"):
|
||||
with nvtx.annotate(f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"):
|
||||
if args.use_cuda_graph:
|
||||
g.replay()
|
||||
else:
|
||||
@ -144,16 +142,16 @@ torch.cuda.synchronize()
|
||||
|
||||
# Print statistics
|
||||
# Print before `cudaProfilerStop` to ensure messages are included in the profile
|
||||
time_list = [
|
||||
start.elapsed_time(stop) for start, stop in zip(events, events[1:])
|
||||
]
|
||||
time_list = [start.elapsed_time(stop) for start, stop in zip(events, events[1:])]
|
||||
time_list = time_list[warmup_times:]
|
||||
print(f"[RANK {rank}]"
|
||||
f" min {np.min(time_list) * 1000:.1f}"
|
||||
f" max {np.max(time_list) * 1000:.1f}"
|
||||
f" mean {np.mean(time_list) * 1000:.1f}"
|
||||
f" median {np.median(time_list) * 1000:.1f}"
|
||||
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
|
||||
f" (us)")
|
||||
print(
|
||||
f"[RANK {rank}]"
|
||||
f" min {np.min(time_list) * 1000:.1f}"
|
||||
f" max {np.max(time_list) * 1000:.1f}"
|
||||
f" mean {np.mean(time_list) * 1000:.1f}"
|
||||
f" median {np.median(time_list) * 1000:.1f}"
|
||||
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
|
||||
f" (us)"
|
||||
)
|
||||
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
|
||||
@ -123,7 +123,6 @@ exclude = [
|
||||
"examples/infinitebench/compute_scores.py",
|
||||
"examples/infinitebench/construct_synthetic_dataset.py",
|
||||
"examples/infinitebench/eval_utils.py",
|
||||
"examples/layer_wise_benchmarks/run_single.py",
|
||||
"examples/llm-api/_tensorrt_engine/llm_eagle_decoding.py",
|
||||
"examples/llm-api/_tensorrt_engine/llm_eagle2_decoding.py",
|
||||
"examples/llm-api/_tensorrt_engine/llm_inference_customize.py",
|
||||
@ -851,7 +850,6 @@ exclude = [
|
||||
"tensorrt_llm/serve/tool_parser/utils.py",
|
||||
"tensorrt_llm/tools/__init__.py",
|
||||
"tensorrt_llm/tools/importlib_utils.py",
|
||||
"tensorrt_llm/tools/layer_wise_benchmarks/deepseekv3_runner.py",
|
||||
"tensorrt_llm/tools/multimodal_builder.py",
|
||||
"tensorrt_llm/tools/onnx_utils.py",
|
||||
"tensorrt_llm/tools/plugin_gen/__init__.py",
|
||||
@ -1228,7 +1226,6 @@ exclude = [
|
||||
"tests/unittest/tools/plugin_gen/test_core.py",
|
||||
"tests/unittest/tools/plugin_gen/test_plugin_gen.py",
|
||||
"tests/unittest/tools/plugin_gen/test_shape_infer.py",
|
||||
"tests/unittest/tools/test_layer_wise_benchmarks.py",
|
||||
"tests/unittest/tools/test_prepare_dataset.py",
|
||||
"tests/unittest/tools/test_test_to_stage_mapping.py",
|
||||
"tests/unittest/trt/__init__.py",
|
||||
|
||||
@ -0,0 +1,7 @@
|
||||
from .runner_factory import get_runner_cls
|
||||
from .runner_interface import BalanceMethod
|
||||
|
||||
__all__ = [
|
||||
"BalanceMethod",
|
||||
"get_runner_cls",
|
||||
]
|
||||
@ -1,49 +1,22 @@
|
||||
import functools
|
||||
import os
|
||||
import weakref
|
||||
from enum import IntEnum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.models.modeling_deepseekv3
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_deepseekv3 import (
|
||||
DeepseekV3DecoderLayer, DeepseekV3Gate)
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import WideEPMoE
|
||||
from tensorrt_llm._torch.modules.linear import Linear, WeightMode
|
||||
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3DecoderLayer, DeepseekV3Gate
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._torch.pyexecutor._util import get_kv_cache_manager_cls
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm._torch.utils import (AuxStreamType, get_model_extra_attrs,
|
||||
model_extra_attrs)
|
||||
from tensorrt_llm._utils import (local_mpi_size, mpi_rank, mpi_world_size,
|
||||
torch_dtype_to_binding)
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm._torch.utils import AuxStreamType
|
||||
from tensorrt_llm._utils import mpi_rank, mpi_world_size
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
|
||||
class BalanceMethod(IntEnum):
|
||||
NotModified = 1
|
||||
Balanced = 2
|
||||
ImbalancedRanks = 3
|
||||
ImbalancedExperts = 4
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def round_up(a, b):
|
||||
return ceil_div(a, b) * b
|
||||
from .runner_interface import BalanceMethod, RunnerBase
|
||||
from .runner_utils import RunnerMixin, ceil_div
|
||||
|
||||
|
||||
class RoutingMethod(DeepseekV3Gate):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.world_size = mpi_world_size()
|
||||
@ -52,98 +25,120 @@ class RoutingMethod(DeepseekV3Gate):
|
||||
self.balance_ratio = None
|
||||
|
||||
def apply(self, router_logits) -> (torch.Tensor, torch.Tensor):
|
||||
token_selected_experts, token_final_scales = super().apply(
|
||||
router_logits)
|
||||
token_selected_experts, token_final_scales = super().apply(router_logits)
|
||||
num_experts = self.weight.shape[0]
|
||||
if self.balance_method == BalanceMethod.NotModified:
|
||||
pass
|
||||
elif self.balance_method == BalanceMethod.Balanced:
|
||||
token_selected_experts = RoutingMethod.get_balanced_selection(
|
||||
token_selected_experts.shape[0],
|
||||
token_selected_experts.shape[1], num_experts,
|
||||
token_selected_experts.dtype, self.world_size, self.rank)
|
||||
token_selected_experts.shape[1],
|
||||
num_experts,
|
||||
token_selected_experts.dtype,
|
||||
self.world_size,
|
||||
self.rank,
|
||||
)
|
||||
elif self.balance_method == BalanceMethod.ImbalancedRanks:
|
||||
token_selected_experts = RoutingMethod.get_all_to_one_selection(
|
||||
token_selected_experts.shape[0],
|
||||
token_selected_experts.shape[1], num_experts,
|
||||
self.balance_ratio, token_selected_experts.dtype,
|
||||
self.world_size, self.rank)
|
||||
token_selected_experts.shape[1],
|
||||
num_experts,
|
||||
self.balance_ratio,
|
||||
token_selected_experts.dtype,
|
||||
self.world_size,
|
||||
self.rank,
|
||||
)
|
||||
elif self.balance_method == BalanceMethod.ImbalancedExperts:
|
||||
token_selected_experts = RoutingMethod.get_balanced_rank_imbalanced_expert_selection(
|
||||
token_selected_experts.shape[0],
|
||||
token_selected_experts.shape[1], num_experts,
|
||||
self.balance_ratio, token_selected_experts.dtype,
|
||||
self.world_size, self.rank)
|
||||
token_selected_experts.shape[1],
|
||||
num_experts,
|
||||
self.balance_ratio,
|
||||
token_selected_experts.dtype,
|
||||
self.world_size,
|
||||
self.rank,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Not support balance_method {self.balance_method}")
|
||||
raise NotImplementedError(f"Not support balance_method {self.balance_method}")
|
||||
return token_selected_experts, token_final_scales
|
||||
|
||||
@functools.cache
|
||||
@staticmethod
|
||||
def get_balanced_selection(num_tokens, top_k, num_experts, dtype,
|
||||
world_size, rank):
|
||||
a = torch.arange(num_tokens * world_size * top_k,
|
||||
dtype=dtype,
|
||||
device="cuda").view(num_tokens, world_size,
|
||||
top_k)[:, rank]
|
||||
experts = (a * (num_experts // world_size + 1) + a // num_experts *
|
||||
(num_experts // world_size)) % num_experts
|
||||
@functools.cache
|
||||
def get_balanced_selection(num_tokens, top_k, num_experts, dtype, world_size, rank):
|
||||
a = torch.arange(num_tokens * world_size * top_k, dtype=dtype, device="cuda").view(
|
||||
num_tokens, world_size, top_k
|
||||
)[:, rank]
|
||||
experts = (
|
||||
a * (num_experts // world_size + 1) + a // num_experts * (num_experts // world_size)
|
||||
) % num_experts
|
||||
return experts.contiguous()
|
||||
|
||||
@staticmethod
|
||||
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio,
|
||||
world_size, rank):
|
||||
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank):
|
||||
num_tokens, top_k = imbalanced_experts.shape
|
||||
dtype = imbalanced_experts.dtype
|
||||
balanced_experts = RoutingMethod.get_balanced_selection(
|
||||
num_tokens, top_k, num_experts, dtype, world_size, rank)
|
||||
num_tokens, top_k, num_experts, dtype, world_size, rank
|
||||
)
|
||||
num_balanced_tokens = round(num_tokens * balance_ratio)
|
||||
if balance_ratio != 0:
|
||||
# Activate all experts
|
||||
num_balanced_tokens = max(num_balanced_tokens,
|
||||
ceil_div(num_experts, world_size * top_k))
|
||||
num_balanced_tokens = max(
|
||||
num_balanced_tokens, ceil_div(num_experts, world_size * top_k)
|
||||
)
|
||||
mixed_experts = balanced_experts.clone()
|
||||
mixed_experts[num_balanced_tokens:] = imbalanced_experts[
|
||||
num_balanced_tokens:]
|
||||
mixed_experts[num_balanced_tokens:] = imbalanced_experts[num_balanced_tokens:]
|
||||
return mixed_experts
|
||||
|
||||
@functools.cache
|
||||
@staticmethod
|
||||
def get_all_to_one_selection(num_tokens, top_k, num_experts, balance_ratio,
|
||||
dtype, world_size, rank):
|
||||
@functools.cache
|
||||
def get_all_to_one_selection(
|
||||
num_tokens, top_k, num_experts, balance_ratio, dtype, world_size, rank
|
||||
):
|
||||
assert num_experts // world_size >= top_k
|
||||
imbalanced_experts = torch.arange(
|
||||
num_tokens * top_k, dtype=dtype, device="cuda").view(
|
||||
num_tokens, top_k) % (num_experts // world_size)
|
||||
return RoutingMethod.apply_balance_ratio(imbalanced_experts,
|
||||
num_experts, balance_ratio,
|
||||
world_size, rank)
|
||||
imbalanced_experts = torch.arange(num_tokens * top_k, dtype=dtype, device="cuda").view(
|
||||
num_tokens, top_k
|
||||
) % (num_experts // world_size)
|
||||
return RoutingMethod.apply_balance_ratio(
|
||||
imbalanced_experts, num_experts, balance_ratio, world_size, rank
|
||||
)
|
||||
|
||||
@functools.cache
|
||||
@staticmethod
|
||||
def get_balanced_rank_imbalanced_expert_selection(num_tokens, top_k,
|
||||
num_experts,
|
||||
balance_ratio, dtype,
|
||||
world_size, rank):
|
||||
@functools.cache
|
||||
def get_balanced_rank_imbalanced_expert_selection(
|
||||
num_tokens, top_k, num_experts, balance_ratio, dtype, world_size, rank
|
||||
):
|
||||
experts_per_rank = num_experts // world_size
|
||||
activate_experts_per_rank = ceil_div(top_k, world_size)
|
||||
a = torch.arange(num_tokens * top_k, dtype=dtype,
|
||||
device="cuda").view(num_tokens, top_k)
|
||||
a = torch.arange(num_tokens * top_k, dtype=dtype, device="cuda").view(num_tokens, top_k)
|
||||
narrow_experts = a % (activate_experts_per_rank * world_size)
|
||||
imbalanced_experts = narrow_experts * experts_per_rank % num_experts + narrow_experts // world_size % experts_per_rank
|
||||
return RoutingMethod.apply_balance_ratio(imbalanced_experts,
|
||||
num_experts, balance_ratio,
|
||||
world_size, rank)
|
||||
imbalanced_experts = (
|
||||
narrow_experts * experts_per_rank % num_experts
|
||||
+ narrow_experts // world_size % experts_per_rank
|
||||
)
|
||||
return RoutingMethod.apply_balance_ratio(
|
||||
imbalanced_experts, num_experts, balance_ratio, world_size, rank
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekV3Runner:
|
||||
|
||||
def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *,
|
||||
moe_backend: str, layer_indices: List[int],
|
||||
scaled_from: Optional[int], max_seq_len: int,
|
||||
max_num_tokens: int, use_cuda_graph: bool):
|
||||
class DeepSeekV3Runner(RunnerMixin, RunnerBase):
|
||||
@staticmethod
|
||||
def has_mamba_metadata() -> bool:
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_model_name_or_path: str,
|
||||
mapping: Mapping,
|
||||
*,
|
||||
moe_backend: str,
|
||||
layer_indices: List[int],
|
||||
scaled_from: Optional[int],
|
||||
max_seq_len: int,
|
||||
max_num_tokens: int,
|
||||
moe_max_num_tokens: int,
|
||||
use_cuda_graph: bool,
|
||||
):
|
||||
# Temporally replace the gate class
|
||||
gate_cls_orig = tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate
|
||||
tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = RoutingMethod
|
||||
@ -158,7 +153,7 @@ class DeepSeekV3Runner:
|
||||
sparse_attention_config=None, # To be loaded from config
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
moe_max_num_tokens=None,
|
||||
moe_max_num_tokens=moe_max_num_tokens,
|
||||
moe_load_balancer=None,
|
||||
lora_config=None,
|
||||
allreduce_strategy=AllReduceStrategy.AUTO,
|
||||
@ -169,245 +164,56 @@ class DeepSeekV3Runner:
|
||||
use_low_precision_moe_combine=False,
|
||||
skip_create_weights_in_init=True,
|
||||
)
|
||||
|
||||
pretrained_config = self.model_config.pretrained_config
|
||||
if scaled_from is not None:
|
||||
# To run the problem size of $B$ GPUs on $A$ GPUs, we need:
|
||||
# (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change.
|
||||
# (2) MoE: If EP, reduce the number of experts; If TP, reduce head size.
|
||||
# Maintain the result of AllToAll method selection because it is affected by EP size.
|
||||
if not mapping.enable_attention_dp:
|
||||
if hasattr(pretrained_config, "index_n_heads"):
|
||||
raise NotImplementedError(
|
||||
"Not support Indexer TP for weak scaling")
|
||||
pretrained_config.num_attention_heads = pretrained_config.num_attention_heads // scaled_from * mapping.tp_size
|
||||
pretrained_config.num_key_value_heads = pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size
|
||||
if mapping.moe_ep_size != mapping.world_size:
|
||||
raise NotImplementedError("Not support MoE TP for weak scaling")
|
||||
pretrained_config.n_routed_experts = pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size
|
||||
select_alltoall_method_type_orig = WideEPMoE.select_alltoall_method_type
|
||||
|
||||
def select_alltoall_method_type(cls: type, mapping: Mapping,
|
||||
top_k: int, *args, **kwargs):
|
||||
# Replace the condition `mapping.moe_ep_size <= top_k` with `scaled_from <= top_k`
|
||||
# by replacing `top_k` with `fake_top_k`
|
||||
if scaled_from <= top_k:
|
||||
fake_top_k = mapping.moe_ep_size + 1
|
||||
else:
|
||||
fake_top_k = mapping.moe_ep_size - 1
|
||||
assert (mapping.moe_ep_size <= fake_top_k) == (scaled_from
|
||||
<= top_k)
|
||||
return select_alltoall_method_type_orig(mapping, fake_top_k,
|
||||
*args, **kwargs)
|
||||
with self.scaled_from_ctx(scaled_from, mapping, pretrained_config):
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
|
||||
aux_stream_dict = {
|
||||
AuxStreamType.Attention: aux_stream_list[0],
|
||||
AuxStreamType.MoeShared: aux_stream_list[0],
|
||||
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
|
||||
}
|
||||
|
||||
WideEPMoE.select_alltoall_method_type = select_alltoall_method_type
|
||||
layers = [
|
||||
DeepseekV3DecoderLayer(
|
||||
model_config=self.model_config,
|
||||
layer_idx=layer_idx,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
)
|
||||
for layer_idx in layer_indices
|
||||
]
|
||||
next_layer_layernorm = RMSNorm(
|
||||
hidden_size=pretrained_config.hidden_size,
|
||||
eps=pretrained_config.rms_norm_eps,
|
||||
dtype=pretrained_config.torch_dtype,
|
||||
)
|
||||
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
|
||||
aux_stream_dict = {
|
||||
AuxStreamType.Attention: aux_stream_list[0],
|
||||
AuxStreamType.MoeShared: aux_stream_list[0],
|
||||
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
|
||||
}
|
||||
# TODO: apply_layerwise_quant_config
|
||||
self.apply_quant_config_exclude_modules(layers, self.model_config.quant_config)
|
||||
for layer in layers:
|
||||
for module in layer.modules():
|
||||
if callable(getattr(module, "create_weights", None)):
|
||||
module.create_weights()
|
||||
layer.cuda()
|
||||
for module in layer.modules():
|
||||
if hasattr(module, "post_load_weights") and not getattr(
|
||||
module, "_weights_removed", False
|
||||
):
|
||||
module.post_load_weights()
|
||||
next_layer_layernorm.cuda()
|
||||
for layer, next_layer in zip(layers[:-1], layers[1:]):
|
||||
layer.next_layer_layernorm = next_layer.input_layernorm
|
||||
layers[-1].next_layer_layernorm = next_layer_layernorm
|
||||
|
||||
layers = [
|
||||
DeepseekV3DecoderLayer(
|
||||
model_config=self.model_config,
|
||||
layer_idx=layer_idx,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
) for layer_idx in layer_indices
|
||||
]
|
||||
next_layer_layernorm = RMSNorm(
|
||||
hidden_size=pretrained_config.hidden_size,
|
||||
eps=pretrained_config.rms_norm_eps,
|
||||
dtype=pretrained_config.torch_dtype)
|
||||
|
||||
# apply_quant_config_exclude_modules
|
||||
# Please refer to tensorrt_llm/_torch/models/modeling_utils.py
|
||||
quant_config = self.model_config.quant_config
|
||||
new_quant_config = QuantConfig(
|
||||
kv_cache_quant_algo=quant_config.kv_cache_quant_algo)
|
||||
for layer in layers:
|
||||
for name, module in layer.named_modules():
|
||||
name = f"model.layers.{layer.layer_idx}.{name}"
|
||||
candidates = [name]
|
||||
if isinstance(module, Linear):
|
||||
weight_mode = module.weights_loading_config.weight_mode
|
||||
if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
|
||||
# sometimes gate and up proj are not packed in the checkpoint,
|
||||
# but they still share the same exclusion rule
|
||||
candidates += [
|
||||
name.replace('gate_up_proj', 'gate_proj'),
|
||||
name.replace('gate_up_proj', 'up_proj')
|
||||
]
|
||||
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
|
||||
# sometimes q_proj, k_proj and v_proj are not packed in the checkpoint,
|
||||
# but they still share the same exclusion rule
|
||||
candidates += [
|
||||
name.replace('qkv_proj', 'q_proj'),
|
||||
name.replace('qkv_proj', 'k_proj'),
|
||||
name.replace('qkv_proj', 'v_proj')
|
||||
]
|
||||
is_excluded = any(
|
||||
quant_config.is_module_excluded_from_quantization(n)
|
||||
for n in candidates)
|
||||
if is_excluded and getattr(module, "quant_config",
|
||||
None) is not None:
|
||||
module.quant_config = new_quant_config
|
||||
for name, module in layer.named_modules():
|
||||
if callable(getattr(module, "create_weights", None)):
|
||||
module.create_weights()
|
||||
layer.cuda()
|
||||
for name, module in layer.named_modules():
|
||||
if hasattr(module, 'post_load_weights') and not getattr(
|
||||
module, '_weights_removed', False):
|
||||
module.post_load_weights()
|
||||
next_layer_layernorm.cuda()
|
||||
for layer, next_layer in zip(layers[:-1], layers[1:]):
|
||||
layer.next_layer_layernorm = next_layer.input_layernorm
|
||||
layers[-1].next_layer_layernorm = next_layer_layernorm
|
||||
|
||||
self.layers = layers
|
||||
if scaled_from is not None:
|
||||
WideEPMoE.select_alltoall_method_type = select_alltoall_method_type_orig
|
||||
self.layers = layers
|
||||
tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = gate_cls_orig
|
||||
|
||||
def create_run_pack(self,
|
||||
run_type: str,
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
seq_len_kv_cache: int,
|
||||
kv_cache_manager: KVCacheManager,
|
||||
attn_workspace: Optional[torch.Tensor] = None):
|
||||
if self.model_config.moe_backend == "TRTLLM" and os.getenv(
|
||||
"TRTLLM_ENABLE_PDL") != "1":
|
||||
raise ValueError(
|
||||
"Suggest to set TRTLLM_ENABLE_PDL=1 when moe_backend is TRTLLM")
|
||||
world_size = mpi_world_size()
|
||||
AttentionCls = get_attention_backend(
|
||||
self.model_config.attn_backend,
|
||||
self.model_config.sparse_attention_config)
|
||||
attn_metadata = AttentionCls.Metadata(
|
||||
seq_lens=torch.tensor([seq_len_q] * batch_size, dtype=torch.int),
|
||||
request_ids=list(range(batch_size)),
|
||||
max_num_requests=kv_cache_manager.max_batch_size,
|
||||
num_contexts={
|
||||
"CTX": batch_size,
|
||||
"GEN": 0,
|
||||
}[run_type],
|
||||
prompt_lens=[{
|
||||
"CTX": seq_len_q,
|
||||
"GEN": seq_len_kv_cache,
|
||||
}[run_type]] * batch_size,
|
||||
max_num_tokens=batch_size * seq_len_q,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=[seq_len_kv_cache] * batch_size,
|
||||
),
|
||||
workspace=attn_workspace,
|
||||
mapping=self.model_config.mapping,
|
||||
sparse_attention_config=self.model_config.sparse_attention_config,
|
||||
)
|
||||
attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q
|
||||
] * world_size
|
||||
attn_metadata.prepare()
|
||||
with model_extra_attrs(self.model_config.extra_attrs):
|
||||
get_model_extra_attrs()["attention_metadata"] = weakref.ref(
|
||||
attn_metadata)
|
||||
hidden_size = self.model_config.pretrained_config.hidden_size
|
||||
position_ids = torch.tensor([
|
||||
list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) *
|
||||
batch_size
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
hidden_states = torch.rand((batch_size * seq_len_q, hidden_size),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
residual = torch.rand((batch_size * seq_len_q, hidden_size),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
|
||||
def run_pack():
|
||||
output = hidden_states, residual
|
||||
with model_extra_attrs(self.model_config.extra_attrs):
|
||||
with torch.inference_mode():
|
||||
for layer in self.layers:
|
||||
output = layer(position_ids, output[0], attn_metadata,
|
||||
output[1])
|
||||
return output
|
||||
|
||||
return run_pack
|
||||
|
||||
def replace_routing_method(self, balance_method: BalanceMethod,
|
||||
balance_ratio: float):
|
||||
if self.model_config.moe_backend not in [
|
||||
"CUTLASS", "DEEPGEMM", "TRTLLM", "WIDEEP"
|
||||
]:
|
||||
def replace_routing_method(self, balance_method: BalanceMethod, balance_ratio: float):
|
||||
if self.model_config.moe_backend not in ["CUTLASS", "DEEPGEMM", "TRTLLM", "WIDEEP"]:
|
||||
raise NotImplementedError(
|
||||
f"Not support replace routing method for moe_backend \"{self.model_config.moe_backend}\","
|
||||
f" please set balance_method to \"NotModified\"")
|
||||
f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",'
|
||||
f' please set balance_method to "NotModified"'
|
||||
)
|
||||
for layer in self.layers:
|
||||
layer.mlp.gate.balance_method = balance_method
|
||||
layer.mlp.gate.balance_ratio = balance_ratio
|
||||
|
||||
@staticmethod
|
||||
def create_kv_cache_manager(pretrained_model_name_or_path, mapping,
|
||||
tokens_per_block, max_batch_size, max_seq_len,
|
||||
layer_indices):
|
||||
# Please refer to `tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` for `tokens_per_block`
|
||||
model_config = ModelConfig.from_pretrained(
|
||||
pretrained_model_name_or_path)
|
||||
if model_config.enable_flash_mla:
|
||||
assert tokens_per_block == 64
|
||||
|
||||
# Please refer to `tensorrt_llm/_torch/pyexecutor/_util.py` for `kv_cache_manager`
|
||||
kv_cache_manager_cls = get_kv_cache_manager_cls(model_config)
|
||||
kv_cache_manager = kv_cache_manager_cls(
|
||||
KvCacheConfig(
|
||||
max_tokens=max_batch_size *
|
||||
round_up(max_seq_len, tokens_per_block),
|
||||
enable_block_reuse=False,
|
||||
),
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY,
|
||||
num_layers=len(layer_indices),
|
||||
num_kv_heads=1,
|
||||
head_dim=model_config.pretrained_config.kv_lora_rank +
|
||||
model_config.pretrained_config.qk_rope_head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=torch_dtype_to_binding({
|
||||
None: torch.bfloat16,
|
||||
"FP8": torch.float8_e4m3fn,
|
||||
}[model_config.quant_config.kv_cache_quant_algo]),
|
||||
sparse_attn_config=model_config.sparse_attention_config,
|
||||
)
|
||||
kv_cache_manager.layer_offsets = {
|
||||
layer_idx: i
|
||||
for i, layer_idx in enumerate(layer_indices)
|
||||
}
|
||||
kv_cache_manager.add_dummy_requests(list(range(max_batch_size)),
|
||||
[max_seq_len] * max_batch_size)
|
||||
return kv_cache_manager
|
||||
|
||||
@staticmethod
|
||||
def create_mapping(enable_attention_dp: bool):
|
||||
world_size = mpi_world_size()
|
||||
rank = mpi_rank()
|
||||
mapping = Mapping(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
gpus_per_node=local_mpi_size(),
|
||||
cp_size=1,
|
||||
tp_size=world_size,
|
||||
pp_size=1,
|
||||
moe_cluster_size=1,
|
||||
moe_tp_size=1,
|
||||
moe_ep_size=world_size,
|
||||
attn_tp_size=world_size,
|
||||
attn_cp_size=1,
|
||||
enable_attention_dp=enable_attention_dp,
|
||||
)
|
||||
return mapping
|
||||
|
||||
@ -0,0 +1,90 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_qwen3_next import ALL_DECODER_LAYER_TYPES
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from .runner_interface import RunnerBase
|
||||
from .runner_utils import RunnerMixin
|
||||
|
||||
|
||||
class Qwen3NextRunner(RunnerMixin, RunnerBase):
|
||||
@staticmethod
|
||||
def has_mamba_metadata() -> bool:
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_model_name_or_path: str,
|
||||
mapping: Mapping,
|
||||
*,
|
||||
moe_backend: str,
|
||||
layer_indices: List[int],
|
||||
scaled_from: Optional[int],
|
||||
max_seq_len: int,
|
||||
max_num_tokens: int,
|
||||
moe_max_num_tokens: int,
|
||||
use_cuda_graph: bool,
|
||||
):
|
||||
self.model_config = ModelConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
mapping=mapping,
|
||||
enable_min_latency=False,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
force_dynamic_quantization=False,
|
||||
spec_config=None,
|
||||
sparse_attention_config=None, # To be loaded from config
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
moe_max_num_tokens=moe_max_num_tokens,
|
||||
moe_load_balancer=None,
|
||||
lora_config=None,
|
||||
allreduce_strategy=AllReduceStrategy.AUTO,
|
||||
mm_encoder_only=False,
|
||||
attn_backend="TRTLLM",
|
||||
moe_backend=moe_backend,
|
||||
moe_disable_finalize_fusion=False,
|
||||
use_low_precision_moe_combine=False,
|
||||
skip_create_weights_in_init=True,
|
||||
)
|
||||
pretrained_config = self.model_config.pretrained_config
|
||||
|
||||
with self.scaled_from_ctx(scaled_from, mapping, pretrained_config):
|
||||
aux_stream = torch.cuda.Stream()
|
||||
layers = [
|
||||
ALL_DECODER_LAYER_TYPES[pretrained_config.layer_types[layer_idx]](
|
||||
self.model_config,
|
||||
layer_idx,
|
||||
aux_stream,
|
||||
)
|
||||
for layer_idx in layer_indices
|
||||
]
|
||||
next_layer_layernorm = RMSNorm(
|
||||
hidden_size=pretrained_config.hidden_size,
|
||||
eps=pretrained_config.rms_norm_eps,
|
||||
dtype=pretrained_config.torch_dtype,
|
||||
use_gemma=True,
|
||||
)
|
||||
|
||||
# TODO: apply_layerwise_quant_config
|
||||
self.apply_quant_config_exclude_modules(layers, self.model_config.quant_config)
|
||||
for layer in layers:
|
||||
for module in layer.modules():
|
||||
if callable(getattr(module, "create_weights", None)):
|
||||
module.create_weights()
|
||||
layer.cuda()
|
||||
for module in layer.modules():
|
||||
if hasattr(module, "post_load_weights") and not getattr(
|
||||
module, "_weights_removed", False
|
||||
):
|
||||
module.post_load_weights()
|
||||
next_layer_layernorm.cuda()
|
||||
for layer, next_layer in zip(layers[:-1], layers[1:]):
|
||||
layer.next_layer_layernorm = next_layer.input_layernorm
|
||||
layers[-1].next_layer_layernorm = next_layer_layernorm
|
||||
|
||||
self.layers = layers
|
||||
13
tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py
Normal file
13
tensorrt_llm/tools/layer_wise_benchmarks/runner_factory.py
Normal file
@ -0,0 +1,13 @@
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config
|
||||
|
||||
from .deepseekv3_runner import DeepSeekV3Runner
|
||||
from .qwen3_next_runner import Qwen3NextRunner
|
||||
|
||||
|
||||
def get_runner_cls(pretrained_model_name_or_path: str) -> type:
|
||||
pretrained_config = load_pretrained_config(pretrained_model_name_or_path)
|
||||
return {
|
||||
"deepseek_v3": DeepSeekV3Runner,
|
||||
"deepseek_v32": DeepSeekV3Runner,
|
||||
"qwen3_next": Qwen3NextRunner,
|
||||
}[pretrained_config.model_type]
|
||||
49
tensorrt_llm/tools/layer_wise_benchmarks/runner_interface.py
Normal file
49
tensorrt_llm/tools/layer_wise_benchmarks/runner_interface.py
Normal file
@ -0,0 +1,49 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
|
||||
|
||||
class BalanceMethod(IntEnum):
|
||||
NotModified = 1
|
||||
Balanced = 2
|
||||
ImbalancedRanks = 3
|
||||
ImbalancedExperts = 4
|
||||
|
||||
|
||||
class RunnerBase(ABC):
|
||||
@abstractmethod
|
||||
def create_run_pack(
|
||||
self,
|
||||
run_type: str,
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
seq_len_kv_cache: int,
|
||||
kv_cache_manager: KVCacheManager,
|
||||
attn_workspace: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def replace_routing_method(self, balance_method: BalanceMethod, balance_ratio: float):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_kv_cache_manager(
|
||||
pretrained_model_name_or_path,
|
||||
mapping,
|
||||
tokens_per_block,
|
||||
max_batch_size,
|
||||
max_seq_len,
|
||||
layer_indices,
|
||||
):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_mapping(enable_attention_dp: bool):
|
||||
pass
|
||||
303
tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py
Normal file
303
tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py
Normal file
@ -0,0 +1,303 @@
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import WideEPMoE
|
||||
from tensorrt_llm._torch.modules.linear import Linear, WeightMode
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from tensorrt_llm._torch.pyexecutor._util import get_kv_cache_manager_cls
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import is_mla, is_qwen3_next
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs
|
||||
from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.bindings.internal.batch_manager import CacheType
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from .runner_interface import BalanceMethod
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def round_up(a, b):
|
||||
return ceil_div(a, b) * b
|
||||
|
||||
|
||||
class RunnerMixin(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def has_mamba_metadata() -> bool:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def scaled_from_ctx(scaled_from, mapping, pretrained_config):
|
||||
if scaled_from is None:
|
||||
yield
|
||||
return
|
||||
# To run the problem size of $B$ GPUs on $A$ GPUs, we need:
|
||||
# (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change.
|
||||
# (2) MoE: If EP, reduce the number of experts; If TP, reduce head size.
|
||||
# Maintain the result of AllToAll method selection because it is affected by EP size.
|
||||
if not mapping.enable_attention_dp:
|
||||
if hasattr(pretrained_config, "index_n_heads"):
|
||||
raise NotImplementedError("Not support Indexer TP for weak scaling")
|
||||
pretrained_config.num_attention_heads = (
|
||||
pretrained_config.num_attention_heads // scaled_from * mapping.tp_size
|
||||
)
|
||||
pretrained_config.num_key_value_heads = (
|
||||
pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size
|
||||
)
|
||||
if mapping.moe_ep_size != mapping.world_size:
|
||||
raise NotImplementedError("Not support MoE TP for weak scaling")
|
||||
pretrained_config.n_routed_experts = (
|
||||
pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size
|
||||
)
|
||||
select_alltoall_method_type_orig = WideEPMoE.select_alltoall_method_type
|
||||
|
||||
def select_alltoall_method_type(cls: type, mapping: Mapping, top_k: int, *args, **kwargs):
|
||||
# Replace the condition `mapping.moe_ep_size <= top_k` with `scaled_from <= top_k`
|
||||
# by replacing `top_k` with `fake_top_k`
|
||||
if scaled_from <= top_k:
|
||||
fake_top_k = mapping.moe_ep_size + 1
|
||||
else:
|
||||
fake_top_k = mapping.moe_ep_size - 1
|
||||
assert (mapping.moe_ep_size <= fake_top_k) == (scaled_from <= top_k)
|
||||
return select_alltoall_method_type_orig(mapping, fake_top_k, *args, **kwargs)
|
||||
|
||||
WideEPMoE.select_alltoall_method_type = select_alltoall_method_type
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
WideEPMoE.select_alltoall_method_type = select_alltoall_method_type_orig
|
||||
|
||||
@staticmethod
|
||||
def apply_quant_config_exclude_modules(layers, quant_config):
|
||||
# Please refer to tensorrt_llm/_torch/models/modeling_utils.py
|
||||
new_quant_config = QuantConfig(kv_cache_quant_algo=quant_config.kv_cache_quant_algo)
|
||||
for layer in layers:
|
||||
for name, module in layer.named_modules():
|
||||
name = f"model.layers.{layer.layer_idx}.{name}"
|
||||
candidates = [name]
|
||||
if isinstance(module, Linear):
|
||||
weight_mode = module.weights_loading_config.weight_mode
|
||||
if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
|
||||
# sometimes gate and up proj are not packed in the checkpoint,
|
||||
# but they still share the same exclusion rule
|
||||
candidates += [
|
||||
name.replace("gate_up_proj", "gate_proj"),
|
||||
name.replace("gate_up_proj", "up_proj"),
|
||||
]
|
||||
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
|
||||
# sometimes q_proj, k_proj and v_proj are not packed in the checkpoint,
|
||||
# but they still share the same exclusion rule
|
||||
candidates += [
|
||||
name.replace("qkv_proj", "q_proj"),
|
||||
name.replace("qkv_proj", "k_proj"),
|
||||
name.replace("qkv_proj", "v_proj"),
|
||||
]
|
||||
is_excluded = any(
|
||||
quant_config.is_module_excluded_from_quantization(n) for n in candidates
|
||||
)
|
||||
if is_excluded and getattr(module, "quant_config", None) is not None:
|
||||
module.quant_config = new_quant_config
|
||||
|
||||
def create_run_pack(
|
||||
self,
|
||||
run_type: str,
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
seq_len_kv_cache: int,
|
||||
kv_cache_manager: KVCacheManager,
|
||||
attn_workspace: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.model_config.moe_backend == "TRTLLM" and os.getenv("TRTLLM_ENABLE_PDL") != "1":
|
||||
raise ValueError("Suggest to set TRTLLM_ENABLE_PDL=1 when moe_backend is TRTLLM")
|
||||
world_size = mpi_world_size()
|
||||
AttentionCls = get_attention_backend(
|
||||
self.model_config.attn_backend, self.model_config.sparse_attention_config
|
||||
)
|
||||
attn_metadata = AttentionCls.Metadata(
|
||||
seq_lens=torch.tensor([seq_len_q] * batch_size, dtype=torch.int),
|
||||
request_ids=list(range(batch_size)),
|
||||
max_num_requests=kv_cache_manager.max_batch_size,
|
||||
num_contexts={
|
||||
"CTX": batch_size,
|
||||
"GEN": 0,
|
||||
}[run_type],
|
||||
prompt_lens=[
|
||||
{
|
||||
"CTX": seq_len_q,
|
||||
"GEN": seq_len_kv_cache,
|
||||
}[run_type]
|
||||
]
|
||||
* batch_size,
|
||||
max_num_tokens=batch_size * seq_len_q,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=[seq_len_kv_cache] * batch_size,
|
||||
),
|
||||
workspace=attn_workspace,
|
||||
mapping=self.model_config.mapping,
|
||||
sparse_attention_config=self.model_config.sparse_attention_config,
|
||||
)
|
||||
attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q] * world_size
|
||||
attn_metadata.prepare()
|
||||
with model_extra_attrs(self.model_config.extra_attrs):
|
||||
get_model_extra_attrs()["attention_metadata"] = weakref.ref(attn_metadata)
|
||||
hidden_size = self.model_config.pretrained_config.hidden_size
|
||||
position_ids = torch.tensor(
|
||||
[list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) * batch_size],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
hidden_states = torch.rand(
|
||||
(batch_size * seq_len_q, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
residual = torch.rand(
|
||||
(batch_size * seq_len_q, hidden_size), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
kwargs = {}
|
||||
|
||||
if self.has_mamba_metadata():
|
||||
# Please refer to `tensorrt_llm/_torch/models/modeling_qwen3_next.py` for `mamba_metadata`
|
||||
mamba_metadata = Mamba2Metadata(attn_metadata.max_num_requests, chunk_size=128)
|
||||
mamba_metadata.prepare(attn_metadata)
|
||||
kwargs["mamba_metadata"] = mamba_metadata
|
||||
|
||||
def run_pack():
|
||||
output = hidden_states, residual
|
||||
with model_extra_attrs(self.model_config.extra_attrs):
|
||||
with torch.inference_mode():
|
||||
for layer in self.layers:
|
||||
output = layer(position_ids, output[0], attn_metadata, output[1], **kwargs)
|
||||
return output
|
||||
|
||||
return run_pack
|
||||
|
||||
def replace_routing_method(self, balance_method: BalanceMethod, balance_ratio: float):
|
||||
if balance_method != BalanceMethod.NotModified:
|
||||
raise NotImplementedError("not support replacing routing method for this runner")
|
||||
|
||||
@staticmethod
|
||||
def create_kv_cache_manager(
|
||||
pretrained_model_name_or_path,
|
||||
mapping,
|
||||
tokens_per_block,
|
||||
max_batch_size,
|
||||
max_seq_len,
|
||||
layer_indices,
|
||||
):
|
||||
# Please refer to `tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` for `tokens_per_block`
|
||||
model_config = ModelConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
if model_config.enable_flash_mla:
|
||||
assert tokens_per_block == 64
|
||||
|
||||
# Please refer to `tensorrt_llm/_torch/pyexecutor/_util.py` for `kv_cache_manager`
|
||||
kv_cache_manager_cls = get_kv_cache_manager_cls(model_config)
|
||||
config = model_config.pretrained_config
|
||||
kv_cache_config = KvCacheConfig(
|
||||
max_tokens=max_batch_size * round_up(max_seq_len, tokens_per_block),
|
||||
enable_block_reuse=False,
|
||||
)
|
||||
kv_cache_dtype = torch_dtype_to_binding(
|
||||
{
|
||||
None: torch.bfloat16,
|
||||
"FP8": torch.float8_e4m3fn,
|
||||
}[model_config.quant_config.kv_cache_quant_algo]
|
||||
)
|
||||
if is_mla(config):
|
||||
layer_mask = [i in layer_indices for i in range(config.num_hidden_layers)]
|
||||
num_layers = sum(layer_mask)
|
||||
kv_cache_manager = kv_cache_manager_cls(
|
||||
kv_cache_config,
|
||||
CacheType.SELFKONLY,
|
||||
num_layers=sum(layer_mask),
|
||||
num_kv_heads=1,
|
||||
head_dim=model_config.pretrained_config.kv_lora_rank
|
||||
+ model_config.pretrained_config.qk_rope_head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
layer_mask=layer_mask,
|
||||
sparse_attn_config=model_config.sparse_attention_config,
|
||||
)
|
||||
elif is_qwen3_next(config):
|
||||
mamba_layer_mask = [
|
||||
i in layer_indices
|
||||
if i % config.full_attention_interval != config.full_attention_interval - 1
|
||||
else False
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
layer_mask = [
|
||||
False
|
||||
if i % config.full_attention_interval != config.full_attention_interval - 1
|
||||
else i in layer_indices
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
num_mamba_layers = sum(mamba_layer_mask)
|
||||
num_layers = sum(layer_mask)
|
||||
kv_cache_manager = kv_cache_manager_cls(
|
||||
# mamba cache parameters
|
||||
config.linear_key_head_dim,
|
||||
config.linear_conv_kernel_dim,
|
||||
config.linear_num_value_heads,
|
||||
config.linear_num_key_heads,
|
||||
config.linear_value_head_dim,
|
||||
num_mamba_layers,
|
||||
mamba_layer_mask,
|
||||
config.torch_dtype,
|
||||
model_config.quant_config.mamba_ssm_cache_dtype,
|
||||
# kv cache parameters
|
||||
kv_cache_config,
|
||||
CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
layer_mask=layer_mask,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
spec_config=None,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported config")
|
||||
kv_cache_manager.add_dummy_requests(
|
||||
list(range(max_batch_size)), [max_seq_len] * max_batch_size
|
||||
)
|
||||
return kv_cache_manager
|
||||
|
||||
@staticmethod
|
||||
def create_mapping(enable_attention_dp: bool):
|
||||
world_size = mpi_world_size()
|
||||
rank = mpi_rank()
|
||||
mapping = Mapping(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
gpus_per_node=local_mpi_size(),
|
||||
cp_size=1,
|
||||
tp_size=world_size,
|
||||
pp_size=1,
|
||||
moe_cluster_size=1,
|
||||
moe_tp_size=1,
|
||||
moe_ep_size=world_size,
|
||||
attn_tp_size=world_size,
|
||||
attn_cp_size=1,
|
||||
enable_attention_dp=enable_attention_dp,
|
||||
)
|
||||
return mapping
|
||||
@ -76,6 +76,8 @@ l0_b200:
|
||||
- unittest/_torch/modeling -k "modeling_llama"
|
||||
- unittest/_torch/modeling -k "modeling_mixtral"
|
||||
- unittest/_torch/modeling -k "modeling_gpt_oss"
|
||||
- unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_dep[1]
|
||||
- unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1]
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1]
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
|
||||
@ -18,8 +18,6 @@ l0_dgx_b200:
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek"
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
|
||||
- unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_tep
|
||||
- unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_gen_scaled_from_16_dep
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
|
||||
|
||||
@ -6,7 +6,6 @@ import pytest
|
||||
import yaml
|
||||
from _model_test_utils import get_small_model_config
|
||||
from click.testing import CliRunner
|
||||
from utils.cpp_paths import llm_root # noqa: F401
|
||||
|
||||
from tensorrt_llm.commands.bench import main
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ import pytest
|
||||
import torch
|
||||
import tqdm
|
||||
from mpi4py.futures import MPIPoolExecutor
|
||||
from utils.cpp_paths import llm_root # noqa: F401
|
||||
from utils.util import get_current_process_gpu_memory
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
@ -3,67 +3,122 @@ from subprocess import check_call
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.cpp_paths import llm_root # noqa: F401
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="needs 4 GPUs to run this test")
|
||||
def test_deepseek_r1_ctx_tep(llm_root):
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_deepseek_r1_ctx_dep(llm_root, world_size):
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"needs {world_size:d} GPUs to run this test")
|
||||
model_root = llm_models_root(check=True)
|
||||
check_call([
|
||||
"./mpi_launch.sh",
|
||||
"./run_single.sh",
|
||||
"config_ctx.yaml",
|
||||
"--model",
|
||||
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
|
||||
"--no-enable-attention-dp",
|
||||
"--moe-backend=TRTLLM",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
**os.environ,
|
||||
"NP": "4",
|
||||
"TRTLLM_ENABLE_PDL": "1",
|
||||
})
|
||||
check_call(
|
||||
[
|
||||
"./mpi_launch.sh",
|
||||
"./run_single.sh",
|
||||
"config_ctx.yaml",
|
||||
"--model",
|
||||
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
**os.environ,
|
||||
"NP": f"{world_size:d}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="needs 4 GPUs to run this test")
|
||||
def test_deepseek_v32_ctx_dep(llm_root):
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_deepseek_r1_ctx_tep(llm_root, world_size):
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"needs {world_size:d} GPUs to run this test")
|
||||
model_root = llm_models_root(check=True)
|
||||
check_call([
|
||||
"./mpi_launch.sh",
|
||||
"./run_single.sh",
|
||||
"config_ctx.yaml",
|
||||
"--model",
|
||||
model_root / "DeepSeek-V3.2-Exp-hf",
|
||||
"--tokens-per-block=64",
|
||||
"--moe-backend=DEEPGEMM",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
**os.environ,
|
||||
"NP": "4",
|
||||
})
|
||||
check_call(
|
||||
[
|
||||
"./mpi_launch.sh",
|
||||
"./run_single.sh",
|
||||
"config_ctx.yaml",
|
||||
"--model",
|
||||
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
|
||||
"--no-enable-attention-dp",
|
||||
"--moe-backend=TRTLLM",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
**os.environ,
|
||||
"NP": f"{world_size:d}",
|
||||
"TRTLLM_ENABLE_PDL": "1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="needs 4 GPUs to run this test")
|
||||
def test_deepseek_r1_gen_scaled_from_16_dep(llm_root):
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_deepseek_v32_ctx_dep(llm_root, world_size):
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"needs {world_size:d} GPUs to run this test")
|
||||
model_root = llm_models_root(check=True)
|
||||
check_call([
|
||||
"./mpi_launch.sh",
|
||||
"./run_single.sh",
|
||||
"config_gen.yaml",
|
||||
"--model",
|
||||
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
|
||||
"--layer-indices=5,6",
|
||||
"--scaled-from=16",
|
||||
"--moe-backend=WIDEEP",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
**os.environ,
|
||||
"NP": "4",
|
||||
})
|
||||
check_call(
|
||||
[
|
||||
"./mpi_launch.sh",
|
||||
"./run_single.sh",
|
||||
"config_ctx.yaml",
|
||||
"--model",
|
||||
model_root / "DeepSeek-V3.2-Exp-hf",
|
||||
"--tokens-per-block=64",
|
||||
"--moe-backend=DEEPGEMM",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
**os.environ,
|
||||
"NP": f"{world_size:d}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
def test_deepseek_r1_gen_scaled_from_16_dep(llm_root, world_size):
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"needs {world_size:d} GPUs to run this test")
|
||||
model_root = llm_models_root(check=True)
|
||||
check_call(
|
||||
[
|
||||
"./mpi_launch.sh",
|
||||
"./run_single.sh",
|
||||
"config_gen.yaml",
|
||||
"--model",
|
||||
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
|
||||
"--layer-indices=5,6",
|
||||
"--scaled-from=16",
|
||||
"--moe-backend=WIDEEP",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
**os.environ,
|
||||
"NP": f"{world_size:d}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_qwen3_next_gen_tep(llm_root, world_size):
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"needs {world_size:d} GPUs to run this test")
|
||||
model_root = llm_models_root(check=True)
|
||||
check_call(
|
||||
[
|
||||
"./mpi_launch.sh",
|
||||
"./run_single.sh",
|
||||
"config_gen.yaml",
|
||||
"--model",
|
||||
model_root / "Qwen3" / "Qwen3-Next-80B-A3B-Instruct",
|
||||
"--layer-indices=6,7",
|
||||
"--no-enable-attention-dp",
|
||||
"--moe-backend=TRTLLM",
|
||||
"--balance-method=NotModified",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
**os.environ,
|
||||
"NP": f"{world_size:d}",
|
||||
"TRTLLM_ENABLE_PDL": "1",
|
||||
},
|
||||
)
|
||||
|
||||
@ -6,7 +6,6 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
from utils.cpp_paths import llm_root # noqa: F401
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
# Constants for test configuration
|
||||
|
||||
Loading…
Reference in New Issue
Block a user