mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Layer-wise benchmarks: support TEP balance, polish slurm scripts (#10237)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
This commit is contained in:
parent
aaf80be0f3
commit
a7fe043b13
1
.gitignore
vendored
1
.gitignore
vendored
@ -77,6 +77,7 @@ cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmha_v2_cu/
|
|||||||
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
|
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.h
|
||||||
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
|
cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_cubin.cpp
|
||||||
.devcontainer/.env
|
.devcontainer/.env
|
||||||
|
/examples/layer_wise_benchmarks/autotuner_cache/
|
||||||
/examples/layer_wise_benchmarks/profiles/
|
/examples/layer_wise_benchmarks/profiles/
|
||||||
|
|
||||||
# User config files
|
# User config files
|
||||||
|
|||||||
@ -15,6 +15,9 @@ pip install -e ../..
|
|||||||
**Step 3:** In the container, run benchmarks and generate profiles:
|
**Step 3:** In the container, run benchmarks and generate profiles:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Set autotune cache path
|
||||||
|
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
|
||||||
|
|
||||||
# Run DeepSeek-R1 NVFP4
|
# Run DeepSeek-R1 NVFP4
|
||||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml
|
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml
|
||||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml
|
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml
|
||||||
@ -24,16 +27,16 @@ NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-
|
|||||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
|
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
|
||||||
|
|
||||||
# Run DeepSeek-V3.2-Exp with 32k context length
|
# Run DeepSeek-V3.2-Exp with 32k context length
|
||||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
|
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
|
||||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769
|
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM --seq-len-kv-cache 32769
|
||||||
|
|
||||||
# Run with attention TP
|
# Run with attention TP
|
||||||
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp
|
NP=4 ./mpi_launch.sh ./run.sh config_ctx.yaml --no-enable-attention-dp
|
||||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp
|
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --no-enable-attention-dp
|
||||||
|
|
||||||
# Run with attention TP and TRTLLMGen
|
# Run with attention TP and TRTLLMGen
|
||||||
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified
|
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM
|
||||||
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM --balance-method NotModified
|
NP=4 ./mpi_launch.sh -x TRTLLM_ENABLE_PDL=1 ./run.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM
|
||||||
|
|
||||||
# Run with MTP3
|
# Run with MTP3
|
||||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --batch-size 32 --seq-len-q 4
|
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --batch-size 32 --seq-len-q 4
|
||||||
@ -48,7 +51,7 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID
|
|||||||
# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
|
# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
|
||||||
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp
|
NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp
|
||||||
|
|
||||||
# Run Qwen3-Next (balanced routing is not implemented)
|
# Run Qwen3-Next
|
||||||
NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 4
|
NP=2 ./mpi_launch.sh ./run.sh config_ctx.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 4
|
||||||
NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 512
|
NP=2 ./mpi_launch.sh ./run.sh config_gen.yaml --model Qwen/Qwen3-Next-80B-A3B-Instruct --layer-indices 6,7 --no-enable-attention-dp --batch-size 512
|
||||||
|
|
||||||
@ -69,13 +72,13 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --balance-method ImbalancedExperts
|
|||||||
### Run with Slurm
|
### Run with Slurm
|
||||||
|
|
||||||
> Tips:
|
> Tips:
|
||||||
> 1. If you have a running Slurm job, please skip step 1 and go straight to step 2 and 3.
|
> 1. If you have a running Slurm job, you can set environment variable `export SLURM_JOB_ID=aaa` and skip step 1.
|
||||||
> 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2 and run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_query_container_name.sh)` to get it.
|
> 2. Further, if you have installed `tensorrt_llm` in the Slurm job, you can also skip step 2. Just run step 3 with `export CONTAINER_NAME=aaa` specified. If you don't know the container name, run `export CONTAINER_NAME=$(./slurm_query_container_name.sh)` to get it.
|
||||||
|
|
||||||
**Step 1:** On the controller node, allocate one or multiple nodes, and record the `SLURM_JOB_ID`:
|
**Step 1:** On the controller node, allocate one or multiple nodes, and export the `SLURM_JOB_ID`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh)
|
export SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh)
|
||||||
```
|
```
|
||||||
|
|
||||||
Please fill the variables in `./slurm_alloc.sh`.
|
Please fill the variables in `./slurm_alloc.sh`.
|
||||||
@ -83,28 +86,47 @@ Please fill the variables in `./slurm_alloc.sh`.
|
|||||||
**Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node:
|
**Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_init_containers.sh
|
./slurm_init_containers.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once.
|
It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once.
|
||||||
|
|
||||||
> Tips: If you want to change the image, no need to reallocate Slurm jobs. Just start another container by running step 2 with `export CONTAINER_NAME=aaa`, and step 3 will run in the container specified by the `CONTAINER_NAME` env.
|
> Tips: If you want to change the image, no need to reallocate Slurm jobs. Just start another container by running step 2 with `export CONTAINER_NAME=aaa`, and step 3 will run in the container specified by the `CONTAINER_NAME` env.
|
||||||
|
|
||||||
|
**(Optional) Get an interactive shell**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
NODES=1 NP=1 ./slurm_launch.sh --overlap --pty middleware/exclude_slurm_envs bash
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--overlap` option allows this shell to share the node with other jobs. The middleware enables nested MPI process spawning from within Slurm jobs.
|
||||||
|
|
||||||
|
You may compile C++ extensions in the interactive shell:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ../..
|
||||||
|
export CCACHE_DIR=$(realpath cpp/.ccache)
|
||||||
|
python3 scripts/build_wheel.py --cuda_architectures native --no-venv --skip_building_wheel -G Ninja --use_ccache --clean
|
||||||
|
```
|
||||||
|
|
||||||
**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes:
|
**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` ≤ the number of allocated nodes:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Set autotune cache path
|
||||||
|
export TLLM_AUTOTUNER_CACHE_PATH=autotuner_cache/cache
|
||||||
|
|
||||||
# Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable
|
# Run DeepSeek-R1 NVFP4 with wide ep: uses MNNVL A2A if applicable
|
||||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
NODES=4 NP=16 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
||||||
|
|
||||||
# Run with TRTLLMGen
|
# Run with TRTLLMGen
|
||||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM
|
NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend TRTLLM
|
||||||
|
|
||||||
# Run with DeepEPLowLatency
|
# Run with DeepEPLowLatency
|
||||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run.sh config_gen.yaml --moe-backend WIDEEP
|
||||||
|
|
||||||
# You can run 4-GPU and 8-GPU tasks without reallocate the slurm job
|
# You can run 4-GPU and 8-GPU tasks without reallocating the slurm job
|
||||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml
|
NODES=1 NP=4 ./slurm_launch.sh ./run.sh config_ctx.yaml
|
||||||
SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gtx.yaml
|
NODES=2 NP=8 ./slurm_launch.sh ./run.sh config_gen.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Batched run
|
### Batched run
|
||||||
@ -131,9 +153,11 @@ NP=4 ./mpi_launch.sh ./run.sh config_gen.yaml --scaled-from 16 --moe-backend WID
|
|||||||
Run the following command in the container:
|
Run the following command in the container:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Parse the profile at the default directory
|
||||||
python3 parse.py --world-size 4
|
python3 parse.py --world-size 4
|
||||||
|
|
||||||
# Specify the location of the .nsys-rep file
|
# Specify the file path
|
||||||
|
python3 parse.py --file-path profiles/report_np4_rank0.nsys-rep
|
||||||
python3 parse.py --profile-dir ./profiles --world-size 4 --rank 0
|
python3 parse.py --profile-dir ./profiles --world-size 4 --rank 0
|
||||||
|
|
||||||
# Parse a specific module. The module must appear exactly once in each run.
|
# Parse a specific module. The module must appear exactly once in each run.
|
||||||
@ -145,6 +169,15 @@ You will receive three reports, each containing kernel timing statistics grouped
|
|||||||
2. A CSV report at `profiles/report_np4_rank0.csv`
|
2. A CSV report at `profiles/report_np4_rank0.csv`
|
||||||
3. An HTML report at `profiles/report_np4_rank0.html`
|
3. An HTML report at `profiles/report_np4_rank0.html`
|
||||||
|
|
||||||
|
## Developer utilities
|
||||||
|
|
||||||
|
1. Less startup time when debug a model
|
||||||
|
1. Disable autotuner: add `--no-enable-autotuner` option
|
||||||
|
2. Disable nsys profile: set `PROFILE=0` environment variable
|
||||||
|
2. Capture more information
|
||||||
|
1. Enable GPU metrics: set `GPU_METRICS=1` environment variable
|
||||||
|
2. Enable backtrace: set `BACKTRACE=1` environment variable
|
||||||
|
|
||||||
## Trouble shooting
|
## Trouble shooting
|
||||||
|
|
||||||
1. Error `fp8 blockscale gemm only support Hopper` on Blackwell.
|
1. Error `fp8 blockscale gemm only support Hopper` on Blackwell.
|
||||||
|
|||||||
8
examples/layer_wise_benchmarks/middleware/exclude_slurm_envs
Executable file
8
examples/layer_wise_benchmarks/middleware/exclude_slurm_envs
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Clear slurm envs
|
||||||
|
unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true))
|
||||||
|
|
||||||
|
"$@"
|
||||||
@ -3,8 +3,12 @@
|
|||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# Clear slurm envs
|
# Clear slurm envs
|
||||||
unset $(env | grep -i slurm | awk -F'=' '{print $1}')
|
unset $(env | awk -F'=' '{print $1}' | (grep -E "SLURM_|SLURMD_|slurm_|MPI_|PMIX_" || true))
|
||||||
unset $(env | grep MPI | awk -F'=' '{print $1}')
|
|
||||||
|
extra_args=
|
||||||
|
if [ -v TLLM_AUTOTUNER_CACHE_PATH ]; then
|
||||||
|
extra_args+="-x TLLM_AUTOTUNER_CACHE_PATH"
|
||||||
|
fi
|
||||||
|
|
||||||
set -x
|
set -x
|
||||||
mpirun --allow-run-as-root --np ${NP} "$@"
|
mpirun --allow-run-as-root --np ${NP} $extra_args "$@"
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -14,8 +15,9 @@ import pandas as pd
|
|||||||
|
|
||||||
# Parse cmdline
|
# Parse cmdline
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--file-path", type=str)
|
||||||
parser.add_argument("--profile-dir", type=str, default="profiles")
|
parser.add_argument("--profile-dir", type=str, default="profiles")
|
||||||
parser.add_argument("--world-size", "--np", type=int, required=True)
|
parser.add_argument("--world-size", "--np", type=int)
|
||||||
parser.add_argument("--rank", type=int, default=0)
|
parser.add_argument("--rank", type=int, default=0)
|
||||||
parser.add_argument("--warmup-times", type=int)
|
parser.add_argument("--warmup-times", type=int)
|
||||||
parser.add_argument("--module", type=str)
|
parser.add_argument("--module", type=str)
|
||||||
@ -27,6 +29,8 @@ group.add_argument(
|
|||||||
)
|
)
|
||||||
parser.set_defaults(error_on_unknown_kernel=False)
|
parser.set_defaults(error_on_unknown_kernel=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if (args.file_path is None) == (args.world_size is None):
|
||||||
|
parser.error("Please specify exactly one of --file-path and --world-size.")
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
|
|
||||||
@ -89,11 +93,20 @@ def shortest_common_supersequence(a, b):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
profile_dir = Path(args.profile_dir)
|
if args.file_path is not None:
|
||||||
nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep"
|
nsys_rep_file_path = Path(args.file_path)
|
||||||
sqlite_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.sqlite"
|
if not nsys_rep_file_path.name.endswith(".nsys-rep"):
|
||||||
csv_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.csv"
|
raise ValueError("Expect a .nsys-rep file")
|
||||||
html_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.html"
|
else:
|
||||||
|
profile_dir = Path(args.profile_dir)
|
||||||
|
nsys_rep_file_path = profile_dir / f"report_np{args.world_size}_rank{args.rank}.nsys-rep"
|
||||||
|
sqlite_file_path = nsys_rep_file_path.parent / (
|
||||||
|
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".sqlite"
|
||||||
|
)
|
||||||
|
csv_file_path = nsys_rep_file_path.parent / (nsys_rep_file_path.name[: -len(".nsys-rep")] + ".csv")
|
||||||
|
html_file_path = nsys_rep_file_path.parent / (
|
||||||
|
nsys_rep_file_path.name[: -len(".nsys-rep")] + ".html"
|
||||||
|
)
|
||||||
lazy_convert_sqlite(nsys_rep_file_path, sqlite_file_path)
|
lazy_convert_sqlite(nsys_rep_file_path, sqlite_file_path)
|
||||||
|
|
||||||
conn = sqlite3.connect(f"file:{sqlite_file_path}?mode=ro", uri=True)
|
conn = sqlite3.connect(f"file:{sqlite_file_path}?mode=ro", uri=True)
|
||||||
@ -196,7 +209,8 @@ if "CUDA_GRAPH_NODE_EVENTS" in tables:
|
|||||||
R.start AS runtime_start, R.end AS runtime_end,
|
R.start AS runtime_start, R.end AS runtime_end,
|
||||||
CGE2.start AS capture_start, CGE2.end AS capture_end
|
CGE2.start AS capture_start, CGE2.end AS capture_end
|
||||||
FROM ({unified_subquery}) AS unified
|
FROM ({unified_subquery}) AS unified
|
||||||
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.correlationId = R.correlationId
|
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.graphNodeId IS NOT NULL AND
|
||||||
|
unified.correlationId = R.correlationId
|
||||||
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND
|
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND
|
||||||
CGE1.originalGraphNodeId IS NOT NULL
|
CGE1.originalGraphNodeId IS NOT NULL
|
||||||
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId"""
|
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId"""
|
||||||
@ -318,6 +332,7 @@ parser_keywords = [
|
|||||||
("routingInitExpertCounts", "routingInitExpertCounts"),
|
("routingInitExpertCounts", "routingInitExpertCounts"),
|
||||||
("routingIndicesCluster", "routingIndicesClusterKernel"),
|
("routingIndicesCluster", "routingIndicesClusterKernel"),
|
||||||
("routingIndicesCoop", "routingIndicesCoopKernel"),
|
("routingIndicesCoop", "routingIndicesCoopKernel"),
|
||||||
|
("router_gemm", "router_gemm_kernel"),
|
||||||
("bmm_4_44_32", "bmm_E2m1_E2m1E2m1_Fp32_t"),
|
("bmm_4_44_32", "bmm_E2m1_E2m1E2m1_Fp32_t"),
|
||||||
("finalize", "finalize::finalizeKernel"),
|
("finalize", "finalize::finalizeKernel"),
|
||||||
("bmm_16_44_32", "bmm_Bfloat16_E2m1E2m1_Fp32_"),
|
("bmm_16_44_32", "bmm_Bfloat16_E2m1E2m1_Fp32_"),
|
||||||
@ -381,7 +396,7 @@ def parse_kernel_name(demangledName):
|
|||||||
if all(keyword in name for keyword in src):
|
if all(keyword in name for keyword in src):
|
||||||
return dst
|
return dst
|
||||||
if name not in warned_names:
|
if name not in warned_names:
|
||||||
print(f"Unknown kernel name: {name}")
|
print(f"Unknown kernel name: {name}", file=sys.stderr)
|
||||||
warned_names.add(name)
|
warned_names.add(name)
|
||||||
if args.error_on_unknown_kernel:
|
if args.error_on_unknown_kernel:
|
||||||
raise NotImplementedError(f"Unknown kernel name: {name}")
|
raise NotImplementedError(f"Unknown kernel name: {name}")
|
||||||
@ -491,11 +506,33 @@ with csv_file_path.open("w", newline="") as f:
|
|||||||
for row in csv_data:
|
for row in csv_data:
|
||||||
csv_writer.writerow(row)
|
csv_writer.writerow(row)
|
||||||
js_header_config = [{"name": problem["text"]} for problem in problem_set]
|
js_header_config = [{"name": problem["text"]} for problem in problem_set]
|
||||||
|
js_header_config = []
|
||||||
|
for problem in problem_set:
|
||||||
|
innermost_children = js_header_config
|
||||||
|
for k, msg_prefix in [
|
||||||
|
("batch_size", "b="),
|
||||||
|
("seq_len_q", "q="),
|
||||||
|
("seq_len_kv_cache", "past="),
|
||||||
|
]:
|
||||||
|
if len(run_args[k + "_list"]) > 1:
|
||||||
|
if len(innermost_children) == 0 or problem["spec"][k] != innermost_children[-1][k]:
|
||||||
|
innermost_children.append(
|
||||||
|
{
|
||||||
|
"name": msg_prefix + str(problem["spec"][k]),
|
||||||
|
"children": [],
|
||||||
|
k: problem["spec"][k],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
innermost_children = innermost_children[-1]["children"]
|
||||||
|
innermost_children.append({"name": problem["text"]})
|
||||||
loader = jinja2.FileSystemLoader(Path(__file__).parent)
|
loader = jinja2.FileSystemLoader(Path(__file__).parent)
|
||||||
template = jinja2.Environment(loader=loader).get_template("template.html")
|
template = jinja2.Environment(loader=loader).get_template("template.html")
|
||||||
with html_file_path.open("w") as f:
|
with html_file_path.open("w") as f:
|
||||||
configText = (
|
configText = (
|
||||||
"Run:\n" + json.dumps(run_args, indent=4) + "\n\nParse:\n" + json.dumps(args.__dict__)
|
"Run:\n"
|
||||||
|
+ json.dumps(run_args, indent=4)
|
||||||
|
+ "\n\nParse:\n"
|
||||||
|
+ json.dumps(args.__dict__, indent=4)
|
||||||
)
|
)
|
||||||
f.write(template.render(headerConfig=js_header_config, rawData=js_data, configText=configText))
|
f.write(template.render(headerConfig=js_header_config, rawData=js_data, configText=configText))
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import argparse
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import nvtx
|
import nvtx
|
||||||
@ -9,6 +10,8 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||||
|
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
|
||||||
|
from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType
|
||||||
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
|
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
|
||||||
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
|
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
@ -108,17 +111,6 @@ if args.enable_attention_dp is None:
|
|||||||
args.enable_attention_dp = False
|
args.enable_attention_dp = False
|
||||||
if args.max_num_tokens is None:
|
if args.max_num_tokens is None:
|
||||||
args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list)
|
args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list)
|
||||||
if args.run_type == "GEN":
|
|
||||||
ctx_batch_size = max(1, max(20480, args.max_num_tokens) // max(args.seq_len_kv_cache_list))
|
|
||||||
args.max_num_tokens = max(
|
|
||||||
args.max_num_tokens, ctx_batch_size * max(args.seq_len_kv_cache_list)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if args.run_type == "GEN":
|
|
||||||
ctx_batch_size = max(1, args.max_num_tokens // max(args.seq_len_kv_cache_list))
|
|
||||||
assert args.max_num_tokens >= ctx_batch_size * max(args.seq_len_kv_cache_list), (
|
|
||||||
"Max_num_tokens is too small to prefill KV cache"
|
|
||||||
)
|
|
||||||
if args.use_low_precision_moe_combine is None:
|
if args.use_low_precision_moe_combine is None:
|
||||||
args.use_low_precision_moe_combine = False
|
args.use_low_precision_moe_combine = False
|
||||||
if args.enable_autotuner is None:
|
if args.enable_autotuner is None:
|
||||||
@ -169,17 +161,75 @@ runner = Runner(
|
|||||||
)
|
)
|
||||||
logger.info("Layer-wise benchmarks: Create runner ... Done")
|
logger.info("Layer-wise benchmarks: Create runner ... Done")
|
||||||
|
|
||||||
|
# Autotune
|
||||||
|
run_pack = runner.create_run_pack(
|
||||||
|
args.run_type,
|
||||||
|
batch_size=max(args.batch_size_list),
|
||||||
|
request_id_begin=0,
|
||||||
|
seq_len_q=max(args.seq_len_q_list),
|
||||||
|
seq_len_kv_cache=args.seq_len_kv_cache_list[0],
|
||||||
|
kv_cache_manager=kv_cache_manager,
|
||||||
|
attn_workspace=attn_workspace,
|
||||||
|
)
|
||||||
|
if args.enable_autotuner:
|
||||||
|
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
|
||||||
|
with autotune(cache_path=cache_path):
|
||||||
|
run_pack()
|
||||||
|
else:
|
||||||
|
run_pack()
|
||||||
|
|
||||||
|
# Prefill KV cache
|
||||||
|
if args.run_type == "GEN":
|
||||||
|
logger.info("Layer-wise benchmarks: Create runner for prefill")
|
||||||
|
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
|
||||||
|
ctx_batch_size = min(
|
||||||
|
args.max_batch_size,
|
||||||
|
max(1, 20480 // ctx_seq_len_q),
|
||||||
|
)
|
||||||
|
ctx_attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
|
||||||
|
with mock.patch.object(
|
||||||
|
CutlassFusedMoE, "select_alltoall_method_type", return_value=AlltoallMethodType.NotEnabled
|
||||||
|
):
|
||||||
|
ctx_runner = Runner(
|
||||||
|
args.model,
|
||||||
|
mapping,
|
||||||
|
moe_backend="CUTLASS",
|
||||||
|
layer_indices=args.layer_indices,
|
||||||
|
scaled_from=args.scaled_from,
|
||||||
|
max_seq_len=args.max_seq_len,
|
||||||
|
max_num_tokens=ctx_batch_size * ctx_seq_len_q,
|
||||||
|
moe_max_num_tokens=16384,
|
||||||
|
use_low_precision_moe_combine=args.use_low_precision_moe_combine,
|
||||||
|
use_cuda_graph=False,
|
||||||
|
)
|
||||||
|
logger.info("Layer-wise benchmarks: Create runner for prefill ... Done")
|
||||||
|
|
||||||
|
logger.info("Layer-wise benchmarks: Prefill KV cache")
|
||||||
|
assert ctx_batch_size <= args.max_batch_size
|
||||||
|
assert ctx_seq_len_q + 0 <= args.max_seq_len
|
||||||
|
num_requests = max(args.batch_size_list)
|
||||||
|
for request_id_begin in range(0, num_requests, ctx_batch_size):
|
||||||
|
run_pack = ctx_runner.create_run_pack(
|
||||||
|
"CTX",
|
||||||
|
batch_size=min(ctx_batch_size, num_requests - request_id_begin),
|
||||||
|
request_id_begin=request_id_begin,
|
||||||
|
seq_len_q=ctx_seq_len_q,
|
||||||
|
seq_len_kv_cache=0,
|
||||||
|
kv_cache_manager=kv_cache_manager,
|
||||||
|
attn_workspace=ctx_attn_workspace,
|
||||||
|
)
|
||||||
|
with ctx_runner.replace_routing_method_ctx(
|
||||||
|
balance_method=BalanceMethod.Balanced, balance_ratio=None
|
||||||
|
):
|
||||||
|
run_pack(check=True)
|
||||||
|
del ctx_runner
|
||||||
|
del ctx_attn_workspace
|
||||||
|
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
|
||||||
|
|
||||||
# Warm up
|
# Warm up
|
||||||
for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
|
logger.info("Layer-wise benchmarks: Warmup")
|
||||||
[
|
for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
|
||||||
True,
|
|
||||||
max(args.batch_size_list),
|
|
||||||
max(args.seq_len_q_list),
|
|
||||||
args.seq_len_kv_cache_list[0],
|
|
||||||
args.balance_ratio_list[0],
|
|
||||||
],
|
|
||||||
*itertools.product(
|
*itertools.product(
|
||||||
[False],
|
|
||||||
args.batch_size_list,
|
args.batch_size_list,
|
||||||
args.seq_len_q_list,
|
args.seq_len_q_list,
|
||||||
args.seq_len_kv_cache_list,
|
args.seq_len_kv_cache_list,
|
||||||
@ -203,34 +253,10 @@ for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
|
|||||||
):
|
):
|
||||||
capture_stream.wait_stream(torch.cuda.current_stream())
|
capture_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(capture_stream):
|
with torch.cuda.stream(capture_stream):
|
||||||
if autotune_flag:
|
run_pack(check=True)
|
||||||
if args.enable_autotuner:
|
|
||||||
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
|
|
||||||
with autotune(cache_path=cache_path):
|
|
||||||
run_pack()
|
|
||||||
if args.run_type == "GEN":
|
|
||||||
logger.info("Layer-wise benchmarks: Prefill KV cache")
|
|
||||||
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
|
|
||||||
assert ctx_batch_size <= args.max_batch_size
|
|
||||||
assert ctx_seq_len_q + 0 <= args.max_seq_len
|
|
||||||
assert ctx_batch_size * ctx_seq_len_q <= args.max_num_tokens
|
|
||||||
max_batch_size = max(args.batch_size_list)
|
|
||||||
for request_id_begin in range(0, max_batch_size, ctx_batch_size):
|
|
||||||
ctx_run_pack = runner.create_run_pack(
|
|
||||||
"CTX",
|
|
||||||
batch_size=min(ctx_batch_size, max_batch_size - request_id_begin),
|
|
||||||
request_id_begin=request_id_begin,
|
|
||||||
seq_len_q=ctx_seq_len_q,
|
|
||||||
seq_len_kv_cache=0,
|
|
||||||
kv_cache_manager=kv_cache_manager,
|
|
||||||
attn_workspace=attn_workspace,
|
|
||||||
)
|
|
||||||
ctx_run_pack(check=True)
|
|
||||||
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
|
|
||||||
else:
|
|
||||||
run_pack(check=True)
|
|
||||||
torch.cuda.current_stream().wait_stream(capture_stream)
|
torch.cuda.current_stream().wait_stream(capture_stream)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
logger.info("Layer-wise benchmarks: Warmup ... Done")
|
||||||
|
|
||||||
events = [
|
events = [
|
||||||
torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1)
|
torch.cuda.Event(enable_timing=True) for _ in range(args.warmup_times + args.run_times + 1)
|
||||||
@ -270,7 +296,7 @@ for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product(
|
|||||||
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
|
with torch.cuda.graph(g, stream=capture_stream, capture_error_mode="global"):
|
||||||
run_pack()
|
run_pack()
|
||||||
|
|
||||||
balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}"
|
balance_ratio_str = "" if balance_ratio is None else f" balance={balance_ratio:.2g}"
|
||||||
nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}"
|
nvtx_message = f"b={batch_size} s={seq_len_q} past={seq_len_kv_cache}{balance_ratio_str} NP{world_size}"
|
||||||
for i in range(args.warmup_times + args.run_times):
|
for i in range(args.warmup_times + args.run_times):
|
||||||
events[i].record()
|
events[i].record()
|
||||||
|
|||||||
@ -40,7 +40,9 @@ else
|
|||||||
PROFILE_CMD=
|
PROFILE_CMD=
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
SCRIPT_PATH=$(realpath --relative-to="$(pwd)" "$(dirname -- "$0")"/run.py)
|
||||||
|
|
||||||
set -x
|
set -x
|
||||||
$PROFILE_CMD bash -c \
|
$PROFILE_CMD bash -o pipefail -c \
|
||||||
"python3 -u run.py \"\$@\" 2>&1 | tee \"$PROFILE_DIR/report_np${WORLD_SIZE}_rank${RANK}.log\"" \
|
"python3 -u \"\$1\" \"\${@:3}\" 2>&1 | tee \"\$2/report_np${WORLD_SIZE}_rank${RANK}.log\"" \
|
||||||
bash "$@"
|
bash "$SCRIPT_PATH" "$PROFILE_DIR" "$@"
|
||||||
|
|||||||
@ -4,23 +4,23 @@ set -euo pipefail
|
|||||||
|
|
||||||
# CONTAINER_IMAGE=
|
# CONTAINER_IMAGE=
|
||||||
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
|
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
|
||||||
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..")
|
TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..)
|
||||||
|
CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT
|
||||||
|
|
||||||
if [ "${SLURM_JOB_ID:-}" == "" ]; then
|
if [ -z "${SLURM_JOB_ID:-}" ]; then
|
||||||
echo "Please set SLURM_JOB_ID"
|
echo "Please set SLURM_JOB_ID"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
NODES=$(squeue -j $SLURM_JOB_ID -h -o "%D")
|
NODES=$(squeue -j $SLURM_JOB_ID -h -o "%D")
|
||||||
|
|
||||||
if [ "${CONTAINER_IMAGE:-}" == "" ]; then
|
if [ -z "${CONTAINER_IMAGE:-}" ]; then
|
||||||
# Read Docker image from current_image_tags.properties
|
# Read Docker image from current_image_tags.properties
|
||||||
source ../../jenkins/current_image_tags.properties
|
|
||||||
MACHINE="$(uname -m)"
|
MACHINE="$(uname -m)"
|
||||||
if [ "$MACHINE" == "x86_64" ]; then
|
if [ "$MACHINE" == "x86_64" ]; then
|
||||||
DOCKER_IMAGE=$LLM_DOCKER_IMAGE
|
DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_DOCKER_IMAGE)
|
||||||
elif [ "$MACHINE" == "aarch64" ]; then
|
elif [ "$MACHINE" == "aarch64" ]; then
|
||||||
DOCKER_IMAGE=$LLM_SBSA_DOCKER_IMAGE
|
DOCKER_IMAGE=$(source "$TRTLLM_ROOT/jenkins/current_image_tags.properties" && echo $LLM_SBSA_DOCKER_IMAGE)
|
||||||
else
|
else
|
||||||
echo "Unsupported machine hardware name \"$MACHINE\""
|
echo "Unsupported machine hardware name \"$MACHINE\""
|
||||||
exit 1
|
exit 1
|
||||||
@ -34,7 +34,7 @@ if [ "${CONTAINER_IMAGE:-}" == "" ]; then
|
|||||||
SQSH_FILE_NAME=$(echo "$DOCKER_IMAGE" |
|
SQSH_FILE_NAME=$(echo "$DOCKER_IMAGE" |
|
||||||
awk -F'#' '{print $2}' |
|
awk -F'#' '{print $2}' |
|
||||||
awk -F':' '{gsub(/\//,"+",$1); print $1"+"$2".sqsh"}')
|
awk -F':' '{gsub(/\//,"+",$1); print $1"+"$2".sqsh"}')
|
||||||
CONTAINER_IMAGE="../../enroot/$SQSH_FILE_NAME"
|
CONTAINER_IMAGE="$TRTLLM_ROOT/enroot/$SQSH_FILE_NAME"
|
||||||
if [ ! -f "$CONTAINER_IMAGE" ]; then
|
if [ ! -f "$CONTAINER_IMAGE" ]; then
|
||||||
echo "Container image file $CONTAINER_IMAGE does not exist, importing ..."
|
echo "Container image file $CONTAINER_IMAGE does not exist, importing ..."
|
||||||
srun -N 1 --pty enroot import -o "$CONTAINER_IMAGE" "docker://$DOCKER_IMAGE"
|
srun -N 1 --pty enroot import -o "$CONTAINER_IMAGE" "docker://$DOCKER_IMAGE"
|
||||||
@ -50,6 +50,7 @@ srun -N "$NODES" \
|
|||||||
--container-name "$CONTAINER_NAME" \
|
--container-name "$CONTAINER_NAME" \
|
||||||
--container-mounts "$CONTAINER_MOUNTS" \
|
--container-mounts "$CONTAINER_MOUNTS" \
|
||||||
--container-workdir "$WORKDIR" \
|
--container-workdir "$WORKDIR" \
|
||||||
bash -c "pip install -U packaging &&
|
bash -c "cd \"\$1\" &&
|
||||||
pip install -r ../../requirements.txt --no-build-isolation &&
|
pip install -U packaging &&
|
||||||
pip install -e ../.."
|
pip install -r requirements.txt --no-build-isolation &&
|
||||||
|
pip install -e ." bash "$TRTLLM_ROOT"
|
||||||
|
|||||||
@ -3,9 +3,10 @@
|
|||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
|
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
|
||||||
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..")
|
TRTLLM_ROOT=$(realpath "$(dirname -- "$0")"/../..)
|
||||||
|
CONTAINER_MOUNTS=$TRTLLM_ROOT:$TRTLLM_ROOT
|
||||||
|
|
||||||
if [ "${SLURM_JOB_ID:-}" == "" ]; then
|
if [ -z "${SLURM_JOB_ID:-}" ]; then
|
||||||
echo "Please set SLURM_JOB_ID"
|
echo "Please set SLURM_JOB_ID"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|||||||
@ -2,6 +2,11 @@
|
|||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ -z "${SLURM_JOB_ID:-}" ]; then
|
||||||
|
echo "Please set SLURM_JOB_ID"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
prefix="pyxis_${SLURM_JOB_ID}_"
|
prefix="pyxis_${SLURM_JOB_ID}_"
|
||||||
matches=$(printf "%s\n" "$(srun -N 1 enroot list)" | grep "^${prefix}" || true)
|
matches=$(printf "%s\n" "$(srun -N 1 enroot list)" | grep "^${prefix}" || true)
|
||||||
count=$(printf "%s\n" "$matches" | wc -l)
|
count=$(printf "%s\n" "$matches" | wc -l)
|
||||||
|
|||||||
@ -313,6 +313,7 @@
|
|||||||
const columnLabels = [];
|
const columnLabels = [];
|
||||||
let columnTotals = [];
|
let columnTotals = [];
|
||||||
let totalNode = null; // Special node for the "Total" row
|
let totalNode = null; // Special node for the "Total" row
|
||||||
|
let numCols = null;
|
||||||
|
|
||||||
const fmt = (num) => num.toFixed(1);
|
const fmt = (num) => num.toFixed(1);
|
||||||
|
|
||||||
@ -332,7 +333,7 @@
|
|||||||
const topLevelTimes = rawData.map(aggregate);
|
const topLevelTimes = rawData.map(aggregate);
|
||||||
|
|
||||||
// Calculate Totals
|
// Calculate Totals
|
||||||
const numCols = topLevelTimes[0].length;
|
numCols = topLevelTimes[0].length;
|
||||||
columnTotals = new Array(numCols).fill(0);
|
columnTotals = new Array(numCols).fill(0);
|
||||||
for (let c = 0; c < numCols; c++) {
|
for (let c = 0; c < numCols; c++) {
|
||||||
columnTotals[c] = topLevelTimes.reduce((sum, row) => sum + row[c], 0);
|
columnTotals[c] = topLevelTimes.reduce((sum, row) => sum + row[c], 0);
|
||||||
@ -420,7 +421,8 @@
|
|||||||
const total = columnTotals[colIdx] || 1;
|
const total = columnTotals[colIdx] || 1;
|
||||||
const pct = (val / total) * 100;
|
const pct = (val / total) * 100;
|
||||||
|
|
||||||
html += `<td class="data-cell col-${colIdx}" onclick="handleCellClick('${uniqueId}', ${colIdx})">
|
html += `<td class="data-cell col-${colIdx}" ${val ? "title=\"" + fmt(val) + "\"" : ""}
|
||||||
|
onclick="handleCellClick('${uniqueId}', ${colIdx})">
|
||||||
<div class="progress-bar" style="width: ${pct}%;"></div>
|
<div class="progress-bar" style="width: ${pct}%;"></div>
|
||||||
<div class="cell-content">${val ? fmt(val) : ""}</div>
|
<div class="cell-content">${val ? fmt(val) : ""}</div>
|
||||||
</td>`;
|
</td>`;
|
||||||
@ -717,6 +719,7 @@
|
|||||||
});
|
});
|
||||||
|
|
||||||
window.onload = () => {
|
window.onload = () => {
|
||||||
|
document.getElementById('kernelTable').style["min-width"] = Math.max(1000, (220 + 60 * numCols)) + "px";
|
||||||
renderHeader();
|
renderHeader();
|
||||||
document.getElementById('tableBody').innerHTML = renderTableBody(rawData, 0);
|
document.getElementById('tableBody').innerHTML = renderTableBody(rawData, 0);
|
||||||
renderTotalRow();
|
renderTotalRow();
|
||||||
@ -741,7 +744,6 @@
|
|||||||
for (let node of rawData) {
|
for (let node of rawData) {
|
||||||
maxDepth = Math.max(maxDepth, getDepth(node));
|
maxDepth = Math.max(maxDepth, getDepth(node));
|
||||||
}
|
}
|
||||||
console.log(maxDepth);
|
|
||||||
|
|
||||||
const container = document.getElementById('level-buttons');
|
const container = document.getElementById('level-buttons');
|
||||||
if (container) {
|
if (container) {
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import nvtx
|
import nvtx
|
||||||
|
|
||||||
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate
|
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate, Deepseekv3MoE
|
||||||
from tensorrt_llm._torch.models.modeling_qwen3_next import (
|
from tensorrt_llm._torch.models.modeling_qwen3_next import (
|
||||||
Qwen3NextGatedDeltaNet,
|
Qwen3NextGatedDeltaNet,
|
||||||
Qwen3NextSparseMoeBlock,
|
Qwen3NextSparseMoeBlock,
|
||||||
@ -12,6 +12,7 @@ from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
|||||||
|
|
||||||
def mark_ranges():
|
def mark_ranges():
|
||||||
DeepseekV3Gate.forward = nvtx.annotate("DeepseekV3Gate")(DeepseekV3Gate.forward)
|
DeepseekV3Gate.forward = nvtx.annotate("DeepseekV3Gate")(DeepseekV3Gate.forward)
|
||||||
|
Deepseekv3MoE.forward = nvtx.annotate("Deepseekv3MoE")(Deepseekv3MoE.forward)
|
||||||
Qwen3NextGatedDeltaNet.forward = nvtx.annotate("Qwen3NextGatedDeltaNet")(
|
Qwen3NextGatedDeltaNet.forward = nvtx.annotate("Qwen3NextGatedDeltaNet")(
|
||||||
Qwen3NextGatedDeltaNet.forward
|
Qwen3NextGatedDeltaNet.forward
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
import os
|
import os
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
import weakref
|
import weakref
|
||||||
@ -23,6 +24,7 @@ from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs
|
|||||||
from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding
|
from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding
|
||||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||||
from tensorrt_llm.bindings.internal.batch_manager import CacheType
|
from tensorrt_llm.bindings.internal.batch_manager import CacheType
|
||||||
|
from tensorrt_llm.logger import logger
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||||
|
|
||||||
@ -38,24 +40,16 @@ def round_up(a, b):
|
|||||||
|
|
||||||
|
|
||||||
def get_balanced_selection_no_cache(
|
def get_balanced_selection_no_cache(
|
||||||
num_tokens, top_k, num_experts, dtype, device, world_size, rank
|
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
||||||
):
|
):
|
||||||
# First, each sender selects target rank
|
token_id = torch.arange(dp_rank * num_tokens * top_k, (dp_rank + 1) * num_tokens * top_k).view(
|
||||||
target_rank_before_mod = torch.arange(num_tokens * world_size * top_k).view(
|
num_tokens, top_k
|
||||||
num_tokens, world_size, top_k
|
|
||||||
)
|
)
|
||||||
target_rank_before_mod += top_k * torch.arange(num_tokens).view(
|
experts_per_rank = num_experts // ep_size
|
||||||
num_tokens, 1, 1
|
token_selected_experts = (token_id % ep_size) * experts_per_rank + (
|
||||||
) # Shift `top_k` ranks for the next token on each rank, to balance network traffic
|
token_id // ep_size
|
||||||
target_rank = target_rank_before_mod % world_size
|
) % experts_per_rank
|
||||||
# Second, each receiver selects target expert
|
token_selected_experts = token_selected_experts.sort(dim=-1).values
|
||||||
target_expert = torch.empty_like(target_rank)
|
|
||||||
for reciever_rank in range(world_size):
|
|
||||||
mask = target_rank == reciever_rank
|
|
||||||
experts_per_rank = num_experts // world_size
|
|
||||||
local_expert = torch.arange(num_tokens * top_k) % experts_per_rank
|
|
||||||
target_expert[mask] = (reciever_rank * experts_per_rank) + local_expert
|
|
||||||
token_selected_experts = target_expert[:, rank].sort(dim=-1).values
|
|
||||||
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
|
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
@ -64,47 +58,62 @@ get_balanced_selection = functools.cache(get_balanced_selection_no_cache)
|
|||||||
|
|
||||||
def test_get_balanced_selection():
|
def test_get_balanced_selection():
|
||||||
dtype = torch.long
|
dtype = torch.long
|
||||||
for num_tokens in range(1, 33):
|
for num_tokens, num_experts, enable_attention_dp in itertools.product(
|
||||||
for num_experts in range(1, 65):
|
range(1, 35), range(1, 35), [False, True]
|
||||||
print(f"{num_tokens=} {num_experts=}")
|
):
|
||||||
for top_k in range(1, min(11, num_experts)):
|
print(f"{num_tokens=} {num_experts=} {enable_attention_dp=}")
|
||||||
for world_size in range(1, 65):
|
for top_k in range(1, min(10, num_experts) + 1):
|
||||||
if num_experts % world_size == 0:
|
for world_size in range(1, 35):
|
||||||
tokens_per_expert = torch.zeros(num_experts)
|
dp_size = world_size if enable_attention_dp else 1
|
||||||
for rank in range(world_size):
|
ep_size = world_size
|
||||||
token_selected_experts = get_balanced_selection_no_cache(
|
if num_experts % ep_size == 0:
|
||||||
num_tokens, top_k, num_experts, dtype, "cpu", world_size, rank
|
tokens_per_expert = torch.zeros(num_experts)
|
||||||
|
for dp_rank in range(dp_size):
|
||||||
|
token_selected_experts = get_balanced_selection_no_cache(
|
||||||
|
num_tokens, top_k, num_experts, dtype, "cpu", dp_size, dp_rank, ep_size
|
||||||
|
)
|
||||||
|
sorted_selection = token_selected_experts.sort(dim=-1).values
|
||||||
|
if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any():
|
||||||
|
raise ValueError(f"duplicated experts on rank {dp_rank}")
|
||||||
|
experts_per_rank = num_experts // ep_size
|
||||||
|
tokens_per_rank = (
|
||||||
|
(token_selected_experts // experts_per_rank)
|
||||||
|
.view(-1)
|
||||||
|
.bincount(minlength=ep_size)
|
||||||
|
)
|
||||||
|
if tokens_per_rank.max() - tokens_per_rank.min() > 1:
|
||||||
|
raise ValueError(f"tokens sent from rank {dp_rank} is not balanced")
|
||||||
|
unique_tokens_per_rank = (
|
||||||
|
(
|
||||||
|
torch.arange(ep_size).view(ep_size, 1, 1)
|
||||||
|
== token_selected_experts // experts_per_rank
|
||||||
)
|
)
|
||||||
sorted_selection = token_selected_experts.sort(dim=-1).values
|
.any(dim=2)
|
||||||
if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any():
|
.sum(dim=1)
|
||||||
raise ValueError(f"duplicated experts on rank {rank}")
|
)
|
||||||
experts_per_rank = num_experts // world_size
|
if unique_tokens_per_rank.max() - unique_tokens_per_rank.min() > 1:
|
||||||
tokens_per_rank = (
|
raise ValueError(
|
||||||
(token_selected_experts // experts_per_rank)
|
f"tokens sent from rank {dp_rank} is not balanced after removing duplicates"
|
||||||
.view(-1)
|
|
||||||
.bincount(minlength=world_size)
|
|
||||||
)
|
)
|
||||||
if tokens_per_rank.max() - tokens_per_rank.min() > 1:
|
tokens_per_expert += token_selected_experts.view(-1).bincount(
|
||||||
raise ValueError(f"tokens sent from rank {rank} is not balanced")
|
minlength=num_experts
|
||||||
tokens_per_expert += token_selected_experts.view(-1).bincount(
|
)
|
||||||
minlength=num_experts
|
if tokens_per_expert.max() - tokens_per_expert.min() > 1:
|
||||||
)
|
raise ValueError("tokens per expert is not balanced")
|
||||||
if tokens_per_expert.max() - tokens_per_expert.min() > 1:
|
|
||||||
raise ValueError("tokens per expert is not balanced")
|
|
||||||
|
|
||||||
|
|
||||||
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank):
|
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size):
|
||||||
num_tokens, top_k = imbalanced_experts.shape
|
num_tokens, top_k = imbalanced_experts.shape
|
||||||
dtype = imbalanced_experts.dtype
|
dtype = imbalanced_experts.dtype
|
||||||
device = imbalanced_experts.device
|
device = imbalanced_experts.device
|
||||||
balanced_experts = get_balanced_selection_no_cache(
|
balanced_experts = get_balanced_selection_no_cache(
|
||||||
num_tokens, top_k, num_experts, dtype, device, world_size, rank
|
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
||||||
)
|
)
|
||||||
if balance_ratio == 0.0:
|
if balance_ratio == 0.0:
|
||||||
num_balanced_tokens = 0
|
num_balanced_tokens = 0
|
||||||
else:
|
else:
|
||||||
# Activate all experts
|
# Activate all experts
|
||||||
min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, world_size * top_k))
|
min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, dp_size * top_k))
|
||||||
num_balanced_tokens = min_num_balanced_tokens + round(
|
num_balanced_tokens = min_num_balanced_tokens + round(
|
||||||
(num_tokens - min_num_balanced_tokens) * balance_ratio
|
(num_tokens - min_num_balanced_tokens) * balance_ratio
|
||||||
)
|
)
|
||||||
@ -116,57 +125,78 @@ def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_si
|
|||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def get_all_to_one_selection(
|
def get_all_to_one_selection(
|
||||||
num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank
|
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
|
||||||
):
|
):
|
||||||
experts_per_rank = num_experts // world_size
|
experts_per_rank = num_experts // ep_size
|
||||||
if top_k > experts_per_rank:
|
if top_k > experts_per_rank:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot send all tokens to a single rank because `top_k > experts_per_rank`"
|
"Cannot send all tokens to a single rank because `top_k > experts_per_rank`"
|
||||||
)
|
)
|
||||||
imbalanced_experts = (
|
imbalanced_experts = (
|
||||||
torch.arange(
|
torch.arange(
|
||||||
rank * num_tokens * top_k, (rank + 1) * num_tokens * top_k, dtype=dtype, device=device
|
dp_rank * num_tokens * top_k,
|
||||||
|
(dp_rank + 1) * num_tokens * top_k,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
).view(num_tokens, top_k)
|
).view(num_tokens, top_k)
|
||||||
% experts_per_rank
|
% experts_per_rank
|
||||||
)
|
)
|
||||||
imbalanced_experts = imbalanced_experts.sort(dim=-1).values
|
imbalanced_experts = imbalanced_experts.sort(dim=-1).values
|
||||||
return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank)
|
return apply_balance_ratio(
|
||||||
|
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def get_balanced_rank_imbalanced_expert_selection(
|
def get_balanced_rank_imbalanced_expert_selection(
|
||||||
num_tokens, top_k, num_experts, balance_ratio, dtype, device, world_size, rank
|
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
|
||||||
):
|
):
|
||||||
experts_per_rank = num_experts // world_size
|
experts_per_rank = num_experts // ep_size
|
||||||
active_experts_per_rank = ceil_div(top_k, world_size)
|
active_experts_per_rank = ceil_div(top_k, ep_size)
|
||||||
# Select expert from [0, active_experts_per_rank * world_size),
|
# Select expert from [0, active_experts_per_rank * ep_size),
|
||||||
# then scale to [0, experts_per_rank * world_size)
|
# then scale to [0, experts_per_rank * ep_size)
|
||||||
narrow_experts = get_balanced_selection_no_cache(
|
narrow_experts = get_balanced_selection_no_cache(
|
||||||
num_tokens, top_k, active_experts_per_rank * world_size, dtype, device, world_size, rank
|
num_tokens,
|
||||||
|
top_k,
|
||||||
|
active_experts_per_rank * ep_size,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
dp_size,
|
||||||
|
dp_rank,
|
||||||
|
ep_size,
|
||||||
)
|
)
|
||||||
imbalanced_experts = (
|
imbalanced_experts = (
|
||||||
narrow_experts // active_experts_per_rank * experts_per_rank
|
narrow_experts // active_experts_per_rank * experts_per_rank
|
||||||
+ narrow_experts % active_experts_per_rank
|
+ narrow_experts % active_experts_per_rank
|
||||||
)
|
)
|
||||||
return apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_size, rank)
|
return apply_balance_ratio(
|
||||||
|
imbalanced_experts, num_experts, balance_ratio, dp_size, dp_rank, ep_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_balanced_routing_method(
|
def make_balanced_routing_method(
|
||||||
apply_method_orig, num_experts, balance_method, balance_ratio, world_size, rank
|
moe_module,
|
||||||
|
apply_method_orig,
|
||||||
|
num_experts,
|
||||||
|
balance_method,
|
||||||
|
balance_ratio,
|
||||||
|
dp_size,
|
||||||
|
dp_rank,
|
||||||
|
ep_size,
|
||||||
):
|
):
|
||||||
def balanced_routing_method(router_logits):
|
def balanced_routing_method(router_logits):
|
||||||
token_selected_experts, token_final_scales = apply_method_orig(router_logits)
|
token_selected_experts, token_final_scales = apply_method_orig(router_logits)
|
||||||
if balance_method == BalanceMethod.NotModified:
|
assert moe_module._routing_results_replaced_at in [None, "make_balanced_routing_method"]
|
||||||
pass
|
if balance_method == BalanceMethod.Balanced:
|
||||||
elif balance_method == BalanceMethod.Balanced:
|
|
||||||
token_selected_experts = get_balanced_selection(
|
token_selected_experts = get_balanced_selection(
|
||||||
token_selected_experts.shape[0],
|
token_selected_experts.shape[0],
|
||||||
token_selected_experts.shape[1],
|
token_selected_experts.shape[1],
|
||||||
num_experts,
|
num_experts,
|
||||||
token_selected_experts.dtype,
|
token_selected_experts.dtype,
|
||||||
token_selected_experts.device,
|
token_selected_experts.device,
|
||||||
world_size,
|
dp_size,
|
||||||
rank,
|
dp_rank,
|
||||||
|
ep_size,
|
||||||
)
|
)
|
||||||
elif balance_method == BalanceMethod.ImbalancedRanks:
|
elif balance_method == BalanceMethod.ImbalancedRanks:
|
||||||
token_selected_experts = get_all_to_one_selection(
|
token_selected_experts = get_all_to_one_selection(
|
||||||
@ -176,8 +206,9 @@ def make_balanced_routing_method(
|
|||||||
balance_ratio,
|
balance_ratio,
|
||||||
token_selected_experts.dtype,
|
token_selected_experts.dtype,
|
||||||
token_selected_experts.device,
|
token_selected_experts.device,
|
||||||
world_size,
|
dp_size,
|
||||||
rank,
|
dp_rank,
|
||||||
|
ep_size,
|
||||||
)
|
)
|
||||||
elif balance_method == BalanceMethod.ImbalancedExperts:
|
elif balance_method == BalanceMethod.ImbalancedExperts:
|
||||||
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
|
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
|
||||||
@ -187,16 +218,127 @@ def make_balanced_routing_method(
|
|||||||
balance_ratio,
|
balance_ratio,
|
||||||
token_selected_experts.dtype,
|
token_selected_experts.dtype,
|
||||||
token_selected_experts.device,
|
token_selected_experts.device,
|
||||||
world_size,
|
dp_size,
|
||||||
rank,
|
dp_rank,
|
||||||
|
ep_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Not support balance_method {balance_method}")
|
raise NotImplementedError(f"Not support balance_method {balance_method}")
|
||||||
|
moe_module._routing_results_replaced_at = "make_balanced_routing_method"
|
||||||
return token_selected_experts, token_final_scales
|
return token_selected_experts, token_final_scales
|
||||||
|
|
||||||
return balanced_routing_method
|
return balanced_routing_method
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_token_final_scales(shape, device):
|
||||||
|
return torch.full(shape, 1.0 / shape[-1], dtype=torch.bfloat16, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def make_balanced_run_moe(
|
||||||
|
moe_module,
|
||||||
|
run_moe_orig,
|
||||||
|
top_k,
|
||||||
|
num_experts,
|
||||||
|
balance_method,
|
||||||
|
balance_ratio,
|
||||||
|
dp_size,
|
||||||
|
dp_rank,
|
||||||
|
ep_size,
|
||||||
|
):
|
||||||
|
def balanced_run_moe(
|
||||||
|
x, token_selected_experts, token_final_scales, x_sf, router_logits, do_finalize, moe_output
|
||||||
|
):
|
||||||
|
if moe_module._routing_results_replaced_at is not None:
|
||||||
|
return run_moe_orig(
|
||||||
|
x,
|
||||||
|
token_selected_experts,
|
||||||
|
token_final_scales,
|
||||||
|
x_sf,
|
||||||
|
router_logits,
|
||||||
|
do_finalize,
|
||||||
|
moe_output,
|
||||||
|
)
|
||||||
|
logger.warning_once(
|
||||||
|
'Layer-wise benchmarks: Specifying routing results of "TRTLLM" MoE backend in TEP cases leads to different'
|
||||||
|
" execution path around the topk kernel",
|
||||||
|
key="replace_routing_method_ctx_trtllm_tp",
|
||||||
|
)
|
||||||
|
if balance_method == BalanceMethod.Balanced:
|
||||||
|
token_selected_experts = get_balanced_selection(
|
||||||
|
x.shape[0],
|
||||||
|
top_k,
|
||||||
|
num_experts,
|
||||||
|
torch.int32,
|
||||||
|
x.device,
|
||||||
|
dp_size,
|
||||||
|
dp_rank,
|
||||||
|
ep_size,
|
||||||
|
)
|
||||||
|
elif balance_method == BalanceMethod.ImbalancedRanks:
|
||||||
|
token_selected_experts = get_all_to_one_selection(
|
||||||
|
x.shape[0],
|
||||||
|
top_k,
|
||||||
|
num_experts,
|
||||||
|
balance_ratio,
|
||||||
|
torch.int32,
|
||||||
|
x.device,
|
||||||
|
dp_size,
|
||||||
|
dp_rank,
|
||||||
|
ep_size,
|
||||||
|
)
|
||||||
|
elif balance_method == BalanceMethod.ImbalancedExperts:
|
||||||
|
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
|
||||||
|
x.shape[0],
|
||||||
|
top_k,
|
||||||
|
num_experts,
|
||||||
|
balance_ratio,
|
||||||
|
torch.int32,
|
||||||
|
x.device,
|
||||||
|
dp_size,
|
||||||
|
dp_rank,
|
||||||
|
ep_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Not support balance_method {balance_method}")
|
||||||
|
token_final_scales = get_token_final_scales(
|
||||||
|
token_selected_experts.shape, token_selected_experts.device
|
||||||
|
)
|
||||||
|
router_logits = None
|
||||||
|
final_hidden_states = run_moe_orig(
|
||||||
|
x,
|
||||||
|
token_selected_experts,
|
||||||
|
token_final_scales,
|
||||||
|
x_sf,
|
||||||
|
router_logits,
|
||||||
|
do_finalize,
|
||||||
|
moe_output,
|
||||||
|
)
|
||||||
|
if not do_finalize:
|
||||||
|
final_hidden_states = (
|
||||||
|
final_hidden_states[0],
|
||||||
|
token_final_scales, # WAR for TRTLLMGenFusedMoE bug that it returns wrong `token_final_scales`
|
||||||
|
final_hidden_states[2],
|
||||||
|
)
|
||||||
|
moe_module._routing_results_replaced_at = "make_balanced_run_moe"
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
return balanced_run_moe
|
||||||
|
|
||||||
|
|
||||||
|
def make_forward_impl_check(moe_module, forward_impl_orig):
|
||||||
|
def forward_impl(*args, **kwargs):
|
||||||
|
moe_module._routing_results_replaced_at = None
|
||||||
|
res = forward_impl_orig(*args, **kwargs)
|
||||||
|
assert moe_module._routing_results_replaced_at is not None, (
|
||||||
|
"Routing results are not replaced"
|
||||||
|
)
|
||||||
|
del moe_module._routing_results_replaced_at
|
||||||
|
return res
|
||||||
|
|
||||||
|
return forward_impl
|
||||||
|
|
||||||
|
|
||||||
class RunnerMixin(ABC):
|
class RunnerMixin(ABC):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -398,8 +540,9 @@ class RunnerMixin(ABC):
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float):
|
def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float):
|
||||||
if balance_method == BalanceMethod.NotModified:
|
if balance_method == BalanceMethod.NotModified:
|
||||||
pass
|
yield
|
||||||
elif self.model_config.moe_backend not in [
|
return
|
||||||
|
if self.model_config.moe_backend not in [
|
||||||
"CUTEDSL",
|
"CUTEDSL",
|
||||||
"CUTLASS",
|
"CUTLASS",
|
||||||
"DEEPGEMM",
|
"DEEPGEMM",
|
||||||
@ -410,29 +553,59 @@ class RunnerMixin(ABC):
|
|||||||
f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",'
|
f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",'
|
||||||
f' please set balance_method to "NotModified"'
|
f' please set balance_method to "NotModified"'
|
||||||
)
|
)
|
||||||
elif (
|
original_methods = []
|
||||||
self.model_config.moe_backend == "TRTLLM"
|
dp_rank = self.model_config.mapping.rank // (
|
||||||
and not self.model_config.mapping.enable_attention_dp
|
self.model_config.mapping.world_size // self.model_config.mapping.dp_size
|
||||||
):
|
)
|
||||||
raise NotImplementedError(
|
for layer in self.layers:
|
||||||
'Not support replace routing method for moe_backend "TRTLLM" with attention TP,'
|
moe_module = layer.mlp.experts
|
||||||
' please set balance_method to "NotModified"'
|
|
||||||
|
# Replace `routing_method.apply` for normal cases
|
||||||
|
apply_method_orig = moe_module.routing_method.apply
|
||||||
|
moe_module.routing_method.apply = make_balanced_routing_method(
|
||||||
|
moe_module,
|
||||||
|
apply_method_orig,
|
||||||
|
moe_module.num_experts,
|
||||||
|
balance_method,
|
||||||
|
balance_ratio,
|
||||||
|
self.model_config.mapping.dp_size,
|
||||||
|
dp_rank,
|
||||||
|
self.model_config.mapping.moe_ep_size,
|
||||||
)
|
)
|
||||||
apply_methods_orig = [layer.mlp.experts.routing_method.apply for layer in self.layers]
|
|
||||||
try:
|
# Replace `run_moe` for TRTLLMGenFusedMoE TEP because it does not call `routing_method.apply`
|
||||||
for layer, apply_method_orig in zip(self.layers, apply_methods_orig):
|
if isinstance(moe_module, TRTLLMGenFusedMoE):
|
||||||
layer.mlp.experts.routing_method.apply = make_balanced_routing_method(
|
run_moe_orig = moe_module.run_moe
|
||||||
apply_method_orig,
|
moe_module.run_moe = make_balanced_run_moe(
|
||||||
|
moe_module,
|
||||||
|
run_moe_orig,
|
||||||
|
layer.mlp.experts.routing_method.top_k,
|
||||||
layer.mlp.experts.num_experts,
|
layer.mlp.experts.num_experts,
|
||||||
balance_method,
|
balance_method,
|
||||||
balance_ratio,
|
balance_ratio,
|
||||||
layer.mlp.experts.ep_size,
|
self.model_config.mapping.dp_size,
|
||||||
layer.mlp.experts.ep_rank,
|
dp_rank,
|
||||||
|
self.model_config.mapping.moe_ep_size,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
run_moe_orig = None
|
||||||
|
|
||||||
|
# Replace `forward_impl` to ensure that routing results are replaced
|
||||||
|
forward_impl_orig = moe_module.forward_impl
|
||||||
|
moe_module.forward_impl = make_forward_impl_check(moe_module, forward_impl_orig)
|
||||||
|
|
||||||
|
original_methods.append((apply_method_orig, run_moe_orig, forward_impl_orig))
|
||||||
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
for layer, apply_method_orig in zip(self.layers, apply_methods_orig):
|
for layer, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip(
|
||||||
layer.mlp.experts.routing_method.apply = apply_method_orig
|
self.layers, original_methods
|
||||||
|
):
|
||||||
|
moe_module = layer.mlp.experts
|
||||||
|
moe_module.routing_method.apply = apply_method_orig
|
||||||
|
if isinstance(moe_module, TRTLLMGenFusedMoE):
|
||||||
|
moe_module.run_moe = run_moe_orig
|
||||||
|
moe_module.forward_impl = forward_impl_orig
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_kv_cache_manager(
|
def create_kv_cache_manager(
|
||||||
@ -467,7 +640,7 @@ class RunnerMixin(ABC):
|
|||||||
kv_cache_manager = kv_cache_manager_cls(
|
kv_cache_manager = kv_cache_manager_cls(
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
CacheType.SELFKONLY,
|
CacheType.SELFKONLY,
|
||||||
num_layers=sum(layer_mask),
|
num_layers=num_layers,
|
||||||
num_kv_heads=1,
|
num_kv_heads=1,
|
||||||
head_dim=model_config.pretrained_config.kv_lora_rank
|
head_dim=model_config.pretrained_config.kv_lora_rank
|
||||||
+ model_config.pretrained_config.qk_rope_head_dim,
|
+ model_config.pretrained_config.qk_rope_head_dim,
|
||||||
|
|||||||
@ -48,7 +48,6 @@ def test_deepseek_r1_ctx_tep(llm_root, world_size):
|
|||||||
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
|
model_root / "DeepSeek-R1" / "DeepSeek-R1-0528-FP4-v2",
|
||||||
"--no-enable-attention-dp",
|
"--no-enable-attention-dp",
|
||||||
"--moe-backend=TRTLLM",
|
"--moe-backend=TRTLLM",
|
||||||
"--balance-method=NotModified",
|
|
||||||
],
|
],
|
||||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||||
env={
|
env={
|
||||||
@ -139,7 +138,6 @@ def test_qwen3_next_gen_tep(llm_root, world_size):
|
|||||||
"--layer-indices=6,7",
|
"--layer-indices=6,7",
|
||||||
"--no-enable-attention-dp",
|
"--no-enable-attention-dp",
|
||||||
"--moe-backend=TRTLLM",
|
"--moe-backend=TRTLLM",
|
||||||
"--balance-method=NotModified",
|
|
||||||
],
|
],
|
||||||
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
cwd=llm_root / "examples" / "layer_wise_benchmarks",
|
||||||
env={
|
env={
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user