From a7fe043b13d36c45e21168753481474af32bdd50 Mon Sep 17 00:00:00 2001 From: Tailing Yuan Date: Mon, 5 Jan 2026 11:23:04 +0800 Subject: [PATCH] [None][feat] Layer-wise benchmarks: support TEP balance, polish slurm scripts (#10237) Signed-off-by: Tailing Yuan --- .gitignore | 1 + examples/layer_wise_benchmarks/README.md | 67 +++- .../middleware/exclude_slurm_envs | 8 + examples/layer_wise_benchmarks/mpi_launch.sh | 10 +- examples/layer_wise_benchmarks/parse.py | 55 ++- examples/layer_wise_benchmarks/run.py | 120 +++--- examples/layer_wise_benchmarks/run.sh | 8 +- .../slurm_init_containers.sh | 21 +- .../layer_wise_benchmarks/slurm_launch.sh | 5 +- .../slurm_query_container_name.sh | 5 + examples/layer_wise_benchmarks/template.html | 8 +- .../tools/layer_wise_benchmarks/mark_utils.py | 3 +- .../layer_wise_benchmarks/runner_utils.py | 341 +++++++++++++----- .../tools/test_layer_wise_benchmarks.py | 2 - 14 files changed, 473 insertions(+), 181 deletions(-) create mode 100755 examples/layer_wise_benchmarks/middleware/exclude_slurm_envs diff --git a/.gitignore b/.gitignore index 840a515397..40d086f60e 100644 --- a/.gitignore +++ b/.gitignore @@ -77,6 +77,7 @@ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmha_v2_cu/ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp .devcontainer/.env +/examples/layer_wise_benchmarks/autotuner_cache/ /examples/layer_wise_benchmarks/profiles/ # User config files diff --git a/examples/layer_wise_benchmarks/README.md b/examples/layer_wise_benchmarks/README.md index 63b8f6f3da..af2663c9aa 100644 --- a/examples/layer_wise_benchmarks/README.md +++ b/examples/layer_wise_benchmarks/README.md @@ -15,6 +15,9 @@ pip install -e ../.. **Step 3:** In the container, run benchmarks and generate profiles: ```bash +# Set autotune cache path +export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache + # Run DeepSeek-R1 NVFP4 NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml @@ -24,16 +27,16 @@ NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2- NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM # Run DeepSeek-V3.2-Exp with 32k context length -NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769 -NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769 +NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769 +NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --seq-len-kv-cache 32769 # Run with attention TP NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp # Run with attention TP and TRTLLMGen -NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified -NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified +NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM +NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM # Run with MTP3 NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --batch-size 32 --seq-len-q 4 @@ -48,7 +51,7 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID # Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp -# Run Qwen3-Next (balanced routing is not implemented) +# Run Qwen3-Next NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 4 NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 512 @@ -69,13 +72,13 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --balance-method ImbalancedExperts ### Run with Slurm > Tips: -> 1. If you have a running Slurm job, please skip step 1 and go straight to step 2 and 3. -> 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2 and run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_query_container_name.sh)` to get it. +> 1. If you have a running Slurm job, you can set environment variable `export SLURM_JOB_ID=aaa` and skip step 1. +> 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2. Just run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(./slurm_query_container_name.sh)` to get it. -**Step 1:** On the controller node, allocate one or multiple nodes, and record the `SLURM_JOB_ID`: +**Step 1:** On the controller node, allocate one or multiple nodes, and export the `SLURM_JOB_ID`: ```bash -SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh) +export SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh) ``` Please fill the variables in `./slurm_alloc.sh`. @@ -83,28 +86,47 @@ Please fill the variables in `./slurm_alloc.sh`. **Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node: ```bash -SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_init_containers.sh +./slurm_init_containers.sh ``` It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once. > Tips: If you want to change the image, no need to reallocate Slurm jobs. Just start another container by running step 2 with `export CONTAINER_NAME=aaa`, and step 3 will run in the container specified by the `CONTAINER_NAME` env. +**(Optional) Get an interactive shell** + +```bash +NODES=1 NP=1 ./slurm_launch.sh --overlap --pty middleware/exclude_slurm_envs bash +``` + +The `--overlap` option allows this shell to share the node with other jobs. The middleware enables nested MPI process spawning from within Slurm jobs. + +You may compile C++ extensions in the interactive shell: + +```bash +cd ../.. +export CCACHE_DIR=$(realpath cpp/.ccache) +python3 scripts/build_wheel.py --cuda_architectures native --no-venv --skip_building_wheel -G Ninja --use_ccache --clean +``` + **Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes: ```bash +# Set autotune cache path +export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache + # Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable -SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP +NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP # Run with TRTLLMGen -SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM +NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM # Run with DeepEPLowLatency -SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP +NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP -# You can run 4-GPU and 8-GPU tasks without reallocate the slurm job -SLURM_JOB_ID=$SLURM_JOB_ID NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml -SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gtx.yaml +# You can run 4-GPU and 8-GPU tasks without reallocating the slurm job +NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml +NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gen.yaml ``` ### Batched run @@ -131,9 +153,11 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID Run the following command in the container: ```bash +# Parse the profile at the default directory python3 parse.py --world-size 4 -# Specify the location of the .nsys-rep file +# Specify the file path +python3 parse.py --file-path profiles/report_np4_rank0.nsys-rep python3 parse.py --profile-dir ./profiles --world-size 4 --rank 0 # Parse a specific module. The module must appear exactly once in each run. @@ -145,6 +169,15 @@ You will receive three reports, each containing kernel timing statistics grouped 2. A CSV report at `profiles/report_np4_rank0.csv` 3. An HTML report at `profiles/report_np4_rank0.html` +## Developer utilities + +1. Less startup time when debug a model + 1. Disable autotuner: add `--no-enable-autotuner` option + 2. Disable nsys profile: set `PROFILE=0` environment variable +2. Capture more information + 1. Enable GPU metrics: set `GPU_METRICS=1` environment variable + 2. Enable backtrace: set `BACKTRACE=1` environment variable + ## Trouble shooting 1. Error `fp8 blockscale gemm only support Hopper` on Blackwell. diff --git a/examples/layer_wise_benchmarks/middleware/exclude_slurm_envs b/examples/layer_wise_benchmarks/middleware/exclude_slurm_envs new file mode 100755 index 0000000000..1b0b504d6c --- /dev/null +++ b/examples/layer_wise_benchmarks/middleware/exclude_slurm_envs @@ -0,0 +1,8 @@ +#!/bin/bash + +set -euo pipefail + +# Clear slurm envs +unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true)) + +"$@" diff --git a/examples/layer_wise_benchmarks/mpi_launch.sh b/examples/layer_wise_benchmarks/mpi_launch.sh index 24f7643beb..44f6d28dcf 100755 --- a/examples/layer_wise_benchmarks/mpi_launch.sh +++ b/examples/layer_wise_benchmarks/mpi_launch.sh @@ -3,8 +3,12 @@ set -euo pipefail # Clear slurm envs -unset $(env | grep -i slurm | awk -F'=' '{print $1}') -unset $(env | grep MPI | awk -F'=' '{print $1}') +unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true)) + +extra_args= +if [ -v TLLM_AUTOTUNER_CACHE_PATH ]; then + extra_args+="-x TLLM_AUTOTUNER_CACHE_PATH" +fi set -x -mpirun --allow-run-as-root --np ${NP} "$@" +mpirun --allow-run-as-root --np ${NP} $extra_args "$@" diff --git a/examples/layer_wise_benchmarks/parse.py b/examples/layer_wise_benchmarks/parse.py index 9253a9115d..3ebd479922 100644 --- a/examples/layer_wise_benchmarks/parse.py +++ b/examples/layer_wise_benchmarks/parse.py @@ -5,6 +5,7 @@ import json import re import sqlite3 import subprocess +import sys from collections import defaultdict from pathlib import Path @@ -14,8 +15,9 @@ import pandas as pd # Parse cmdline parser = argparse.ArgumentParser() +parser.add_argument("--file-path", type=str) parser.add_argument("--profile-dir", type=str, default="profiles") -parser.add_argument("--world-size", "--np", type=int, required=True) +parser.add_argument("--world-size", "--np", type=int) parser.add_argument("--rank", type=int, default=0) parser.add_argument("--warmup-times", type=int) parser.add_argument("--module", type=str) @@ -27,6 +29,8 @@ group.add_argument( ) parser.set_defaults(error_on_unknown_kernel=False) args = parser.parse_args() +if (args.file_path is None) == (args.world_size is None): + parser.error("Please specify exactly one of --file-path and --world-size.") print(args) @@ -89,11 +93,20 @@ def shortest_common_supersequence(a, b): return res -profile_dir = Path(args.profile_dir) -nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep" -sqlite_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.sqlite" -csv_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.csv" -html_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.html" +if args.file_path is not None: + nsys_rep_file_path = Path(args.file_path) + if not nsys_rep_file_path.name.endswith(".nsys-rep"): + raise ValueError("Expect a .nsys-rep file") +else: + profile_dir = Path(args.profile_dir) + nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep" +sqlite_file_path = nsys_rep_file_path.parent / ( + nsys_rep_file_path.name[: -len(".nsys-rep")] + ".sqlite" +) +csv_file_path = nsys_rep_file_path.parent / (nsys_rep_file_path.name[: -len(".nsys-rep")] + ".csv") +html_file_path = nsys_rep_file_path.parent / ( + nsys_rep_file_path.name[: -len(".nsys-rep")] + ".html" +) lazy_convert_sqlite(nsys_rep_file_path, sqlite_file_path) conn = sqlite3.connect(f"file:{sqlite_file_path}?mode=ro", uri=True) @@ -196,7 +209,8 @@ if "CUDA_GRAPH_NODE_EVENTS" in tables: R.start AS runtime_start, R.end AS runtime_end, CGE2.start AS capture_start, CGE2.end AS capture_end FROM ({unified_subquery}) AS unified - JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.correlationId = R.correlationId + JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.graphNodeId IS NOT NULL AND + unified.correlationId = R.correlationId LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND CGE1.originalGraphNodeId IS NOT NULL LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId""" @@ -318,6 +332,7 @@ parser_keywords = [ ("routingInitExpertCounts", "routingInitExpertCounts"), ("routingIndicesCluster", "routingIndicesClusterKernel"), ("routingIndicesCoop", "routingIndicesCoopKernel"), + ("router_gemm", "router_gemm_kernel"), ("bmm_4_44_32", "bmm_E2m1_E2m1E2m1_Fp32_t"), ("finalize", "finalize::finalizeKernel"), ("bmm_16_44_32", "bmm_Bfloat16_E2m1E2m1_Fp32_"), @@ -381,7 +396,7 @@ def parse_kernel_name(demangledName): if all(keyword in name for keyword in src): return dst if name not in warned_names: - print(f"Unknown kernel name: {name}") + print(f"Unknown kernel name: {name}", file=sys.stderr) warned_names.add(name) if args.error_on_unknown_kernel: raise NotImplementedError(f"Unknown kernel name: {name}") @@ -491,11 +506,33 @@ with csv_file_path.open("w", newline="") as f: for row in csv_data: csv_writer.writerow(row) js_header_config = [{"name": problem["text"]} for problem in problem_set] +js_header_config = [] +for problem in problem_set: + innermost_children = js_header_config + for k, msg_prefix in [ + ("batch_size", "b="), + ("seq_len_q", "q="), + ("seq_len_kv_cache", "past="), + ]: + if len(run_args[k + "_list"]) > 1: + if len(innermost_children) == 0 or problem["spec"][k] != innermost_children[-1][k]: + innermost_children.append( + { + "name": msg_prefix + str(problem["spec"][k]), + "children": [], + k: problem["spec"][k], + } + ) + innermost_children = innermost_children[-1]["children"] + innermost_children.append({"name": problem["text"]}) loader = jinja2.FileSystemLoader(Path(__file__).parent) template = jinja2.Environment(loader=loader).get_template("template.html") with html_file_path.open("w") as f: configText = ( - "Run:\n" + json.dumps(run_args, indent=4) + "\n\nParse:\n" + json.dumps(args.__dict__) + "Run:\n" + + json.dumps(run_args, indent=4) + + "\n\nParse:\n" + + json.dumps(args.__dict__, indent=4) ) f.write(template.render(headerConfig=js_header_config, rawData=js_data, configText=configText)) diff --git a/examples/layer_wise_benchmarks/run.py b/examples/layer_wise_benchmarks/run.py index c1e3ab5133..1dc61fc9dc 100644 --- a/examples/layer_wise_benchmarks/run.py +++ b/examples/layer_wise_benchmarks/run.py @@ -2,6 +2,7 @@ import argparse import itertools import json import os +from unittest import mock import numpy as np import nvtx @@ -9,6 +10,8 @@ import torch import yaml from tensorrt_llm._torch.autotuner import AutoTuner, autotune +from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE +from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size from tensorrt_llm.logger import logger @@ -108,17 +111,6 @@ if args.enable_attention_dp is None: args.enable_attention_dp = False if args.max_num_tokens is None: args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list) - if args.run_type == "GEN": - ctx_batch_size = max(1, max(20480, args.max_num_tokens) // max(args.seq_len_kv_cache_list)) - args.max_num_tokens = max( - args.max_num_tokens, ctx_batch_size * max(args.seq_len_kv_cache_list) - ) -else: - if args.run_type == "GEN": - ctx_batch_size = max(1, args.max_num_tokens // max(args.seq_len_kv_cache_list)) - assert args.max_num_tokens >= ctx_batch_size * max(args.seq_len_kv_cache_list), ( - "Max_num_tokens is too small to prefill KV cache" - ) if args.use_low_precision_moe_combine is None: args.use_low_precision_moe_combine = False if args.enable_autotuner is None: @@ -169,17 +161,75 @@ runner = Runner( ) logger.info("Layer-wise benchmarks: Create runner ... Done") +# Autotune +run_pack = runner.create_run_pack( + args.run_type, + batch_size=max(args.batch_size_list), + request_id_begin=0, + seq_len_q=max(args.seq_len_q_list), + seq_len_kv_cache=args.seq_len_kv_cache_list[0], + kv_cache_manager=kv_cache_manager, + attn_workspace=attn_workspace, +) +if args.enable_autotuner: + cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None + with autotune(cache_path=cache_path): + run_pack() +else: + run_pack() + +# Prefill KV cache +if args.run_type == "GEN": + logger.info("Layer-wise benchmarks: Create runner for prefill") + ctx_seq_len_q = max(args.seq_len_kv_cache_list) + ctx_batch_size = min( + args.max_batch_size, + max(1, 20480 // ctx_seq_len_q), + ) + ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8) + with mock.patch.object( + CutlassFusedMoE, "select_alltoall_method_type", return_value=AlltoallMethodType.NotEnabled + ): + ctx_runner = Runner( + args.model, + mapping, + moe_backend="CUTLASS", + layer_indices=args.layer_indices, + scaled_from=args.scaled_from, + max_seq_len=args.max_seq_len, + max_num_tokens=ctx_batch_size * ctx_seq_len_q, + moe_max_num_tokens=16384, + use_low_precision_moe_combine=args.use_low_precision_moe_combine, + use_cuda_graph=False, + ) + logger.info("Layer-wise benchmarks: Create runner for prefill ... Done") + + logger.info("Layer-wise benchmarks: Prefill KV cache") + assert ctx_batch_size <= args.max_batch_size + assert ctx_seq_len_q + 0 <= args.max_seq_len + num_requests = max(args.batch_size_list) + for request_id_begin in range(0, num_requests, ctx_batch_size): + run_pack = ctx_runner.create_run_pack( + "CTX", + batch_size=min(ctx_batch_size, num_requests - request_id_begin), + request_id_begin=request_id_begin, + seq_len_q=ctx_seq_len_q, + seq_len_kv_cache=0, + kv_cache_manager=kv_cache_manager, + attn_workspace=ctx_attn_workspace, + ) + with ctx_runner.replace_routing_method_ctx( + balance_method=BalanceMethod.Balanced, balance_ratio=None + ): + run_pack(check=True) + del ctx_runner + del ctx_attn_workspace + logger.info("Layer-wise benchmarks: Prefill KV cache ... Done") + # Warm up -for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [ - [ - True, - max(args.batch_size_list), - max(args.seq_len_q_list), - args.seq_len_kv_cache_list[0], - args.balance_ratio_list[0], - ], +logger.info("Layer-wise benchmarks: Warmup") +for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [ *itertools.product( - [False], args.batch_size_list, args.seq_len_q_list, args.seq_len_kv_cache_list, @@ -203,34 +253,10 @@ for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [ ): capture_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(capture_stream): - if autotune_flag: - if args.enable_autotuner: - cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None - with autotune(cache_path=cache_path): - run_pack() - if args.run_type == "GEN": - logger.info("Layer-wise benchmarks: Prefill KV cache") - ctx_seq_len_q = max(args.seq_len_kv_cache_list) - assert ctx_batch_size <= args.max_batch_size - assert ctx_seq_len_q + 0 <= args.max_seq_len - assert ctx_batch_size * ctx_seq_len_q <= args.max_num_tokens - max_batch_size = max(args.batch_size_list) - for request_id_begin in range(0, max_batch_size, ctx_batch_size): - ctx_run_pack = runner.create_run_pack( - "CTX", - batch_size=min(ctx_batch_size, max_batch_size - request_id_begin), - request_id_begin=request_id_begin, - seq_len_q=ctx_seq_len_q, - seq_len_kv_cache=0, - kv_cache_manager=kv_cache_manager, - attn_workspace=attn_workspace, - ) - ctx_run_pack(check=True) - logger.info("Layer-wise benchmarks: Prefill KV cache ... Done") - else: - run_pack(check=True) + run_pack(check=True) torch.cuda.current_stream().wait_stream(capture_stream) torch.cuda.synchronize() +logger.info("Layer-wise benchmarks: Warmup ... Done") events = [ torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1) @@ -270,7 +296,7 @@ for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product( with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"): run_pack() - balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}" + balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}" nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}" for i in range(args.warmup_times + args.run_times): events[i].record() diff --git a/examples/layer_wise_benchmarks/run.sh b/examples/layer_wise_benchmarks/run.sh index a96265d8f3..4f20394d84 100755 --- a/examples/layer_wise_benchmarks/run.sh +++ b/examples/layer_wise_benchmarks/run.sh @@ -40,7 +40,9 @@ else PROFILE_CMD= fi +SCRIPT_PATH=$(realpath --relative-to="$(pwd)" "$(dirname -- "$0")"/run.py) + set -x -$PROFILE_CMD bash -c \ - "python3 -u run.py \"\$@\" 2>&1 | tee \"$PROFILE_DIR/report_np${WORLD_SIZE}_rank${RANK}.log\"" \ - bash "$@" +$PROFILE_CMD bash -o pipefail -c \ + "python3 -u \"\$1\" \"\${@:3}\" 2>&1 | tee \"\$2/report_np${WORLD_SIZE}_rank${RANK}.log\"" \ + bash "$SCRIPT_PATH" "$PROFILE_DIR" "$@" diff --git a/examples/layer_wise_benchmarks/slurm_init_containers.sh b/examples/layer_wise_benchmarks/slurm_init_containers.sh index 08e77c5623..7e97505bc1 100755 --- a/examples/layer_wise_benchmarks/slurm_init_containers.sh +++ b/examples/layer_wise_benchmarks/slurm_init_containers.sh @@ -4,23 +4,23 @@ set -euo pipefail # CONTAINER_IMAGE= CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks} -CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..") +TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..) +CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT -if [ "${SLURM_JOB_ID:-}" == "" ]; then +if [ -z "${SLURM_JOB_ID:-}" ]; then echo "Please set SLURM_JOB_ID" exit 1 fi NODES=$(squeue -j $SLURM_JOB_ID -h -o "%D") -if [ "${CONTAINER_IMAGE:-}" == "" ]; then +if [ -z "${CONTAINER_IMAGE:-}" ]; then # Read Docker image from current_image_tags.properties - source ../../jenkins/current_image_tags.properties MACHINE="$(uname -m)" if [ "$MACHINE" == "x86_64" ]; then - DOCKER_IMAGE=$LLM_DOCKER_IMAGE + DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_DOCKER_IMAGE) elif [ "$MACHINE" == "aarch64" ]; then - DOCKER_IMAGE=$LLM_SBSA_DOCKER_IMAGE + DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_SBSA_DOCKER_IMAGE) else echo "Unsupported machine hardware name \"$MACHINE\"" exit 1 @@ -34,7 +34,7 @@ if [ "${CONTAINER_IMAGE:-}" == "" ]; then SQSH_FILE_NAME=$(echo "$DOCKER_IMAGE" | awk -F'#' '{print $2}' | awk -F':' '{gsub(/\//,"+",$1); print $1"+"$2".sqsh"}') - CONTAINER_IMAGE="../../enroot/$SQSH_FILE_NAME" + CONTAINER_IMAGE="$TRTLLM_ROOT/enroot/$SQSH_FILE_NAME" if [ ! -f "$CONTAINER_IMAGE" ]; then echo "Container image file $CONTAINER_IMAGE does not exist, importing ..." srun -N 1 --pty enroot import -o "$CONTAINER_IMAGE" "docker://$DOCKER_IMAGE" @@ -50,6 +50,7 @@ srun -N "$NODES" \ --container-name "$CONTAINER_NAME" \ --container-mounts "$CONTAINER_MOUNTS" \ --container-workdir "$WORKDIR" \ -bash -c "pip install -U packaging && - pip install -r ../../requirements.txt --no-build-isolation && - pip install -e ../.." +bash -c "cd \"\$1\" && + pip install -U packaging && + pip install -r requirements.txt --no-build-isolation && + pip install -e ." bash "$TRTLLM_ROOT" diff --git a/examples/layer_wise_benchmarks/slurm_launch.sh b/examples/layer_wise_benchmarks/slurm_launch.sh index 36fe0a9a00..64857b2877 100755 --- a/examples/layer_wise_benchmarks/slurm_launch.sh +++ b/examples/layer_wise_benchmarks/slurm_launch.sh @@ -3,9 +3,10 @@ set -euo pipefail CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks} -CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..") +TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..) +CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT -if [ "${SLURM_JOB_ID:-}" == "" ]; then +if [ -z "${SLURM_JOB_ID:-}" ]; then echo "Please set SLURM_JOB_ID" exit 1 fi diff --git a/examples/layer_wise_benchmarks/slurm_query_container_name.sh b/examples/layer_wise_benchmarks/slurm_query_container_name.sh index a7f6b1ba81..6d3192da0b 100755 --- a/examples/layer_wise_benchmarks/slurm_query_container_name.sh +++ b/examples/layer_wise_benchmarks/slurm_query_container_name.sh @@ -2,6 +2,11 @@ set -euo pipefail +if [ -z "${SLURM_JOB_ID:-}" ]; then + echo "Please set SLURM_JOB_ID" + exit 1 +fi + prefix="pyxis_${SLURM_JOB_ID}_" matches=$(printf "%s\n" "$(srun -N 1 enroot list)" | grep "^${prefix}" || true) count=$(printf "%s\n" "$matches" | wc -l) diff --git a/examples/layer_wise_benchmarks/template.html b/examples/layer_wise_benchmarks/template.html index 6f190ed664..ac35ee2ce1 100644 --- a/examples/layer_wise_benchmarks/template.html +++ b/examples/layer_wise_benchmarks/template.html @@ -313,6 +313,7 @@ const columnLabels = []; let columnTotals = []; let totalNode = null; // Special node for the "Total" row + let numCols = null; const fmt = (num) => num.toFixed(1); @@ -332,7 +333,7 @@ const topLevelTimes = rawData.map(aggregate); // Calculate Totals - const numCols = topLevelTimes[0].length; + numCols = topLevelTimes[0].length; columnTotals = new Array(numCols).fill(0); for (let c = 0; c < numCols; c++) { columnTotals[c] = topLevelTimes.reduce((sum, row) => sum + row[c], 0); @@ -420,7 +421,8 @@ const total = columnTotals[colIdx] || 1; const pct = (val / total) * 100; - html += ` + html += `
${val ? fmt(val) : ""}
`; @@ -717,6 +719,7 @@ }); window.onload = () => { + document.getElementById('kernelTable').style["min-width"] = Math.max(1000, (220 + 60 * numCols)) + "px"; renderHeader(); document.getElementById('tableBody').innerHTML = renderTableBody(rawData, 0); renderTotalRow(); @@ -741,7 +744,6 @@ for (let node of rawData) { maxDepth = Math.max(maxDepth, getDepth(node)); } - console.log(maxDepth); const container = document.getElementById('level-buttons'); if (container) { diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/mark_utils.py b/tensorrt_llm/tools/layer_wise_benchmarks/mark_utils.py index 72625d1059..7ebde93e08 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/mark_utils.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/mark_utils.py @@ -1,6 +1,6 @@ import nvtx -from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate +from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate, Deepseekv3MoE from tensorrt_llm._torch.models.modeling_qwen3_next import ( Qwen3NextGatedDeltaNet, Qwen3NextSparseMoeBlock, @@ -12,6 +12,7 @@ from tensorrt_llm._torch.modules.gated_mlp import GatedMLP def mark_ranges(): DeepseekV3Gate.forward = nvtx.annotate("DeepseekV3Gate")(DeepseekV3Gate.forward) + Deepseekv3MoE.forward = nvtx.annotate("Deepseekv3MoE")(Deepseekv3MoE.forward) Qwen3NextGatedDeltaNet.forward = nvtx.annotate("Qwen3NextGatedDeltaNet")( Qwen3NextGatedDeltaNet.forward ) diff --git a/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py b/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py index 578221a436..93d6d84e11 100644 --- a/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py +++ b/tensorrt_llm/tools/layer_wise_benchmarks/runner_utils.py @@ -1,5 +1,6 @@ import contextlib import functools +import itertools import os import unittest.mock import weakref @@ -23,6 +24,7 @@ from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.bindings.internal.batch_manager import CacheType +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig @@ -38,24 +40,16 @@ def round_up(a, b): def get_balanced_selection_no_cache( - num_tokens, top_k, num_experts, dtype, device, world_size, rank + num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size ): - # First, each sender selects target rank - target_rank_before_mod = torch.arange(num_tokens * world_size * top_k).view( - num_tokens, world_size, top_k + token_id = torch.arange(dp_rank * num_tokens * top_k, (dp_rank + 1) * num_tokens * top_k).view( + num_tokens, top_k ) - target_rank_before_mod += top_k * torch.arange(num_tokens).view( - num_tokens, 1, 1 - ) # Shift `top_k` ranks for the next token on each rank, to balance network traffic - target_rank = target_rank_before_mod % world_size - # Second, each receiver selects target expert - target_expert = torch.empty_like(target_rank) - for reciever_rank in range(world_size): - mask = target_rank == reciever_rank - experts_per_rank = num_experts // world_size - local_expert = torch.arange(num_tokens * top_k) % experts_per_rank - target_expert[mask] = (reciever_rank * experts_per_rank) + local_expert - token_selected_experts = target_expert[:, rank].sort(dim=-1).values + experts_per_rank = num_experts // ep_size + token_selected_experts = (token_id % ep_size) * experts_per_rank + ( + token_id // ep_size + ) % experts_per_rank + token_selected_experts = token_selected_experts.sort(dim=-1).values return token_selected_experts.contiguous().to(dtype=dtype, device=device) @@ -64,47 +58,62 @@ get_balanced_selection = functools.cache(get_balanced_selection_no_cache) def test_get_balanced_selection(): dtype = torch.long - for num_tokens in range(1, 33): - for num_experts in range(1, 65): - print(f"{num_tokens=} {num_experts=}") - for top_k in range(1, min(11, num_experts)): - for world_size in range(1, 65): - if num_experts % world_size == 0: - tokens_per_expert = torch.zeros(num_experts) - for rank in range(world_size): - token_selected_experts = get_balanced_selection_no_cache( - num_tokens, top_k, num_experts, dtype, "cpu", world_size, rank + for num_tokens, num_experts, enable_attention_dp in itertools.product( + range(1, 35), range(1, 35), [False, True] + ): + print(f"{num_tokens=} {num_experts=} {enable_attention_dp=}") + for top_k in range(1, min(10, num_experts) + 1): + for world_size in range(1, 35): + dp_size = world_size if enable_attention_dp else 1 + ep_size = world_size + if num_experts % ep_size == 0: + tokens_per_expert = torch.zeros(num_experts) + for dp_rank in range(dp_size): + token_selected_experts = get_balanced_selection_no_cache( + num_tokens, top_k, num_experts, dtype, "cpu", dp_size, dp_rank, ep_size + ) + sorted_selection = token_selected_experts.sort(dim=-1).values + if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any(): + raise ValueError(f"duplicated experts on rank {dp_rank}") + experts_per_rank = num_experts // ep_size + tokens_per_rank = ( + (token_selected_experts // experts_per_rank) + .view(-1) + .bincount(minlength=ep_size) + ) + if tokens_per_rank.max() - tokens_per_rank.min() > 1: + raise ValueError(f"tokens sent from rank {dp_rank} is not balanced") + unique_tokens_per_rank = ( + ( + torch.arange(ep_size).view(ep_size, 1, 1) + == token_selected_experts // experts_per_rank ) - sorted_selection = token_selected_experts.sort(dim=-1).values - if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any(): - raise ValueError(f"duplicated experts on rank {rank}") - experts_per_rank = num_experts // world_size - tokens_per_rank = ( - (token_selected_experts // experts_per_rank) - .view(-1) - .bincount(minlength=world_size) + .any(dim=2) + .sum(dim=1) + ) + if unique_tokens_per_rank.max() - unique_tokens_per_rank.min() > 1: + raise ValueError( + f"tokens sent from rank {dp_rank} is not balanced after removing duplicates" ) - if tokens_per_rank.max() - tokens_per_rank.min() > 1: - raise ValueError(f"tokens sent from rank {rank} is not balanced") - tokens_per_expert += token_selected_experts.view(-1).bincount( - minlength=num_experts - ) - if tokens_per_expert.max() - tokens_per_expert.min() > 1: - raise ValueError("tokens per expert is not balanced") + tokens_per_expert += token_selected_experts.view(-1).bincount( + minlength=num_experts + ) + if tokens_per_expert.max() - tokens_per_expert.min() > 1: + raise ValueError("tokens per expert is not balanced") -def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank): +def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size): num_tokens, top_k = imbalanced_experts.shape dtype = imbalanced_experts.dtype device = imbalanced_experts.device balanced_experts = get_balanced_selection_no_cache( - num_tokens, top_k, num_experts, dtype, device, world_size, rank + num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size ) if balance_ratio == 0.0: num_balanced_tokens = 0 else: # Activate all experts - min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, world_size * top_k)) + min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, dp_size * top_k)) num_balanced_tokens = min_num_balanced_tokens + round( (num_tokens - min_num_balanced_tokens) * balance_ratio ) @@ -116,57 +125,78 @@ def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_si @functools.cache def get_all_to_one_selection( - num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank + num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size ): - experts_per_rank = num_experts // world_size + experts_per_rank = num_experts // ep_size if top_k > experts_per_rank: raise ValueError( "Cannot send all tokens to a single rank because `top_k > experts_per_rank`" ) imbalanced_experts = ( torch.arange( - rank * num_tokens * top_k, (rank + 1) * num_tokens * top_k, dtype=dtype, device=device + dp_rank * num_tokens * top_k, + (dp_rank + 1) * num_tokens * top_k, + dtype=dtype, + device=device, ).view(num_tokens, top_k) % experts_per_rank ) imbalanced_experts = imbalanced_experts.sort(dim=-1).values - return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank) + return apply_balance_ratio( + imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size + ) @functools.cache def get_balanced_rank_imbalanced_expert_selection( - num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank + num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size ): - experts_per_rank = num_experts // world_size - active_experts_per_rank = ceil_div(top_k, world_size) - # Select expert from [0, active_experts_per_rank * world_size), - # then scale to [0, experts_per_rank * world_size) + experts_per_rank = num_experts // ep_size + active_experts_per_rank = ceil_div(top_k, ep_size) + # Select expert from [0, active_experts_per_rank * ep_size), + # then scale to [0, experts_per_rank * ep_size) narrow_experts = get_balanced_selection_no_cache( - num_tokens, top_k, active_experts_per_rank * world_size, dtype, device, world_size, rank + num_tokens, + top_k, + active_experts_per_rank * ep_size, + dtype, + device, + dp_size, + dp_rank, + ep_size, ) imbalanced_experts = ( narrow_experts // active_experts_per_rank * experts_per_rank + narrow_experts % active_experts_per_rank ) - return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank) + return apply_balance_ratio( + imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size + ) def make_balanced_routing_method( - apply_method_orig, num_experts, balance_method, balance_ratio, world_size, rank + moe_module, + apply_method_orig, + num_experts, + balance_method, + balance_ratio, + dp_size, + dp_rank, + ep_size, ): def balanced_routing_method(router_logits): token_selected_experts, token_final_scales = apply_method_orig(router_logits) - if balance_method == BalanceMethod.NotModified: - pass - elif balance_method == BalanceMethod.Balanced: + assert moe_module._routing_results_replaced_at in [None, "make_balanced_routing_method"] + if balance_method == BalanceMethod.Balanced: token_selected_experts = get_balanced_selection( token_selected_experts.shape[0], token_selected_experts.shape[1], num_experts, token_selected_experts.dtype, token_selected_experts.device, - world_size, - rank, + dp_size, + dp_rank, + ep_size, ) elif balance_method == BalanceMethod.ImbalancedRanks: token_selected_experts = get_all_to_one_selection( @@ -176,8 +206,9 @@ def make_balanced_routing_method( balance_ratio, token_selected_experts.dtype, token_selected_experts.device, - world_size, - rank, + dp_size, + dp_rank, + ep_size, ) elif balance_method == BalanceMethod.ImbalancedExperts: token_selected_experts = get_balanced_rank_imbalanced_expert_selection( @@ -187,16 +218,127 @@ def make_balanced_routing_method( balance_ratio, token_selected_experts.dtype, token_selected_experts.device, - world_size, - rank, + dp_size, + dp_rank, + ep_size, ) else: raise NotImplementedError(f"Not support balance_method {balance_method}") + moe_module._routing_results_replaced_at = "make_balanced_routing_method" return token_selected_experts, token_final_scales return balanced_routing_method +@functools.cache +def get_token_final_scales(shape, device): + return torch.full(shape, 1.0 / shape[-1], dtype=torch.bfloat16, device=device) + + +def make_balanced_run_moe( + moe_module, + run_moe_orig, + top_k, + num_experts, + balance_method, + balance_ratio, + dp_size, + dp_rank, + ep_size, +): + def balanced_run_moe( + x, token_selected_experts, token_final_scales, x_sf, router_logits, do_finalize, moe_output + ): + if moe_module._routing_results_replaced_at is not None: + return run_moe_orig( + x, + token_selected_experts, + token_final_scales, + x_sf, + router_logits, + do_finalize, + moe_output, + ) + logger.warning_once( + 'Layer-wise benchmarks: Specifying routing results of "TRTLLM" MoE backend in TEP cases leads to different' + " execution path around the topk kernel", + key="replace_routing_method_ctx_trtllm_tp", + ) + if balance_method == BalanceMethod.Balanced: + token_selected_experts = get_balanced_selection( + x.shape[0], + top_k, + num_experts, + torch.int32, + x.device, + dp_size, + dp_rank, + ep_size, + ) + elif balance_method == BalanceMethod.ImbalancedRanks: + token_selected_experts = get_all_to_one_selection( + x.shape[0], + top_k, + num_experts, + balance_ratio, + torch.int32, + x.device, + dp_size, + dp_rank, + ep_size, + ) + elif balance_method == BalanceMethod.ImbalancedExperts: + token_selected_experts = get_balanced_rank_imbalanced_expert_selection( + x.shape[0], + top_k, + num_experts, + balance_ratio, + torch.int32, + x.device, + dp_size, + dp_rank, + ep_size, + ) + else: + raise NotImplementedError(f"Not support balance_method {balance_method}") + token_final_scales = get_token_final_scales( + token_selected_experts.shape, token_selected_experts.device + ) + router_logits = None + final_hidden_states = run_moe_orig( + x, + token_selected_experts, + token_final_scales, + x_sf, + router_logits, + do_finalize, + moe_output, + ) + if not do_finalize: + final_hidden_states = ( + final_hidden_states[0], + token_final_scales, # WAR for TRTLLMGenFusedMoE bug that it returns wrong `token_final_scales` + final_hidden_states[2], + ) + moe_module._routing_results_replaced_at = "make_balanced_run_moe" + return final_hidden_states + + return balanced_run_moe + + +def make_forward_impl_check(moe_module, forward_impl_orig): + def forward_impl(*args, **kwargs): + moe_module._routing_results_replaced_at = None + res = forward_impl_orig(*args, **kwargs) + assert moe_module._routing_results_replaced_at is not None, ( + "Routing results are not replaced" + ) + del moe_module._routing_results_replaced_at + return res + + return forward_impl + + class RunnerMixin(ABC): @staticmethod @abstractmethod @@ -398,8 +540,9 @@ class RunnerMixin(ABC): @contextlib.contextmanager def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float): if balance_method == BalanceMethod.NotModified: - pass - elif self.model_config.moe_backend not in [ + yield + return + if self.model_config.moe_backend not in [ "CUTEDSL", "CUTLASS", "DEEPGEMM", @@ -410,29 +553,59 @@ class RunnerMixin(ABC): f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",' f' please set balance_method to "NotModified"' ) - elif ( - self.model_config.moe_backend == "TRTLLM" - and not self.model_config.mapping.enable_attention_dp - ): - raise NotImplementedError( - 'Not support replace routing method for moe_backend "TRTLLM" with attention TP,' - ' please set balance_method to "NotModified"' + original_methods = [] + dp_rank = self.model_config.mapping.rank // ( + self.model_config.mapping.world_size // self.model_config.mapping.dp_size + ) + for layer in self.layers: + moe_module = layer.mlp.experts + + # Replace `routing_method.apply` for normal cases + apply_method_orig = moe_module.routing_method.apply + moe_module.routing_method.apply = make_balanced_routing_method( + moe_module, + apply_method_orig, + moe_module.num_experts, + balance_method, + balance_ratio, + self.model_config.mapping.dp_size, + dp_rank, + self.model_config.mapping.moe_ep_size, ) - apply_methods_orig = [layer.mlp.experts.routing_method.apply for layer in self.layers] - try: - for layer, apply_method_orig in zip(self.layers, apply_methods_orig): - layer.mlp.experts.routing_method.apply = make_balanced_routing_method( - apply_method_orig, + + # Replace `run_moe` for TRTLLMGenFusedMoE TEP because it does not call `routing_method.apply` + if isinstance(moe_module, TRTLLMGenFusedMoE): + run_moe_orig = moe_module.run_moe + moe_module.run_moe = make_balanced_run_moe( + moe_module, + run_moe_orig, + layer.mlp.experts.routing_method.top_k, layer.mlp.experts.num_experts, balance_method, balance_ratio, - layer.mlp.experts.ep_size, - layer.mlp.experts.ep_rank, + self.model_config.mapping.dp_size, + dp_rank, + self.model_config.mapping.moe_ep_size, ) + else: + run_moe_orig = None + + # Replace `forward_impl` to ensure that routing results are replaced + forward_impl_orig = moe_module.forward_impl + moe_module.forward_impl = make_forward_impl_check(moe_module, forward_impl_orig) + + original_methods.append((apply_method_orig, run_moe_orig, forward_impl_orig)) + try: yield finally: - for layer, apply_method_orig in zip(self.layers, apply_methods_orig): - layer.mlp.experts.routing_method.apply = apply_method_orig + for layer, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip( + self.layers, original_methods + ): + moe_module = layer.mlp.experts + moe_module.routing_method.apply = apply_method_orig + if isinstance(moe_module, TRTLLMGenFusedMoE): + moe_module.run_moe = run_moe_orig + moe_module.forward_impl = forward_impl_orig @staticmethod def create_kv_cache_manager( @@ -467,7 +640,7 @@ class RunnerMixin(ABC): kv_cache_manager = kv_cache_manager_cls( kv_cache_config, CacheType.SELFKONLY, - num_layers=sum(layer_mask), + num_layers=num_layers, num_kv_heads=1, head_dim=model_config.pretrained_config.kv_lora_rank + model_config.pretrained_config.qk_rope_head_dim, diff --git a/tests/unittest/tools/test_layer_wise_benchmarks.py b/tests/unittest/tools/test_layer_wise_benchmarks.py index a78e11484b..b5f7e58a90 100644 --- a/tests/unittest/tools/test_layer_wise_benchmarks.py +++ b/tests/unittest/tools/test_layer_wise_benchmarks.py @@ -48,7 +48,6 @@ def test_deepseek_r1_ctx_tep(llm_root, world_size): model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2", "--no-enable-attention-dp", "--moe-backend=TRTLLM", - "--balance-method=NotModified", ], cwd=llm_root / "examples" / "layer_wise_benchmarks", env={ @@ -139,7 +138,6 @@ def test_qwen3_next_gen_tep(llm_root, world_size): "--layer-indices=6,7", "--no-enable-attention-dp", "--moe-backend=TRTLLM", - "--balance-method=NotModified", ], cwd=llm_root / "examples" / "layer_wise_benchmarks", env={