mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Layer-wise benchmarks: make model init more general and support weights loading (#10562)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
This commit is contained in:
parent
50c78179dd
commit
38296a472b
@ -15,28 +15,29 @@ pip install -e ../..
|
||||
**Step 3:** In the container, run benchmarks and generate profiles:
|
||||
|
||||
```bash
|
||||
# Set autotune cache path
|
||||
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
|
||||
|
||||
# Run DeepSeek-R1 NVFP4
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml
|
||||
|
||||
# Run with weights loaded. Requires local model directory
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model "$LLM_MODELS_ROOT/DeepSeek-R1/DeepSeek-R1-0528-FP4-v2" --load-format AUTO
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model "$LLM_MODELS_ROOT/DeepSeek-R1/DeepSeek-R1-0528-FP4-v2" --load-format AUTO
|
||||
|
||||
# Run DeepSeek-V3.2-Exp
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --moe-backend-for-prefill DEEPGEMM
|
||||
|
||||
# Run DeepSeek-V3.2-Exp with 32k context length
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --seq-len-kv-cache 32769
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --moe-backend-for-prefill DEEPGEMM --seq-len-kv-cache 32769
|
||||
|
||||
# Run with attention TP
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp
|
||||
|
||||
# Run with attention TP and TRTLLMGen
|
||||
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM
|
||||
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM
|
||||
|
||||
# Run with MTP3
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --batch-size 32 --seq-len-q 4
|
||||
@ -51,9 +52,13 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID
|
||||
# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp
|
||||
|
||||
# Run Nemotron-3-Nano
|
||||
NP=1 ./mpi_launch.sh ./run.sh config_ctx.yaml --model nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 --layer-indices 4,5,6 --mamba-ssm-cache-dtype float16
|
||||
NP=1 ./mpi_launch.sh ./run.sh config_gen.yaml --model nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 --layer-indices 4,5,6 --mamba-ssm-cache-dtype float16
|
||||
|
||||
# Run Qwen3-Next
|
||||
NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 4
|
||||
NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 512
|
||||
NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --mamba-ssm-cache-dtype float16 --batch-size 4
|
||||
NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --mamba-ssm-cache-dtype float16 --batch-size 512
|
||||
|
||||
# Run with DeepEP A2A
|
||||
NP=4 ./mpi_launch.sh -x TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./run.sh config_ctx.yaml --moe-backend WIDEEP
|
||||
@ -112,14 +117,11 @@ python3 scripts/build_wheel.py --cuda_architectures native --no-venv --skip_buil
|
||||
**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes:
|
||||
|
||||
```bash
|
||||
# Set autotune cache path
|
||||
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
|
||||
|
||||
# Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable
|
||||
NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
||||
|
||||
# Run with TRTLLMGen
|
||||
NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM
|
||||
NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM
|
||||
|
||||
# Run with DeepEPLowLatency
|
||||
NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
||||
@ -172,7 +174,9 @@ You will receive three reports, each containing kernel timing statistics grouped
|
||||
## Developer utilities
|
||||
|
||||
1. Less startup time when debug a model
|
||||
1. Disable autotuner: add `--no-enable-autotuner` option
|
||||
1. Set autotuner cache or disable autotuner
|
||||
1. Set autotuner cache: add `TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache` environment variable. This is enabled at your own risk, and you may need to delete the cache if `NP` changes or the code changes
|
||||
2. Disable autotuner: add `--no-enable-autotuner` option
|
||||
2. Disable nsys profile: set `PROFILE=0` environment variable
|
||||
2. Capture more information
|
||||
1. Enable GPU metrics: set `GPU_METRICS=1` environment variable
|
||||
@ -182,4 +186,8 @@ You will receive three reports, each containing kernel timing statistics grouped
|
||||
|
||||
1. Error `fp8 blockscale gemm only support Hopper` on Blackwell.
|
||||
|
||||
The default MoE backend "CUTLASS" does not support FP8 weights. Please choose the same MoE backend as your end-to-end config. A typical choice is adding `--moe-backend DEEPGEMM`, `--moe-backend TRTLLM`, or `--moe-backend WIDEEP` option.
|
||||
The default MoE backend "CUTLASS" does not support FP8 weights. Please choose the same MoE backend as your end-to-end config. A typical choice is adding `--moe-backend DEEPGEMM` (or `TRTLLM`, `WIDEEP`) and `--moe-backend-for-prefill DEEPGEMM` (or `WIDEEP`) option.
|
||||
|
||||
2. Error `huggingface_hub.errors.HfHubHTTPError: 429 Client Error: Too Many Requests for url: https://huggingface.co/nvidia/DeepSeek-R1-0528-FP4-v2/resolve/main/config.json`.
|
||||
|
||||
Please use a local model through the `--model` option, or follow Hugging Face's instructions: "We had to rate limit your IP. To continue using our service, create a HF account or login to your existing account, and make sure you pass a HF_TOKEN if you're using the API."
|
||||
|
||||
@ -6,7 +6,6 @@ import re
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import jinja2
|
||||
@ -139,7 +138,7 @@ for start, text in df.itertuples(index=False):
|
||||
"runs": [],
|
||||
"runs_end": [],
|
||||
"ranges": [],
|
||||
"range_in_module": [],
|
||||
"kernel_count_per_range": [],
|
||||
}
|
||||
)
|
||||
|
||||
@ -161,28 +160,7 @@ for start, end, text in df.itertuples(index=False):
|
||||
problem_set[problem_id]["runs_end"].append(end)
|
||||
else:
|
||||
problem_set[problem_id]["ranges"].append((start, end, text))
|
||||
|
||||
# Determine whether each range is the first range that matches `args.module`,
|
||||
# and store the result in `problem["range_in_module"]`
|
||||
for problem in problem_set:
|
||||
if args.module is not None:
|
||||
problem["range_in_module"] = [False] * len(problem["ranges"])
|
||||
run_ids = [bisect.bisect(problem["runs"], start) - 1 for start, _, _ in problem["ranges"]]
|
||||
run2ranges = defaultdict(list)
|
||||
for i, run_id in enumerate(run_ids):
|
||||
run2ranges[run_id].append(i)
|
||||
for run_id, ranges in run2ranges.items():
|
||||
ranges = sorted(ranges, key=lambda i: problem["ranges"][i][0])
|
||||
num_matches = 0
|
||||
for range_id in ranges:
|
||||
if problem["ranges"][range_id][2] == args.module:
|
||||
problem["range_in_module"][range_id] = True
|
||||
num_matches += 1
|
||||
if num_matches != 1:
|
||||
raise ValueError(
|
||||
f'Module "{args.module}" appears {num_matches} times'
|
||||
f' in "{problem["text"]}"\'s {run_id + 1}-th run'
|
||||
)
|
||||
problem_set[problem_id]["kernel_count_per_range"].append(0)
|
||||
|
||||
query = """SELECT name FROM sqlite_master WHERE type = ?"""
|
||||
df = pd.read_sql_query(query, conn, params=("table",))
|
||||
@ -228,19 +206,17 @@ for (
|
||||
problem_id = bisect.bisect(problem_start, start) - 1
|
||||
problem = problem_set[problem_id]
|
||||
run_id = bisect.bisect(problem["runs"], runtime_start) - 1
|
||||
if (
|
||||
run_id == -1
|
||||
or run_id == len(problem["runs"])
|
||||
or runtime_start >= problem["runs_end"][run_id]
|
||||
):
|
||||
run_id = -1
|
||||
if run_id == -1 or runtime_start >= problem["runs_end"][run_id]:
|
||||
continue
|
||||
ranges = [
|
||||
i
|
||||
for i, (range_start, range_end, text) in enumerate(problem["ranges"])
|
||||
if capture_start >= range_start and capture_end <= range_end
|
||||
]
|
||||
if args.module is None or any(problem["range_in_module"][i] for i in ranges):
|
||||
range_names = [problem["ranges"][i][2] for i in ranges]
|
||||
for range_id in ranges:
|
||||
problem["kernel_count_per_range"][range_id] += 1
|
||||
range_names = [problem["ranges"][i][2] for i in ranges]
|
||||
if args.module is None or args.module in range_names:
|
||||
kernel_list.append(
|
||||
(
|
||||
problem_id,
|
||||
@ -262,6 +238,22 @@ string_ids = dict(zip(df["id"], df["value"]))
|
||||
|
||||
conn.close()
|
||||
|
||||
# Check ambiguous modules
|
||||
if args.module:
|
||||
for problem in problem_set:
|
||||
num_matches_per_run = [0] * (len(problem["runs"]) + 1)
|
||||
for (range_start, _, text), kernel_count in zip(
|
||||
problem["ranges"], problem["kernel_count_per_range"]
|
||||
):
|
||||
if text == args.module and kernel_count > 0:
|
||||
num_matches_per_run[bisect.bisect(problem["runs"], range_start)] += 1
|
||||
for run_id_plus_one, num_matches in enumerate(num_matches_per_run):
|
||||
if num_matches > 1:
|
||||
raise ValueError(
|
||||
f'Module is ambiguous: "{args.module}" appears {num_matches} times'
|
||||
f' in "{problem["text"]}"\'s {run_id_plus_one}-th run'
|
||||
)
|
||||
|
||||
kernel_list.sort(key=lambda t: (t[6], t[8]))
|
||||
kernels = [[[] for _ in problem["runs"]] for problem in problem_set]
|
||||
for (
|
||||
@ -276,8 +268,7 @@ for (
|
||||
capture_start,
|
||||
capture_end,
|
||||
) in kernel_list:
|
||||
if run_id != -1:
|
||||
kernels[problem_id][run_id].append((demangledName, start, end, ranges))
|
||||
kernels[problem_id][run_id].append((demangledName, start, end, ranges))
|
||||
for problem_id in range(len(kernels)):
|
||||
required_seq = [demangledName for demangledName, _, _, _ in kernels[problem_id][0]]
|
||||
for run_id in range(len(kernels[problem_id])):
|
||||
@ -287,86 +278,8 @@ for problem_id in range(len(kernels)):
|
||||
|
||||
parser_keywords = [
|
||||
("cuBLASGemm", "nvjet"),
|
||||
("splitKreduce", "splitKreduce_kernel"),
|
||||
("fusedAGemm", "fused_a_gemm_kernel"),
|
||||
("RMSNorm", "RMSNormKernel"),
|
||||
("torchCat", "CatArrayBatchedCopy"),
|
||||
("applyMLARope", "applyMLARope"),
|
||||
("fmhaSm100f", "fmhaSm100fKernel_Qkv"),
|
||||
("fmhaReduction", "fmhaReductionKernel"),
|
||||
("quant", "quantize_with_block_size"),
|
||||
("AllGather", "ncclDevKernel_AllGather_"),
|
||||
("ReduceScatter", "ncclDevKernel_ReduceScatter_"),
|
||||
("allreduce_oneshot", "allreduce_fusion_kernel_oneshot_lamport"),
|
||||
("allreduce_twoshot", "allreduce_fusion_kernel_twoshot_sync"),
|
||||
("expandInput", "expandInputRowsKernel"),
|
||||
("computeStrides", "computeStridesTmaWarpSpecializedKernel"),
|
||||
("cutlassGroupGemm", "cutlass::device_kernel<cutlass::gemm::kernel::GemmUniversal"),
|
||||
("doActivation", "doActivationKernel"),
|
||||
("cutlassGemm", "GemmUniversal"),
|
||||
("deepseek_v3_topk", "deepseek_v3_topk_kernel"),
|
||||
("CountAndIndice", "computeCountAndIndiceDevice"),
|
||||
("Cumsum", "computeCumsumDevice"),
|
||||
("moveIndice", "moveIndiceDevice"),
|
||||
("moeAllToAll", "moeAllToAllKernel"),
|
||||
("moeA2APrepareDispatch", "moe_comm::moeA2APrepareDispatchKernel"),
|
||||
("moeA2ADispatch", "moe_comm::moeA2ADispatchKernel"),
|
||||
("moeA2ASanitizeExpertIds", "moe_comm::moeA2ASanitizeExpertIdsKernel"),
|
||||
("moeA2APrepareCombine", "moe_comm::moeA2APrepareCombineKernel"),
|
||||
("moeA2ACombine", "moe_comm::moeA2ACombineKernel"),
|
||||
("memsetExpertIds", "memsetExpertIdsDevice"),
|
||||
("blockSum", "blockExpertPrefixSumKernel"),
|
||||
("globalSum", "globalExpertPrefixSumKernel"),
|
||||
("globalSumLarge", "globalExpertPrefixSumLargeKernel"),
|
||||
("mergePrefix", "mergeExpertPrefixSumKernel"),
|
||||
("fusedBuildExpertMaps", "fusedBuildExpertMapsSortFirstTokenKernel"),
|
||||
("swiglu", "silu_and_mul_kernel"),
|
||||
("torchAdd", "CUDAFunctor_add"),
|
||||
("torchFill", "at::native::FillFunctor"),
|
||||
("triton_fused_add_sum", "triton_red_fused_add_sum_0"),
|
||||
("torchCopy", "at::native::bfloat16_copy_kernel_cuda"),
|
||||
("torchDistribution", "distribution_elementwise_grid_stride_kernel"),
|
||||
("torchArange", "at::native::arange_cuda_out"),
|
||||
("torchDirectCopy", "at::native::direct_copy_kernel_cuda"),
|
||||
("torchBitonicSort", "at::native::bitonicSortKVInPlace"),
|
||||
("routingInitExpertCounts", "routingInitExpertCounts"),
|
||||
("routingIndicesCluster", "routingIndicesClusterKernel"),
|
||||
("routingIndicesCoop", "routingIndicesCoopKernel"),
|
||||
("router_gemm", "router_gemm_kernel"),
|
||||
("bmm_4_44_32", "bmm_E2m1_E2m1E2m1_Fp32_t"),
|
||||
("finalize", "finalize::finalizeKernel"),
|
||||
("bmm_16_44_32", "bmm_Bfloat16_E2m1E2m1_Fp32_"),
|
||||
("deep_gemm_gemm", "deep_gemm::sm100_fp8_gemm_1d1d_impl<"),
|
||||
("per_token_quant", "_per_token_quant_and_transform_kernel"),
|
||||
("triton_fused_layer_norm", "triton_per_fused__to_copy_native_layer_norm_0"),
|
||||
("flashinferRoPE", "flashinfer::BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel<"),
|
||||
("flashinferRoPE", "flashinfer::BatchQKApplyRotaryPosIdsCosSinCacheKernel<"),
|
||||
("fp8_blockscale_gemm", "tensorrt_llm::kernels::fp8_blockscale_gemm"),
|
||||
("triton_fused_mul_squeeze", "triton_poi_fused_mul_squeeze_0"),
|
||||
("indexerKCacheScatter", "tensorrt_llm::kernels::indexerKCacheScatterUnifiedKernel"),
|
||||
("deep_gemm_mqa_logits", "deep_gemm::sm100_fp8_paged_mqa_logits<"),
|
||||
("topKPerRowDecode", "tensorrt_llm::kernels::topKPerRowDecode<"),
|
||||
("torchAdd<int>", "at::native::CUDAFunctorOnSelf_add"),
|
||||
("convert_req_index", "_convert_req_index_to_global_index_kernel_with_stride_factor"),
|
||||
("preprocess_after_permute", "_preprocess_after_permute_kernel"),
|
||||
("masked_index_copy_quant", "_masked_index_copy_group_quant_fp8"),
|
||||
("swiglu_quant", "_silu_and_mul_post_quant_kernel"),
|
||||
("masked_index_gather", "masked_index_gather_kernel"),
|
||||
("finalizeMoeRouting", "tensorrt_llm::kernels::cutlass_kernels::finalizeMoeRoutingKernel<"),
|
||||
("fused_qkvzba_split", "fused_qkvzba_split_reshape_cat_kernel"),
|
||||
("causal_conv1d_update", "tensorrt_llm::kernels::causal_conv1d::causal_conv1d_update_kernel<"),
|
||||
("fused_delta_rule_update", "fused_sigmoid_gating_delta_rule_update_kernel"),
|
||||
("layer_norm_fwd_1pass", "_layer_norm_fwd_1pass_kernel"),
|
||||
("torchGatherTopK", "at::native::sbtopk::gatherTopK<"),
|
||||
("softmax_warp_forward", "softmax_warp_forward<"),
|
||||
("torchSigmoid", "at::native::sigmoid_kernel_cuda"),
|
||||
("torchMul", "at::native::binary_internal::MulFunctor<"),
|
||||
("computeSeqAndPaddingOffsets", "tensorrt_llm::kernels::computeSeqAndPaddingOffsets<"),
|
||||
("applyBiasRopeUpdateKVCache", "tensorrt_llm::kernels::applyBiasRopeUpdateKVCacheV2<"),
|
||||
("routingIndicesHistogramScores", "routingRenormalize::routingIndicesHistogramScoresKernel<"),
|
||||
("routingIndicesHistogram", "routingIndicesHistogramKernel<"),
|
||||
("routingIndicesOffsets", "routingIndicesOffsetsKernel<"),
|
||||
("torchReduceSum", ["at::native::reduce_kernel<", "at::native::sum_functor<"]),
|
||||
("CuteDSLMoePermute", "cute_dsl::moePermuteKernel"),
|
||||
(
|
||||
"CuteDSLGemm",
|
||||
@ -380,6 +293,19 @@ parser_keywords = [
|
||||
"CuteDSLGroupedGemmFinalize",
|
||||
["cute_dsl_kernels", "blockscaled_contiguous_grouped_gemm_finalize_fusion"],
|
||||
),
|
||||
("torchAdd", "at::native::CUDAFunctorOnSelf_add"),
|
||||
("torchAdd", "CUDAFunctor_add"),
|
||||
("torchClamp", "at::native::<unnamed>::launch_clamp_scalar("),
|
||||
("torchCompare", "at::native::<unnamed>::CompareFunctor<"),
|
||||
("torchCopy", "at::native::bfloat16_copy_kernel_cuda"),
|
||||
("torchCopy", "at::native::direct_copy_kernel_cuda("),
|
||||
("torchFill", "at::native::FillFunctor"),
|
||||
("torchIndexPut", "at::native::index_put_kernel_impl<"),
|
||||
("torchMul", "at::native::binary_internal::MulFunctor<"),
|
||||
("torchPow", "at::native::<unnamed>::pow_tensor_scalar_kernel_impl<"),
|
||||
("torchReduceSum", ["at::native::reduce_kernel<", "at::native::sum_functor<"]),
|
||||
("torchSigmoid", "at::native::sigmoid_kernel_cuda"),
|
||||
("torchWhere", "at::native::<unnamed>::where_kernel_impl("),
|
||||
]
|
||||
warned_names = set()
|
||||
|
||||
@ -395,15 +321,19 @@ def parse_kernel_name(demangledName):
|
||||
src = [src]
|
||||
if all(keyword in name for keyword in src):
|
||||
return dst
|
||||
if name not in warned_names:
|
||||
print(f"Unknown kernel name: {name}", file=sys.stderr)
|
||||
warned_names.add(name)
|
||||
if args.error_on_unknown_kernel:
|
||||
raise NotImplementedError(f"Unknown kernel name: {name}")
|
||||
if re.search(r"at::native::.*elementwise_kernel<", name):
|
||||
if name not in warned_names:
|
||||
print(f"Not parsed torch kernel name: {name}", file=sys.stderr)
|
||||
warned_names.add(name)
|
||||
assert "!unnamed!" not in name
|
||||
name = name.replace("<unnamed>", "!unnamed!")
|
||||
if "<" in name:
|
||||
name = name[: name.index("<")]
|
||||
if "(" in name:
|
||||
name = name[: name.index("(")]
|
||||
if "::" in name:
|
||||
name = name[name.rindex("::") + 2 :]
|
||||
name = name.replace("!unnamed!", "<unnamed>")
|
||||
return name
|
||||
|
||||
|
||||
@ -438,6 +368,8 @@ for runs in kernels:
|
||||
converted_seq.append((("Space",), np.mean(space_list[warmup_times:]).tolist()))
|
||||
converted_seq.append((("Total",), sum(t for _, t in converted_seq)))
|
||||
converted_seqs.append(converted_seq)
|
||||
if args.error_on_unknown_kernel and warned_names:
|
||||
raise ValueError("Unknown kernel names encountered")
|
||||
|
||||
merged_title = []
|
||||
for converted_seq in converted_seqs:
|
||||
@ -459,7 +391,7 @@ print("Problem set:")
|
||||
for problem in problem_set:
|
||||
print(
|
||||
f'- "{problem["text"]}" {len(problem["runs"])} runs'
|
||||
f" Ranges: [{', '.join(text for _, _, text in problem['ranges'])}]"
|
||||
f" Ranges: [{', '.join(text for _, end, text in problem['ranges'] if end <= problem['runs_end'][0])}]"
|
||||
)
|
||||
|
||||
stack = []
|
||||
|
||||
@ -11,12 +11,10 @@ import yaml
|
||||
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.distributed import MPIDist, TorchDist
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType
|
||||
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
|
||||
from tensorrt_llm._utils import local_mpi_rank, mpi_disabled, mpi_rank, mpi_world_size
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls, mark_ranges
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, Runner, mark_ranges
|
||||
|
||||
|
||||
def comma_separated_ints(s):
|
||||
@ -46,9 +44,17 @@ group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp")
|
||||
group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp")
|
||||
parser.set_defaults(enable_attention_dp=None)
|
||||
parser.add_argument("--kv-cache-dtype", type=str, choices=["fp8", "nvfp4", "auto"])
|
||||
parser.add_argument(
|
||||
"--mamba-ssm-cache-dtype", type=str, choices=["auto", "float16", "bfloat16", "float32"]
|
||||
)
|
||||
# Model init args
|
||||
parser.add_argument("--load-format", type=str, choices=["AUTO", "DUMMY"])
|
||||
parser.add_argument("--max-num-tokens", type=int)
|
||||
parser.add_argument("--moe-backend", type=str)
|
||||
parser.add_argument(
|
||||
"--moe-backend-for-prefill", type=str, choices=["CUTLASS", "DEEPGEMM", "WIDEEP"]
|
||||
)
|
||||
parser.add_argument("--moe-max-num-tokens", type=int)
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
@ -110,8 +116,16 @@ if args.max_seq_len is None:
|
||||
args.max_seq_len = max(args.seq_len_q_list) + max(args.seq_len_kv_cache_list)
|
||||
if args.enable_attention_dp is None:
|
||||
args.enable_attention_dp = False
|
||||
if args.kv_cache_dtype is None:
|
||||
args.kv_cache_dtype = "auto"
|
||||
if args.mamba_ssm_cache_dtype is None:
|
||||
args.mamba_ssm_cache_dtype = "auto"
|
||||
if args.load_format is None:
|
||||
args.load_format = "DUMMY"
|
||||
if args.max_num_tokens is None:
|
||||
args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list)
|
||||
if args.moe_backend_for_prefill is None:
|
||||
args.moe_backend_for_prefill = "CUTLASS"
|
||||
if args.use_low_precision_moe_combine is None:
|
||||
args.use_low_precision_moe_combine = False
|
||||
if args.enable_autotuner is None:
|
||||
@ -128,7 +142,6 @@ torch.cuda.set_device(local_rank)
|
||||
|
||||
# Create KV cache manager
|
||||
logger.info("Layer-wise benchmarks: Create KV cache manager")
|
||||
Runner = get_runner_cls(args.model)
|
||||
mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp)
|
||||
kv_cache_manager = Runner.create_kv_cache_manager(
|
||||
args.model,
|
||||
@ -136,6 +149,8 @@ kv_cache_manager = Runner.create_kv_cache_manager(
|
||||
tokens_per_block=args.tokens_per_block,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_seq_len=args.max_seq_len,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype,
|
||||
layer_indices=args.layer_indices,
|
||||
)
|
||||
attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
|
||||
@ -151,12 +166,15 @@ logger.info("Layer-wise benchmarks: Create runner")
|
||||
runner = Runner(
|
||||
args.model,
|
||||
mapping,
|
||||
load_format=args.load_format,
|
||||
moe_backend=args.moe_backend,
|
||||
layer_indices=args.layer_indices,
|
||||
scaled_from=args.scaled_from,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
moe_max_num_tokens=args.moe_max_num_tokens,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype,
|
||||
use_low_precision_moe_combine=args.use_low_precision_moe_combine,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
)
|
||||
@ -190,18 +208,19 @@ if args.run_type == "GEN":
|
||||
max(1, 20480 // ctx_seq_len_q),
|
||||
)
|
||||
ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
|
||||
with mock.patch.object(
|
||||
CutlassFusedMoE, "select_alltoall_method_type", return_value=AlltoallMethodType.NotEnabled
|
||||
):
|
||||
with mock.patch.dict(os.environ, {"TRTLLM_FORCE_ALLTOALL_METHOD": "NotEnabled"}, clear=False):
|
||||
ctx_runner = Runner(
|
||||
args.model,
|
||||
mapping,
|
||||
moe_backend="CUTLASS",
|
||||
load_format=args.load_format,
|
||||
moe_backend=args.moe_backend_for_prefill,
|
||||
layer_indices=args.layer_indices,
|
||||
scaled_from=args.scaled_from,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=ctx_batch_size * ctx_seq_len_q,
|
||||
moe_max_num_tokens=16384,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype,
|
||||
use_low_precision_moe_combine=args.use_low_precision_moe_combine,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
@ -221,10 +240,7 @@ if args.run_type == "GEN":
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
attn_workspace=ctx_attn_workspace,
|
||||
)
|
||||
with ctx_runner.replace_routing_method_ctx(
|
||||
balance_method=BalanceMethod.Balanced, balance_ratio=None
|
||||
):
|
||||
run_pack(check=True)
|
||||
run_pack(check=True)
|
||||
del ctx_runner
|
||||
del ctx_attn_workspace
|
||||
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
|
||||
@ -293,6 +309,7 @@ for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product(
|
||||
with runner.replace_routing_method_ctx(
|
||||
balance_method=BalanceMethod[args.balance_method], balance_ratio=balance_ratio
|
||||
):
|
||||
run_pack()
|
||||
if args.use_cuda_graph:
|
||||
with with_multi_stream(True):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
|
||||
@ -833,7 +833,7 @@ class Qwen3NextGatedDeltaNet(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class Qwen3NextLinearDecoderLayer(nn.Module):
|
||||
class Qwen3NextLinearDecoderLayer(DecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1255,6 +1255,7 @@ class Qwen3NextForCausalLM(SpecDecOneEngineForCausalLM[Qwen3NextModel,
|
||||
new_weights = weight_mapper.preprocess_weights(weights)
|
||||
super().load_weights(new_weights, weight_mapper)
|
||||
|
||||
def post_load_weights(self):
|
||||
for idx, layer in enumerate(
|
||||
self.model.layers[:self.config.num_hidden_layers]):
|
||||
if idx == self.config.num_hidden_layers - 1:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from .mark_utils import mark_ranges
|
||||
from .runner_factory import get_runner_cls
|
||||
from .runner_interface import BalanceMethod
|
||||
from .runner import BalanceMethod, Runner
|
||||
|
||||
__all__ = ["BalanceMethod", "get_runner_cls", "mark_ranges"]
|
||||
__all__ = ["BalanceMethod", "Runner", "mark_ranges"]
|
||||
|
||||
@ -1,101 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3DecoderLayer
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._torch.pyexecutor.model_loader import initialize_dummy_weights
|
||||
from tensorrt_llm._torch.utils import AuxStreamType
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from .runner_interface import RunnerBase
|
||||
from .runner_utils import RunnerMixin
|
||||
|
||||
|
||||
class DeepSeekV3Runner(RunnerMixin, RunnerBase):
|
||||
@staticmethod
|
||||
def has_mamba_metadata() -> bool:
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_model_name_or_path: str,
|
||||
mapping: Mapping,
|
||||
*,
|
||||
moe_backend: str,
|
||||
layer_indices: List[int],
|
||||
scaled_from: Optional[int],
|
||||
max_seq_len: int,
|
||||
max_num_tokens: int,
|
||||
moe_max_num_tokens: int,
|
||||
use_low_precision_moe_combine: bool,
|
||||
use_cuda_graph: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_config = ModelConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
mapping=mapping,
|
||||
enable_min_latency=False,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
force_dynamic_quantization=False,
|
||||
spec_config=None,
|
||||
sparse_attention_config=None, # To be loaded from config
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
moe_max_num_tokens=moe_max_num_tokens,
|
||||
moe_load_balancer=None,
|
||||
lora_config=None,
|
||||
allreduce_strategy=AllReduceStrategy.AUTO,
|
||||
mm_encoder_only=False,
|
||||
attn_backend="TRTLLM",
|
||||
moe_backend=moe_backend,
|
||||
moe_disable_finalize_fusion=False,
|
||||
use_low_precision_moe_combine=use_low_precision_moe_combine,
|
||||
skip_create_weights_in_init=True,
|
||||
)
|
||||
pretrained_config = self.model_config.pretrained_config
|
||||
|
||||
with self.scaled_from_ctx(scaled_from, mapping, pretrained_config):
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
|
||||
aux_stream_dict = {
|
||||
AuxStreamType.Attention: aux_stream_list[0],
|
||||
AuxStreamType.MoeShared: aux_stream_list[0],
|
||||
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
|
||||
}
|
||||
|
||||
layers = [
|
||||
DeepseekV3DecoderLayer(
|
||||
model_config=self.model_config,
|
||||
layer_idx=layer_idx,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
)
|
||||
for layer_idx in layer_indices
|
||||
]
|
||||
next_layer_layernorm = RMSNorm(
|
||||
hidden_size=pretrained_config.hidden_size,
|
||||
eps=pretrained_config.rms_norm_eps,
|
||||
dtype=pretrained_config.torch_dtype,
|
||||
)
|
||||
|
||||
# TODO: apply_layerwise_quant_config
|
||||
self.apply_quant_config_exclude_modules(layers, self.model_config.quant_config)
|
||||
for layer in layers:
|
||||
for module in layer.modules():
|
||||
if callable(getattr(module, "create_weights", None)):
|
||||
module.create_weights()
|
||||
layer.cuda()
|
||||
initialize_dummy_weights(layer)
|
||||
for module in layer.modules():
|
||||
if hasattr(module, "post_load_weights") and not getattr(
|
||||
module, "_weights_removed", False
|
||||
):
|
||||
module.post_load_weights()
|
||||
next_layer_layernorm.cuda()
|
||||
initialize_dummy_weights(next_layer_layernorm)
|
||||
for layer, next_layer in zip(layers[:-1], layers[1:]):
|
||||
layer.next_layer_layernorm = next_layer.input_layernorm
|
||||
layers[-1].next_layer_layernorm = next_layer_layernorm
|
||||
|
||||
self.layers = layers
|
||||
@ -1,6 +1,7 @@
|
||||
import nvtx
|
||||
|
||||
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate, Deepseekv3MoE
|
||||
from tensorrt_llm._torch.models.modeling_nemotron_h import MLPLayer, NemotronHMOE
|
||||
from tensorrt_llm._torch.models.modeling_qwen3_next import (
|
||||
Qwen3NextGatedDeltaNet,
|
||||
Qwen3NextSparseMoeBlock,
|
||||
@ -8,11 +9,14 @@ from tensorrt_llm._torch.models.modeling_qwen3_next import (
|
||||
from tensorrt_llm._torch.modules.attention import MLA, Attention
|
||||
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
|
||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_mixer import Mamba2Mixer
|
||||
|
||||
|
||||
def mark_ranges():
|
||||
DeepseekV3Gate.forward = nvtx.annotate("DeepseekV3Gate")(DeepseekV3Gate.forward)
|
||||
Deepseekv3MoE.forward = nvtx.annotate("Deepseekv3MoE")(Deepseekv3MoE.forward)
|
||||
MLPLayer.forward = nvtx.annotate("MLPLayer")(MLPLayer.forward)
|
||||
NemotronHMOE.forward = nvtx.annotate("NemotronHMOE")(NemotronHMOE.forward)
|
||||
Qwen3NextGatedDeltaNet.forward = nvtx.annotate("Qwen3NextGatedDeltaNet")(
|
||||
Qwen3NextGatedDeltaNet.forward
|
||||
)
|
||||
@ -23,3 +27,4 @@ def mark_ranges():
|
||||
Attention.forward = nvtx.annotate("Attention")(Attention.forward)
|
||||
MoE.forward = nvtx.annotate("MoE")(MoE.forward)
|
||||
GatedMLP.forward = nvtx.annotate("GatedMLP")(GatedMLP.forward)
|
||||
Mamba2Mixer.forward = nvtx.annotate("Mamba2Mixer")(Mamba2Mixer.forward)
|
||||
|
||||
@ -1,95 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_qwen3_next import ALL_DECODER_LAYER_TYPES
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._torch.pyexecutor.model_loader import initialize_dummy_weights
|
||||
from tensorrt_llm.functional import AllReduceStrategy
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from .runner_interface import RunnerBase
|
||||
from .runner_utils import RunnerMixin
|
||||
|
||||
|
||||
class Qwen3NextRunner(RunnerMixin, RunnerBase):
|
||||
@staticmethod
|
||||
def has_mamba_metadata() -> bool:
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_model_name_or_path: str,
|
||||
mapping: Mapping,
|
||||
*,
|
||||
moe_backend: str,
|
||||
layer_indices: List[int],
|
||||
scaled_from: Optional[int],
|
||||
max_seq_len: int,
|
||||
max_num_tokens: int,
|
||||
moe_max_num_tokens: int,
|
||||
use_low_precision_moe_combine: bool,
|
||||
use_cuda_graph: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_config = ModelConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
mapping=mapping,
|
||||
enable_min_latency=False,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
force_dynamic_quantization=False,
|
||||
spec_config=None,
|
||||
sparse_attention_config=None, # To be loaded from config
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
moe_max_num_tokens=moe_max_num_tokens,
|
||||
moe_load_balancer=None,
|
||||
lora_config=None,
|
||||
allreduce_strategy=AllReduceStrategy.AUTO,
|
||||
mm_encoder_only=False,
|
||||
attn_backend="TRTLLM",
|
||||
moe_backend=moe_backend,
|
||||
moe_disable_finalize_fusion=False,
|
||||
use_low_precision_moe_combine=use_low_precision_moe_combine,
|
||||
skip_create_weights_in_init=True,
|
||||
)
|
||||
pretrained_config = self.model_config.pretrained_config
|
||||
|
||||
with self.scaled_from_ctx(scaled_from, mapping, pretrained_config):
|
||||
aux_stream = torch.cuda.Stream()
|
||||
layers = [
|
||||
ALL_DECODER_LAYER_TYPES[pretrained_config.layer_types[layer_idx]](
|
||||
self.model_config,
|
||||
layer_idx,
|
||||
aux_stream,
|
||||
)
|
||||
for layer_idx in layer_indices
|
||||
]
|
||||
next_layer_layernorm = RMSNorm(
|
||||
hidden_size=pretrained_config.hidden_size,
|
||||
eps=pretrained_config.rms_norm_eps,
|
||||
dtype=pretrained_config.torch_dtype,
|
||||
use_gemma=True,
|
||||
)
|
||||
|
||||
# TODO: apply_layerwise_quant_config
|
||||
self.apply_quant_config_exclude_modules(layers, self.model_config.quant_config)
|
||||
for layer in layers:
|
||||
for module in layer.modules():
|
||||
if callable(getattr(module, "create_weights", None)):
|
||||
module.create_weights()
|
||||
layer.cuda()
|
||||
initialize_dummy_weights(layer)
|
||||
for module in layer.modules():
|
||||
if hasattr(module, "post_load_weights") and not getattr(
|
||||
module, "_weights_removed", False
|
||||
):
|
||||
module.post_load_weights()
|
||||
next_layer_layernorm.cuda()
|
||||
initialize_dummy_weights(next_layer_layernorm)
|
||||
for layer, next_layer in zip(layers[:-1], layers[1:]):
|
||||
layer.next_layer_layernorm = next_layer.input_layernorm
|
||||
layers[-1].next_layer_layernorm = next_layer_layernorm
|
||||
|
||||
self.layers = layers
|
||||
@ -1,34 +1,49 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
import unittest.mock
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.model_config
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_utils import PostInitCaller, skip_forward
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import WideEPMoE
|
||||
from tensorrt_llm._torch.modules.linear import Linear, WeightMode
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from tensorrt_llm._torch.pyexecutor._util import get_kv_cache_manager_cls
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import is_mla, is_qwen3_next
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import (
|
||||
is_mla,
|
||||
is_nemotron_hybrid,
|
||||
is_qwen3_next,
|
||||
load_pretrained_config,
|
||||
)
|
||||
from tensorrt_llm._torch.pyexecutor.model_loader import (
|
||||
ModelLoader,
|
||||
_construct_checkpoint_loader,
|
||||
validate_and_set_kv_cache_quant,
|
||||
validate_and_set_mamba_ssm_cache_dtype,
|
||||
)
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs
|
||||
from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.bindings.internal.batch_manager import CacheType
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, MoeConfig, TorchLlmArgs
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from .runner_interface import BalanceMethod
|
||||
|
||||
class BalanceMethod(IntEnum):
|
||||
NotModified = 1
|
||||
Balanced = 2
|
||||
ImbalancedRanks = 3
|
||||
ImbalancedExperts = 4
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
@ -102,31 +117,28 @@ def test_get_balanced_selection():
|
||||
raise ValueError("tokens per expert is not balanced")
|
||||
|
||||
|
||||
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size):
|
||||
num_tokens, top_k = imbalanced_experts.shape
|
||||
dtype = imbalanced_experts.dtype
|
||||
device = imbalanced_experts.device
|
||||
balanced_experts = get_balanced_selection_no_cache(
|
||||
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
||||
)
|
||||
def get_num_balanced_tokens(num_tokens, top_k, num_experts, dp_size, balance_ratio):
|
||||
if balance_ratio == 0.0:
|
||||
num_balanced_tokens = 0
|
||||
return 0
|
||||
else:
|
||||
# Activate all experts
|
||||
min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, dp_size * top_k))
|
||||
num_balanced_tokens = min_num_balanced_tokens + round(
|
||||
return min_num_balanced_tokens + round(
|
||||
(num_tokens - min_num_balanced_tokens) * balance_ratio
|
||||
)
|
||||
mixed_experts = torch.cat(
|
||||
[balanced_experts[:num_balanced_tokens], imbalanced_experts[num_balanced_tokens:]]
|
||||
)
|
||||
return mixed_experts
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_all_to_one_selection(
|
||||
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
|
||||
):
|
||||
num_balanced_tokens = get_num_balanced_tokens(
|
||||
num_tokens, top_k, num_experts, dp_size, balance_ratio
|
||||
)
|
||||
balanced_experts = get_balanced_selection_no_cache(
|
||||
num_balanced_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
||||
)
|
||||
num_imbalanced_tokens = num_tokens - num_balanced_tokens
|
||||
experts_per_rank = num_experts // ep_size
|
||||
if top_k > experts_per_rank:
|
||||
raise ValueError(
|
||||
@ -134,29 +146,34 @@ def get_all_to_one_selection(
|
||||
)
|
||||
imbalanced_experts = (
|
||||
torch.arange(
|
||||
dp_rank * num_tokens * top_k,
|
||||
(dp_rank + 1) * num_tokens * top_k,
|
||||
dp_rank * num_imbalanced_tokens * top_k,
|
||||
(dp_rank + 1) * num_imbalanced_tokens * top_k,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
).view(num_tokens, top_k)
|
||||
).view(num_imbalanced_tokens, top_k)
|
||||
% experts_per_rank
|
||||
)
|
||||
imbalanced_experts = imbalanced_experts.sort(dim=-1).values
|
||||
return apply_balance_ratio(
|
||||
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
|
||||
)
|
||||
mixed_experts = torch.cat([balanced_experts, imbalanced_experts])
|
||||
return mixed_experts.sort(dim=-1).values
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_balanced_rank_imbalanced_expert_selection(
|
||||
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
|
||||
):
|
||||
num_balanced_tokens = get_num_balanced_tokens(
|
||||
num_tokens, top_k, num_experts, dp_size, balance_ratio
|
||||
)
|
||||
balanced_experts = get_balanced_selection_no_cache(
|
||||
num_balanced_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
||||
)
|
||||
num_imbalanced_tokens = num_tokens - num_balanced_tokens
|
||||
experts_per_rank = num_experts // ep_size
|
||||
active_experts_per_rank = ceil_div(top_k, ep_size)
|
||||
# Select expert from [0, active_experts_per_rank * ep_size),
|
||||
# then scale to [0, experts_per_rank * ep_size)
|
||||
narrow_experts = get_balanced_selection_no_cache(
|
||||
num_tokens,
|
||||
num_imbalanced_tokens,
|
||||
top_k,
|
||||
active_experts_per_rank * ep_size,
|
||||
dtype,
|
||||
@ -169,9 +186,8 @@ def get_balanced_rank_imbalanced_expert_selection(
|
||||
narrow_experts // active_experts_per_rank * experts_per_rank
|
||||
+ narrow_experts % active_experts_per_rank
|
||||
)
|
||||
return apply_balance_ratio(
|
||||
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
|
||||
)
|
||||
mixed_experts = torch.cat([balanced_experts, imbalanced_experts])
|
||||
return mixed_experts.sort(dim=-1).values
|
||||
|
||||
|
||||
def make_balanced_routing_method(
|
||||
@ -339,36 +355,91 @@ def make_forward_impl_check(moe_module, forward_impl_orig):
|
||||
return forward_impl
|
||||
|
||||
|
||||
class RunnerMixin(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def has_mamba_metadata() -> bool:
|
||||
pass
|
||||
class Runner:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_model_name_or_path: str,
|
||||
mapping: Mapping,
|
||||
*,
|
||||
load_format: str,
|
||||
moe_backend: str,
|
||||
layer_indices: list[int],
|
||||
scaled_from: Optional[int],
|
||||
max_seq_len: int,
|
||||
max_num_tokens: int,
|
||||
moe_max_num_tokens: int,
|
||||
kv_cache_dtype,
|
||||
mamba_ssm_cache_dtype: str,
|
||||
use_low_precision_moe_combine: bool,
|
||||
use_cuda_graph: bool,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
checkpoint_loader = _construct_checkpoint_loader("pytorch", None, "HF")
|
||||
# Please refer to `tensorrt_llm/_torch/pyexecutor/model_loader.py` for effective args
|
||||
llm_args = TorchLlmArgs(
|
||||
model=pretrained_model_name_or_path,
|
||||
load_format=load_format,
|
||||
**{} if use_cuda_graph else {"cuda_graph_config": None},
|
||||
moe_config=MoeConfig(
|
||||
backend=moe_backend,
|
||||
max_num_tokens=moe_max_num_tokens,
|
||||
disable_finalize_fusion=False,
|
||||
use_low_precision_moe_combine=use_low_precision_moe_combine,
|
||||
),
|
||||
attn_backend="TRTLLM",
|
||||
kv_cache_config=KvCacheConfig(
|
||||
dtype=kv_cache_dtype, mamba_ssm_cache_dtype=mamba_ssm_cache_dtype
|
||||
),
|
||||
)
|
||||
model_loader = ModelLoader(
|
||||
llm_args=llm_args,
|
||||
mapping=mapping,
|
||||
spec_config=None,
|
||||
sparse_attention_config=None,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
with self.scaled_from_ctx(scaled_from, mapping), self.skip_unused_layers_ctx(layer_indices):
|
||||
model, _ = model_loader.load(
|
||||
checkpoint_dir=pretrained_model_name_or_path, checkpoint_loader=checkpoint_loader
|
||||
)
|
||||
|
||||
self.layers = [model.model.layers[i] for i in layer_indices]
|
||||
self.model_config = model.model_config
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def scaled_from_ctx(scaled_from, mapping, pretrained_config):
|
||||
def scaled_from_ctx(scaled_from, mapping):
|
||||
if scaled_from is None:
|
||||
yield
|
||||
return
|
||||
# To run the problem size of $B$ GPUs on $A$ GPUs, we need:
|
||||
# (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change.
|
||||
# (2) MoE: If EP, reduce the number of experts; If TP, reduce head size.
|
||||
# Maintain the result of AllToAll method selection because it is affected by EP size.
|
||||
if not mapping.enable_attention_dp:
|
||||
if hasattr(pretrained_config, "index_n_heads"):
|
||||
raise NotImplementedError("Not support Indexer TP for weak scaling")
|
||||
pretrained_config.num_attention_heads = (
|
||||
pretrained_config.num_attention_heads // scaled_from * mapping.tp_size
|
||||
)
|
||||
pretrained_config.num_key_value_heads = (
|
||||
pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size
|
||||
)
|
||||
if mapping.moe_ep_size != mapping.world_size:
|
||||
raise NotImplementedError("Not support MoE TP for weak scaling")
|
||||
pretrained_config.n_routed_experts = (
|
||||
pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size
|
||||
)
|
||||
|
||||
def make_load_pretrained_config(mapping, load_pretrained_config_orig):
|
||||
# To run the problem size of $B$ GPUs on $A$ GPUs, we need:
|
||||
# (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change.
|
||||
# (2) MoE: If EP, reduce the number of experts; If TP, reduce head size.
|
||||
# Maintain the result of AllToAll method selection because it is affected by EP size.
|
||||
def load_pretrained_config(*args, **kwargs):
|
||||
pretrained_config = load_pretrained_config_orig(*args, **kwargs)
|
||||
if not mapping.enable_attention_dp:
|
||||
if hasattr(pretrained_config, "index_n_heads"):
|
||||
raise NotImplementedError("Not support Indexer TP for weak scaling")
|
||||
pretrained_config.num_attention_heads = (
|
||||
pretrained_config.num_attention_heads // scaled_from * mapping.tp_size
|
||||
)
|
||||
pretrained_config.num_key_value_heads = (
|
||||
pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size
|
||||
)
|
||||
if mapping.moe_ep_size != mapping.tp_size:
|
||||
raise NotImplementedError("Not support MoE TP for weak scaling")
|
||||
pretrained_config.n_routed_experts = (
|
||||
pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size
|
||||
)
|
||||
return pretrained_config
|
||||
|
||||
return load_pretrained_config
|
||||
|
||||
def make_select_alltoall_method_type(select_alltoall_method_type_orig):
|
||||
def select_alltoall_method_type(
|
||||
@ -408,6 +479,9 @@ class RunnerMixin(ABC):
|
||||
select_alltoall_method_type_cutlass = CutlassFusedMoE.select_alltoall_method_type
|
||||
select_alltoall_method_type_trtllm_gen = TRTLLMGenFusedMoE.select_alltoall_method_type
|
||||
select_alltoall_method_type_wide_ep = WideEPMoE.select_alltoall_method_type
|
||||
tensorrt_llm._torch.model_config.load_pretrained_config = make_load_pretrained_config(
|
||||
mapping, load_pretrained_config
|
||||
)
|
||||
CutlassFusedMoE.select_alltoall_method_type = make_select_alltoall_method_type_2(
|
||||
select_alltoall_method_type_cutlass
|
||||
)
|
||||
@ -420,40 +494,50 @@ class RunnerMixin(ABC):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tensorrt_llm._torch.model_config.load_pretrained_config = load_pretrained_config
|
||||
CutlassFusedMoE.select_alltoall_method_type = select_alltoall_method_type_cutlass
|
||||
TRTLLMGenFusedMoE.select_alltoall_method_type = select_alltoall_method_type_trtllm_gen
|
||||
WideEPMoE.select_alltoall_method_type = select_alltoall_method_type_wide_ep
|
||||
|
||||
@staticmethod
|
||||
def apply_quant_config_exclude_modules(layers, quant_config):
|
||||
# Please refer to tensorrt_llm/_torch/models/modeling_utils.py
|
||||
new_quant_config = QuantConfig(kv_cache_quant_algo=quant_config.kv_cache_quant_algo)
|
||||
for layer in layers:
|
||||
for name, module in layer.named_modules():
|
||||
name = f"model.layers.{layer.layer_idx}.{name}"
|
||||
candidates = [name]
|
||||
if isinstance(module, Linear):
|
||||
weight_mode = module.weights_loading_config.weight_mode
|
||||
if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
|
||||
# sometimes gate and up proj are not packed in the checkpoint,
|
||||
# but they still share the same exclusion rule
|
||||
candidates += [
|
||||
name.replace("gate_up_proj", "gate_proj"),
|
||||
name.replace("gate_up_proj", "up_proj"),
|
||||
]
|
||||
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
|
||||
# sometimes q_proj, k_proj and v_proj are not packed in the checkpoint,
|
||||
# but they still share the same exclusion rule
|
||||
candidates += [
|
||||
name.replace("qkv_proj", "q_proj"),
|
||||
name.replace("qkv_proj", "k_proj"),
|
||||
name.replace("qkv_proj", "v_proj"),
|
||||
]
|
||||
is_excluded = any(
|
||||
quant_config.is_module_excluded_from_quantization(n) for n in candidates
|
||||
@contextlib.contextmanager
|
||||
def skip_unused_layers_ctx(layer_indices):
|
||||
call_orig = PostInitCaller.__call__
|
||||
|
||||
def call_new(cls, *args, **kwargs):
|
||||
model = call_orig(cls, *args, **kwargs)
|
||||
for module in (
|
||||
model.prologue + model.model.prologue + model.model.epilogue + model.epilogue
|
||||
):
|
||||
skip_forward(module)
|
||||
num_hidden_layers = model.model_config.pretrained_config.num_hidden_layers
|
||||
if hasattr(model.model, "embed_tokens"):
|
||||
skip_forward(model.model.embed_tokens)
|
||||
for layer_idx in range(num_hidden_layers):
|
||||
layer = model.model.layers[layer_idx]
|
||||
if layer_idx not in layer_indices:
|
||||
# keep next layer's input_layernorm's weights for fusion
|
||||
skip_forward(
|
||||
layer,
|
||||
ignore_modules=[layer.input_layernorm]
|
||||
if layer_idx - 1 in layer_indices
|
||||
and hasattr(model.model.layers[layer_idx - 1], "next_layer_layernorm")
|
||||
else None,
|
||||
)
|
||||
if hasattr(model.model, "norm"):
|
||||
skip_forward(
|
||||
model.model.norm,
|
||||
ignore_modules=[model.model.norm]
|
||||
if num_hidden_layers - 1 in layer_indices
|
||||
else None,
|
||||
)
|
||||
if is_excluded and getattr(module, "quant_config", None) is not None:
|
||||
module.quant_config = new_quant_config
|
||||
return model
|
||||
|
||||
PostInitCaller.__call__ = call_new
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
PostInitCaller.__call__ = call_orig
|
||||
|
||||
def create_run_pack(
|
||||
self,
|
||||
@ -466,9 +550,8 @@ class RunnerMixin(ABC):
|
||||
kv_cache_manager: KVCacheManager,
|
||||
attn_workspace: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.model_config.moe_backend == "TRTLLM" and os.getenv("TRTLLM_ENABLE_PDL") != "1":
|
||||
raise ValueError("Suggest to set TRTLLM_ENABLE_PDL=1 when moe_backend is TRTLLM")
|
||||
world_size = mpi_world_size()
|
||||
pretrained_config = self.model_config.pretrained_config
|
||||
AttentionCls = get_attention_backend(
|
||||
self.model_config.attn_backend, self.model_config.sparse_attention_config
|
||||
)
|
||||
@ -499,7 +582,7 @@ class RunnerMixin(ABC):
|
||||
)
|
||||
attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q] * world_size
|
||||
attn_metadata.prepare()
|
||||
hidden_size = self.model_config.pretrained_config.hidden_size
|
||||
hidden_size = pretrained_config.hidden_size
|
||||
position_ids = torch.tensor(
|
||||
[list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) * batch_size],
|
||||
dtype=torch.int32,
|
||||
@ -513,9 +596,14 @@ class RunnerMixin(ABC):
|
||||
)
|
||||
kwargs = {}
|
||||
|
||||
if self.has_mamba_metadata():
|
||||
# Please refer to `tensorrt_llm/_torch/models/modeling_qwen3_next.py` for `mamba_metadata`
|
||||
mamba_metadata = Mamba2Metadata(attn_metadata.max_num_requests, chunk_size=128)
|
||||
if is_nemotron_hybrid(pretrained_config) or is_qwen3_next(pretrained_config):
|
||||
# Please refer to `tensorrt_llm/_torch/models/modeling_qwen3_next.py` for the magic number chunk_size=128
|
||||
mamba_metadata = Mamba2Metadata(
|
||||
attn_metadata.max_num_requests,
|
||||
chunk_size=128
|
||||
if is_qwen3_next(pretrained_config)
|
||||
else pretrained_config.chunk_size,
|
||||
)
|
||||
mamba_metadata.prepare(attn_metadata)
|
||||
kwargs["mamba_metadata"] = mamba_metadata
|
||||
|
||||
@ -524,8 +612,15 @@ class RunnerMixin(ABC):
|
||||
with model_extra_attrs(self.model_config.extra_attrs):
|
||||
get_model_extra_attrs()["attention_metadata"] = weakref.ref(attn_metadata)
|
||||
with torch.inference_mode():
|
||||
# TODO: to be more general, we should call DecoderModel.forward
|
||||
for layer in self.layers:
|
||||
output = layer(position_ids, output[0], attn_metadata, output[1], **kwargs)
|
||||
residual_fusion = hasattr(layer, "next_layer_layernorm")
|
||||
if residual_fusion:
|
||||
output = layer(
|
||||
position_ids, output[0], attn_metadata, output[1], **kwargs
|
||||
)
|
||||
else:
|
||||
output = layer(position_ids, output[0], attn_metadata, **kwargs), None
|
||||
if check:
|
||||
if output[0].isnan().any():
|
||||
raise ValueError("Has nan, please fix weights initialization")
|
||||
@ -554,12 +649,20 @@ class RunnerMixin(ABC):
|
||||
f' please set balance_method to "NotModified"'
|
||||
)
|
||||
original_methods = []
|
||||
dp_rank = self.model_config.mapping.rank // (
|
||||
self.model_config.mapping.world_size // self.model_config.mapping.dp_size
|
||||
dp_rank = (
|
||||
self.model_config.mapping.tp_rank
|
||||
if self.model_config.mapping.enable_attention_dp
|
||||
else 0
|
||||
)
|
||||
moe_modules = []
|
||||
for layer in self.layers:
|
||||
moe_module = layer.mlp.experts
|
||||
if layer.__class__.__name__ == "NemotronHLayer":
|
||||
if layer.layer_type == "E":
|
||||
moe_modules.append(layer.mixer.experts)
|
||||
else:
|
||||
moe_modules.append(layer.mlp.experts)
|
||||
|
||||
for moe_module in moe_modules:
|
||||
# Replace `routing_method.apply` for normal cases
|
||||
apply_method_orig = moe_module.routing_method.apply
|
||||
moe_module.routing_method.apply = make_balanced_routing_method(
|
||||
@ -579,8 +682,8 @@ class RunnerMixin(ABC):
|
||||
moe_module.run_moe = make_balanced_run_moe(
|
||||
moe_module,
|
||||
run_moe_orig,
|
||||
layer.mlp.experts.routing_method.top_k,
|
||||
layer.mlp.experts.num_experts,
|
||||
moe_module.routing_method.top_k,
|
||||
moe_module.num_experts,
|
||||
balance_method,
|
||||
balance_ratio,
|
||||
self.model_config.mapping.dp_size,
|
||||
@ -598,10 +701,9 @@ class RunnerMixin(ABC):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for layer, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip(
|
||||
self.layers, original_methods
|
||||
for moe_module, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip(
|
||||
moe_modules, original_methods
|
||||
):
|
||||
moe_module = layer.mlp.experts
|
||||
moe_module.routing_method.apply = apply_method_orig
|
||||
if isinstance(moe_module, TRTLLMGenFusedMoE):
|
||||
moe_module.run_moe = run_moe_orig
|
||||
@ -614,10 +716,14 @@ class RunnerMixin(ABC):
|
||||
tokens_per_block,
|
||||
max_batch_size,
|
||||
max_seq_len,
|
||||
kv_cache_dtype,
|
||||
mamba_ssm_cache_dtype,
|
||||
layer_indices,
|
||||
):
|
||||
# Please refer to `tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` for `tokens_per_block`
|
||||
model_config = ModelConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
validate_and_set_kv_cache_quant(model_config, kv_cache_dtype)
|
||||
validate_and_set_mamba_ssm_cache_dtype(model_config, mamba_ssm_cache_dtype)
|
||||
if model_config.enable_flash_mla:
|
||||
assert tokens_per_block == 64
|
||||
|
||||
@ -628,18 +734,17 @@ class RunnerMixin(ABC):
|
||||
max_tokens=max_batch_size * round_up(max_seq_len, tokens_per_block),
|
||||
enable_block_reuse=False,
|
||||
)
|
||||
kv_cache_dtype = torch_dtype_to_binding(
|
||||
{
|
||||
None: torch.bfloat16,
|
||||
"FP8": torch.float8_e4m3fn,
|
||||
}[model_config.quant_config.kv_cache_quant_algo]
|
||||
)
|
||||
kv_cache_dtype = {
|
||||
"FP8": tensorrt_llm.bindings.DataType.FP8,
|
||||
"NVFP4": tensorrt_llm.bindings.DataType.NVFP4,
|
||||
None: torch_dtype_to_binding(config.torch_dtype),
|
||||
}[model_config.quant_config.kv_cache_quant_algo]
|
||||
if is_mla(config):
|
||||
layer_mask = [i in layer_indices for i in range(config.num_hidden_layers)]
|
||||
num_layers = sum(layer_mask)
|
||||
kv_cache_manager = kv_cache_manager_cls(
|
||||
kv_cache_config,
|
||||
CacheType.SELFKONLY,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=1,
|
||||
head_dim=model_config.pretrained_config.kv_lora_rank
|
||||
@ -649,9 +754,46 @@ class RunnerMixin(ABC):
|
||||
max_batch_size=max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
spec_config=None,
|
||||
layer_mask=layer_mask,
|
||||
sparse_attn_config=model_config.sparse_attention_config,
|
||||
)
|
||||
elif is_nemotron_hybrid(config):
|
||||
mamba_layer_mask = [
|
||||
i in layer_indices and char == "M"
|
||||
for i, char in enumerate(config.hybrid_override_pattern)
|
||||
]
|
||||
layer_mask = [
|
||||
i in layer_indices and char == "*"
|
||||
for i, char in enumerate(config.hybrid_override_pattern)
|
||||
]
|
||||
num_mamba_layers = sum(mamba_layer_mask)
|
||||
num_layers = sum(layer_mask)
|
||||
kv_cache_manager = kv_cache_manager_cls(
|
||||
# mamba cache parameters
|
||||
config.ssm_state_size,
|
||||
config.conv_kernel,
|
||||
config.mamba_num_heads,
|
||||
config.n_groups,
|
||||
config.mamba_head_dim,
|
||||
num_mamba_layers,
|
||||
mamba_layer_mask,
|
||||
config.torch_dtype,
|
||||
model_config.quant_config.mamba_ssm_cache_dtype,
|
||||
# kv cache parameters
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
layer_mask=layer_mask,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=config.head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
spec_config=None,
|
||||
)
|
||||
elif is_qwen3_next(config):
|
||||
mamba_layer_mask = [
|
||||
i in layer_indices
|
||||
@ -680,7 +822,7 @@ class RunnerMixin(ABC):
|
||||
model_config.quant_config.mamba_ssm_cache_dtype,
|
||||
# kv cache parameters
|
||||
kv_cache_config,
|
||||
CacheType.SELF,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
layer_mask=layer_mask,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
@ -1,13 +0,0 @@
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import load_pretrained_config
|
||||
|
||||
from .deepseekv3_runner import DeepSeekV3Runner
|
||||
from .qwen3_next_runner import Qwen3NextRunner
|
||||
|
||||
|
||||
def get_runner_cls(pretrained_model_name_or_path: str) -> type:
|
||||
pretrained_config = load_pretrained_config(pretrained_model_name_or_path)
|
||||
return {
|
||||
"deepseek_v3": DeepSeekV3Runner,
|
||||
"deepseek_v32": DeepSeekV3Runner,
|
||||
"qwen3_next": Qwen3NextRunner,
|
||||
}[pretrained_config.model_type]
|
||||
@ -1,49 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
|
||||
|
||||
class BalanceMethod(IntEnum):
|
||||
NotModified = 1
|
||||
Balanced = 2
|
||||
ImbalancedRanks = 3
|
||||
ImbalancedExperts = 4
|
||||
|
||||
|
||||
class RunnerBase(ABC):
|
||||
@abstractmethod
|
||||
def create_run_pack(
|
||||
self,
|
||||
run_type: str,
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
seq_len_kv_cache: int,
|
||||
kv_cache_manager: KVCacheManager,
|
||||
attn_workspace: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_kv_cache_manager(
|
||||
pretrained_model_name_or_path,
|
||||
mapping,
|
||||
tokens_per_block,
|
||||
max_batch_size,
|
||||
max_seq_len,
|
||||
layer_indices,
|
||||
):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_mapping(enable_attention_dp: bool):
|
||||
pass
|
||||
@ -84,6 +84,7 @@ l0_b200:
|
||||
- unittest/_torch/modeling -k "modeling_gpt_oss"
|
||||
- unittest/_torch/modeling/test_modeling_exaone_moe.py
|
||||
- unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_dep[1]
|
||||
- unittest/tools/test_layer_wise_benchmarks.py::test_nemotron_gen_dep[1]
|
||||
- unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1]
|
||||
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8
|
||||
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_nvfp4[enable_configurable_moe-disable_finalize_fusion-TRTLLM-dtype1]
|
||||
|
||||
@ -54,7 +54,6 @@ def test_deepseek_r1_ctx_tep(llm_root, world_size):
|
||||
**os.environ,
|
||||
"NP": f"{world_size:d}",
|
||||
"PROFILE_DIR": profile_dir,
|
||||
"TRTLLM_ENABLE_PDL": "1",
|
||||
},
|
||||
)
|
||||
check_call(
|
||||
@ -122,6 +121,35 @@ def test_deepseek_r1_gen_scaled_from_16_dep(llm_root, world_size):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_nemotron_gen_dep(llm_root, world_size):
|
||||
if torch.cuda.device_count() < world_size:
|
||||
pytest.skip(f"needs {world_size:d} GPUs to run this test")
|
||||
model_root = llm_models_root(check=True)
|
||||
profile_dir = f"profiles/test_nemotron_gen_dep_{world_size}"
|
||||
check_call(
|
||||
[
|
||||
"./mpi_launch.sh",
|
||||
"./run.sh",
|
||||
"config_gen.yaml",
|
||||
"--model",
|
||||
model_root / "NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
|
||||
"--layer-indices=4,5,6",
|
||||
"--mamba-ssm-cache-dtype=float16",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
**os.environ,
|
||||
"NP": f"{world_size:d}",
|
||||
"PROFILE_DIR": profile_dir,
|
||||
},
|
||||
)
|
||||
check_call(
|
||||
["python3", "parse.py", "--profile-dir", profile_dir, f"--world-size={world_size}"],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_qwen3_next_gen_tep(llm_root, world_size):
|
||||
if torch.cuda.device_count() < world_size:
|
||||
@ -137,6 +165,7 @@ def test_qwen3_next_gen_tep(llm_root, world_size):
|
||||
model_root / "Qwen3" / "Qwen3-Next-80B-A3B-Instruct",
|
||||
"--layer-indices=6,7",
|
||||
"--no-enable-attention-dp",
|
||||
"--mamba-ssm-cache-dtype=float16",
|
||||
"--moe-backend=TRTLLM",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
@ -144,7 +173,6 @@ def test_qwen3_next_gen_tep(llm_root, world_size):
|
||||
**os.environ,
|
||||
"NP": f"{world_size:d}",
|
||||
"PROFILE_DIR": profile_dir,
|
||||
"TRTLLM_ENABLE_PDL": "1",
|
||||
},
|
||||
)
|
||||
check_call(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user