mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Layer-wise benchmarks: support TEP balance, polish slurm scripts (#10237)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
This commit is contained in:
parent
aaf80be0f3
commit
a7fe043b13
1
.gitignore
vendored
1
.gitignore
vendored
@ -77,6 +77,7 @@ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmha_v2_cu/
|
||||
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
|
||||
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
|
||||
.devcontainer/.env
|
||||
/examples/layer_wise_benchmarks/autotuner_cache/
|
||||
/examples/layer_wise_benchmarks/profiles/
|
||||
|
||||
# User config files
|
||||
|
||||
@ -15,6 +15,9 @@ pip install -e ../..
|
||||
**Step 3:** In the container, run benchmarks and generate profiles:
|
||||
|
||||
```bash
|
||||
# Set autotune cache path
|
||||
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
|
||||
|
||||
# Run DeepSeek-R1 NVFP4
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml
|
||||
@ -24,16 +27,16 @@ NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
|
||||
|
||||
# Run DeepSeek-V3.2-Exp with 32k context length
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --seq-len-kv-cache 32769
|
||||
|
||||
# Run with attention TP
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp
|
||||
|
||||
# Run with attention TP and TRTLLMGen
|
||||
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified
|
||||
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified
|
||||
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM
|
||||
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM
|
||||
|
||||
# Run with MTP3
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --batch-size 32 --seq-len-q 4
|
||||
@ -48,7 +51,7 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID
|
||||
# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
|
||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp
|
||||
|
||||
# Run Qwen3-Next (balanced routing is not implemented)
|
||||
# Run Qwen3-Next
|
||||
NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 4
|
||||
NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 512
|
||||
|
||||
@ -69,13 +72,13 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --balance-method ImbalancedExperts
|
||||
### Run with Slurm
|
||||
|
||||
> Tips:
|
||||
> 1. If you have a running Slurm job, please skip step 1 and go straight to step 2 and 3.
|
||||
> 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2 and run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_query_container_name.sh)` to get it.
|
||||
> 1. If you have a running Slurm job, you can set environment variable `export SLURM_JOB_ID=aaa` and skip step 1.
|
||||
> 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2. Just run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(./slurm_query_container_name.sh)` to get it.
|
||||
|
||||
**Step 1:** On the controller node, allocate one or multiple nodes, and record the `SLURM_JOB_ID`:
|
||||
**Step 1:** On the controller node, allocate one or multiple nodes, and export the `SLURM_JOB_ID`:
|
||||
|
||||
```bash
|
||||
SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh)
|
||||
export SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh)
|
||||
```
|
||||
|
||||
Please fill the variables in `./slurm_alloc.sh`.
|
||||
@ -83,28 +86,47 @@ Please fill the variables in `./slurm_alloc.sh`.
|
||||
**Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node:
|
||||
|
||||
```bash
|
||||
SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_init_containers.sh
|
||||
./slurm_init_containers.sh
|
||||
```
|
||||
|
||||
It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once.
|
||||
|
||||
> Tips: If you want to change the image, no need to reallocate Slurm jobs. Just start another container by running step 2 with `export CONTAINER_NAME=aaa`, and step 3 will run in the container specified by the `CONTAINER_NAME` env.
|
||||
|
||||
**(Optional) Get an interactive shell**
|
||||
|
||||
```bash
|
||||
NODES=1 NP=1 ./slurm_launch.sh --overlap --pty middleware/exclude_slurm_envs bash
|
||||
```
|
||||
|
||||
The `--overlap` option allows this shell to share the node with other jobs. The middleware enables nested MPI process spawning from within Slurm jobs.
|
||||
|
||||
You may compile C++ extensions in the interactive shell:
|
||||
|
||||
```bash
|
||||
cd ../..
|
||||
export CCACHE_DIR=$(realpath cpp/.ccache)
|
||||
python3 scripts/build_wheel.py --cuda_architectures native --no-venv --skip_building_wheel -G Ninja --use_ccache --clean
|
||||
```
|
||||
|
||||
**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes:
|
||||
|
||||
```bash
|
||||
# Set autotune cache path
|
||||
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
|
||||
|
||||
# Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable
|
||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
||||
NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
||||
|
||||
# Run with TRTLLMGen
|
||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM
|
||||
NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM
|
||||
|
||||
# Run with DeepEPLowLatency
|
||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
||||
NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
||||
|
||||
# You can run 4-GPU and 8-GPU tasks without reallocate the slurm job
|
||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml
|
||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gtx.yaml
|
||||
# You can run 4-GPU and 8-GPU tasks without reallocating the slurm job
|
||||
NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml
|
||||
NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gen.yaml
|
||||
```
|
||||
|
||||
### Batched run
|
||||
@ -131,9 +153,11 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID
|
||||
Run the following command in the container:
|
||||
|
||||
```bash
|
||||
# Parse the profile at the default directory
|
||||
python3 parse.py --world-size 4
|
||||
|
||||
# Specify the location of the .nsys-rep file
|
||||
# Specify the file path
|
||||
python3 parse.py --file-path profiles/report_np4_rank0.nsys-rep
|
||||
python3 parse.py --profile-dir ./profiles --world-size 4 --rank 0
|
||||
|
||||
# Parse a specific module. The module must appear exactly once in each run.
|
||||
@ -145,6 +169,15 @@ You will receive three reports, each containing kernel timing statistics grouped
|
||||
2. A CSV report at `profiles/report_np4_rank0.csv`
|
||||
3. An HTML report at `profiles/report_np4_rank0.html`
|
||||
|
||||
## Developer utilities
|
||||
|
||||
1. Less startup time when debug a model
|
||||
1. Disable autotuner: add `--no-enable-autotuner` option
|
||||
2. Disable nsys profile: set `PROFILE=0` environment variable
|
||||
2. Capture more information
|
||||
1. Enable GPU metrics: set `GPU_METRICS=1` environment variable
|
||||
2. Enable backtrace: set `BACKTRACE=1` environment variable
|
||||
|
||||
## Trouble shooting
|
||||
|
||||
1. Error `fp8 blockscale gemm only support Hopper` on Blackwell.
|
||||
|
||||
8
examples/layer_wise_benchmarks/middleware/exclude_slurm_envs
Executable file
8
examples/layer_wise_benchmarks/middleware/exclude_slurm_envs
Executable file
@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Clear slurm envs
|
||||
unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true))
|
||||
|
||||
"$@"
|
||||
@ -3,8 +3,12 @@
|
||||
set -euo pipefail
|
||||
|
||||
# Clear slurm envs
|
||||
unset $(env | grep -i slurm | awk -F'=' '{print $1}')
|
||||
unset $(env | grep MPI | awk -F'=' '{print $1}')
|
||||
unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true))
|
||||
|
||||
extra_args=
|
||||
if [ -v TLLM_AUTOTUNER_CACHE_PATH ]; then
|
||||
extra_args+="-x TLLM_AUTOTUNER_CACHE_PATH"
|
||||
fi
|
||||
|
||||
set -x
|
||||
mpirun --allow-run-as-root --np ${NP} "$@"
|
||||
mpirun --allow-run-as-root --np ${NP} $extra_args "$@"
|
||||
|
||||
@ -5,6 +5,7 @@ import json
|
||||
import re
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
@ -14,8 +15,9 @@ import pandas as pd
|
||||
|
||||
# Parse cmdline
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--file-path", type=str)
|
||||
parser.add_argument("--profile-dir", type=str, default="profiles")
|
||||
parser.add_argument("--world-size", "--np", type=int, required=True)
|
||||
parser.add_argument("--world-size", "--np", type=int)
|
||||
parser.add_argument("--rank", type=int, default=0)
|
||||
parser.add_argument("--warmup-times", type=int)
|
||||
parser.add_argument("--module", type=str)
|
||||
@ -27,6 +29,8 @@ group.add_argument(
|
||||
)
|
||||
parser.set_defaults(error_on_unknown_kernel=False)
|
||||
args = parser.parse_args()
|
||||
if (args.file_path is None) == (args.world_size is None):
|
||||
parser.error("Please specify exactly one of --file-path and --world-size.")
|
||||
print(args)
|
||||
|
||||
|
||||
@ -89,11 +93,20 @@ def shortest_common_supersequence(a, b):
|
||||
return res
|
||||
|
||||
|
||||
profile_dir = Path(args.profile_dir)
|
||||
nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep"
|
||||
sqlite_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.sqlite"
|
||||
csv_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.csv"
|
||||
html_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.html"
|
||||
if args.file_path is not None:
|
||||
nsys_rep_file_path = Path(args.file_path)
|
||||
if not nsys_rep_file_path.name.endswith(".nsys-rep"):
|
||||
raise ValueError("Expect a .nsys-rep file")
|
||||
else:
|
||||
profile_dir = Path(args.profile_dir)
|
||||
nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep"
|
||||
sqlite_file_path = nsys_rep_file_path.parent / (
|
||||
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".sqlite"
|
||||
)
|
||||
csv_file_path = nsys_rep_file_path.parent / (nsys_rep_file_path.name[: -len(".nsys-rep")] + ".csv")
|
||||
html_file_path = nsys_rep_file_path.parent / (
|
||||
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".html"
|
||||
)
|
||||
lazy_convert_sqlite(nsys_rep_file_path, sqlite_file_path)
|
||||
|
||||
conn = sqlite3.connect(f"file:{sqlite_file_path}?mode=ro", uri=True)
|
||||
@ -196,7 +209,8 @@ if "CUDA_GRAPH_NODE_EVENTS" in tables:
|
||||
R.start AS runtime_start, R.end AS runtime_end,
|
||||
CGE2.start AS capture_start, CGE2.end AS capture_end
|
||||
FROM ({unified_subquery}) AS unified
|
||||
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.correlationId = R.correlationId
|
||||
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.graphNodeId IS NOT NULL AND
|
||||
unified.correlationId = R.correlationId
|
||||
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND
|
||||
CGE1.originalGraphNodeId IS NOT NULL
|
||||
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId"""
|
||||
@ -318,6 +332,7 @@ parser_keywords = [
|
||||
("routingInitExpertCounts", "routingInitExpertCounts"),
|
||||
("routingIndicesCluster", "routingIndicesClusterKernel"),
|
||||
("routingIndicesCoop", "routingIndicesCoopKernel"),
|
||||
("router_gemm", "router_gemm_kernel"),
|
||||
("bmm_4_44_32", "bmm_E2m1_E2m1E2m1_Fp32_t"),
|
||||
("finalize", "finalize::finalizeKernel"),
|
||||
("bmm_16_44_32", "bmm_Bfloat16_E2m1E2m1_Fp32_"),
|
||||
@ -381,7 +396,7 @@ def parse_kernel_name(demangledName):
|
||||
if all(keyword in name for keyword in src):
|
||||
return dst
|
||||
if name not in warned_names:
|
||||
print(f"Unknown kernel name: {name}")
|
||||
print(f"Unknown kernel name: {name}", file=sys.stderr)
|
||||
warned_names.add(name)
|
||||
if args.error_on_unknown_kernel:
|
||||
raise NotImplementedError(f"Unknown kernel name: {name}")
|
||||
@ -491,11 +506,33 @@ with csv_file_path.open("w", newline="") as f:
|
||||
for row in csv_data:
|
||||
csv_writer.writerow(row)
|
||||
js_header_config = [{"name": problem["text"]} for problem in problem_set]
|
||||
js_header_config = []
|
||||
for problem in problem_set:
|
||||
innermost_children = js_header_config
|
||||
for k, msg_prefix in [
|
||||
("batch_size", "b="),
|
||||
("seq_len_q", "q="),
|
||||
("seq_len_kv_cache", "past="),
|
||||
]:
|
||||
if len(run_args[k + "_list"]) > 1:
|
||||
if len(innermost_children) == 0 or problem["spec"][k] != innermost_children[-1][k]:
|
||||
innermost_children.append(
|
||||
{
|
||||
"name": msg_prefix + str(problem["spec"][k]),
|
||||
"children": [],
|
||||
k: problem["spec"][k],
|
||||
}
|
||||
)
|
||||
innermost_children = innermost_children[-1]["children"]
|
||||
innermost_children.append({"name": problem["text"]})
|
||||
loader = jinja2.FileSystemLoader(Path(__file__).parent)
|
||||
template = jinja2.Environment(loader=loader).get_template("template.html")
|
||||
with html_file_path.open("w") as f:
|
||||
configText = (
|
||||
"Run:\n" + json.dumps(run_args, indent=4) + "\n\nParse:\n" + json.dumps(args.__dict__)
|
||||
"Run:\n"
|
||||
+ json.dumps(run_args, indent=4)
|
||||
+ "\n\nParse:\n"
|
||||
+ json.dumps(args.__dict__, indent=4)
|
||||
)
|
||||
f.write(template.render(headerConfig=js_header_config, rawData=js_data, configText=configText))
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import argparse
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import nvtx
|
||||
@ -9,6 +10,8 @@ import torch
|
||||
import yaml
|
||||
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType
|
||||
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
|
||||
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -108,17 +111,6 @@ if args.enable_attention_dp is None:
|
||||
args.enable_attention_dp = False
|
||||
if args.max_num_tokens is None:
|
||||
args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list)
|
||||
if args.run_type == "GEN":
|
||||
ctx_batch_size = max(1, max(20480, args.max_num_tokens) // max(args.seq_len_kv_cache_list))
|
||||
args.max_num_tokens = max(
|
||||
args.max_num_tokens, ctx_batch_size * max(args.seq_len_kv_cache_list)
|
||||
)
|
||||
else:
|
||||
if args.run_type == "GEN":
|
||||
ctx_batch_size = max(1, args.max_num_tokens // max(args.seq_len_kv_cache_list))
|
||||
assert args.max_num_tokens >= ctx_batch_size * max(args.seq_len_kv_cache_list), (
|
||||
"Max_num_tokens is too small to prefill KV cache"
|
||||
)
|
||||
if args.use_low_precision_moe_combine is None:
|
||||
args.use_low_precision_moe_combine = False
|
||||
if args.enable_autotuner is None:
|
||||
@ -169,17 +161,75 @@ runner = Runner(
|
||||
)
|
||||
logger.info("Layer-wise benchmarks: Create runner ... Done")
|
||||
|
||||
# Autotune
|
||||
run_pack = runner.create_run_pack(
|
||||
args.run_type,
|
||||
batch_size=max(args.batch_size_list),
|
||||
request_id_begin=0,
|
||||
seq_len_q=max(args.seq_len_q_list),
|
||||
seq_len_kv_cache=args.seq_len_kv_cache_list[0],
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
attn_workspace=attn_workspace,
|
||||
)
|
||||
if args.enable_autotuner:
|
||||
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
|
||||
with autotune(cache_path=cache_path):
|
||||
run_pack()
|
||||
else:
|
||||
run_pack()
|
||||
|
||||
# Prefill KV cache
|
||||
if args.run_type == "GEN":
|
||||
logger.info("Layer-wise benchmarks: Create runner for prefill")
|
||||
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
|
||||
ctx_batch_size = min(
|
||||
args.max_batch_size,
|
||||
max(1, 20480 // ctx_seq_len_q),
|
||||
)
|
||||
ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
|
||||
with mock.patch.object(
|
||||
CutlassFusedMoE, "select_alltoall_method_type", return_value=AlltoallMethodType.NotEnabled
|
||||
):
|
||||
ctx_runner = Runner(
|
||||
args.model,
|
||||
mapping,
|
||||
moe_backend="CUTLASS",
|
||||
layer_indices=args.layer_indices,
|
||||
scaled_from=args.scaled_from,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=ctx_batch_size * ctx_seq_len_q,
|
||||
moe_max_num_tokens=16384,
|
||||
use_low_precision_moe_combine=args.use_low_precision_moe_combine,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
logger.info("Layer-wise benchmarks: Create runner for prefill ... Done")
|
||||
|
||||
logger.info("Layer-wise benchmarks: Prefill KV cache")
|
||||
assert ctx_batch_size <= args.max_batch_size
|
||||
assert ctx_seq_len_q + 0 <= args.max_seq_len
|
||||
num_requests = max(args.batch_size_list)
|
||||
for request_id_begin in range(0, num_requests, ctx_batch_size):
|
||||
run_pack = ctx_runner.create_run_pack(
|
||||
"CTX",
|
||||
batch_size=min(ctx_batch_size, num_requests - request_id_begin),
|
||||
request_id_begin=request_id_begin,
|
||||
seq_len_q=ctx_seq_len_q,
|
||||
seq_len_kv_cache=0,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
attn_workspace=ctx_attn_workspace,
|
||||
)
|
||||
with ctx_runner.replace_routing_method_ctx(
|
||||
balance_method=BalanceMethod.Balanced, balance_ratio=None
|
||||
):
|
||||
run_pack(check=True)
|
||||
del ctx_runner
|
||||
del ctx_attn_workspace
|
||||
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
|
||||
|
||||
# Warm up
|
||||
for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
|
||||
[
|
||||
True,
|
||||
max(args.batch_size_list),
|
||||
max(args.seq_len_q_list),
|
||||
args.seq_len_kv_cache_list[0],
|
||||
args.balance_ratio_list[0],
|
||||
],
|
||||
logger.info("Layer-wise benchmarks: Warmup")
|
||||
for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
|
||||
*itertools.product(
|
||||
[False],
|
||||
args.batch_size_list,
|
||||
args.seq_len_q_list,
|
||||
args.seq_len_kv_cache_list,
|
||||
@ -203,34 +253,10 @@ for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
|
||||
):
|
||||
capture_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(capture_stream):
|
||||
if autotune_flag:
|
||||
if args.enable_autotuner:
|
||||
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
|
||||
with autotune(cache_path=cache_path):
|
||||
run_pack()
|
||||
if args.run_type == "GEN":
|
||||
logger.info("Layer-wise benchmarks: Prefill KV cache")
|
||||
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
|
||||
assert ctx_batch_size <= args.max_batch_size
|
||||
assert ctx_seq_len_q + 0 <= args.max_seq_len
|
||||
assert ctx_batch_size * ctx_seq_len_q <= args.max_num_tokens
|
||||
max_batch_size = max(args.batch_size_list)
|
||||
for request_id_begin in range(0, max_batch_size, ctx_batch_size):
|
||||
ctx_run_pack = runner.create_run_pack(
|
||||
"CTX",
|
||||
batch_size=min(ctx_batch_size, max_batch_size - request_id_begin),
|
||||
request_id_begin=request_id_begin,
|
||||
seq_len_q=ctx_seq_len_q,
|
||||
seq_len_kv_cache=0,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
attn_workspace=attn_workspace,
|
||||
)
|
||||
ctx_run_pack(check=True)
|
||||
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
|
||||
else:
|
||||
run_pack(check=True)
|
||||
run_pack(check=True)
|
||||
torch.cuda.current_stream().wait_stream(capture_stream)
|
||||
torch.cuda.synchronize()
|
||||
logger.info("Layer-wise benchmarks: Warmup ... Done")
|
||||
|
||||
events = [
|
||||
torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1)
|
||||
@ -270,7 +296,7 @@ for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product(
|
||||
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
|
||||
run_pack()
|
||||
|
||||
balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}"
|
||||
balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}"
|
||||
nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}"
|
||||
for i in range(args.warmup_times + args.run_times):
|
||||
events[i].record()
|
||||
|
||||
@ -40,7 +40,9 @@ else
|
||||
PROFILE_CMD=
|
||||
fi
|
||||
|
||||
SCRIPT_PATH=$(realpath --relative-to="$(pwd)" "$(dirname -- "$0")"/run.py)
|
||||
|
||||
set -x
|
||||
$PROFILE_CMD bash -c \
|
||||
"python3 -u run.py \"\$@\" 2>&1 | tee \"$PROFILE_DIR/report_np${WORLD_SIZE}_rank${RANK}.log\"" \
|
||||
bash "$@"
|
||||
$PROFILE_CMD bash -o pipefail -c \
|
||||
"python3 -u \"\$1\" \"\${@:3}\" 2>&1 | tee \"\$2/report_np${WORLD_SIZE}_rank${RANK}.log\"" \
|
||||
bash "$SCRIPT_PATH" "$PROFILE_DIR" "$@"
|
||||
|
||||
@ -4,23 +4,23 @@ set -euo pipefail
|
||||
|
||||
# CONTAINER_IMAGE=
|
||||
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
|
||||
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..")
|
||||
TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..)
|
||||
CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT
|
||||
|
||||
if [ "${SLURM_JOB_ID:-}" == "" ]; then
|
||||
if [ -z "${SLURM_JOB_ID:-}" ]; then
|
||||
echo "Please set SLURM_JOB_ID"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
NODES=$(squeue -j $SLURM_JOB_ID -h -o "%D")
|
||||
|
||||
if [ "${CONTAINER_IMAGE:-}" == "" ]; then
|
||||
if [ -z "${CONTAINER_IMAGE:-}" ]; then
|
||||
# Read Docker image from current_image_tags.properties
|
||||
source ../../jenkins/current_image_tags.properties
|
||||
MACHINE="$(uname -m)"
|
||||
if [ "$MACHINE" == "x86_64" ]; then
|
||||
DOCKER_IMAGE=$LLM_DOCKER_IMAGE
|
||||
DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_DOCKER_IMAGE)
|
||||
elif [ "$MACHINE" == "aarch64" ]; then
|
||||
DOCKER_IMAGE=$LLM_SBSA_DOCKER_IMAGE
|
||||
DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_SBSA_DOCKER_IMAGE)
|
||||
else
|
||||
echo "Unsupported machine hardware name \"$MACHINE\""
|
||||
exit 1
|
||||
@ -34,7 +34,7 @@ if [ "${CONTAINER_IMAGE:-}" == "" ]; then
|
||||
SQSH_FILE_NAME=$(echo "$DOCKER_IMAGE" |
|
||||
awk -F'#' '{print $2}' |
|
||||
awk -F':' '{gsub(/\//,"+",$1); print $1"+"$2".sqsh"}')
|
||||
CONTAINER_IMAGE="../../enroot/$SQSH_FILE_NAME"
|
||||
CONTAINER_IMAGE="$TRTLLM_ROOT/enroot/$SQSH_FILE_NAME"
|
||||
if [ ! -f "$CONTAINER_IMAGE" ]; then
|
||||
echo "Container image file $CONTAINER_IMAGE does not exist, importing ..."
|
||||
srun -N 1 --pty enroot import -o "$CONTAINER_IMAGE" "docker://$DOCKER_IMAGE"
|
||||
@ -50,6 +50,7 @@ srun -N "$NODES" \
|
||||
--container-name "$CONTAINER_NAME" \
|
||||
--container-mounts "$CONTAINER_MOUNTS" \
|
||||
--container-workdir "$WORKDIR" \
|
||||
bash -c "pip install -U packaging &&
|
||||
pip install -r ../../requirements.txt --no-build-isolation &&
|
||||
pip install -e ../.."
|
||||
bash -c "cd \"\$1\" &&
|
||||
pip install -U packaging &&
|
||||
pip install -r requirements.txt --no-build-isolation &&
|
||||
pip install -e ." bash "$TRTLLM_ROOT"
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
set -euo pipefail
|
||||
|
||||
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
|
||||
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..")
|
||||
TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..)
|
||||
CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT
|
||||
|
||||
if [ "${SLURM_JOB_ID:-}" == "" ]; then
|
||||
if [ -z "${SLURM_JOB_ID:-}" ]; then
|
||||
echo "Please set SLURM_JOB_ID"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@ -2,6 +2,11 @@
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${SLURM_JOB_ID:-}" ]; then
|
||||
echo "Please set SLURM_JOB_ID"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
prefix="pyxis_${SLURM_JOB_ID}_"
|
||||
matches=$(printf "%s\n" "$(srun -N 1 enroot list)" | grep "^${prefix}" || true)
|
||||
count=$(printf "%s\n" "$matches" | wc -l)
|
||||
|
||||
@ -313,6 +313,7 @@
|
||||
const columnLabels = [];
|
||||
let columnTotals = [];
|
||||
let totalNode = null; // Special node for the "Total" row
|
||||
let numCols = null;
|
||||
|
||||
const fmt = (num) => num.toFixed(1);
|
||||
|
||||
@ -332,7 +333,7 @@
|
||||
const topLevelTimes = rawData.map(aggregate);
|
||||
|
||||
// Calculate Totals
|
||||
const numCols = topLevelTimes[0].length;
|
||||
numCols = topLevelTimes[0].length;
|
||||
columnTotals = new Array(numCols).fill(0);
|
||||
for (let c = 0; c < numCols; c++) {
|
||||
columnTotals[c] = topLevelTimes.reduce((sum, row) => sum + row[c], 0);
|
||||
@ -420,7 +421,8 @@
|
||||
const total = columnTotals[colIdx] || 1;
|
||||
const pct = (val / total) * 100;
|
||||
|
||||
html += `<td class="data-cell col-${colIdx}" onclick="handleCellClick('${uniqueId}', ${colIdx})">
|
||||
html += `<td class="data-cell col-${colIdx}" ${val ? "title=\"" + fmt(val) + "\"" : ""}
|
||||
onclick="handleCellClick('${uniqueId}', ${colIdx})">
|
||||
<div class="progress-bar" style="width: ${pct}%;"></div>
|
||||
<div class="cell-content">${val ? fmt(val) : ""}</div>
|
||||
</td>`;
|
||||
@ -717,6 +719,7 @@
|
||||
});
|
||||
|
||||
window.onload = () => {
|
||||
document.getElementById('kernelTable').style["min-width"] = Math.max(1000, (220 + 60 * numCols)) + "px";
|
||||
renderHeader();
|
||||
document.getElementById('tableBody').innerHTML = renderTableBody(rawData, 0);
|
||||
renderTotalRow();
|
||||
@ -741,7 +744,6 @@
|
||||
for (let node of rawData) {
|
||||
maxDepth = Math.max(maxDepth, getDepth(node));
|
||||
}
|
||||
console.log(maxDepth);
|
||||
|
||||
const container = document.getElementById('level-buttons');
|
||||
if (container) {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import nvtx
|
||||
|
||||
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate
|
||||
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate, Deepseekv3MoE
|
||||
from tensorrt_llm._torch.models.modeling_qwen3_next import (
|
||||
Qwen3NextGatedDeltaNet,
|
||||
Qwen3NextSparseMoeBlock,
|
||||
@ -12,6 +12,7 @@ from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||
|
||||
def mark_ranges():
|
||||
DeepseekV3Gate.forward = nvtx.annotate("DeepseekV3Gate")(DeepseekV3Gate.forward)
|
||||
Deepseekv3MoE.forward = nvtx.annotate("Deepseekv3MoE")(Deepseekv3MoE.forward)
|
||||
Qwen3NextGatedDeltaNet.forward = nvtx.annotate("Qwen3NextGatedDeltaNet")(
|
||||
Qwen3NextGatedDeltaNet.forward
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
import unittest.mock
|
||||
import weakref
|
||||
@ -23,6 +24,7 @@ from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs
|
||||
from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.bindings.internal.batch_manager import CacheType
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
@ -38,24 +40,16 @@ def round_up(a, b):
|
||||
|
||||
|
||||
def get_balanced_selection_no_cache(
|
||||
num_tokens, top_k, num_experts, dtype, device, world_size, rank
|
||||
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
||||
):
|
||||
# First, each sender selects target rank
|
||||
target_rank_before_mod = torch.arange(num_tokens * world_size * top_k).view(
|
||||
num_tokens, world_size, top_k
|
||||
token_id = torch.arange(dp_rank * num_tokens * top_k, (dp_rank + 1) * num_tokens * top_k).view(
|
||||
num_tokens, top_k
|
||||
)
|
||||
target_rank_before_mod += top_k * torch.arange(num_tokens).view(
|
||||
num_tokens, 1, 1
|
||||
) # Shift `top_k` ranks for the next token on each rank, to balance network traffic
|
||||
target_rank = target_rank_before_mod % world_size
|
||||
# Second, each receiver selects target expert
|
||||
target_expert = torch.empty_like(target_rank)
|
||||
for reciever_rank in range(world_size):
|
||||
mask = target_rank == reciever_rank
|
||||
experts_per_rank = num_experts // world_size
|
||||
local_expert = torch.arange(num_tokens * top_k) % experts_per_rank
|
||||
target_expert[mask] = (reciever_rank * experts_per_rank) + local_expert
|
||||
token_selected_experts = target_expert[:, rank].sort(dim=-1).values
|
||||
experts_per_rank = num_experts // ep_size
|
||||
token_selected_experts = (token_id % ep_size) * experts_per_rank + (
|
||||
token_id // ep_size
|
||||
) % experts_per_rank
|
||||
token_selected_experts = token_selected_experts.sort(dim=-1).values
|
||||
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
|
||||
|
||||
|
||||
@ -64,47 +58,62 @@ get_balanced_selection = functools.cache(get_balanced_selection_no_cache)
|
||||
|
||||
def test_get_balanced_selection():
|
||||
dtype = torch.long
|
||||
for num_tokens in range(1, 33):
|
||||
for num_experts in range(1, 65):
|
||||
print(f"{num_tokens=} {num_experts=}")
|
||||
for top_k in range(1, min(11, num_experts)):
|
||||
for world_size in range(1, 65):
|
||||
if num_experts % world_size == 0:
|
||||
tokens_per_expert = torch.zeros(num_experts)
|
||||
for rank in range(world_size):
|
||||
token_selected_experts = get_balanced_selection_no_cache(
|
||||
num_tokens, top_k, num_experts, dtype, "cpu", world_size, rank
|
||||
for num_tokens, num_experts, enable_attention_dp in itertools.product(
|
||||
range(1, 35), range(1, 35), [False, True]
|
||||
):
|
||||
print(f"{num_tokens=} {num_experts=} {enable_attention_dp=}")
|
||||
for top_k in range(1, min(10, num_experts) + 1):
|
||||
for world_size in range(1, 35):
|
||||
dp_size = world_size if enable_attention_dp else 1
|
||||
ep_size = world_size
|
||||
if num_experts % ep_size == 0:
|
||||
tokens_per_expert = torch.zeros(num_experts)
|
||||
for dp_rank in range(dp_size):
|
||||
token_selected_experts = get_balanced_selection_no_cache(
|
||||
num_tokens, top_k, num_experts, dtype, "cpu", dp_size, dp_rank, ep_size
|
||||
)
|
||||
sorted_selection = token_selected_experts.sort(dim=-1).values
|
||||
if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any():
|
||||
raise ValueError(f"duplicated experts on rank {dp_rank}")
|
||||
experts_per_rank = num_experts // ep_size
|
||||
tokens_per_rank = (
|
||||
(token_selected_experts // experts_per_rank)
|
||||
.view(-1)
|
||||
.bincount(minlength=ep_size)
|
||||
)
|
||||
if tokens_per_rank.max() - tokens_per_rank.min() > 1:
|
||||
raise ValueError(f"tokens sent from rank {dp_rank} is not balanced")
|
||||
unique_tokens_per_rank = (
|
||||
(
|
||||
torch.arange(ep_size).view(ep_size, 1, 1)
|
||||
== token_selected_experts // experts_per_rank
|
||||
)
|
||||
sorted_selection = token_selected_experts.sort(dim=-1).values
|
||||
if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any():
|
||||
raise ValueError(f"duplicated experts on rank {rank}")
|
||||
experts_per_rank = num_experts // world_size
|
||||
tokens_per_rank = (
|
||||
(token_selected_experts // experts_per_rank)
|
||||
.view(-1)
|
||||
.bincount(minlength=world_size)
|
||||
.any(dim=2)
|
||||
.sum(dim=1)
|
||||
)
|
||||
if unique_tokens_per_rank.max() - unique_tokens_per_rank.min() > 1:
|
||||
raise ValueError(
|
||||
f"tokens sent from rank {dp_rank} is not balanced after removing duplicates"
|
||||
)
|
||||
if tokens_per_rank.max() - tokens_per_rank.min() > 1:
|
||||
raise ValueError(f"tokens sent from rank {rank} is not balanced")
|
||||
tokens_per_expert += token_selected_experts.view(-1).bincount(
|
||||
minlength=num_experts
|
||||
)
|
||||
if tokens_per_expert.max() - tokens_per_expert.min() > 1:
|
||||
raise ValueError("tokens per expert is not balanced")
|
||||
tokens_per_expert += token_selected_experts.view(-1).bincount(
|
||||
minlength=num_experts
|
||||
)
|
||||
if tokens_per_expert.max() - tokens_per_expert.min() > 1:
|
||||
raise ValueError("tokens per expert is not balanced")
|
||||
|
||||
|
||||
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank):
|
||||
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size):
|
||||
num_tokens, top_k = imbalanced_experts.shape
|
||||
dtype = imbalanced_experts.dtype
|
||||
device = imbalanced_experts.device
|
||||
balanced_experts = get_balanced_selection_no_cache(
|
||||
num_tokens, top_k, num_experts, dtype, device, world_size, rank
|
||||
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
||||
)
|
||||
if balance_ratio == 0.0:
|
||||
num_balanced_tokens = 0
|
||||
else:
|
||||
# Activate all experts
|
||||
min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, world_size * top_k))
|
||||
min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, dp_size * top_k))
|
||||
num_balanced_tokens = min_num_balanced_tokens + round(
|
||||
(num_tokens - min_num_balanced_tokens) * balance_ratio
|
||||
)
|
||||
@ -116,57 +125,78 @@ def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_si
|
||||
|
||||
@functools.cache
|
||||
def get_all_to_one_selection(
|
||||
num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank
|
||||
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
|
||||
):
|
||||
experts_per_rank = num_experts // world_size
|
||||
experts_per_rank = num_experts // ep_size
|
||||
if top_k > experts_per_rank:
|
||||
raise ValueError(
|
||||
"Cannot send all tokens to a single rank because `top_k > experts_per_rank`"
|
||||
)
|
||||
imbalanced_experts = (
|
||||
torch.arange(
|
||||
rank * num_tokens * top_k, (rank + 1) * num_tokens * top_k, dtype=dtype, device=device
|
||||
dp_rank * num_tokens * top_k,
|
||||
(dp_rank + 1) * num_tokens * top_k,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
).view(num_tokens, top_k)
|
||||
% experts_per_rank
|
||||
)
|
||||
imbalanced_experts = imbalanced_experts.sort(dim=-1).values
|
||||
return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank)
|
||||
return apply_balance_ratio(
|
||||
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_balanced_rank_imbalanced_expert_selection(
|
||||
num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank
|
||||
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
|
||||
):
|
||||
experts_per_rank = num_experts // world_size
|
||||
active_experts_per_rank = ceil_div(top_k, world_size)
|
||||
# Select expert from [0, active_experts_per_rank * world_size),
|
||||
# then scale to [0, experts_per_rank * world_size)
|
||||
experts_per_rank = num_experts // ep_size
|
||||
active_experts_per_rank = ceil_div(top_k, ep_size)
|
||||
# Select expert from [0, active_experts_per_rank * ep_size),
|
||||
# then scale to [0, experts_per_rank * ep_size)
|
||||
narrow_experts = get_balanced_selection_no_cache(
|
||||
num_tokens, top_k, active_experts_per_rank * world_size, dtype, device, world_size, rank
|
||||
num_tokens,
|
||||
top_k,
|
||||
active_experts_per_rank * ep_size,
|
||||
dtype,
|
||||
device,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
ep_size,
|
||||
)
|
||||
imbalanced_experts = (
|
||||
narrow_experts // active_experts_per_rank * experts_per_rank
|
||||
+ narrow_experts % active_experts_per_rank
|
||||
)
|
||||
return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank)
|
||||
return apply_balance_ratio(
|
||||
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
|
||||
)
|
||||
|
||||
|
||||
def make_balanced_routing_method(
|
||||
apply_method_orig, num_experts, balance_method, balance_ratio, world_size, rank
|
||||
moe_module,
|
||||
apply_method_orig,
|
||||
num_experts,
|
||||
balance_method,
|
||||
balance_ratio,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
ep_size,
|
||||
):
|
||||
def balanced_routing_method(router_logits):
|
||||
token_selected_experts, token_final_scales = apply_method_orig(router_logits)
|
||||
if balance_method == BalanceMethod.NotModified:
|
||||
pass
|
||||
elif balance_method == BalanceMethod.Balanced:
|
||||
assert moe_module._routing_results_replaced_at in [None, "make_balanced_routing_method"]
|
||||
if balance_method == BalanceMethod.Balanced:
|
||||
token_selected_experts = get_balanced_selection(
|
||||
token_selected_experts.shape[0],
|
||||
token_selected_experts.shape[1],
|
||||
num_experts,
|
||||
token_selected_experts.dtype,
|
||||
token_selected_experts.device,
|
||||
world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
ep_size,
|
||||
)
|
||||
elif balance_method == BalanceMethod.ImbalancedRanks:
|
||||
token_selected_experts = get_all_to_one_selection(
|
||||
@ -176,8 +206,9 @@ def make_balanced_routing_method(
|
||||
balance_ratio,
|
||||
token_selected_experts.dtype,
|
||||
token_selected_experts.device,
|
||||
world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
ep_size,
|
||||
)
|
||||
elif balance_method == BalanceMethod.ImbalancedExperts:
|
||||
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
|
||||
@ -187,16 +218,127 @@ def make_balanced_routing_method(
|
||||
balance_ratio,
|
||||
token_selected_experts.dtype,
|
||||
token_selected_experts.device,
|
||||
world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
ep_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Not support balance_method {balance_method}")
|
||||
moe_module._routing_results_replaced_at = "make_balanced_routing_method"
|
||||
return token_selected_experts, token_final_scales
|
||||
|
||||
return balanced_routing_method
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_token_final_scales(shape, device):
|
||||
return torch.full(shape, 1.0 / shape[-1], dtype=torch.bfloat16, device=device)
|
||||
|
||||
|
||||
def make_balanced_run_moe(
|
||||
moe_module,
|
||||
run_moe_orig,
|
||||
top_k,
|
||||
num_experts,
|
||||
balance_method,
|
||||
balance_ratio,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
ep_size,
|
||||
):
|
||||
def balanced_run_moe(
|
||||
x, token_selected_experts, token_final_scales, x_sf, router_logits, do_finalize, moe_output
|
||||
):
|
||||
if moe_module._routing_results_replaced_at is not None:
|
||||
return run_moe_orig(
|
||||
x,
|
||||
token_selected_experts,
|
||||
token_final_scales,
|
||||
x_sf,
|
||||
router_logits,
|
||||
do_finalize,
|
||||
moe_output,
|
||||
)
|
||||
logger.warning_once(
|
||||
'Layer-wise benchmarks: Specifying routing results of "TRTLLM" MoE backend in TEP cases leads to different'
|
||||
" execution path around the topk kernel",
|
||||
key="replace_routing_method_ctx_trtllm_tp",
|
||||
)
|
||||
if balance_method == BalanceMethod.Balanced:
|
||||
token_selected_experts = get_balanced_selection(
|
||||
x.shape[0],
|
||||
top_k,
|
||||
num_experts,
|
||||
torch.int32,
|
||||
x.device,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
ep_size,
|
||||
)
|
||||
elif balance_method == BalanceMethod.ImbalancedRanks:
|
||||
token_selected_experts = get_all_to_one_selection(
|
||||
x.shape[0],
|
||||
top_k,
|
||||
num_experts,
|
||||
balance_ratio,
|
||||
torch.int32,
|
||||
x.device,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
ep_size,
|
||||
)
|
||||
elif balance_method == BalanceMethod.ImbalancedExperts:
|
||||
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
|
||||
x.shape[0],
|
||||
top_k,
|
||||
num_experts,
|
||||
balance_ratio,
|
||||
torch.int32,
|
||||
x.device,
|
||||
dp_size,
|
||||
dp_rank,
|
||||
ep_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Not support balance_method {balance_method}")
|
||||
token_final_scales = get_token_final_scales(
|
||||
token_selected_experts.shape, token_selected_experts.device
|
||||
)
|
||||
router_logits = None
|
||||
final_hidden_states = run_moe_orig(
|
||||
x,
|
||||
token_selected_experts,
|
||||
token_final_scales,
|
||||
x_sf,
|
||||
router_logits,
|
||||
do_finalize,
|
||||
moe_output,
|
||||
)
|
||||
if not do_finalize:
|
||||
final_hidden_states = (
|
||||
final_hidden_states[0],
|
||||
token_final_scales, # WAR for TRTLLMGenFusedMoE bug that it returns wrong `token_final_scales`
|
||||
final_hidden_states[2],
|
||||
)
|
||||
moe_module._routing_results_replaced_at = "make_balanced_run_moe"
|
||||
return final_hidden_states
|
||||
|
||||
return balanced_run_moe
|
||||
|
||||
|
||||
def make_forward_impl_check(moe_module, forward_impl_orig):
|
||||
def forward_impl(*args, **kwargs):
|
||||
moe_module._routing_results_replaced_at = None
|
||||
res = forward_impl_orig(*args, **kwargs)
|
||||
assert moe_module._routing_results_replaced_at is not None, (
|
||||
"Routing results are not replaced"
|
||||
)
|
||||
del moe_module._routing_results_replaced_at
|
||||
return res
|
||||
|
||||
return forward_impl
|
||||
|
||||
|
||||
class RunnerMixin(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
@ -398,8 +540,9 @@ class RunnerMixin(ABC):
|
||||
@contextlib.contextmanager
|
||||
def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float):
|
||||
if balance_method == BalanceMethod.NotModified:
|
||||
pass
|
||||
elif self.model_config.moe_backend not in [
|
||||
yield
|
||||
return
|
||||
if self.model_config.moe_backend not in [
|
||||
"CUTEDSL",
|
||||
"CUTLASS",
|
||||
"DEEPGEMM",
|
||||
@ -410,29 +553,59 @@ class RunnerMixin(ABC):
|
||||
f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",'
|
||||
f' please set balance_method to "NotModified"'
|
||||
)
|
||||
elif (
|
||||
self.model_config.moe_backend == "TRTLLM"
|
||||
and not self.model_config.mapping.enable_attention_dp
|
||||
):
|
||||
raise NotImplementedError(
|
||||
'Not support replace routing method for moe_backend "TRTLLM" with attention TP,'
|
||||
' please set balance_method to "NotModified"'
|
||||
original_methods = []
|
||||
dp_rank = self.model_config.mapping.rank // (
|
||||
self.model_config.mapping.world_size // self.model_config.mapping.dp_size
|
||||
)
|
||||
for layer in self.layers:
|
||||
moe_module = layer.mlp.experts
|
||||
|
||||
# Replace `routing_method.apply` for normal cases
|
||||
apply_method_orig = moe_module.routing_method.apply
|
||||
moe_module.routing_method.apply = make_balanced_routing_method(
|
||||
moe_module,
|
||||
apply_method_orig,
|
||||
moe_module.num_experts,
|
||||
balance_method,
|
||||
balance_ratio,
|
||||
self.model_config.mapping.dp_size,
|
||||
dp_rank,
|
||||
self.model_config.mapping.moe_ep_size,
|
||||
)
|
||||
apply_methods_orig = [layer.mlp.experts.routing_method.apply for layer in self.layers]
|
||||
try:
|
||||
for layer, apply_method_orig in zip(self.layers, apply_methods_orig):
|
||||
layer.mlp.experts.routing_method.apply = make_balanced_routing_method(
|
||||
apply_method_orig,
|
||||
|
||||
# Replace `run_moe` for TRTLLMGenFusedMoE TEP because it does not call `routing_method.apply`
|
||||
if isinstance(moe_module, TRTLLMGenFusedMoE):
|
||||
run_moe_orig = moe_module.run_moe
|
||||
moe_module.run_moe = make_balanced_run_moe(
|
||||
moe_module,
|
||||
run_moe_orig,
|
||||
layer.mlp.experts.routing_method.top_k,
|
||||
layer.mlp.experts.num_experts,
|
||||
balance_method,
|
||||
balance_ratio,
|
||||
layer.mlp.experts.ep_size,
|
||||
layer.mlp.experts.ep_rank,
|
||||
self.model_config.mapping.dp_size,
|
||||
dp_rank,
|
||||
self.model_config.mapping.moe_ep_size,
|
||||
)
|
||||
else:
|
||||
run_moe_orig = None
|
||||
|
||||
# Replace `forward_impl` to ensure that routing results are replaced
|
||||
forward_impl_orig = moe_module.forward_impl
|
||||
moe_module.forward_impl = make_forward_impl_check(moe_module, forward_impl_orig)
|
||||
|
||||
original_methods.append((apply_method_orig, run_moe_orig, forward_impl_orig))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for layer, apply_method_orig in zip(self.layers, apply_methods_orig):
|
||||
layer.mlp.experts.routing_method.apply = apply_method_orig
|
||||
for layer, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip(
|
||||
self.layers, original_methods
|
||||
):
|
||||
moe_module = layer.mlp.experts
|
||||
moe_module.routing_method.apply = apply_method_orig
|
||||
if isinstance(moe_module, TRTLLMGenFusedMoE):
|
||||
moe_module.run_moe = run_moe_orig
|
||||
moe_module.forward_impl = forward_impl_orig
|
||||
|
||||
@staticmethod
|
||||
def create_kv_cache_manager(
|
||||
@ -467,7 +640,7 @@ class RunnerMixin(ABC):
|
||||
kv_cache_manager = kv_cache_manager_cls(
|
||||
kv_cache_config,
|
||||
CacheType.SELFKONLY,
|
||||
num_layers=sum(layer_mask),
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=1,
|
||||
head_dim=model_config.pretrained_config.kv_lora_rank
|
||||
+ model_config.pretrained_config.qk_rope_head_dim,
|
||||
|
||||
@ -48,7 +48,6 @@ def test_deepseek_r1_ctx_tep(llm_root, world_size):
|
||||
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
|
||||
"--no-enable-attention-dp",
|
||||
"--moe-backend=TRTLLM",
|
||||
"--balance-method=NotModified",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
@ -139,7 +138,6 @@ def test_qwen3_next_gen_tep(llm_root, world_size):
|
||||
"--layer-indices=6,7",
|
||||
"--no-enable-attention-dp",
|
||||
"--moe-backend=TRTLLM",
|
||||
"--balance-method=NotModified",
|
||||
],
|
||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||
env={
|
||||
|
||||
Loading…
Reference in New Issue
Block a user