[None][feat] Add layer wise benchmarks (#8777)

Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
This commit is contained in:
Tailing Yuan 2025-10-30 20:29:34 +08:00 committed by GitHub
parent f666ad2f6b
commit f9c7786dc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 928 additions and 0 deletions

1
.gitignore vendored
View File

@ -82,4 +82,5 @@ compile_commands.json
.devcontainer/docker-compose.override.yml
# Enroot sqsh files
enroot/sw-tensorrt-docker+*.sqsh
enroot/tensorrt_llm.devel.sqsh

View File

@ -0,0 +1,95 @@
# Layer-wise Benchmarks
## Generate profiles
### Run with MPI
**Step 1:** Start a container using Docker, Enroot or others. Please refer to `../../jenkins/current_image_tags.properties` for the Docker image URI.
**Step 2:** In the container, install `tensorrt_llm`:
```bash
pip install -e ../..
```
**Step 3:** In the container, run benchmarks and generate profiles:
```bash
# Run DeepSeek-R1
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml
# Run DeepSeek-V3.2-Exp
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --moe-backend DEEPGEMM
# Run DeepSeek-V3.2-Exp with 32k context length
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --max-num-tokens $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --batch-size 1 --seq-len-q 32769
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --model deepseek-ai/DeepSeek-V3.2-Exp --tokens-per-block 64 --max-seq-len $((32768 + 1024 + 4)) --moe-backend DEEPGEMM --seq-len-kv-cache 32769
# Run with attention TP
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --no-enable-attention-dp
# Run with attention TP and TRTLLMGen
NP=4 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --no-enable-attention-dp --moe-backend TRTLLM
NP=4 TRTLLM_ENABLE_PDL=1 ./mpi_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM
# Run with MTP3
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --batch-size 32 --seq-len-q 4
# Run 4 layers
NP=4 ./mpi_launch.sh ./run_single.sh config_ctx.yaml --layer-indices 5,6,7,8
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --layer-indices 5,6,7,8
# Scale DEP=16 MNNVL to 4 GPUs: reduce the number of experts, uses MNNVL A2A if applicable
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --moe-backend WIDEEP
# Scale TEP=16 to 4 GPUs: reduce the number of attention heads and experts
NP=4 ./mpi_launch.sh ./run_single.sh config_gen.yaml --scaled-from 16 --no-enable-attention-dp
# Run with DeepEP A2A
NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_ctx.yaml --moe-backend WIDEEP
NP=4 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEP ./mpi_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP
```
### Run with Slurm
> Tips: If you have a running job with environment installed, please skip step 1 and 2 and go straight to step 3. In this case, your job must be run with `--container-name aaa`, and if the container name is not "layer_wise_benchmarks" please `export CONTAINER_NAME=aaa`.
**Step 1:** On the controller node, allocate one or multiple nodes, and record the `SLURM_JOB_ID`:
```bash
SLURM_JOB_ID=$(NODES=4 TIME=02:00:00 ./slurm_alloc.sh)
```
Please fill the variables in `./slurm_alloc.sh`.
**Step 2:** Start a container and install `tensorrt_llm`. Run the following command on the controller node:
```bash
SLURM_JOB_ID=$SLURM_JOB_ID ./slurm_init_containers.sh
```
It uses the image recorded in `../../jenkins/current_image_tags.properties`. The image will be downloaded to `../../enroot/` for once.
**Step 3:** Run benchmarks to generate profiles. Run the following command on the controller node, where `NODES` &le; the number of allocated nodes:
```bash
# Run DeepSeek-R1 with wide ep: uses MNNVL A2A if applicable
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP
# Run with attention TP and TRTLLMGen
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_ENABLE_PDL=1 ./slurm_launch.sh ./run_single.sh config_gen.yaml --no-enable-attention-dp --moe-backend TRTLLM
# Run with DeepEPLowLatency
SLURM_JOB_ID=$SLURM_JOB_ID NODES=4 NP=16 TRTLLM_FORCE_ALLTOALL_METHOD=DeepEPLowLatency ./slurm_launch.sh ./run_single.sh config_gen.yaml --moe-backend WIDEEP
# You can run 4-GPU and 8-GPU tasks without reallocate the slurm job
SLURM_JOB_ID=$SLURM_JOB_ID NODES=1 NP=4 ./slurm_launch.sh ./run_single.sh config_ctx.yaml
SLURM_JOB_ID=$SLURM_JOB_ID NODES=2 NP=8 ./slurm_launch.sh ./run_single.sh config_ctx.yaml
```
## Parse profiles
Coming soon.

View File

@ -0,0 +1,21 @@
model: nvidia/DeepSeek-R1-0528-FP4-v2
layer_indices: [5]
run_type: CTX
scaled_from: null
# KV cache related args
tokens_per_block: 32
max_seq_len: 9220 # 8192 + 1024 + 4
enable_attention_dp: true
# Model init args
max_num_tokens: 20480
moe_backend: CUTLASS
use_cuda_graph: false
# Per iteration args
batch_size: 1
seq_len_q: 8193
seq_len_kv_cache: 0
balance_method: Balanced
balance_ratio: null

View File

@ -0,0 +1,21 @@
model: nvidia/DeepSeek-R1-0528-FP4-v2
layer_indices: [5]
run_type: GEN
scaled_from: null
# KV cache related args
tokens_per_block: 32
max_seq_len: 9220 # 8192 + 1024 + 4
enable_attention_dp: true
# Model init args
max_num_tokens: 4096 # MTP3 as max
moe_backend: CUTLASS
use_cuda_graph: true
# Per iteration args
batch_size: 128
seq_len_q: 1 # Set to (1 + MTP)
seq_len_kv_cache: 8193
balance_method: Balanced
balance_ratio: null

View File

@ -0,0 +1,10 @@
#!/bin/bash
set -euo pipefail
# Clear slurm envs
unset $(env | grep -i slurm | awk -F'=' '{print $1}')
unset $(env | grep MPI | awk -F'=' '{print $1}')
set -x
mpirun --allow-run-as-root --np ${NP} "$@"

View File

@ -0,0 +1,159 @@
import argparse
import numpy as np
import nvtx
import torch
import yaml
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.tools.layer_wise_benchmarks.deepseekv3_runner import (
BalanceMethod, DeepSeekV3Runner)
def comma_separated_ints(s):
return [int(x) for x in s.split(",")]
# Parse cmdline
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str)
parser.add_argument("--model", type=str, help="Pretrained model name or path")
parser.add_argument(
"--layer-indices",
type=comma_separated_ints,
help="Comma separated indices of layers, should be a contiguous range")
parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"])
parser.add_argument("--scaled-from", type=int)
# KV cache related args
parser.add_argument("--tokens-per-block", type=int)
parser.add_argument("--max-seq-len", type=int)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--enable-attention-dp",
action="store_true",
dest="enable_attention_dp")
group.add_argument("--no-enable-attention-dp",
action="store_false",
dest="enable_attention_dp")
parser.set_defaults(enable_attention_dp=None)
# Model init args
parser.add_argument("--max-num-tokens", type=int)
parser.add_argument("--moe-backend", type=str)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--use-cuda-graph",
action="store_true",
dest="use_cuda_graph")
group.add_argument("--no-use-cuda-graph",
action="store_false",
dest="use_cuda_graph")
parser.set_defaults(use_cuda_graph=None)
# Per iteration args
parser.add_argument("--batch-size", type=int)
parser.add_argument("--seq-len-q", type=int)
parser.add_argument("--seq-len-kv-cache", type=int)
parser.add_argument("--balance-method", type=str)
parser.add_argument("--balance-ratio", type=float)
args = parser.parse_args()
with open(args.config_path) as f:
config = yaml.safe_load(f)
del args.config_path
for k, v in vars(args).items():
if v is None:
setattr(args, k, config[k])
print(args)
# MPI args
rank = mpi_rank()
world_size = mpi_world_size()
local_rank = local_mpi_rank()
torch.cuda.set_device(local_rank)
# Create KV cache manager
mapping = DeepSeekV3Runner.create_mapping(
enable_attention_dp=args.enable_attention_dp)
max_batch_size = 2048
kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager(
args.model,
mapping,
tokens_per_block=args.tokens_per_block,
max_batch_size=max_batch_size,
max_seq_len=args.max_seq_len,
layer_indices=args.layer_indices)
attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8)
# Create other global objects
AutoTuner.get().clear_cache()
capture_stream = torch.cuda.Stream()
# Create Runner
runner = DeepSeekV3Runner(args.model,
mapping,
moe_backend=args.moe_backend,
layer_indices=args.layer_indices,
scaled_from=args.scaled_from,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
use_cuda_graph=args.use_cuda_graph)
# Warm up
assert args.batch_size <= max_batch_size
assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len
run_pack = runner.create_run_pack(args.run_type,
batch_size=args.batch_size,
seq_len_q=args.seq_len_q,
seq_len_kv_cache=args.seq_len_kv_cache,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace)
runner.replace_routing_method(balance_method=BalanceMethod[args.balance_method],
balance_ratio=args.balance_ratio)
capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream):
run_pack()
with autotune():
run_pack()
torch.cuda.current_stream().wait_stream(capture_stream)
torch.cuda.synchronize()
# Profile: capture graph and replay it
torch.cuda.cudart().cudaProfilerStart()
if args.use_cuda_graph:
with with_multi_stream(True):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g,
stream=capture_stream,
capture_error_mode="global"):
run_pack()
warmup_times = 20
run_times = 100
events = [
torch.cuda.Event(enable_timing=True)
for _ in range(warmup_times + run_times + 1)
]
for i in range(warmup_times + run_times):
events[i].record()
with nvtx.annotate(
f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"):
if args.use_cuda_graph:
g.replay()
else:
run_pack()
events[-1].record()
torch.cuda.synchronize()
# Print statistics
# Print before `cudaProfilerStop` to ensure messages are included in the profile
time_list = [
start.elapsed_time(stop) for start, stop in zip(events, events[1:])
]
time_list = time_list[warmup_times:]
print(f"[RANK {rank}]"
f" min {np.min(time_list) * 1000:.1f}"
f" max {np.max(time_list) * 1000:.1f}"
f" mean {np.mean(time_list) * 1000:.1f}"
f" median {np.median(time_list) * 1000:.1f}"
f" P90 {np.percentile(time_list, 90) * 1000:.1f}"
f" (us)")
torch.cuda.cudart().cudaProfilerStop()

View File

@ -0,0 +1,37 @@
#!/bin/bash
set -euo pipefail
if [ -v OMPI_COMM_WORLD_SIZE ]; then
export WORLD_SIZE=$OMPI_COMM_WORLD_SIZE
export RANK=$OMPI_COMM_WORLD_RANK
export LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
export NODE_RANK=$OMPI_COMM_WORLD_NODE_RANK
fi
if [ "$RANK" -eq 0 ]; then
export TLLM_LOG_LEVEL=INFO
fi
PROFILE=${PROFILE:-1}
GPU_METRICS=${GPU_METRICS:-0}
if [ "$PROFILE" -eq 1 ]; then
PROFILE_FOLDER=profiles/run_single
mkdir -p ${PROFILE_FOLDER}
PROFILE_CMD="nsys profile
-t cuda,nvtx -s none
--cpuctxsw none --cuda-event-trace false
--cuda-graph-trace node
-c cudaProfilerApi --capture-range-end stop
-o ${PROFILE_FOLDER}/run_single_ep${WORLD_SIZE}_rank${RANK}.nsys-rep
--force-overwrite true"
if [ "$GPU_METRICS" -eq 1 ]; then
PROFILE_CMD+=" --gpu-metrics-devices $LOCAL_RANK
--gpu-metrics-frequency 10000"
fi
else
PROFILE_CMD=
fi
set -x
$PROFILE_CMD python3 -u run_single.py "$@"

View File

@ -0,0 +1,20 @@
#!/bin/bash
set -euo pipefail
# ACCOUNT=
# PARTITION=
# EXTRA_ARGS="--gres gpu:4"
TIME=${TIME:-01:00:00}
set -x
salloc -A "$ACCOUNT" \
-p "$PARTITION" \
-N "$NODES" \
--segment "$NODES" \
$EXTRA_ARGS \
-t "$TIME" \
--no-shell \
2>&1 \
| tee >(cat >&2) \
| awk '/Granted job allocation/ {print $NF}'

View File

@ -0,0 +1,53 @@
#!/bin/bash
set -euo pipefail
# CONTAINER_IMAGE=
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..")
if [ "${SLURM_JOB_ID:-}" == "" ]; then
echo "Please set SLURM_JOB_ID"
exit 1
fi
NODES=$(squeue -j $SLURM_JOB_ID -h -o "%D")
if [ "${CONTAINER_IMAGE:-}" == "" ]; then
# Read Docker image from current_image_tags.properties
source ../../jenkins/current_image_tags.properties
MACHINE="$(uname -m)"
if [ "$MACHINE" == "x86_64" ]; then
DOCKER_IMAGE=$LLM_DOCKER_IMAGE
elif [ "$MACHINE" == "aarch64" ]; then
DOCKER_IMAGE=$LLM_SBSA_DOCKER_IMAGE
else
echo "Unsupported machine hardware name \"$MACHINE\""
fi
# Change "urm.nvidia.com/sw-tensorrt-docker/..." to "urm.nvidia.com#sw-tensorrt-docker/..." to bypass credentials
DOCKER_IMAGE="${DOCKER_IMAGE/\//#}"
echo "CONTAINER_IMAGE was not set, using Docker image $DOCKER_IMAGE"
# Import to .sqsh file
SQSH_FILE_NAME=$(echo "$DOCKER_IMAGE" |
awk -F'#' '{print $2}' |
awk -F':' '{gsub(/\//,"+",$1); print $1"+"$2".sqsh"}')
CONTAINER_IMAGE="../../enroot/$SQSH_FILE_NAME"
if [ ! -f "$CONTAINER_IMAGE" ]; then
echo "Container image file $CONTAINER_IMAGE does not exist, importing ..."
srun -N 1 --pty enroot import -o "$CONTAINER_IMAGE" "docker://$DOCKER_IMAGE"
fi
fi
WORKDIR=$(realpath "$(pwd)")
set -x
srun -N "$NODES" \
--ntasks-per-node 1 \
--container-image "$CONTAINER_IMAGE" \
--container-name "layer_wise_benchmarks" \
--container-mounts "$CONTAINER_MOUNTS" \
--container-workdir "$WORKDIR" \
bash -c "pip install -U packaging &&
pip install -r ../../requirements.txt --no-build-isolation &&
pip install -e ../.."

View File

@ -0,0 +1,22 @@
#!/bin/bash
set -euo pipefail
CONTAINER_NAME=${CONTAINER_NAME:-layer_wise_benchmarks}
CONTAINER_MOUNTS=$(realpath "$(pwd)/../.."):$(realpath "$(pwd)/../..")
if [ "${SLURM_JOB_ID:-}" == "" ]; then
echo "Please set SLURM_JOB_ID"
exit 1
fi
WORKDIR=$(realpath "$(pwd)")
set -x
srun --mpi=pmix \
-N "$NODES" \
--ntasks-per-node $(($NP / $NODES)) \
--container-name "$CONTAINER_NAME" \
--container-mounts "$CONTAINER_MOUNTS" \
--container-workdir "$WORKDIR" \
"$@"

View File

@ -0,0 +1,413 @@
import functools
import os
import weakref
from enum import IntEnum
from typing import List, Optional
import torch
import tensorrt_llm._torch.models.modeling_deepseekv3
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_deepseekv3 import (
DeepseekV3DecoderLayer, DeepseekV3Gate)
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import WideEPMoE
from tensorrt_llm._torch.modules.linear import Linear, WeightMode
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._torch.pyexecutor._util import get_kv_cache_manager_cls
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._torch.utils import (AuxStreamType, get_model_extra_attrs,
model_extra_attrs)
from tensorrt_llm._utils import (local_mpi_size, mpi_rank, mpi_world_size,
torch_dtype_to_binding)
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
class BalanceMethod(IntEnum):
NotModified = 1
Balanced = 2
ImbalancedRanks = 3
ImbalancedExperts = 4
def ceil_div(a, b):
return (a + b - 1) // b
def round_up(a, b):
return ceil_div(a, b) * b
class RoutingMethod(DeepseekV3Gate):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.world_size = mpi_world_size()
self.rank = mpi_rank()
self.balance_method = None
self.balance_ratio = None
def apply(self, router_logits) -> (torch.Tensor, torch.Tensor):
token_selected_experts, token_final_scales = super().apply(
router_logits)
num_experts = self.weight.shape[0]
if self.balance_method == BalanceMethod.NotModified:
pass
elif self.balance_method == BalanceMethod.Balanced:
token_selected_experts = RoutingMethod.get_balanced_selection(
token_selected_experts.shape[0],
token_selected_experts.shape[1], num_experts,
token_selected_experts.dtype, self.world_size, self.rank)
elif self.balance_method == BalanceMethod.ImbalancedRanks:
token_selected_experts = RoutingMethod.get_all_to_one_selection(
token_selected_experts.shape[0],
token_selected_experts.shape[1], num_experts,
self.balance_ratio, token_selected_experts.dtype,
self.world_size, self.rank)
elif self.balance_method == BalanceMethod.ImbalancedExperts:
token_selected_experts = RoutingMethod.get_balanced_rank_imbalanced_expert_selection(
token_selected_experts.shape[0],
token_selected_experts.shape[1], num_experts,
self.balance_ratio, token_selected_experts.dtype,
self.world_size, self.rank)
else:
raise NotImplementedError(
f"Not support balance_method {self.balance_method}")
return token_selected_experts, token_final_scales
@functools.cache
@staticmethod
def get_balanced_selection(num_tokens, top_k, num_experts, dtype,
world_size, rank):
a = torch.arange(num_tokens * world_size * top_k,
dtype=dtype,
device="cuda").view(num_tokens, world_size,
top_k)[:, rank]
experts = (a * (num_experts // world_size + 1) + a // num_experts *
(num_experts // world_size)) % num_experts
return experts.contiguous()
@staticmethod
def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio,
world_size, rank):
num_tokens, top_k = imbalanced_experts.shape
dtype = imbalanced_experts.dtype
balanced_experts = RoutingMethod.get_balanced_selection(
num_tokens, top_k, num_experts, dtype, world_size, rank)
num_balanced_tokens = round(num_tokens * balance_ratio)
if balance_ratio != 0:
# Activate all experts
num_balanced_tokens = max(num_balanced_tokens,
ceil_div(num_experts, world_size * top_k))
mixed_experts = balanced_experts.clone()
mixed_experts[num_balanced_tokens:] = imbalanced_experts[
num_balanced_tokens:]
return mixed_experts
@functools.cache
@staticmethod
def get_all_to_one_selection(num_tokens, top_k, num_experts, balance_ratio,
dtype, world_size, rank):
assert num_experts // world_size >= top_k
imbalanced_experts = torch.arange(
num_tokens * top_k, dtype=dtype, device="cuda").view(
num_tokens, top_k) % (num_experts // world_size)
return RoutingMethod.apply_balance_ratio(imbalanced_experts,
num_experts, balance_ratio,
world_size, rank)
@functools.cache
@staticmethod
def get_balanced_rank_imbalanced_expert_selection(num_tokens, top_k,
num_experts,
balance_ratio, dtype,
world_size, rank):
experts_per_rank = num_experts // world_size
activate_experts_per_rank = ceil_div(top_k, world_size)
a = torch.arange(num_tokens * top_k, dtype=dtype,
device="cuda").view(num_tokens, top_k)
narrow_experts = a % (activate_experts_per_rank * world_size)
imbalanced_experts = narrow_experts * experts_per_rank % num_experts + narrow_experts // world_size % experts_per_rank
return RoutingMethod.apply_balance_ratio(imbalanced_experts,
num_experts, balance_ratio,
world_size, rank)
class DeepSeekV3Runner:
def __init__(self, pretrained_model_name_or_path: str, mapping: Mapping, *,
moe_backend: str, layer_indices: List[int],
scaled_from: Optional[int], max_seq_len: int,
max_num_tokens: int, use_cuda_graph: bool):
# Temporally replace the gate class
gate_cls_orig = tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate
tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = RoutingMethod
self.model_config = ModelConfig.from_pretrained(
pretrained_model_name_or_path,
mapping=mapping,
enable_min_latency=False,
use_cuda_graph=use_cuda_graph,
force_dynamic_quantization=False,
spec_config=None,
sparse_attention_config=None, # To be loaded from config
max_num_tokens=max_num_tokens,
max_seq_len=max_seq_len,
moe_max_num_tokens=None,
moe_load_balancer=None,
lora_config=None,
allreduce_strategy=AllReduceStrategy.AUTO,
mm_encoder_only=False,
attn_backend="TRTLLM",
moe_backend=moe_backend,
moe_disable_finalize_fusion=False,
use_low_precision_moe_combine=False,
skip_create_weights_in_init=True,
)
pretrained_config = self.model_config.pretrained_config
if scaled_from is not None:
# To run the problem size of $B$ GPUs on $A$ GPUs, we need:
# (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change.
# (2) MoE: If EP, reduce the number of experts; If TP, reduce head size.
# Maintain the result of AllToAll method selection because it is affected by EP size.
if not mapping.enable_attention_dp:
if hasattr(pretrained_config, "index_n_heads"):
raise NotImplementedError(
"Not support Indexer TP for weak scaling")
pretrained_config.num_attention_heads = pretrained_config.num_attention_heads // scaled_from * mapping.tp_size
pretrained_config.num_key_value_heads = pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size
if mapping.moe_ep_size != mapping.world_size:
raise NotImplementedError("Not support MoE TP for weak scaling")
pretrained_config.n_routed_experts = pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size
select_alltoall_method_type_orig = WideEPMoE.select_alltoall_method_type
def select_alltoall_method_type(cls: type, mapping: Mapping,
top_k: int, *args, **kwargs):
# Replace the condition `mapping.moe_ep_size <= top_k` with `scaled_from <= top_k`
# by replacing `top_k` with `fake_top_k`
if scaled_from <= top_k:
fake_top_k = mapping.moe_ep_size + 1
else:
fake_top_k = mapping.moe_ep_size - 1
assert (mapping.moe_ep_size <= fake_top_k) == (scaled_from
<= top_k)
return select_alltoall_method_type_orig(mapping, fake_top_k,
*args, **kwargs)
WideEPMoE.select_alltoall_method_type = select_alltoall_method_type
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
aux_stream_dict = {
AuxStreamType.Attention: aux_stream_list[0],
AuxStreamType.MoeShared: aux_stream_list[0],
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
}
layers = [
DeepseekV3DecoderLayer(
model_config=self.model_config,
layer_idx=layer_idx,
aux_stream_dict=aux_stream_dict,
) for layer_idx in layer_indices
]
next_layer_layernorm = RMSNorm(
hidden_size=pretrained_config.hidden_size,
eps=pretrained_config.rms_norm_eps,
dtype=pretrained_config.torch_dtype)
# apply_quant_config_exclude_modules
# Please refer to tensorrt_llm/_torch/models/modeling_utils.py
quant_config = self.model_config.quant_config
new_quant_config = QuantConfig(
kv_cache_quant_algo=quant_config.kv_cache_quant_algo)
for layer in layers:
for name, module in layer.named_modules():
name = f"model.layers.{layer.layer_idx}.{name}"
candidates = [name]
if isinstance(module, Linear):
weight_mode = module.weights_loading_config.weight_mode
if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
# sometimes gate and up proj are not packed in the checkpoint,
# but they still share the same exclusion rule
candidates += [
name.replace('gate_up_proj', 'gate_proj'),
name.replace('gate_up_proj', 'up_proj')
]
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
# sometimes q_proj, k_proj and v_proj are not packed in the checkpoint,
# but they still share the same exclusion rule
candidates += [
name.replace('qkv_proj', 'q_proj'),
name.replace('qkv_proj', 'k_proj'),
name.replace('qkv_proj', 'v_proj')
]
is_excluded = any(
quant_config.is_module_excluded_from_quantization(n)
for n in candidates)
if is_excluded and getattr(module, "quant_config",
None) is not None:
module.quant_config = new_quant_config
for name, module in layer.named_modules():
if callable(getattr(module, "create_weights", None)):
module.create_weights()
layer.cuda()
for name, module in layer.named_modules():
if hasattr(module, 'post_load_weights') and not getattr(
module, '_weights_removed', False):
module.post_load_weights()
next_layer_layernorm.cuda()
for layer, next_layer in zip(layers[:-1], layers[1:]):
layer.next_layer_layernorm = next_layer.input_layernorm
layers[-1].next_layer_layernorm = next_layer_layernorm
self.layers = layers
if scaled_from is not None:
WideEPMoE.select_alltoall_method_type = select_alltoall_method_type_orig
tensorrt_llm._torch.models.modeling_deepseekv3.DeepseekV3Gate = gate_cls_orig
def create_run_pack(self,
run_type: str,
batch_size: int,
seq_len_q: int,
seq_len_kv_cache: int,
kv_cache_manager: Optional[KVCacheManager] = None,
attn_workspace: Optional[torch.Tensor] = None):
if self.model_config.moe_backend == "TRTLLM" and os.getenv(
"TRTLLM_ENABLE_PDL") != "1":
raise ValueError(
"Suggest to set TRTLLM_ENABLE_PDL=1 when moe_backend is TRTLLM")
world_size = mpi_world_size()
AttentionCls = get_attention_backend(
self.model_config.attn_backend,
self.model_config.sparse_attention_config)
attn_metadata = AttentionCls.Metadata(
seq_lens=torch.tensor([seq_len_q] * batch_size, dtype=torch.int),
request_ids=list(range(batch_size)),
max_num_requests=kv_cache_manager.max_batch_size,
num_contexts={
"CTX": batch_size,
"GEN": 0
}[run_type],
prompt_lens=[{
"CTX": seq_len_q,
"GEN": seq_len_kv_cache
}[run_type]] * batch_size,
max_num_tokens=batch_size * seq_len_q,
kv_cache_manager=kv_cache_manager,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=[seq_len_kv_cache] * batch_size,
),
workspace=attn_workspace,
mapping=self.model_config.mapping,
sparse_attention_config=self.model_config.sparse_attention_config,
)
attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q
] * world_size
attn_metadata.prepare()
with model_extra_attrs(self.model_config.extra_attrs):
get_model_extra_attrs()["attention_metadata"] = weakref.ref(
attn_metadata)
hidden_size = self.model_config.pretrained_config.hidden_size
position_ids = torch.tensor([
list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) *
batch_size
],
dtype=torch.int32,
device="cuda")
hidden_states = torch.rand((batch_size * seq_len_q, hidden_size),
dtype=torch.bfloat16,
device="cuda")
residual = torch.rand((batch_size * seq_len_q, hidden_size),
dtype=torch.bfloat16,
device="cuda")
def run_pack():
output = hidden_states, residual
with model_extra_attrs(self.model_config.extra_attrs):
with torch.inference_mode():
for layer in self.layers:
output = layer(position_ids, output[0], attn_metadata,
output[1])
return output
return run_pack
def replace_routing_method(self, balance_method: BalanceMethod,
balance_ratio: float):
if self.model_config.moe_backend not in [
"CUTLASS", "DEEPGEMM", "TRTLLM", "WIDEEP"
]:
raise NotImplementedError(
f"Not support replace routing method for moe_backend \"{self.model_config.moe_backend}\","
f" please set balance_method to \"NotModified\"")
for layer in self.layers:
layer.mlp.gate.balance_method = balance_method
layer.mlp.gate.balance_ratio = balance_ratio
@staticmethod
def create_kv_cache_manager(pretrained_model_name_or_path, mapping,
tokens_per_block, max_batch_size, max_seq_len,
layer_indices):
# Please refer to `tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` for `tokens_per_block`
model_config = ModelConfig.from_pretrained(
pretrained_model_name_or_path)
if model_config.enable_flash_mla:
assert tokens_per_block == 64
# Please refer to `tensorrt_llm/_torch/pyexecutor/_util.py` for `kv_cache_manager`
kv_cache_manager_cls = get_kv_cache_manager_cls(model_config)
kv_cache_manager = kv_cache_manager_cls(
KvCacheConfig(
max_tokens=max_batch_size *
round_up(max_seq_len, tokens_per_block),
enable_block_reuse=False,
),
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY,
num_layers=len(layer_indices),
num_kv_heads=1,
head_dim=model_config.pretrained_config.kv_lora_rank +
model_config.pretrained_config.qk_rope_head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
mapping=mapping,
dtype=torch_dtype_to_binding({
None: torch.bfloat16,
"FP8": torch.float8_e4m3fn
}[model_config.quant_config.kv_cache_quant_algo]),
sparse_attn_config=model_config.sparse_attention_config,
)
kv_cache_manager.layer_offsets = {
layer_idx: i
for i, layer_idx in enumerate(layer_indices)
}
kv_cache_manager.add_dummy_requests(list(range(max_batch_size)),
[max_seq_len] * max_batch_size)
return kv_cache_manager
@staticmethod
def create_mapping(enable_attention_dp: bool):
world_size = mpi_world_size()
rank = mpi_rank()
mapping = Mapping(
world_size=world_size,
rank=rank,
gpus_per_node=local_mpi_size(),
cp_size=1,
tp_size=world_size,
pp_size=1,
moe_cluster_size=1,
moe_tp_size=1,
moe_ep_size=world_size,
attn_tp_size=world_size,
attn_cp_size=1,
enable_attention_dp=enable_attention_dp,
)
return mapping

View File

@ -18,6 +18,8 @@ l0_dgx_b200:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
- unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_tep
- unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_gen_scaled_from_16_dep
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]

View File

@ -0,0 +1,74 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from defs.trt_test_alternative import check_call
from utils.cpp_paths import llm_root # noqa: F401
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
def test_deepseek_r1_ctx_tep(llm_root):
check_call([
"./mpi_launch.sh",
"./run_single.sh",
"config_ctx.yaml",
"--no-enable-attention-dp",
"--moe-backend=TRTLLM",
],
cwd=llm_root / "examples" / "layer_wise_benchmarks",
env={
**os.environ,
"NP": "4",
"TRTLLM_ENABLE_PDL": "1",
})
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
def test_deepseek_v32_ctx_dep(llm_root):
check_call([
"./mpi_launch.sh",
"./run_single.sh",
"config_ctx.yaml",
"--model=deepseek-ai/DeepSeek-V3.2-Exp",
"--tokens-per-block=64",
"--moe-backend=DEEPGEMM",
],
cwd=llm_root / "examples" / "layer_wise_benchmarks",
env={
**os.environ,
"NP": "4",
})
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
def test_deepseek_r1_gen_scaled_from_16_dep(llm_root):
check_call([
"./mpi_launch.sh",
"./run_single.sh",
"config_gen.yaml",
"--scaled-from=16",
"--moe-backend=WIDEEP",
"--layer-indices=5,6",
],
cwd=llm_root / "examples" / "layer_wise_benchmarks",
env={
**os.environ,
"NP": "4",
})