[None][feat] Layer-wise benchmarks: support TEP balance, polish slurm scripts (#10237)

Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
This commit is contained in:
Tailing Yuan 2026-01-05 11:23:04 +08:00 committed by GitHub
parent aaf80be0f3
commit a7fe043b13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 473 additions and 181 deletions

1
.gitignore vendored
View File

@ -77,6 +77,7 @@ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmha_v2_cu/
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
.devcontainer/.env
/examples/layer_wise_benchmarks/autotuner_cache/
/examples/layer_wise_benchmarks/profiles/
# User config files

View File

@ -15,6 +15,9 @@ pip install -e ../..
**Step 3:** In the container, run benchmarks and generate profiles:
```bash
# Set autotune cache path
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
# Run DeepSeek-R1 NVFP4
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml
@ -24,16 +27,16 @@ NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
# Run DeepSeek-V3.2-Exp with 32k context length
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --seq-len-kv-cache 32769
# Run with attention TP
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp
# Run with attention TP and TRTLLMGen
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM
# Run with MTP3
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --batch-size 32 --seq-len-q 4
@ -48,7 +51,7 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID
# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp
# Run Qwen3-Next (balanced routing is not implemented)
# Run Qwen3-Next
NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 4
NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 512
@ -69,13 +72,13 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --balance-method ImbalancedExperts
### Run with Slurm
> Tips:
> 1. If you have a running Slurm job, please skip step 1 and go straight to step 2 and 3.
> 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2 and run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_query_container_name.sh)` to get it.
> 1. If you have a running Slurm job, you can set environment variable `export SLURM_JOB_ID=aaa` and skip step 1.
> 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2. Just run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(./slurm_query_container_name.sh)` to get it.
**Step 1:** On the controller node, allocate one or multiple nodes, and record the `SLURM_JOB_ID`:
**Step 1:** On the controller node, allocate one or multiple nodes, and export the `SLURM_JOB_ID`:
```bash
SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh)
export SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh)
```
Please fill the variables in `./slurm_alloc.sh`.
@ -83,28 +86,47 @@ Please fill the variables in `./slurm_alloc.sh`.
**Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node:
```bash
SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_init_containers.sh
./slurm_init_containers.sh
```
It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once.
> Tips: If you want to change the image, no need to reallocate Slurm jobs. Just start another container by running step 2 with `export CONTAINER_NAME=aaa`, and step 3 will run in the container specified by the `CONTAINER_NAME` env.
**(Optional) Get an interactive shell**
```bash
NODES=1 NP=1 ./slurm_launch.sh --overlap --pty middleware/exclude_slurm_envs bash
```
The `--overlap` option allows this shell to share the node with other jobs. The middleware enables nested MPI process spawning from within Slurm jobs.
You may compile C++ extensions in the interactive shell:
```bash
cd ../..
export CCACHE_DIR=$(realpath cpp/.ccache)
python3 scripts/build_wheel.py --cuda_architectures native --no-venv --skip_building_wheel -G Ninja --use_ccache --clean
```
**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` &le; the number of allocated nodes:
```bash
# Set autotune cache path
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
# Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
# Run with TRTLLMGen
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM
NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM
# Run with DeepEPLowLatency
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
# You can run 4-GPU and 8-GPU tasks without reallocate the slurm job
SLURM_JOB_ID=$SLURM_JOB_ID NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml
SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gtx.yaml
# You can run 4-GPU and 8-GPU tasks without reallocating the slurm job
NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml
NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gen.yaml
```
### Batched run
@ -131,9 +153,11 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID
Run the following command in the container:
```bash
# Parse the profile at the default directory
python3 parse.py --world-size 4
# Specify the location of the .nsys-rep file
# Specify the file path
python3 parse.py --file-path profiles/report_np4_rank0.nsys-rep
python3 parse.py --profile-dir ./profiles --world-size 4 --rank 0
# Parse a specific module. The module must appear exactly once in each run.
@ -145,6 +169,15 @@ You will receive three reports, each containing kernel timing statistics grouped
2. A CSV report at `profiles/report_np4_rank0.csv`
3. An HTML report at `profiles/report_np4_rank0.html`
## Developer utilities
1. Less startup time when debug a model
1. Disable autotuner: add `--no-enable-autotuner` option
2. Disable nsys profile: set `PROFILE=0` environment variable
2. Capture more information
1. Enable GPU metrics: set `GPU_METRICS=1` environment variable
2. Enable backtrace: set `BACKTRACE=1` environment variable
## Trouble shooting
1. Error `fp8 blockscale gemm only support Hopper` on Blackwell.

View File

@ -0,0 +1,8 @@
#!/bin/bash
set -euo pipefail
# Clear slurm envs
unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true))
"$@"

View File

@ -3,8 +3,12 @@
set -euo pipefail
# Clear slurm envs
unset $(env | grep -i slurm | awk -F'=' '{print $1}')
unset $(env | grep MPI | awk -F'=' '{print $1}')
unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true))
extra_args=
if [ -v TLLM_AUTOTUNER_CACHE_PATH ]; then
extra_args+="-x TLLM_AUTOTUNER_CACHE_PATH"
fi
set -x
mpirun --allow-run-as-root --np ${NP} "$@"
mpirun --allow-run-as-root --np ${NP} $extra_args "$@"

View File

@ -5,6 +5,7 @@ import json
import re
import sqlite3
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
@ -14,8 +15,9 @@ import pandas as pd
# Parse cmdline
parser = argparse.ArgumentParser()
parser.add_argument("--file-path", type=str)
parser.add_argument("--profile-dir", type=str, default="profiles")
parser.add_argument("--world-size", "--np", type=int, required=True)
parser.add_argument("--world-size", "--np", type=int)
parser.add_argument("--rank", type=int, default=0)
parser.add_argument("--warmup-times", type=int)
parser.add_argument("--module", type=str)
@ -27,6 +29,8 @@ group.add_argument(
)
parser.set_defaults(error_on_unknown_kernel=False)
args = parser.parse_args()
if (args.file_path is None) == (args.world_size is None):
parser.error("Please specify exactly one of --file-path and --world-size.")
print(args)
@ -89,11 +93,20 @@ def shortest_common_supersequence(a, b):
return res
profile_dir = Path(args.profile_dir)
nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep"
sqlite_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.sqlite"
csv_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.csv"
html_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.html"
if args.file_path is not None:
nsys_rep_file_path = Path(args.file_path)
if not nsys_rep_file_path.name.endswith(".nsys-rep"):
raise ValueError("Expect a .nsys-rep file")
else:
profile_dir = Path(args.profile_dir)
nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep"
sqlite_file_path = nsys_rep_file_path.parent / (
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".sqlite"
)
csv_file_path = nsys_rep_file_path.parent / (nsys_rep_file_path.name[: -len(".nsys-rep")] + ".csv")
html_file_path = nsys_rep_file_path.parent / (
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".html"
)
lazy_convert_sqlite(nsys_rep_file_path, sqlite_file_path)
conn = sqlite3.connect(f"file:{sqlite_file_path}?mode=ro", uri=True)
@ -196,7 +209,8 @@ if "CUDA_GRAPH_NODE_EVENTS" in tables:
R.start AS runtime_start, R.end AS runtime_end,
CGE2.start AS capture_start, CGE2.end AS capture_end
FROM ({unified_subquery}) AS unified
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.correlationId = R.correlationId
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.graphNodeId IS NOT NULL AND
unified.correlationId = R.correlationId
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND
CGE1.originalGraphNodeId IS NOT NULL
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId"""
@ -318,6 +332,7 @@ parser_keywords = [
("routingInitExpertCounts", "routingInitExpertCounts"),
("routingIndicesCluster", "routingIndicesClusterKernel"),
("routingIndicesCoop", "routingIndicesCoopKernel"),
("router_gemm", "router_gemm_kernel"),
("bmm_4_44_32", "bmm_E2m1_E2m1E2m1_Fp32_t"),
("finalize", "finalize::finalizeKernel"),
("bmm_16_44_32", "bmm_Bfloat16_E2m1E2m1_Fp32_"),
@ -381,7 +396,7 @@ def parse_kernel_name(demangledName):
if all(keyword in name for keyword in src):
return dst
if name not in warned_names:
print(f"Unknown kernel name: {name}")
print(f"Unknown kernel name: {name}", file=sys.stderr)
warned_names.add(name)
if args.error_on_unknown_kernel:
raise NotImplementedError(f"Unknown kernel name: {name}")
@ -491,11 +506,33 @@ with csv_file_path.open("w", newline="") as f:
for row in csv_data:
csv_writer.writerow(row)
js_header_config = [{"name": problem["text"]} for problem in problem_set]
js_header_config = []
for problem in problem_set:
innermost_children = js_header_config
for k, msg_prefix in [
("batch_size", "b="),
("seq_len_q", "q="),
("seq_len_kv_cache", "past="),
]:
if len(run_args[k + "_list"]) > 1:
if len(innermost_children) == 0 or problem["spec"][k] != innermost_children[-1][k]:
innermost_children.append(
{
"name": msg_prefix + str(problem["spec"][k]),
"children": [],
k: problem["spec"][k],
}
)
innermost_children = innermost_children[-1]["children"]
innermost_children.append({"name": problem["text"]})
loader = jinja2.FileSystemLoader(Path(__file__).parent)
template = jinja2.Environment(loader=loader).get_template("template.html")
with html_file_path.open("w") as f:
configText = (
"Run:\n" + json.dumps(run_args, indent=4) + "\n\nParse:\n" + json.dumps(args.__dict__)
"Run:\n"
+ json.dumps(run_args, indent=4)
+ "\n\nParse:\n"
+ json.dumps(args.__dict__, indent=4)
)
f.write(template.render(headerConfig=js_header_config, rawData=js_data, configText=configText))

View File

@ -2,6 +2,7 @@ import argparse
import itertools
import json
import os
from unittest import mock
import numpy as np
import nvtx
@ -9,6 +10,8 @@ import torch
import yaml
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.logger import logger
@ -108,17 +111,6 @@ if args.enable_attention_dp is None:
args.enable_attention_dp = False
if args.max_num_tokens is None:
args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list)
if args.run_type == "GEN":
ctx_batch_size = max(1, max(20480, args.max_num_tokens) // max(args.seq_len_kv_cache_list))
args.max_num_tokens = max(
args.max_num_tokens, ctx_batch_size * max(args.seq_len_kv_cache_list)
)
else:
if args.run_type == "GEN":
ctx_batch_size = max(1, args.max_num_tokens // max(args.seq_len_kv_cache_list))
assert args.max_num_tokens >= ctx_batch_size * max(args.seq_len_kv_cache_list), (
"Max_num_tokens is too small to prefill KV cache"
)
if args.use_low_precision_moe_combine is None:
args.use_low_precision_moe_combine = False
if args.enable_autotuner is None:
@ -169,17 +161,75 @@ runner = Runner(
)
logger.info("Layer-wise benchmarks: Create runner ... Done")
# Autotune
run_pack = runner.create_run_pack(
args.run_type,
batch_size=max(args.batch_size_list),
request_id_begin=0,
seq_len_q=max(args.seq_len_q_list),
seq_len_kv_cache=args.seq_len_kv_cache_list[0],
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
if args.enable_autotuner:
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
with autotune(cache_path=cache_path):
run_pack()
else:
run_pack()
# Prefill KV cache
if args.run_type == "GEN":
logger.info("Layer-wise benchmarks: Create runner for prefill")
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
ctx_batch_size = min(
args.max_batch_size,
max(1, 20480 // ctx_seq_len_q),
)
ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
with mock.patch.object(
CutlassFusedMoE, "select_alltoall_method_type", return_value=AlltoallMethodType.NotEnabled
):
ctx_runner = Runner(
args.model,
mapping,
moe_backend="CUTLASS",
layer_indices=args.layer_indices,
scaled_from=args.scaled_from,
max_seq_len=args.max_seq_len,
max_num_tokens=ctx_batch_size * ctx_seq_len_q,
moe_max_num_tokens=16384,
use_low_precision_moe_combine=args.use_low_precision_moe_combine,
use_cuda_graph=False,
)
logger.info("Layer-wise benchmarks: Create runner for prefill ... Done")
logger.info("Layer-wise benchmarks: Prefill KV cache")
assert ctx_batch_size <= args.max_batch_size
assert ctx_seq_len_q + 0 <= args.max_seq_len
num_requests = max(args.batch_size_list)
for request_id_begin in range(0, num_requests, ctx_batch_size):
run_pack = ctx_runner.create_run_pack(
"CTX",
batch_size=min(ctx_batch_size, num_requests - request_id_begin),
request_id_begin=request_id_begin,
seq_len_q=ctx_seq_len_q,
seq_len_kv_cache=0,
kv_cache_manager=kv_cache_manager,
attn_workspace=ctx_attn_workspace,
)
with ctx_runner.replace_routing_method_ctx(
balance_method=BalanceMethod.Balanced, balance_ratio=None
):
run_pack(check=True)
del ctx_runner
del ctx_attn_workspace
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
# Warm up
for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
[
True,
max(args.batch_size_list),
max(args.seq_len_q_list),
args.seq_len_kv_cache_list[0],
args.balance_ratio_list[0],
],
logger.info("Layer-wise benchmarks: Warmup")
for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
*itertools.product(
[False],
args.batch_size_list,
args.seq_len_q_list,
args.seq_len_kv_cache_list,
@ -203,34 +253,10 @@ for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
):
capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream):
if autotune_flag:
if args.enable_autotuner:
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
with autotune(cache_path=cache_path):
run_pack()
if args.run_type == "GEN":
logger.info("Layer-wise benchmarks: Prefill KV cache")
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
assert ctx_batch_size <= args.max_batch_size
assert ctx_seq_len_q + 0 <= args.max_seq_len
assert ctx_batch_size * ctx_seq_len_q <= args.max_num_tokens
max_batch_size = max(args.batch_size_list)
for request_id_begin in range(0, max_batch_size, ctx_batch_size):
ctx_run_pack = runner.create_run_pack(
"CTX",
batch_size=min(ctx_batch_size, max_batch_size - request_id_begin),
request_id_begin=request_id_begin,
seq_len_q=ctx_seq_len_q,
seq_len_kv_cache=0,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
ctx_run_pack(check=True)
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
else:
run_pack(check=True)
run_pack(check=True)
torch.cuda.current_stream().wait_stream(capture_stream)
torch.cuda.synchronize()
logger.info("Layer-wise benchmarks: Warmup ... Done")
events = [
torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1)
@ -270,7 +296,7 @@ for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product(
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
run_pack()
balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}"
balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}"
nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}"
for i in range(args.warmup_times + args.run_times):
events[i].record()

View File

@ -40,7 +40,9 @@ else
PROFILE_CMD=
fi
SCRIPT_PATH=$(realpath --relative-to="$(pwd)" "$(dirname -- "$0")"/run.py)
set -x
$PROFILE_CMD bash -c \
"python3 -u run.py \"\$@\" 2>&1 | tee \"$PROFILE_DIR/report_np${WORLD_SIZE}_rank${RANK}.log\"" \
bash "$@"
$PROFILE_CMD bash -o pipefail -c \
"python3 -u \"\$1\" \"\${@:3}\" 2>&1 | tee \"\$2/report_np${WORLD_SIZE}_rank${RANK}.log\"" \
bash "$SCRIPT_PATH" "$PROFILE_DIR" "$@"

View File

@ -4,23 +4,23 @@ set -euo pipefail
# CONTAINER_IMAGE=
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..")
TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..)
CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT
if [ "${SLURM_JOB_ID:-}" == "" ]; then
if [ -z "${SLURM_JOB_ID:-}" ]; then
echo "Please set SLURM_JOB_ID"
exit 1
fi
NODES=$(squeue -j $SLURM_JOB_ID -h -o "%D")
if [ "${CONTAINER_IMAGE:-}" == "" ]; then
if [ -z "${CONTAINER_IMAGE:-}" ]; then
# Read Docker image from current_image_tags.properties
source ../../jenkins/current_image_tags.properties
MACHINE="$(uname -m)"
if [ "$MACHINE" == "x86_64" ]; then
DOCKER_IMAGE=$LLM_DOCKER_IMAGE
DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_DOCKER_IMAGE)
elif [ "$MACHINE" == "aarch64" ]; then
DOCKER_IMAGE=$LLM_SBSA_DOCKER_IMAGE
DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_SBSA_DOCKER_IMAGE)
else
echo "Unsupported machine hardware name \"$MACHINE\""
exit 1
@ -34,7 +34,7 @@ if [ "${CONTAINER_IMAGE:-}" == "" ]; then
SQSH_FILE_NAME=$(echo "$DOCKER_IMAGE" |
awk -F'#' '{print $2}' |
awk -F':' '{gsub(/\//,"+",$1); print $1"+"$2".sqsh"}')
CONTAINER_IMAGE="../../enroot/$SQSH_FILE_NAME"
CONTAINER_IMAGE="$TRTLLM_ROOT/enroot/$SQSH_FILE_NAME"
if [ ! -f "$CONTAINER_IMAGE" ]; then
echo "Container image file $CONTAINER_IMAGE does not exist, importing ..."
srun -N 1 --pty enroot import -o "$CONTAINER_IMAGE" "docker://$DOCKER_IMAGE"
@ -50,6 +50,7 @@ srun -N "$NODES" \
--container-name "$CONTAINER_NAME" \
--container-mounts "$CONTAINER_MOUNTS" \
--container-workdir "$WORKDIR" \
bash -c "pip install -U packaging &&
pip install -r ../../requirements.txt --no-build-isolation &&
pip install -e ../.."
bash -c "cd \"\$1\" &&
pip install -U packaging &&
pip install -r requirements.txt --no-build-isolation &&
pip install -e ." bash "$TRTLLM_ROOT"

View File

@ -3,9 +3,10 @@
set -euo pipefail
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..")
TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..)
CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT
if [ "${SLURM_JOB_ID:-}" == "" ]; then
if [ -z "${SLURM_JOB_ID:-}" ]; then
echo "Please set SLURM_JOB_ID"
exit 1
fi

View File

@ -2,6 +2,11 @@
set -euo pipefail
if [ -z "${SLURM_JOB_ID:-}" ]; then
echo "Please set SLURM_JOB_ID"
exit 1
fi
prefix="pyxis_${SLURM_JOB_ID}_"
matches=$(printf "%s\n" "$(srun -N 1 enroot list)" | grep "^${prefix}" || true)
count=$(printf "%s\n" "$matches" | wc -l)

View File

@ -313,6 +313,7 @@
const columnLabels = [];
let columnTotals = [];
let totalNode = null; // Special node for the "Total" row
let numCols = null;
const fmt = (num) => num.toFixed(1);
@ -332,7 +333,7 @@
const topLevelTimes = rawData.map(aggregate);
// Calculate Totals
const numCols = topLevelTimes[0].length;
numCols = topLevelTimes[0].length;
columnTotals = new Array(numCols).fill(0);
for (let c = 0; c < numCols; c++) {
columnTotals[c] = topLevelTimes.reduce((sum, row) => sum + row[c], 0);
@ -420,7 +421,8 @@
const total = columnTotals[colIdx] || 1;
const pct = (val / total) * 100;
html += `<td class="data-cell col-${colIdx}" onclick="handleCellClick('${uniqueId}', ${colIdx})">
html += `<td class="data-cell col-${colIdx}" ${val ? "title=\"" + fmt(val) + "\"" : ""}
onclick="handleCellClick('${uniqueId}', ${colIdx})">
<div class="progress-bar" style="width: ${pct}%;"></div>
<div class="cell-content">${val ? fmt(val) : ""}</div>
</td>`;
@ -717,6 +719,7 @@
});
window.onload = () => {
document.getElementById('kernelTable').style["min-width"] = Math.max(1000, (220 + 60 * numCols)) + "px";
renderHeader();
document.getElementById('tableBody').innerHTML = renderTableBody(rawData, 0);
renderTotalRow();
@ -741,7 +744,6 @@
for (let node of rawData) {
maxDepth = Math.max(maxDepth, getDepth(node));
}
console.log(maxDepth);
const container = document.getElementById('level-buttons');
if (container) {

View File

@ -1,6 +1,6 @@
import nvtx
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate, Deepseekv3MoE
from tensorrt_llm._torch.models.modeling_qwen3_next import (
Qwen3NextGatedDeltaNet,
Qwen3NextSparseMoeBlock,
@ -12,6 +12,7 @@ from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
def mark_ranges():
DeepseekV3Gate.forward = nvtx.annotate("DeepseekV3Gate")(DeepseekV3Gate.forward)
Deepseekv3MoE.forward = nvtx.annotate("Deepseekv3MoE")(Deepseekv3MoE.forward)
Qwen3NextGatedDeltaNet.forward = nvtx.annotate("Qwen3NextGatedDeltaNet")(
Qwen3NextGatedDeltaNet.forward
)

View File

@ -1,5 +1,6 @@
import contextlib
import functools
import itertools
import os
import unittest.mock
import weakref
@ -23,6 +24,7 @@ from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs
from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.bindings.internal.batch_manager import CacheType
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
@ -38,24 +40,16 @@ def round_up(a, b):
def get_balanced_selection_no_cache(
num_tokens, top_k, num_experts, dtype, device, world_size, rank
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
):
# First, each sender selects target rank
target_rank_before_mod = torch.arange(num_tokens * world_size * top_k).view(
num_tokens, world_size, top_k
token_id = torch.arange(dp_rank * num_tokens * top_k, (dp_rank + 1) * num_tokens * top_k).view(
num_tokens, top_k
)
target_rank_before_mod += top_k * torch.arange(num_tokens).view(
num_tokens, 1, 1
) # Shift `top_k` ranks for the next token on each rank, to balance network traffic
target_rank = target_rank_before_mod % world_size
# Second, each receiver selects target expert
target_expert = torch.empty_like(target_rank)
for reciever_rank in range(world_size):
mask = target_rank == reciever_rank
experts_per_rank = num_experts // world_size
local_expert = torch.arange(num_tokens * top_k) % experts_per_rank
target_expert[mask] = (reciever_rank * experts_per_rank) + local_expert
token_selected_experts = target_expert[:, rank].sort(dim=-1).values
experts_per_rank = num_experts // ep_size
token_selected_experts = (token_id % ep_size) * experts_per_rank + (
token_id // ep_size
) % experts_per_rank
token_selected_experts = token_selected_experts.sort(dim=-1).values
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
@ -64,47 +58,62 @@ get_balanced_selection = functools.cache(get_balanced_selection_no_cache)
def test_get_balanced_selection():
dtype = torch.long
for num_tokens in range(1, 33):
for num_experts in range(1, 65):
print(f"{num_tokens=} {num_experts=}")
for top_k in range(1, min(11, num_experts)):
for world_size in range(1, 65):
if num_experts % world_size == 0:
tokens_per_expert = torch.zeros(num_experts)
for rank in range(world_size):
token_selected_experts = get_balanced_selection_no_cache(
num_tokens, top_k, num_experts, dtype, "cpu", world_size, rank
for num_tokens, num_experts, enable_attention_dp in itertools.product(
range(1, 35), range(1, 35), [False, True]
):
print(f"{num_tokens=} {num_experts=} {enable_attention_dp=}")
for top_k in range(1, min(10, num_experts) + 1):
for world_size in range(1, 35):
dp_size = world_size if enable_attention_dp else 1
ep_size = world_size
if num_experts % ep_size == 0:
tokens_per_expert = torch.zeros(num_experts)
for dp_rank in range(dp_size):
token_selected_experts = get_balanced_selection_no_cache(
num_tokens, top_k, num_experts, dtype, "cpu", dp_size, dp_rank, ep_size
)
sorted_selection = token_selected_experts.sort(dim=-1).values
if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any():
raise ValueError(f"duplicated experts on rank {dp_rank}")
experts_per_rank = num_experts // ep_size
tokens_per_rank = (
(token_selected_experts // experts_per_rank)
.view(-1)
.bincount(minlength=ep_size)
)
if tokens_per_rank.max() - tokens_per_rank.min() > 1:
raise ValueError(f"tokens sent from rank {dp_rank} is not balanced")
unique_tokens_per_rank = (
(
torch.arange(ep_size).view(ep_size, 1, 1)
== token_selected_experts // experts_per_rank
)
sorted_selection = token_selected_experts.sort(dim=-1).values
if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any():
raise ValueError(f"duplicated experts on rank {rank}")
experts_per_rank = num_experts // world_size
tokens_per_rank = (
(token_selected_experts // experts_per_rank)
.view(-1)
.bincount(minlength=world_size)
.any(dim=2)
.sum(dim=1)
)
if unique_tokens_per_rank.max() - unique_tokens_per_rank.min() > 1:
raise ValueError(
f"tokens sent from rank {dp_rank} is not balanced after removing duplicates"
)
if tokens_per_rank.max() - tokens_per_rank.min() > 1:
raise ValueError(f"tokens sent from rank {rank} is not balanced")
tokens_per_expert += token_selected_experts.view(-1).bincount(
minlength=num_experts
)
if tokens_per_expert.max() - tokens_per_expert.min() > 1:
raise ValueError("tokens per expert is not balanced")
tokens_per_expert += token_selected_experts.view(-1).bincount(
minlength=num_experts
)
if tokens_per_expert.max() - tokens_per_expert.min() > 1:
raise ValueError("tokens per expert is not balanced")
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank):
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size):
num_tokens, top_k = imbalanced_experts.shape
dtype = imbalanced_experts.dtype
device = imbalanced_experts.device
balanced_experts = get_balanced_selection_no_cache(
num_tokens, top_k, num_experts, dtype, device, world_size, rank
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
)
if balance_ratio == 0.0:
num_balanced_tokens = 0
else:
# Activate all experts
min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, world_size * top_k))
min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, dp_size * top_k))
num_balanced_tokens = min_num_balanced_tokens + round(
(num_tokens - min_num_balanced_tokens) * balance_ratio
)
@ -116,57 +125,78 @@ def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_si
@functools.cache
def get_all_to_one_selection(
num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
):
experts_per_rank = num_experts // world_size
experts_per_rank = num_experts // ep_size
if top_k > experts_per_rank:
raise ValueError(
"Cannot send all tokens to a single rank because `top_k > experts_per_rank`"
)
imbalanced_experts = (
torch.arange(
rank * num_tokens * top_k, (rank + 1) * num_tokens * top_k, dtype=dtype, device=device
dp_rank * num_tokens * top_k,
(dp_rank + 1) * num_tokens * top_k,
dtype=dtype,
device=device,
).view(num_tokens, top_k)
% experts_per_rank
)
imbalanced_experts = imbalanced_experts.sort(dim=-1).values
return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank)
return apply_balance_ratio(
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
)
@functools.cache
def get_balanced_rank_imbalanced_expert_selection(
num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
):
experts_per_rank = num_experts // world_size
active_experts_per_rank = ceil_div(top_k, world_size)
# Select expert from [0, active_experts_per_rank * world_size),
# then scale to [0, experts_per_rank * world_size)
experts_per_rank = num_experts // ep_size
active_experts_per_rank = ceil_div(top_k, ep_size)
# Select expert from [0, active_experts_per_rank * ep_size),
# then scale to [0, experts_per_rank * ep_size)
narrow_experts = get_balanced_selection_no_cache(
num_tokens, top_k, active_experts_per_rank * world_size, dtype, device, world_size, rank
num_tokens,
top_k,
active_experts_per_rank * ep_size,
dtype,
device,
dp_size,
dp_rank,
ep_size,
)
imbalanced_experts = (
narrow_experts // active_experts_per_rank * experts_per_rank
+ narrow_experts % active_experts_per_rank
)
return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank)
return apply_balance_ratio(
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
)
def make_balanced_routing_method(
apply_method_orig, num_experts, balance_method, balance_ratio, world_size, rank
moe_module,
apply_method_orig,
num_experts,
balance_method,
balance_ratio,
dp_size,
dp_rank,
ep_size,
):
def balanced_routing_method(router_logits):
token_selected_experts, token_final_scales = apply_method_orig(router_logits)
if balance_method == BalanceMethod.NotModified:
pass
elif balance_method == BalanceMethod.Balanced:
assert moe_module._routing_results_replaced_at in [None, "make_balanced_routing_method"]
if balance_method == BalanceMethod.Balanced:
token_selected_experts = get_balanced_selection(
token_selected_experts.shape[0],
token_selected_experts.shape[1],
num_experts,
token_selected_experts.dtype,
token_selected_experts.device,
world_size,
rank,
dp_size,
dp_rank,
ep_size,
)
elif balance_method == BalanceMethod.ImbalancedRanks:
token_selected_experts = get_all_to_one_selection(
@ -176,8 +206,9 @@ def make_balanced_routing_method(
balance_ratio,
token_selected_experts.dtype,
token_selected_experts.device,
world_size,
rank,
dp_size,
dp_rank,
ep_size,
)
elif balance_method == BalanceMethod.ImbalancedExperts:
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
@ -187,16 +218,127 @@ def make_balanced_routing_method(
balance_ratio,
token_selected_experts.dtype,
token_selected_experts.device,
world_size,
rank,
dp_size,
dp_rank,
ep_size,
)
else:
raise NotImplementedError(f"Not support balance_method {balance_method}")
moe_module._routing_results_replaced_at = "make_balanced_routing_method"
return token_selected_experts, token_final_scales
return balanced_routing_method
@functools.cache
def get_token_final_scales(shape, device):
return torch.full(shape, 1.0 / shape[-1], dtype=torch.bfloat16, device=device)
def make_balanced_run_moe(
moe_module,
run_moe_orig,
top_k,
num_experts,
balance_method,
balance_ratio,
dp_size,
dp_rank,
ep_size,
):
def balanced_run_moe(
x, token_selected_experts, token_final_scales, x_sf, router_logits, do_finalize, moe_output
):
if moe_module._routing_results_replaced_at is not None:
return run_moe_orig(
x,
token_selected_experts,
token_final_scales,
x_sf,
router_logits,
do_finalize,
moe_output,
)
logger.warning_once(
'Layer-wise benchmarks: Specifying routing results of "TRTLLM" MoE backend in TEP cases leads to different'
" execution path around the topk kernel",
key="replace_routing_method_ctx_trtllm_tp",
)
if balance_method == BalanceMethod.Balanced:
token_selected_experts = get_balanced_selection(
x.shape[0],
top_k,
num_experts,
torch.int32,
x.device,
dp_size,
dp_rank,
ep_size,
)
elif balance_method == BalanceMethod.ImbalancedRanks:
token_selected_experts = get_all_to_one_selection(
x.shape[0],
top_k,
num_experts,
balance_ratio,
torch.int32,
x.device,
dp_size,
dp_rank,
ep_size,
)
elif balance_method == BalanceMethod.ImbalancedExperts:
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
x.shape[0],
top_k,
num_experts,
balance_ratio,
torch.int32,
x.device,
dp_size,
dp_rank,
ep_size,
)
else:
raise NotImplementedError(f"Not support balance_method {balance_method}")
token_final_scales = get_token_final_scales(
token_selected_experts.shape, token_selected_experts.device
)
router_logits = None
final_hidden_states = run_moe_orig(
x,
token_selected_experts,
token_final_scales,
x_sf,
router_logits,
do_finalize,
moe_output,
)
if not do_finalize:
final_hidden_states = (
final_hidden_states[0],
token_final_scales, # WAR for TRTLLMGenFusedMoE bug that it returns wrong `token_final_scales`
final_hidden_states[2],
)
moe_module._routing_results_replaced_at = "make_balanced_run_moe"
return final_hidden_states
return balanced_run_moe
def make_forward_impl_check(moe_module, forward_impl_orig):
def forward_impl(*args, **kwargs):
moe_module._routing_results_replaced_at = None
res = forward_impl_orig(*args, **kwargs)
assert moe_module._routing_results_replaced_at is not None, (
"Routing results are not replaced"
)
del moe_module._routing_results_replaced_at
return res
return forward_impl
class RunnerMixin(ABC):
@staticmethod
@abstractmethod
@ -398,8 +540,9 @@ class RunnerMixin(ABC):
@contextlib.contextmanager
def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float):
if balance_method == BalanceMethod.NotModified:
pass
elif self.model_config.moe_backend not in [
yield
return
if self.model_config.moe_backend not in [
"CUTEDSL",
"CUTLASS",
"DEEPGEMM",
@ -410,29 +553,59 @@ class RunnerMixin(ABC):
f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",'
f' please set balance_method to "NotModified"'
)
elif (
self.model_config.moe_backend == "TRTLLM"
and not self.model_config.mapping.enable_attention_dp
):
raise NotImplementedError(
'Not support replace routing method for moe_backend "TRTLLM" with attention TP,'
' please set balance_method to "NotModified"'
original_methods = []
dp_rank = self.model_config.mapping.rank // (
self.model_config.mapping.world_size // self.model_config.mapping.dp_size
)
for layer in self.layers:
moe_module = layer.mlp.experts
# Replace `routing_method.apply` for normal cases
apply_method_orig = moe_module.routing_method.apply
moe_module.routing_method.apply = make_balanced_routing_method(
moe_module,
apply_method_orig,
moe_module.num_experts,
balance_method,
balance_ratio,
self.model_config.mapping.dp_size,
dp_rank,
self.model_config.mapping.moe_ep_size,
)
apply_methods_orig = [layer.mlp.experts.routing_method.apply for layer in self.layers]
try:
for layer, apply_method_orig in zip(self.layers, apply_methods_orig):
layer.mlp.experts.routing_method.apply = make_balanced_routing_method(
apply_method_orig,
# Replace `run_moe` for TRTLLMGenFusedMoE TEP because it does not call `routing_method.apply`
if isinstance(moe_module, TRTLLMGenFusedMoE):
run_moe_orig = moe_module.run_moe
moe_module.run_moe = make_balanced_run_moe(
moe_module,
run_moe_orig,
layer.mlp.experts.routing_method.top_k,
layer.mlp.experts.num_experts,
balance_method,
balance_ratio,
layer.mlp.experts.ep_size,
layer.mlp.experts.ep_rank,
self.model_config.mapping.dp_size,
dp_rank,
self.model_config.mapping.moe_ep_size,
)
else:
run_moe_orig = None
# Replace `forward_impl` to ensure that routing results are replaced
forward_impl_orig = moe_module.forward_impl
moe_module.forward_impl = make_forward_impl_check(moe_module, forward_impl_orig)
original_methods.append((apply_method_orig, run_moe_orig, forward_impl_orig))
try:
yield
finally:
for layer, apply_method_orig in zip(self.layers, apply_methods_orig):
layer.mlp.experts.routing_method.apply = apply_method_orig
for layer, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip(
self.layers, original_methods
):
moe_module = layer.mlp.experts
moe_module.routing_method.apply = apply_method_orig
if isinstance(moe_module, TRTLLMGenFusedMoE):
moe_module.run_moe = run_moe_orig
moe_module.forward_impl = forward_impl_orig
@staticmethod
def create_kv_cache_manager(
@ -467,7 +640,7 @@ class RunnerMixin(ABC):
kv_cache_manager = kv_cache_manager_cls(
kv_cache_config,
CacheType.SELFKONLY,
num_layers=sum(layer_mask),
num_layers=num_layers,
num_kv_heads=1,
head_dim=model_config.pretrained_config.kv_lora_rank
+ model_config.pretrained_config.qk_rope_head_dim,

View File

@ -48,7 +48,6 @@ def test_deepseek_r1_ctx_tep(llm_root, world_size):
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
"--no-enable-attention-dp",
"--moe-backend=TRTLLM",
"--balance-method=NotModified",
],
cwd=llm_root / "examples" / "layer_wise_benchmarks",
env={
@ -139,7 +138,6 @@ def test_qwen3_next_gen_tep(llm_root, world_size):
"--layer-indices=6,7",
"--no-enable-attention-dp",
"--moe-backend=TRTLLM",
"--balance-method=NotModified",
],
cwd=llm_root / "examples" / "layer_wise_benchmarks",
env={