[None][feat] Layer-wise benchmarks: support TEP balance, polish slurm scripts (#10237)

Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
This commit is contained in:
Tailing Yuan 2026-01-05 11:23:04 +08:00 committed by GitHub
parent aaf80be0f3
commit a7fe043b13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 473 additions and 181 deletions

1
.gitignore vendored
View File

@ -77,6 +77,7 @@ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmha_v2_cu/
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
.devcontainer/.env .devcontainer/.env
/examples/layer_wise_benchmarks/autotuner_cache/
/examples/layer_wise_benchmarks/profiles/ /examples/layer_wise_benchmarks/profiles/
# User config files # User config files

View File

@ -15,6 +15,9 @@ pip install -e ../..
**Step 3:** In the container, run benchmarks and generate profiles: **Step 3:** In the container, run benchmarks and generate profiles:
```bash ```bash
# Set autotune cache path
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
# Run DeepSeek-R1 NVFP4 # Run DeepSeek-R1 NVFP4
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml
@ -24,16 +27,16 @@ NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
# Run DeepSeek-V3.2-Exp with 32k context length # Run DeepSeek-V3.2-Exp with 32k context length
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769 NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769 NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --seq-len-kv-cache 32769
# Run with attention TP # Run with attention TP
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp
# Run with attention TP and TRTLLMGen # Run with attention TP and TRTLLMGen
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM
# Run with MTP3 # Run with MTP3
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --batch-size 32 --seq-len-q 4 NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --batch-size 32 --seq-len-q 4
@ -48,7 +51,7 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID
# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts # Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp
# Run Qwen3-Next (balanced routing is not implemented) # Run Qwen3-Next
NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 4 NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 4
NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 512 NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 512
@ -69,13 +72,13 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --balance-method ImbalancedExperts
### Run with Slurm ### Run with Slurm
> Tips: > Tips:
> 1. If you have a running Slurm job, please skip step 1 and go straight to step 2 and 3. > 1. If you have a running Slurm job, you can set environment variable `export SLURM_JOB_ID=aaa` and skip step 1.
> 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2 and run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_query_container_name.sh)` to get it. > 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2. Just run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(./slurm_query_container_name.sh)` to get it.
**Step 1:** On the controller node, allocate one or multiple nodes, and record the `SLURM_JOB_ID`: **Step 1:** On the controller node, allocate one or multiple nodes, and export the `SLURM_JOB_ID`:
```bash ```bash
SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh) export SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh)
``` ```
Please fill the variables in `./slurm_alloc.sh`. Please fill the variables in `./slurm_alloc.sh`.
@ -83,28 +86,47 @@ Please fill the variables in `./slurm_alloc.sh`.
**Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node: **Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node:
```bash ```bash
SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_init_containers.sh ./slurm_init_containers.sh
``` ```
It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once. It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once.
> Tips: If you want to change the image, no need to reallocate Slurm jobs. Just start another container by running step 2 with `export CONTAINER_NAME=aaa`, and step 3 will run in the container specified by the `CONTAINER_NAME` env. > Tips: If you want to change the image, no need to reallocate Slurm jobs. Just start another container by running step 2 with `export CONTAINER_NAME=aaa`, and step 3 will run in the container specified by the `CONTAINER_NAME` env.
**(Optional) Get an interactive shell**
```bash
NODES=1 NP=1 ./slurm_launch.sh --overlap --pty middleware/exclude_slurm_envs bash
```
The `--overlap` option allows this shell to share the node with other jobs. The middleware enables nested MPI process spawning from within Slurm jobs.
You may compile C++ extensions in the interactive shell:
```bash
cd ../..
export CCACHE_DIR=$(realpath cpp/.ccache)
python3 scripts/build_wheel.py --cuda_architectures native --no-venv --skip_building_wheel -G Ninja --use_ccache --clean
```
**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` &le; the number of allocated nodes: **Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` &le; the number of allocated nodes:
```bash ```bash
# Set autotune cache path
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
# Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable # Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
# Run with TRTLLMGen # Run with TRTLLMGen
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM
# Run with DeepEPLowLatency # Run with DeepEPLowLatency
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
# You can run 4-GPU and 8-GPU tasks without reallocate the slurm job # You can run 4-GPU and 8-GPU tasks without reallocating the slurm job
SLURM_JOB_ID=$SLURM_JOB_ID NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml
SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gtx.yaml NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gen.yaml
``` ```
### Batched run ### Batched run
@ -131,9 +153,11 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID
Run the following command in the container: Run the following command in the container:
```bash ```bash
# Parse the profile at the default directory
python3 parse.py --world-size 4 python3 parse.py --world-size 4
# Specify the location of the .nsys-rep file # Specify the file path
python3 parse.py --file-path profiles/report_np4_rank0.nsys-rep
python3 parse.py --profile-dir ./profiles --world-size 4 --rank 0 python3 parse.py --profile-dir ./profiles --world-size 4 --rank 0
# Parse a specific module. The module must appear exactly once in each run. # Parse a specific module. The module must appear exactly once in each run.
@ -145,6 +169,15 @@ You will receive three reports, each containing kernel timing statistics grouped
2. A CSV report at `profiles/report_np4_rank0.csv` 2. A CSV report at `profiles/report_np4_rank0.csv`
3. An HTML report at `profiles/report_np4_rank0.html` 3. An HTML report at `profiles/report_np4_rank0.html`
## Developer utilities
1. Less startup time when debug a model
1. Disable autotuner: add `--no-enable-autotuner` option
2. Disable nsys profile: set `PROFILE=0` environment variable
2. Capture more information
1. Enable GPU metrics: set `GPU_METRICS=1` environment variable
2. Enable backtrace: set `BACKTRACE=1` environment variable
## Trouble shooting ## Trouble shooting
1. Error `fp8 blockscale gemm only support Hopper` on Blackwell. 1. Error `fp8 blockscale gemm only support Hopper` on Blackwell.

View File

@ -0,0 +1,8 @@
#!/bin/bash
set -euo pipefail
# Clear slurm envs
unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true))
"$@"

View File

@ -3,8 +3,12 @@
set -euo pipefail set -euo pipefail
# Clear slurm envs # Clear slurm envs
unset $(env | grep -i slurm | awk -F'=' '{print $1}') unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true))
unset $(env | grep MPI | awk -F'=' '{print $1}')
extra_args=
if [ -v TLLM_AUTOTUNER_CACHE_PATH ]; then
extra_args+="-x TLLM_AUTOTUNER_CACHE_PATH"
fi
set -x set -x
mpirun --allow-run-as-root --np ${NP} "$@" mpirun --allow-run-as-root --np ${NP} $extra_args "$@"

View File

@ -5,6 +5,7 @@ import json
import re import re
import sqlite3 import sqlite3
import subprocess import subprocess
import sys
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -14,8 +15,9 @@ import pandas as pd
# Parse cmdline # Parse cmdline
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--file-path", type=str)
parser.add_argument("--profile-dir", type=str, default="profiles") parser.add_argument("--profile-dir", type=str, default="profiles")
parser.add_argument("--world-size", "--np", type=int, required=True) parser.add_argument("--world-size", "--np", type=int)
parser.add_argument("--rank", type=int, default=0) parser.add_argument("--rank", type=int, default=0)
parser.add_argument("--warmup-times", type=int) parser.add_argument("--warmup-times", type=int)
parser.add_argument("--module", type=str) parser.add_argument("--module", type=str)
@ -27,6 +29,8 @@ group.add_argument(
) )
parser.set_defaults(error_on_unknown_kernel=False) parser.set_defaults(error_on_unknown_kernel=False)
args = parser.parse_args() args = parser.parse_args()
if (args.file_path is None) == (args.world_size is None):
parser.error("Please specify exactly one of --file-path and --world-size.")
print(args) print(args)
@ -89,11 +93,20 @@ def shortest_common_supersequence(a, b):
return res return res
profile_dir = Path(args.profile_dir) if args.file_path is not None:
nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep" nsys_rep_file_path = Path(args.file_path)
sqlite_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.sqlite" if not nsys_rep_file_path.name.endswith(".nsys-rep"):
csv_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.csv" raise ValueError("Expect a .nsys-rep file")
html_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.html" else:
profile_dir = Path(args.profile_dir)
nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep"
sqlite_file_path = nsys_rep_file_path.parent / (
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".sqlite"
)
csv_file_path = nsys_rep_file_path.parent / (nsys_rep_file_path.name[: -len(".nsys-rep")] + ".csv")
html_file_path = nsys_rep_file_path.parent / (
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".html"
)
lazy_convert_sqlite(nsys_rep_file_path, sqlite_file_path) lazy_convert_sqlite(nsys_rep_file_path, sqlite_file_path)
conn = sqlite3.connect(f"file:{sqlite_file_path}?mode=ro", uri=True) conn = sqlite3.connect(f"file:{sqlite_file_path}?mode=ro", uri=True)
@ -196,7 +209,8 @@ if "CUDA_GRAPH_NODE_EVENTS" in tables:
R.start AS runtime_start, R.end AS runtime_end, R.start AS runtime_start, R.end AS runtime_end,
CGE2.start AS capture_start, CGE2.end AS capture_end CGE2.start AS capture_start, CGE2.end AS capture_end
FROM ({unified_subquery}) AS unified FROM ({unified_subquery}) AS unified
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.correlationId = R.correlationId JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.graphNodeId IS NOT NULL AND
unified.correlationId = R.correlationId
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND
CGE1.originalGraphNodeId IS NOT NULL CGE1.originalGraphNodeId IS NOT NULL
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId""" LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId"""
@ -318,6 +332,7 @@ parser_keywords = [
("routingInitExpertCounts", "routingInitExpertCounts"), ("routingInitExpertCounts", "routingInitExpertCounts"),
("routingIndicesCluster", "routingIndicesClusterKernel"), ("routingIndicesCluster", "routingIndicesClusterKernel"),
("routingIndicesCoop", "routingIndicesCoopKernel"), ("routingIndicesCoop", "routingIndicesCoopKernel"),
("router_gemm", "router_gemm_kernel"),
("bmm_4_44_32", "bmm_E2m1_E2m1E2m1_Fp32_t"), ("bmm_4_44_32", "bmm_E2m1_E2m1E2m1_Fp32_t"),
("finalize", "finalize::finalizeKernel"), ("finalize", "finalize::finalizeKernel"),
("bmm_16_44_32", "bmm_Bfloat16_E2m1E2m1_Fp32_"), ("bmm_16_44_32", "bmm_Bfloat16_E2m1E2m1_Fp32_"),
@ -381,7 +396,7 @@ def parse_kernel_name(demangledName):
if all(keyword in name for keyword in src): if all(keyword in name for keyword in src):
return dst return dst
if name not in warned_names: if name not in warned_names:
print(f"Unknown kernel name: {name}") print(f"Unknown kernel name: {name}", file=sys.stderr)
warned_names.add(name) warned_names.add(name)
if args.error_on_unknown_kernel: if args.error_on_unknown_kernel:
raise NotImplementedError(f"Unknown kernel name: {name}") raise NotImplementedError(f"Unknown kernel name: {name}")
@ -491,11 +506,33 @@ with csv_file_path.open("w", newline="") as f:
for row in csv_data: for row in csv_data:
csv_writer.writerow(row) csv_writer.writerow(row)
js_header_config = [{"name": problem["text"]} for problem in problem_set] js_header_config = [{"name": problem["text"]} for problem in problem_set]
js_header_config = []
for problem in problem_set:
innermost_children = js_header_config
for k, msg_prefix in [
("batch_size", "b="),
("seq_len_q", "q="),
("seq_len_kv_cache", "past="),
]:
if len(run_args[k + "_list"]) > 1:
if len(innermost_children) == 0 or problem["spec"][k] != innermost_children[-1][k]:
innermost_children.append(
{
"name": msg_prefix + str(problem["spec"][k]),
"children": [],
k: problem["spec"][k],
}
)
innermost_children = innermost_children[-1]["children"]
innermost_children.append({"name": problem["text"]})
loader = jinja2.FileSystemLoader(Path(__file__).parent) loader = jinja2.FileSystemLoader(Path(__file__).parent)
template = jinja2.Environment(loader=loader).get_template("template.html") template = jinja2.Environment(loader=loader).get_template("template.html")
with html_file_path.open("w") as f: with html_file_path.open("w") as f:
configText = ( configText = (
"Run:\n" + json.dumps(run_args, indent=4) + "\n\nParse:\n" + json.dumps(args.__dict__) "Run:\n"
+ json.dumps(run_args, indent=4)
+ "\n\nParse:\n"
+ json.dumps(args.__dict__, indent=4)
) )
f.write(template.render(headerConfig=js_header_config, rawData=js_data, configText=configText)) f.write(template.render(headerConfig=js_header_config, rawData=js_data, configText=configText))

View File

@ -2,6 +2,7 @@ import argparse
import itertools import itertools
import json import json
import os import os
from unittest import mock
import numpy as np import numpy as np
import nvtx import nvtx
@ -9,6 +10,8 @@ import torch
import yaml import yaml
from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
@ -108,17 +111,6 @@ if args.enable_attention_dp is None:
args.enable_attention_dp = False args.enable_attention_dp = False
if args.max_num_tokens is None: if args.max_num_tokens is None:
args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list) args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list)
if args.run_type == "GEN":
ctx_batch_size = max(1, max(20480, args.max_num_tokens) // max(args.seq_len_kv_cache_list))
args.max_num_tokens = max(
args.max_num_tokens, ctx_batch_size * max(args.seq_len_kv_cache_list)
)
else:
if args.run_type == "GEN":
ctx_batch_size = max(1, args.max_num_tokens // max(args.seq_len_kv_cache_list))
assert args.max_num_tokens >= ctx_batch_size * max(args.seq_len_kv_cache_list), (
"Max_num_tokens is too small to prefill KV cache"
)
if args.use_low_precision_moe_combine is None: if args.use_low_precision_moe_combine is None:
args.use_low_precision_moe_combine = False args.use_low_precision_moe_combine = False
if args.enable_autotuner is None: if args.enable_autotuner is None:
@ -169,17 +161,75 @@ runner = Runner(
) )
logger.info("Layer-wise benchmarks: Create runner ... Done") logger.info("Layer-wise benchmarks: Create runner ... Done")
# Autotune
run_pack = runner.create_run_pack(
args.run_type,
batch_size=max(args.batch_size_list),
request_id_begin=0,
seq_len_q=max(args.seq_len_q_list),
seq_len_kv_cache=args.seq_len_kv_cache_list[0],
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
if args.enable_autotuner:
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
with autotune(cache_path=cache_path):
run_pack()
else:
run_pack()
# Prefill KV cache
if args.run_type == "GEN":
logger.info("Layer-wise benchmarks: Create runner for prefill")
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
ctx_batch_size = min(
args.max_batch_size,
max(1, 20480 // ctx_seq_len_q),
)
ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
with mock.patch.object(
CutlassFusedMoE, "select_alltoall_method_type", return_value=AlltoallMethodType.NotEnabled
):
ctx_runner = Runner(
args.model,
mapping,
moe_backend="CUTLASS",
layer_indices=args.layer_indices,
scaled_from=args.scaled_from,
max_seq_len=args.max_seq_len,
max_num_tokens=ctx_batch_size * ctx_seq_len_q,
moe_max_num_tokens=16384,
use_low_precision_moe_combine=args.use_low_precision_moe_combine,
use_cuda_graph=False,
)
logger.info("Layer-wise benchmarks: Create runner for prefill ... Done")
logger.info("Layer-wise benchmarks: Prefill KV cache")
assert ctx_batch_size <= args.max_batch_size
assert ctx_seq_len_q + 0 <= args.max_seq_len
num_requests = max(args.batch_size_list)
for request_id_begin in range(0, num_requests, ctx_batch_size):
run_pack = ctx_runner.create_run_pack(
"CTX",
batch_size=min(ctx_batch_size, num_requests - request_id_begin),
request_id_begin=request_id_begin,
seq_len_q=ctx_seq_len_q,
seq_len_kv_cache=0,
kv_cache_manager=kv_cache_manager,
attn_workspace=ctx_attn_workspace,
)
with ctx_runner.replace_routing_method_ctx(
balance_method=BalanceMethod.Balanced, balance_ratio=None
):
run_pack(check=True)
del ctx_runner
del ctx_attn_workspace
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
# Warm up # Warm up
for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [ logger.info("Layer-wise benchmarks: Warmup")
[ for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
True,
max(args.batch_size_list),
max(args.seq_len_q_list),
args.seq_len_kv_cache_list[0],
args.balance_ratio_list[0],
],
*itertools.product( *itertools.product(
[False],
args.batch_size_list, args.batch_size_list,
args.seq_len_q_list, args.seq_len_q_list,
args.seq_len_kv_cache_list, args.seq_len_kv_cache_list,
@ -203,34 +253,10 @@ for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
): ):
capture_stream.wait_stream(torch.cuda.current_stream()) capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream): with torch.cuda.stream(capture_stream):
if autotune_flag: run_pack(check=True)
if args.enable_autotuner:
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
with autotune(cache_path=cache_path):
run_pack()
if args.run_type == "GEN":
logger.info("Layer-wise benchmarks: Prefill KV cache")
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
assert ctx_batch_size <= args.max_batch_size
assert ctx_seq_len_q + 0 <= args.max_seq_len
assert ctx_batch_size * ctx_seq_len_q <= args.max_num_tokens
max_batch_size = max(args.batch_size_list)
for request_id_begin in range(0, max_batch_size, ctx_batch_size):
ctx_run_pack = runner.create_run_pack(
"CTX",
batch_size=min(ctx_batch_size, max_batch_size - request_id_begin),
request_id_begin=request_id_begin,
seq_len_q=ctx_seq_len_q,
seq_len_kv_cache=0,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
ctx_run_pack(check=True)
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
else:
run_pack(check=True)
torch.cuda.current_stream().wait_stream(capture_stream) torch.cuda.current_stream().wait_stream(capture_stream)
torch.cuda.synchronize() torch.cuda.synchronize()
logger.info("Layer-wise benchmarks: Warmup ... Done")
events = [ events = [
torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1) torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1)
@ -270,7 +296,7 @@ for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product(
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"): with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
run_pack() run_pack()
balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}" balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}"
nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}" nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}"
for i in range(args.warmup_times + args.run_times): for i in range(args.warmup_times + args.run_times):
events[i].record() events[i].record()

View File

@ -40,7 +40,9 @@ else
PROFILE_CMD= PROFILE_CMD=
fi fi
SCRIPT_PATH=$(realpath --relative-to="$(pwd)" "$(dirname -- "$0")"/run.py)
set -x set -x
$PROFILE_CMD bash -c \ $PROFILE_CMD bash -o pipefail -c \
"python3 -u run.py \"\$@\" 2>&1 | tee \"$PROFILE_DIR/report_np${WORLD_SIZE}_rank${RANK}.log\"" \ "python3 -u \"\$1\" \"\${@:3}\" 2>&1 | tee \"\$2/report_np${WORLD_SIZE}_rank${RANK}.log\"" \
bash "$@" bash "$SCRIPT_PATH" "$PROFILE_DIR" "$@"

View File

@ -4,23 +4,23 @@ set -euo pipefail
# CONTAINER_IMAGE= # CONTAINER_IMAGE=
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks} CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..") TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..)
CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT
if [ "${SLURM_JOB_ID:-}" == "" ]; then if [ -z "${SLURM_JOB_ID:-}" ]; then
echo "Please set SLURM_JOB_ID" echo "Please set SLURM_JOB_ID"
exit 1 exit 1
fi fi
NODES=$(squeue -j $SLURM_JOB_ID -h -o "%D") NODES=$(squeue -j $SLURM_JOB_ID -h -o "%D")
if [ "${CONTAINER_IMAGE:-}" == "" ]; then if [ -z "${CONTAINER_IMAGE:-}" ]; then
# Read Docker image from current_image_tags.properties # Read Docker image from current_image_tags.properties
source ../../jenkins/current_image_tags.properties
MACHINE="$(uname -m)" MACHINE="$(uname -m)"
if [ "$MACHINE" == "x86_64" ]; then if [ "$MACHINE" == "x86_64" ]; then
DOCKER_IMAGE=$LLM_DOCKER_IMAGE DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_DOCKER_IMAGE)
elif [ "$MACHINE" == "aarch64" ]; then elif [ "$MACHINE" == "aarch64" ]; then
DOCKER_IMAGE=$LLM_SBSA_DOCKER_IMAGE DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_SBSA_DOCKER_IMAGE)
else else
echo "Unsupported machine hardware name \"$MACHINE\"" echo "Unsupported machine hardware name \"$MACHINE\""
exit 1 exit 1
@ -34,7 +34,7 @@ if [ "${CONTAINER_IMAGE:-}" == "" ]; then
SQSH_FILE_NAME=$(echo "$DOCKER_IMAGE" | SQSH_FILE_NAME=$(echo "$DOCKER_IMAGE" |
awk -F'#' '{print $2}' | awk -F'#' '{print $2}' |
awk -F':' '{gsub(/\//,"+",$1); print $1"+"$2".sqsh"}') awk -F':' '{gsub(/\//,"+",$1); print $1"+"$2".sqsh"}')
CONTAINER_IMAGE="../../enroot/$SQSH_FILE_NAME" CONTAINER_IMAGE="$TRTLLM_ROOT/enroot/$SQSH_FILE_NAME"
if [ ! -f "$CONTAINER_IMAGE" ]; then if [ ! -f "$CONTAINER_IMAGE" ]; then
echo "Container image file $CONTAINER_IMAGE does not exist, importing ..." echo "Container image file $CONTAINER_IMAGE does not exist, importing ..."
srun -N 1 --pty enroot import -o "$CONTAINER_IMAGE" "docker://$DOCKER_IMAGE" srun -N 1 --pty enroot import -o "$CONTAINER_IMAGE" "docker://$DOCKER_IMAGE"
@ -50,6 +50,7 @@ srun -N "$NODES" \
--container-name "$CONTAINER_NAME" \ --container-name "$CONTAINER_NAME" \
--container-mounts "$CONTAINER_MOUNTS" \ --container-mounts "$CONTAINER_MOUNTS" \
--container-workdir "$WORKDIR" \ --container-workdir "$WORKDIR" \
bash -c "pip install -U packaging && bash -c "cd \"\$1\" &&
pip install -r ../../requirements.txt --no-build-isolation && pip install -U packaging &&
pip install -e ../.." pip install -r requirements.txt --no-build-isolation &&
pip install -e ." bash "$TRTLLM_ROOT"

View File

@ -3,9 +3,10 @@
set -euo pipefail set -euo pipefail
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks} CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..") TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..)
CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT
if [ "${SLURM_JOB_ID:-}" == "" ]; then if [ -z "${SLURM_JOB_ID:-}" ]; then
echo "Please set SLURM_JOB_ID" echo "Please set SLURM_JOB_ID"
exit 1 exit 1
fi fi

View File

@ -2,6 +2,11 @@
set -euo pipefail set -euo pipefail
if [ -z "${SLURM_JOB_ID:-}" ]; then
echo "Please set SLURM_JOB_ID"
exit 1
fi
prefix="pyxis_${SLURM_JOB_ID}_" prefix="pyxis_${SLURM_JOB_ID}_"
matches=$(printf "%s\n" "$(srun -N 1 enroot list)" | grep "^${prefix}" || true) matches=$(printf "%s\n" "$(srun -N 1 enroot list)" | grep "^${prefix}" || true)
count=$(printf "%s\n" "$matches" | wc -l) count=$(printf "%s\n" "$matches" | wc -l)

View File

@ -313,6 +313,7 @@
const columnLabels = []; const columnLabels = [];
let columnTotals = []; let columnTotals = [];
let totalNode = null; // Special node for the "Total" row let totalNode = null; // Special node for the "Total" row
let numCols = null;
const fmt = (num) => num.toFixed(1); const fmt = (num) => num.toFixed(1);
@ -332,7 +333,7 @@
const topLevelTimes = rawData.map(aggregate); const topLevelTimes = rawData.map(aggregate);
// Calculate Totals // Calculate Totals
const numCols = topLevelTimes[0].length; numCols = topLevelTimes[0].length;
columnTotals = new Array(numCols).fill(0); columnTotals = new Array(numCols).fill(0);
for (let c = 0; c < numCols; c++) { for (let c = 0; c < numCols; c++) {
columnTotals[c] = topLevelTimes.reduce((sum, row) => sum + row[c], 0); columnTotals[c] = topLevelTimes.reduce((sum, row) => sum + row[c], 0);
@ -420,7 +421,8 @@
const total = columnTotals[colIdx] || 1; const total = columnTotals[colIdx] || 1;
const pct = (val / total) * 100; const pct = (val / total) * 100;
html += `<td class="data-cell col-${colIdx}" onclick="handleCellClick('${uniqueId}', ${colIdx})"> html += `<td class="data-cell col-${colIdx}" ${val ? "title=\"" + fmt(val) + "\"" : ""}
onclick="handleCellClick('${uniqueId}', ${colIdx})">
<div class="progress-bar" style="width: ${pct}%;"></div> <div class="progress-bar" style="width: ${pct}%;"></div>
<div class="cell-content">${val ? fmt(val) : ""}</div> <div class="cell-content">${val ? fmt(val) : ""}</div>
</td>`; </td>`;
@ -717,6 +719,7 @@
}); });
window.onload = () => { window.onload = () => {
document.getElementById('kernelTable').style["min-width"] = Math.max(1000, (220 + 60 * numCols)) + "px";
renderHeader(); renderHeader();
document.getElementById('tableBody').innerHTML = renderTableBody(rawData, 0); document.getElementById('tableBody').innerHTML = renderTableBody(rawData, 0);
renderTotalRow(); renderTotalRow();
@ -741,7 +744,6 @@
for (let node of rawData) { for (let node of rawData) {
maxDepth = Math.max(maxDepth, getDepth(node)); maxDepth = Math.max(maxDepth, getDepth(node));
} }
console.log(maxDepth);
const container = document.getElementById('level-buttons'); const container = document.getElementById('level-buttons');
if (container) { if (container) {

View File

@ -1,6 +1,6 @@
import nvtx import nvtx
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate, Deepseekv3MoE
from tensorrt_llm._torch.models.modeling_qwen3_next import ( from tensorrt_llm._torch.models.modeling_qwen3_next import (
Qwen3NextGatedDeltaNet, Qwen3NextGatedDeltaNet,
Qwen3NextSparseMoeBlock, Qwen3NextSparseMoeBlock,
@ -12,6 +12,7 @@ from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
def mark_ranges(): def mark_ranges():
DeepseekV3Gate.forward = nvtx.annotate("DeepseekV3Gate")(DeepseekV3Gate.forward) DeepseekV3Gate.forward = nvtx.annotate("DeepseekV3Gate")(DeepseekV3Gate.forward)
Deepseekv3MoE.forward = nvtx.annotate("Deepseekv3MoE")(Deepseekv3MoE.forward)
Qwen3NextGatedDeltaNet.forward = nvtx.annotate("Qwen3NextGatedDeltaNet")( Qwen3NextGatedDeltaNet.forward = nvtx.annotate("Qwen3NextGatedDeltaNet")(
Qwen3NextGatedDeltaNet.forward Qwen3NextGatedDeltaNet.forward
) )

View File

@ -1,5 +1,6 @@
import contextlib import contextlib
import functools import functools
import itertools
import os import os
import unittest.mock import unittest.mock
import weakref import weakref
@ -23,6 +24,7 @@ from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs
from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding
from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.bindings.internal.batch_manager import CacheType from tensorrt_llm.bindings.internal.batch_manager import CacheType
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.models.modeling_utils import QuantConfig
@ -38,24 +40,16 @@ def round_up(a, b):
def get_balanced_selection_no_cache( def get_balanced_selection_no_cache(
num_tokens, top_k, num_experts, dtype, device, world_size, rank num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
): ):
# First, each sender selects target rank token_id = torch.arange(dp_rank * num_tokens * top_k, (dp_rank + 1) * num_tokens * top_k).view(
target_rank_before_mod = torch.arange(num_tokens * world_size * top_k).view( num_tokens, top_k
num_tokens, world_size, top_k
) )
target_rank_before_mod += top_k * torch.arange(num_tokens).view( experts_per_rank = num_experts // ep_size
num_tokens, 1, 1 token_selected_experts = (token_id % ep_size) * experts_per_rank + (
) # Shift `top_k` ranks for the next token on each rank, to balance network traffic token_id // ep_size
target_rank = target_rank_before_mod % world_size ) % experts_per_rank
# Second, each receiver selects target expert token_selected_experts = token_selected_experts.sort(dim=-1).values
target_expert = torch.empty_like(target_rank)
for reciever_rank in range(world_size):
mask = target_rank == reciever_rank
experts_per_rank = num_experts // world_size
local_expert = torch.arange(num_tokens * top_k) % experts_per_rank
target_expert[mask] = (reciever_rank * experts_per_rank) + local_expert
token_selected_experts = target_expert[:, rank].sort(dim=-1).values
return token_selected_experts.contiguous().to(dtype=dtype, device=device) return token_selected_experts.contiguous().to(dtype=dtype, device=device)
@ -64,47 +58,62 @@ get_balanced_selection = functools.cache(get_balanced_selection_no_cache)
def test_get_balanced_selection(): def test_get_balanced_selection():
dtype = torch.long dtype = torch.long
for num_tokens in range(1, 33): for num_tokens, num_experts, enable_attention_dp in itertools.product(
for num_experts in range(1, 65): range(1, 35), range(1, 35), [False, True]
print(f"{num_tokens=} {num_experts=}") ):
for top_k in range(1, min(11, num_experts)): print(f"{num_tokens=} {num_experts=} {enable_attention_dp=}")
for world_size in range(1, 65): for top_k in range(1, min(10, num_experts) + 1):
if num_experts % world_size == 0: for world_size in range(1, 35):
tokens_per_expert = torch.zeros(num_experts) dp_size = world_size if enable_attention_dp else 1
for rank in range(world_size): ep_size = world_size
token_selected_experts = get_balanced_selection_no_cache( if num_experts % ep_size == 0:
num_tokens, top_k, num_experts, dtype, "cpu", world_size, rank tokens_per_expert = torch.zeros(num_experts)
for dp_rank in range(dp_size):
token_selected_experts = get_balanced_selection_no_cache(
num_tokens, top_k, num_experts, dtype, "cpu", dp_size, dp_rank, ep_size
)
sorted_selection = token_selected_experts.sort(dim=-1).values
if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any():
raise ValueError(f"duplicated experts on rank {dp_rank}")
experts_per_rank = num_experts // ep_size
tokens_per_rank = (
(token_selected_experts // experts_per_rank)
.view(-1)
.bincount(minlength=ep_size)
)
if tokens_per_rank.max() - tokens_per_rank.min() > 1:
raise ValueError(f"tokens sent from rank {dp_rank} is not balanced")
unique_tokens_per_rank = (
(
torch.arange(ep_size).view(ep_size, 1, 1)
== token_selected_experts // experts_per_rank
) )
sorted_selection = token_selected_experts.sort(dim=-1).values .any(dim=2)
if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any(): .sum(dim=1)
raise ValueError(f"duplicated experts on rank {rank}") )
experts_per_rank = num_experts // world_size if unique_tokens_per_rank.max() - unique_tokens_per_rank.min() > 1:
tokens_per_rank = ( raise ValueError(
(token_selected_experts // experts_per_rank) f"tokens sent from rank {dp_rank} is not balanced after removing duplicates"
.view(-1)
.bincount(minlength=world_size)
) )
if tokens_per_rank.max() - tokens_per_rank.min() > 1: tokens_per_expert += token_selected_experts.view(-1).bincount(
raise ValueError(f"tokens sent from rank {rank} is not balanced") minlength=num_experts
tokens_per_expert += token_selected_experts.view(-1).bincount( )
minlength=num_experts if tokens_per_expert.max() - tokens_per_expert.min() > 1:
) raise ValueError("tokens per expert is not balanced")
if tokens_per_expert.max() - tokens_per_expert.min() > 1:
raise ValueError("tokens per expert is not balanced")
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank): def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size):
num_tokens, top_k = imbalanced_experts.shape num_tokens, top_k = imbalanced_experts.shape
dtype = imbalanced_experts.dtype dtype = imbalanced_experts.dtype
device = imbalanced_experts.device device = imbalanced_experts.device
balanced_experts = get_balanced_selection_no_cache( balanced_experts = get_balanced_selection_no_cache(
num_tokens, top_k, num_experts, dtype, device, world_size, rank num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
) )
if balance_ratio == 0.0: if balance_ratio == 0.0:
num_balanced_tokens = 0 num_balanced_tokens = 0
else: else:
# Activate all experts # Activate all experts
min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, world_size * top_k)) min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, dp_size * top_k))
num_balanced_tokens = min_num_balanced_tokens + round( num_balanced_tokens = min_num_balanced_tokens + round(
(num_tokens - min_num_balanced_tokens) * balance_ratio (num_tokens - min_num_balanced_tokens) * balance_ratio
) )
@ -116,57 +125,78 @@ def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_si
@functools.cache @functools.cache
def get_all_to_one_selection( def get_all_to_one_selection(
num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
): ):
experts_per_rank = num_experts // world_size experts_per_rank = num_experts // ep_size
if top_k > experts_per_rank: if top_k > experts_per_rank:
raise ValueError( raise ValueError(
"Cannot send all tokens to a single rank because `top_k > experts_per_rank`" "Cannot send all tokens to a single rank because `top_k > experts_per_rank`"
) )
imbalanced_experts = ( imbalanced_experts = (
torch.arange( torch.arange(
rank * num_tokens * top_k, (rank + 1) * num_tokens * top_k, dtype=dtype, device=device dp_rank * num_tokens * top_k,
(dp_rank + 1) * num_tokens * top_k,
dtype=dtype,
device=device,
).view(num_tokens, top_k) ).view(num_tokens, top_k)
% experts_per_rank % experts_per_rank
) )
imbalanced_experts = imbalanced_experts.sort(dim=-1).values imbalanced_experts = imbalanced_experts.sort(dim=-1).values
return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank) return apply_balance_ratio(
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
)
@functools.cache @functools.cache
def get_balanced_rank_imbalanced_expert_selection( def get_balanced_rank_imbalanced_expert_selection(
num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
): ):
experts_per_rank = num_experts // world_size experts_per_rank = num_experts // ep_size
active_experts_per_rank = ceil_div(top_k, world_size) active_experts_per_rank = ceil_div(top_k, ep_size)
# Select expert from [0, active_experts_per_rank * world_size), # Select expert from [0, active_experts_per_rank * ep_size),
# then scale to [0, experts_per_rank * world_size) # then scale to [0, experts_per_rank * ep_size)
narrow_experts = get_balanced_selection_no_cache( narrow_experts = get_balanced_selection_no_cache(
num_tokens, top_k, active_experts_per_rank * world_size, dtype, device, world_size, rank num_tokens,
top_k,
active_experts_per_rank * ep_size,
dtype,
device,
dp_size,
dp_rank,
ep_size,
) )
imbalanced_experts = ( imbalanced_experts = (
narrow_experts // active_experts_per_rank * experts_per_rank narrow_experts // active_experts_per_rank * experts_per_rank
+ narrow_experts % active_experts_per_rank + narrow_experts % active_experts_per_rank
) )
return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank) return apply_balance_ratio(
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
)
def make_balanced_routing_method( def make_balanced_routing_method(
apply_method_orig, num_experts, balance_method, balance_ratio, world_size, rank moe_module,
apply_method_orig,
num_experts,
balance_method,
balance_ratio,
dp_size,
dp_rank,
ep_size,
): ):
def balanced_routing_method(router_logits): def balanced_routing_method(router_logits):
token_selected_experts, token_final_scales = apply_method_orig(router_logits) token_selected_experts, token_final_scales = apply_method_orig(router_logits)
if balance_method == BalanceMethod.NotModified: assert moe_module._routing_results_replaced_at in [None, "make_balanced_routing_method"]
pass if balance_method == BalanceMethod.Balanced:
elif balance_method == BalanceMethod.Balanced:
token_selected_experts = get_balanced_selection( token_selected_experts = get_balanced_selection(
token_selected_experts.shape[0], token_selected_experts.shape[0],
token_selected_experts.shape[1], token_selected_experts.shape[1],
num_experts, num_experts,
token_selected_experts.dtype, token_selected_experts.dtype,
token_selected_experts.device, token_selected_experts.device,
world_size, dp_size,
rank, dp_rank,
ep_size,
) )
elif balance_method == BalanceMethod.ImbalancedRanks: elif balance_method == BalanceMethod.ImbalancedRanks:
token_selected_experts = get_all_to_one_selection( token_selected_experts = get_all_to_one_selection(
@ -176,8 +206,9 @@ def make_balanced_routing_method(
balance_ratio, balance_ratio,
token_selected_experts.dtype, token_selected_experts.dtype,
token_selected_experts.device, token_selected_experts.device,
world_size, dp_size,
rank, dp_rank,
ep_size,
) )
elif balance_method == BalanceMethod.ImbalancedExperts: elif balance_method == BalanceMethod.ImbalancedExperts:
token_selected_experts = get_balanced_rank_imbalanced_expert_selection( token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
@ -187,16 +218,127 @@ def make_balanced_routing_method(
balance_ratio, balance_ratio,
token_selected_experts.dtype, token_selected_experts.dtype,
token_selected_experts.device, token_selected_experts.device,
world_size, dp_size,
rank, dp_rank,
ep_size,
) )
else: else:
raise NotImplementedError(f"Not support balance_method {balance_method}") raise NotImplementedError(f"Not support balance_method {balance_method}")
moe_module._routing_results_replaced_at = "make_balanced_routing_method"
return token_selected_experts, token_final_scales return token_selected_experts, token_final_scales
return balanced_routing_method return balanced_routing_method
@functools.cache
def get_token_final_scales(shape, device):
return torch.full(shape, 1.0 / shape[-1], dtype=torch.bfloat16, device=device)
def make_balanced_run_moe(
moe_module,
run_moe_orig,
top_k,
num_experts,
balance_method,
balance_ratio,
dp_size,
dp_rank,
ep_size,
):
def balanced_run_moe(
x, token_selected_experts, token_final_scales, x_sf, router_logits, do_finalize, moe_output
):
if moe_module._routing_results_replaced_at is not None:
return run_moe_orig(
x,
token_selected_experts,
token_final_scales,
x_sf,
router_logits,
do_finalize,
moe_output,
)
logger.warning_once(
'Layer-wise benchmarks: Specifying routing results of "TRTLLM" MoE backend in TEP cases leads to different'
" execution path around the topk kernel",
key="replace_routing_method_ctx_trtllm_tp",
)
if balance_method == BalanceMethod.Balanced:
token_selected_experts = get_balanced_selection(
x.shape[0],
top_k,
num_experts,
torch.int32,
x.device,
dp_size,
dp_rank,
ep_size,
)
elif balance_method == BalanceMethod.ImbalancedRanks:
token_selected_experts = get_all_to_one_selection(
x.shape[0],
top_k,
num_experts,
balance_ratio,
torch.int32,
x.device,
dp_size,
dp_rank,
ep_size,
)
elif balance_method == BalanceMethod.ImbalancedExperts:
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
x.shape[0],
top_k,
num_experts,
balance_ratio,
torch.int32,
x.device,
dp_size,
dp_rank,
ep_size,
)
else:
raise NotImplementedError(f"Not support balance_method {balance_method}")
token_final_scales = get_token_final_scales(
token_selected_experts.shape, token_selected_experts.device
)
router_logits = None
final_hidden_states = run_moe_orig(
x,
token_selected_experts,
token_final_scales,
x_sf,
router_logits,
do_finalize,
moe_output,
)
if not do_finalize:
final_hidden_states = (
final_hidden_states[0],
token_final_scales, # WAR for TRTLLMGenFusedMoE bug that it returns wrong `token_final_scales`
final_hidden_states[2],
)
moe_module._routing_results_replaced_at = "make_balanced_run_moe"
return final_hidden_states
return balanced_run_moe
def make_forward_impl_check(moe_module, forward_impl_orig):
def forward_impl(*args, **kwargs):
moe_module._routing_results_replaced_at = None
res = forward_impl_orig(*args, **kwargs)
assert moe_module._routing_results_replaced_at is not None, (
"Routing results are not replaced"
)
del moe_module._routing_results_replaced_at
return res
return forward_impl
class RunnerMixin(ABC): class RunnerMixin(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
@ -398,8 +540,9 @@ class RunnerMixin(ABC):
@contextlib.contextmanager @contextlib.contextmanager
def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float): def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float):
if balance_method == BalanceMethod.NotModified: if balance_method == BalanceMethod.NotModified:
pass yield
elif self.model_config.moe_backend not in [ return
if self.model_config.moe_backend not in [
"CUTEDSL", "CUTEDSL",
"CUTLASS", "CUTLASS",
"DEEPGEMM", "DEEPGEMM",
@ -410,29 +553,59 @@ class RunnerMixin(ABC):
f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",' f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",'
f' please set balance_method to "NotModified"' f' please set balance_method to "NotModified"'
) )
elif ( original_methods = []
self.model_config.moe_backend == "TRTLLM" dp_rank = self.model_config.mapping.rank // (
and not self.model_config.mapping.enable_attention_dp self.model_config.mapping.world_size // self.model_config.mapping.dp_size
): )
raise NotImplementedError( for layer in self.layers:
'Not support replace routing method for moe_backend "TRTLLM" with attention TP,' moe_module = layer.mlp.experts
' please set balance_method to "NotModified"'
# Replace `routing_method.apply` for normal cases
apply_method_orig = moe_module.routing_method.apply
moe_module.routing_method.apply = make_balanced_routing_method(
moe_module,
apply_method_orig,
moe_module.num_experts,
balance_method,
balance_ratio,
self.model_config.mapping.dp_size,
dp_rank,
self.model_config.mapping.moe_ep_size,
) )
apply_methods_orig = [layer.mlp.experts.routing_method.apply for layer in self.layers]
try: # Replace `run_moe` for TRTLLMGenFusedMoE TEP because it does not call `routing_method.apply`
for layer, apply_method_orig in zip(self.layers, apply_methods_orig): if isinstance(moe_module, TRTLLMGenFusedMoE):
layer.mlp.experts.routing_method.apply = make_balanced_routing_method( run_moe_orig = moe_module.run_moe
apply_method_orig, moe_module.run_moe = make_balanced_run_moe(
moe_module,
run_moe_orig,
layer.mlp.experts.routing_method.top_k,
layer.mlp.experts.num_experts, layer.mlp.experts.num_experts,
balance_method, balance_method,
balance_ratio, balance_ratio,
layer.mlp.experts.ep_size, self.model_config.mapping.dp_size,
layer.mlp.experts.ep_rank, dp_rank,
self.model_config.mapping.moe_ep_size,
) )
else:
run_moe_orig = None
# Replace `forward_impl` to ensure that routing results are replaced
forward_impl_orig = moe_module.forward_impl
moe_module.forward_impl = make_forward_impl_check(moe_module, forward_impl_orig)
original_methods.append((apply_method_orig, run_moe_orig, forward_impl_orig))
try:
yield yield
finally: finally:
for layer, apply_method_orig in zip(self.layers, apply_methods_orig): for layer, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip(
layer.mlp.experts.routing_method.apply = apply_method_orig self.layers, original_methods
):
moe_module = layer.mlp.experts
moe_module.routing_method.apply = apply_method_orig
if isinstance(moe_module, TRTLLMGenFusedMoE):
moe_module.run_moe = run_moe_orig
moe_module.forward_impl = forward_impl_orig
@staticmethod @staticmethod
def create_kv_cache_manager( def create_kv_cache_manager(
@ -467,7 +640,7 @@ class RunnerMixin(ABC):
kv_cache_manager = kv_cache_manager_cls( kv_cache_manager = kv_cache_manager_cls(
kv_cache_config, kv_cache_config,
CacheType.SELFKONLY, CacheType.SELFKONLY,
num_layers=sum(layer_mask), num_layers=num_layers,
num_kv_heads=1, num_kv_heads=1,
head_dim=model_config.pretrained_config.kv_lora_rank head_dim=model_config.pretrained_config.kv_lora_rank
+ model_config.pretrained_config.qk_rope_head_dim, + model_config.pretrained_config.qk_rope_head_dim,

View File

@ -48,7 +48,6 @@ def test_deepseek_r1_ctx_tep(llm_root, world_size):
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2", model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
"--no-enable-attention-dp", "--no-enable-attention-dp",
"--moe-backend=TRTLLM", "--moe-backend=TRTLLM",
"--balance-method=NotModified",
], ],
cwd=llm_root / "examples" / "layer_wise_benchmarks", cwd=llm_root / "examples" / "layer_wise_benchmarks",
env={ env={
@ -139,7 +138,6 @@ def test_qwen3_next_gen_tep(llm_root, world_size):
"--layer-indices=6,7", "--layer-indices=6,7",
"--no-enable-attention-dp", "--no-enable-attention-dp",
"--moe-backend=TRTLLM", "--moe-backend=TRTLLM",
"--balance-method=NotModified",
], ],
cwd=llm_root / "examples" / "layer_wise_benchmarks", cwd=llm_root / "examples" / "layer_wise_benchmarks",
env={ env={