[None][feat] Add weights initialization and context phase parser to layer-wise benchmarks (#9667)

Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
This commit is contained in:
Tailing Yuan 2025-12-04 13:41:15 +08:00 committed by GitHub
parent 87e0c8a749
commit 4eed648e22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 230 additions and 63 deletions

View File

@ -135,9 +135,10 @@ python3 parse.py --world-size 4
# Specify the location of the .nsys-rep file
python3 parse.py --profile-dir ./profiles --world-size 4 --rank 0
```
It can parse only GEN phase profiles for now.
# Parse a specific module. The module must appear exactly once in each run.
python3 parse.py --world-size 4 --module MoE
```
You will receive three reports, each containing kernel timing statistics grouped by module:
1. A printed report on stdout

View File

@ -5,7 +5,6 @@ scaled_from: null
# KV cache related args
tokens_per_block: 32
max_seq_len: 9220 # 8192 + 1024 + 4
enable_attention_dp: true
# Model init args

View File

@ -5,7 +5,6 @@ scaled_from: null
# KV cache related args
tokens_per_block: 32
max_seq_len: 9220 # 8192 + 1024 + 4
enable_attention_dp: true
# Model init args

View File

@ -5,6 +5,7 @@ import json
import re
import sqlite3
import subprocess
from collections import defaultdict
from pathlib import Path
import jinja2
@ -14,15 +15,17 @@ import pandas as pd
# Parse cmdline
parser = argparse.ArgumentParser()
parser.add_argument("--profile-dir", type=str, default="profiles")
parser.add_argument("--world-size", "--np", type=int)
parser.add_argument("--world-size", "--np", type=int, required=True)
parser.add_argument("--rank", type=int, default=0)
parser.add_argument("--warmup-times", type=int)
group = parser.add_mutually_exclusive_group(required=False)
parser.add_argument("--module", type=str)
parser.add_argument("--query", type=str)
group = parser.add_mutually_exclusive_group()
group.add_argument("--error-on-unknown-kernel", action="store_true", dest="error_on_unknown_kernel")
group.add_argument(
"--no-error-on-unknown-kernel", action="store_false", dest="error_on_unknown_kernel"
)
parser.set_defaults(error_on_unknown_kernel=None)
parser.set_defaults(error_on_unknown_kernel=False)
args = parser.parse_args()
print(args)
@ -123,6 +126,7 @@ for start, text in df.itertuples(index=False):
"runs": [],
"runs_end": [],
"ranges": [],
"range_in_module": [],
}
)
@ -145,6 +149,28 @@ for start, end, text in df.itertuples(index=False):
else:
problem_set[problem_id]["ranges"].append((start, end, text))
# Determine whether each range is the first range that matches `args.module`,
# and store the result in `problem["range_in_module"]`
for problem in problem_set:
if args.module is not None:
problem["range_in_module"] = [False] * len(problem["ranges"])
run_ids = [bisect.bisect(problem["runs"], start) - 1 for start, _, _ in problem["ranges"]]
run2ranges = defaultdict(list)
for i, run_id in enumerate(run_ids):
run2ranges[run_id].append(i)
for run_id, ranges in run2ranges.items():
ranges = sorted(ranges, key=lambda i: problem["ranges"][i][0])
num_matches = 0
for range_id in ranges:
if problem["ranges"][range_id][2] == args.module:
problem["range_in_module"][range_id] = True
num_matches += 1
if num_matches != 1:
raise ValueError(
f'Module "{args.module}" appears {num_matches} times'
f' in "{problem["text"]}"\'s {run_id + 1}-th run'
)
query = """SELECT name FROM sqlite_master WHERE type = ?"""
df = pd.read_sql_query(query, conn, params=("table",))
tables = df["name"].tolist()
@ -160,12 +186,20 @@ if "CUPTI_ACTIVITY_KIND_MEMSET" in tables:
FROM CUPTI_ACTIVITY_KIND_MEMSET AS T3"""
query = f"""SELECT unified.start, unified.end, unified.demangledName,
R.start AS runtime_start, R.end AS runtime_end,
CGE2.start AS capture_start, CGE2.end AS capture_end
R.start AS capture_start, R.end AS capture_end
FROM ({unified_subquery}) AS unified
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.correlationId = R.correlationId
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND
CGE1.originalGraphNodeId IS NOT NULL
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId"""
WHERE unified.graphNodeId IS NULL"""
if "CUDA_GRAPH_NODE_EVENTS" in tables:
query += f""" UNION ALL
SELECT unified.start, unified.end, unified.demangledName,
R.start AS runtime_start, R.end AS runtime_end,
CGE2.start AS capture_start, CGE2.end AS capture_end
FROM ({unified_subquery}) AS unified
JOIN CUPTI_ACTIVITY_KIND_RUNTIME AS R ON unified.correlationId = R.correlationId
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE1 ON unified.graphNodeId = CGE1.graphNodeId AND
CGE1.originalGraphNodeId IS NOT NULL
LEFT JOIN CUDA_GRAPH_NODE_EVENTS AS CGE2 ON CGE1.originalGraphNodeId = CGE2.graphNodeId"""
df = pd.read_sql_query(query, conn)
kernel_list = []
for (
@ -178,33 +212,35 @@ for (
capture_end,
) in df.itertuples(index=False):
problem_id = bisect.bisect(problem_start, start) - 1
run_id = bisect.bisect(problem_set[problem_id]["runs"], runtime_start) - 1
problem = problem_set[problem_id]
run_id = bisect.bisect(problem["runs"], runtime_start) - 1
if (
run_id == -1
or run_id == len(problem_set[problem_id]["runs"])
or runtime_start >= problem_set[problem_id]["runs_end"][run_id]
or run_id == len(problem["runs"])
or runtime_start >= problem["runs_end"][run_id]
):
run_id = -1
ranges = [
text
for range_start, range_end, text in problem_set[problem_id]["ranges"]
i
for i, (range_start, range_end, text) in enumerate(problem["ranges"])
if capture_start >= range_start and capture_end <= range_end
]
kernel_list.append(
(
problem_id,
run_id,
ranges,
start,
end,
demangledName,
runtime_start,
runtime_end,
capture_start,
capture_end,
if args.module is None or any(problem["range_in_module"][i] for i in ranges):
range_names = [problem["ranges"][i][2] for i in ranges]
kernel_list.append(
(
problem_id,
run_id,
range_names,
start,
end,
demangledName,
runtime_start,
runtime_end,
capture_start,
capture_end,
)
)
)
# TODO: Parse CTX phases
query = "SELECT * FROM StringIds"
df = pd.read_sql_query(query, conn)
@ -267,6 +303,7 @@ parser_keywords = [
("memsetExpertIds", "memsetExpertIdsDevice"),
("blockSum", "blockExpertPrefixSumKernel"),
("globalSum", "globalExpertPrefixSumKernel"),
("globalSumLarge", "globalExpertPrefixSumLargeKernel"),
("mergePrefix", "mergeExpertPrefixSumKernel"),
("fusedBuildExpertMaps", "fusedBuildExpertMapsSortFirstTokenKernel"),
("swiglu", "silu_and_mul_kernel"),
@ -288,6 +325,7 @@ parser_keywords = [
("per_token_quant", "_per_token_quant_and_transform_kernel"),
("triton_fused_layer_norm", "triton_per_fused__to_copy_native_layer_norm_0"),
("flashinferRoPE", "flashinfer::BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel<"),
("flashinferRoPE", "flashinfer::BatchQKApplyRotaryPosIdsCosSinCacheKernel<"),
("fp8_blockscale_gemm", "tensorrt_llm::kernels::fp8_blockscale_gemm"),
("triton_fused_mul_squeeze", "triton_poi_fused_mul_squeeze_0"),
("indexerKCacheScatter", "tensorrt_llm::kernels::indexerKCacheScatterUnifiedKernel"),
@ -308,12 +346,17 @@ parser_keywords = [
("softmax_warp_forward", "softmax_warp_forward<"),
("torchSigmoid", "at::native::sigmoid_kernel_cuda"),
("torchMul", "at::native::binary_internal::MulFunctor<"),
("computeSeqAndPaddingOffsets", "tensorrt_llm::kernels::computeSeqAndPaddingOffsets<"),
("applyBiasRopeUpdateKVCache", "tensorrt_llm::kernels::applyBiasRopeUpdateKVCacheV2<"),
("routingIndicesHistogramScores", "routingRenormalize::routingIndicesHistogramScoresKernel<"),
("routingIndicesHistogram", "routingIndicesHistogramKernel<"),
("routingIndicesOffsets", "routingIndicesOffsetsKernel<"),
("torchReduceSum", ["at::native::reduce_kernel<", "at::native::sum_functor<"]),
("CuteDSLMoePermute", "cute_dsl::moePermuteKernel"),
(
"CuteDSLGemm",
["cute_dsl_kernels", "blockscaled_gemm_persistent"],
),
(
"CuteDSLGroupedGemmSwiglu",
["cute_dsl_kernels", "blockscaled_contiguous_grouped_gemm_swiglu_fusion"],
@ -342,7 +385,11 @@ def parse_kernel_name(demangledName):
warned_names.add(name)
if args.error_on_unknown_kernel:
raise NotImplementedError(f"Unknown kernel name: {name}")
return name[:30]
if "<" in name:
name = name[: name.index("<")]
if "(" in name:
name = name[: name.index("(")]
return name
converted_seqs = []
@ -390,6 +437,9 @@ for problem_id, converted_seq in enumerate(converted_seqs):
merged_data[cur][problem_id] = t
cur += 1
print("Run args:")
print(run_args)
print("Problem set:")
for problem in problem_set:
print(
@ -401,7 +451,7 @@ stack = []
csv_data = [["", *[problem["text"] for problem in problem_set]]]
js_data = []
js_stack = [js_data]
max_title_len = max((len(title) - 1) * 3 + len(title[-1]) for title in merged_title)
max_title_len = max((len(title) - 1) * 3 + len(title[-1][:40]) for title in merged_title)
for title, time_data in zip(merged_title, merged_data):
while stack != list(title[: len(stack)]):
level_title = stack[-1]
@ -418,11 +468,13 @@ for title, time_data in zip(merged_title, merged_data):
stack.append(level_title)
level = len(stack)
print("|--" * (level - 1) + level_title)
csv_data.append(["|--" * (level - 1) + level_title])
csv_data.append(["|--" * (level - 1) + level_title] + [""] * len(problem_set))
js_stack.append([])
level = len(stack) + 1
print(
"|--" * (level - 1) + title[-1] + " " * (max_title_len - (level - 1) * 3 - len(title[-1])),
"|--" * (level - 1)
+ title[-1][:40]
+ " " * (max_title_len - (level - 1) * 3 - len(title[-1][:40])),
*[f"{x / 1000:-6.1f}" for x in time_data],
)
csv_data.append(["|--" * (level - 1) + title[-1], *[f"{x / 1000:.1f}" for x in time_data]])
@ -442,8 +494,21 @@ js_header_config = [{"name": problem["text"]} for problem in problem_set]
loader = jinja2.FileSystemLoader(Path(__file__).parent)
template = jinja2.Environment(loader=loader).get_template("template.html")
with html_file_path.open("w") as f:
f.write(
template.render(
headerConfig=js_header_config, rawData=js_data, runArgs=json.dumps(run_args, indent=4)
)
configText = (
"Run:\n" + json.dumps(run_args, indent=4) + "\n\nParse:\n" + json.dumps(args.__dict__)
)
f.write(template.render(headerConfig=js_header_config, rawData=js_data, configText=configText))
if args.query is not None:
print("Query:")
for query in args.query.split(","):
query = query.strip()
query_matched = [0.0] * len(problem_set)
for title, time_data in zip(merged_title, merged_data):
if query in ".".join(title):
for i, x in enumerate(time_data):
query_matched[i] += x
print(
query + " " * (max_title_len - len(query)),
*[f"{x / 1000:-6.1f}" for x in query_matched],
)

View File

@ -1,6 +1,7 @@
import argparse
import itertools
import json
import os
import numpy as np
import nvtx
@ -10,6 +11,7 @@ import yaml
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.logger import logger
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls, mark_ranges
@ -36,7 +38,7 @@ parser.add_argument("--scaled-from", type=int)
parser.add_argument("--max-batch-size", type=int)
parser.add_argument("--tokens-per-block", type=int)
parser.add_argument("--max-seq-len", type=int)
group = parser.add_mutually_exclusive_group(required=False)
group = parser.add_mutually_exclusive_group()
group.add_argument("--enable-attention-dp", action="store_true", dest="enable_attention_dp")
group.add_argument("--no-enable-attention-dp", action="store_false", dest="enable_attention_dp")
parser.set_defaults(enable_attention_dp=None)
@ -44,7 +46,21 @@ parser.set_defaults(enable_attention_dp=None)
parser.add_argument("--max-num-tokens", type=int)
parser.add_argument("--moe-backend", type=str)
parser.add_argument("--moe-max-num-tokens", type=int)
group = parser.add_mutually_exclusive_group(required=False)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--use-low-precision-moe-combine", action="store_true", dest="use_low_precision_moe_combine"
)
group.add_argument(
"--no-use-low-precision-moe-combine",
action="store_false",
dest="use_low_precision_moe_combine",
)
parser.set_defaults(use_low_precision_moe_combine=None)
group = parser.add_mutually_exclusive_group()
group.add_argument("--enable-autotuner", action="store_true", dest="enable_autotuner")
group.add_argument("--no-enable-autotuner", action="store_false", dest="enable_autotuner")
parser.set_defaults(enable_autotuner=None)
group = parser.add_mutually_exclusive_group()
group.add_argument("--use-cuda-graph", action="store_true", dest="use_cuda_graph")
group.add_argument("--no-use-cuda-graph", action="store_false", dest="use_cuda_graph")
parser.set_defaults(use_cuda_graph=None)
@ -86,8 +102,29 @@ if config:
# Set default values
if args.max_batch_size is None:
args.max_batch_size = max(args.batch_size_list)
if args.max_seq_len is None:
args.max_seq_len = max(args.seq_len_q_list) + max(args.seq_len_kv_cache_list)
if args.enable_attention_dp is None:
args.enable_attention_dp = False
if args.max_num_tokens is None:
args.max_num_tokens = args.max_batch_size * max(args.seq_len_q_list)
if args.run_type == "GEN":
ctx_batch_size = max(1, max(20480, args.max_num_tokens) // max(args.seq_len_kv_cache_list))
args.max_num_tokens = max(
args.max_num_tokens, ctx_batch_size * max(args.seq_len_kv_cache_list)
)
else:
if args.run_type == "GEN":
ctx_batch_size = max(1, args.max_num_tokens // max(args.seq_len_kv_cache_list))
assert args.max_num_tokens >= ctx_batch_size * max(args.seq_len_kv_cache_list), (
"Max_num_tokens is too small to prefill KV cache"
)
if args.use_low_precision_moe_combine is None:
args.use_low_precision_moe_combine = False
if args.enable_autotuner is None:
args.enable_autotuner = True
if args.use_cuda_graph is None:
args.use_cuda_graph = False
print(args)
# MPI args
@ -97,7 +134,7 @@ local_rank = local_mpi_rank()
torch.cuda.set_device(local_rank)
# Create KV cache manager
mark_ranges()
logger.info("Layer-wise benchmarks: Create KV cache manager")
Runner = get_runner_cls(args.model)
mapping = Runner.create_mapping(enable_attention_dp=args.enable_attention_dp)
kv_cache_manager = Runner.create_kv_cache_manager(
@ -109,12 +146,15 @@ kv_cache_manager = Runner.create_kv_cache_manager(
layer_indices=args.layer_indices,
)
attn_workspace = torch.empty((0,), device="cuda", dtype=torch.int8)
logger.info("Layer-wise benchmarks: Create KV cache manager ... Done")
# Create other global objects
AutoTuner.get().clear_cache()
capture_stream = torch.cuda.Stream()
mark_ranges()
# Create Runner
# Create runner
logger.info("Layer-wise benchmarks: Create runner")
runner = Runner(
args.model,
mapping,
@ -124,8 +164,10 @@ runner = Runner(
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
moe_max_num_tokens=args.moe_max_num_tokens,
use_low_precision_moe_combine=args.use_low_precision_moe_combine,
use_cuda_graph=args.use_cuda_graph,
)
logger.info("Layer-wise benchmarks: Create runner ... Done")
# Warm up
for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
@ -146,9 +188,11 @@ for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
]:
assert batch_size <= args.max_batch_size
assert seq_len_q + seq_len_kv_cache <= args.max_seq_len
assert batch_size * seq_len_q <= args.max_num_tokens
run_pack = runner.create_run_pack(
args.run_type,
batch_size=batch_size,
request_id_begin=0,
seq_len_q=seq_len_q,
seq_len_kv_cache=seq_len_kv_cache,
kv_cache_manager=kv_cache_manager,
@ -160,9 +204,31 @@ for autotune_flag, batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in [
capture_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(capture_stream):
if autotune_flag:
with autotune():
run_pack()
run_pack()
if args.enable_autotuner:
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
with autotune(cache_path=cache_path, rank=rank):
run_pack()
if args.run_type == "GEN":
logger.info("Layer-wise benchmarks: Prefill KV cache")
ctx_seq_len_q = max(args.seq_len_kv_cache_list)
assert ctx_batch_size <= args.max_batch_size
assert ctx_seq_len_q + 0 <= args.max_seq_len
assert ctx_batch_size * ctx_seq_len_q <= args.max_num_tokens
max_batch_size = max(args.batch_size_list)
for request_id_begin in range(0, max_batch_size, ctx_batch_size):
ctx_run_pack = runner.create_run_pack(
"CTX",
batch_size=min(ctx_batch_size, max_batch_size - request_id_begin),
request_id_begin=request_id_begin,
seq_len_q=ctx_seq_len_q,
seq_len_kv_cache=0,
kv_cache_manager=kv_cache_manager,
attn_workspace=attn_workspace,
)
ctx_run_pack(check=True)
logger.info("Layer-wise benchmarks: Prefill KV cache ... Done")
else:
run_pack(check=True)
torch.cuda.current_stream().wait_stream(capture_stream)
torch.cuda.synchronize()
@ -189,6 +255,7 @@ for batch_size, seq_len_q, seq_len_kv_cache, balance_ratio in itertools.product(
run_pack = runner.create_run_pack(
args.run_type,
batch_size=batch_size,
request_id_begin=0,
seq_len_q=seq_len_q,
seq_len_kv_cache=seq_len_kv_cache,
kv_cache_manager=kv_cache_manager,

View File

@ -17,15 +17,21 @@ PROFILE_DIR=${PROFILE_DIR:-profiles}
mkdir -p ${PROFILE_DIR}
PROFILE=${PROFILE:-1}
BACKTRACE=${BACKTRACE:-0}
GPU_METRICS=${GPU_METRICS:-0}
if [ "$PROFILE" -eq 1 ]; then
PROFILE_CMD="nsys profile
-t cuda,nvtx -s none
-t cuda,nvtx
--cpuctxsw none --cuda-event-trace false
--cuda-graph-trace node
-c cudaProfilerApi --capture-range-end stop
-o ${PROFILE_DIR}/report_np${WORLD_SIZE}_rank${RANK}.nsys-rep
--force-overwrite true"
if [ "$BACKTRACE" -eq 1 ]; then
PROFILE_CMD+=" --python-backtrace=cuda --cudabacktrace all"
else
PROFILE_CMD+=" -s none"
fi
if [ "$GPU_METRICS" -eq 1 ]; then
PROFILE_CMD+=" --gpu-metrics-devices $LOCAL_RANK
--gpu-metrics-frequency 10000"

View File

@ -272,7 +272,7 @@
<div class="config-panel">
<details>
<summary>🔧 Configuration (Click to expand)</summary>
<pre>{{ runArgs }}</pre>
<pre>{{ configText }}</pre>
</details>
</div>
@ -412,7 +412,7 @@
// Name Col
html += `<td onclick="toggleFold('${uniqueId}', event)" style="padding-left: ${10 + level * 20}px">
<div class="cell-content">${toggleIcon} ${node.name}</div>
<div class="cell-content" title="${node.name}">${toggleIcon} ${node.name}</div>
</td>`;
// Data Cols
@ -747,15 +747,15 @@
if (container) {
container.innerHTML = '';
for (let i = 1; i < maxDepth; i++) {
const btn = document.createElement('button');
btn.innerText = 'Level ' + i;
btn.onclick = function() { expandToLevel(i); };
// Copy styles from other buttons if possible, or rely on CSS
// The existing buttons don't have inline styles, just CSS class likely.
// But wait, <button> elements in this file might be styled by tag selector.
container.appendChild(btn);
// Add a space
container.appendChild(document.createTextNode(' '));
const btn = document.createElement('button');
btn.innerText = 'Level ' + i;
btn.onclick = function() { expandToLevel(i); };
// Copy styles from other buttons if possible, or rely on CSS
// The existing buttons don't have inline styles, just CSS class likely.
// But wait, <button> elements in this file might be styled by tag selector.
container.appendChild(btn);
// Add a space
container.appendChild(document.createTextNode(' '));
}
}
}

View File

@ -5,6 +5,7 @@ import torch
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3DecoderLayer
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._torch.pyexecutor.model_loader import initialize_dummy_weights
from tensorrt_llm._torch.utils import AuxStreamType
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.mapping import Mapping
@ -29,8 +30,10 @@ class DeepSeekV3Runner(RunnerMixin, RunnerBase):
max_seq_len: int,
max_num_tokens: int,
moe_max_num_tokens: int,
use_low_precision_moe_combine: bool,
use_cuda_graph: bool,
):
super().__init__()
self.model_config = ModelConfig.from_pretrained(
pretrained_model_name_or_path,
mapping=mapping,
@ -49,7 +52,7 @@ class DeepSeekV3Runner(RunnerMixin, RunnerBase):
attn_backend="TRTLLM",
moe_backend=moe_backend,
moe_disable_finalize_fusion=False,
use_low_precision_moe_combine=False,
use_low_precision_moe_combine=use_low_precision_moe_combine,
skip_create_weights_in_init=True,
)
pretrained_config = self.model_config.pretrained_config
@ -83,12 +86,14 @@ class DeepSeekV3Runner(RunnerMixin, RunnerBase):
if callable(getattr(module, "create_weights", None)):
module.create_weights()
layer.cuda()
initialize_dummy_weights(layer)
for module in layer.modules():
if hasattr(module, "post_load_weights") and not getattr(
module, "_weights_removed", False
):
module.post_load_weights()
next_layer_layernorm.cuda()
initialize_dummy_weights(next_layer_layernorm)
for layer, next_layer in zip(layers[:-1], layers[1:]):
layer.next_layer_layernorm = next_layer.input_layernorm
layers[-1].next_layer_layernorm = next_layer_layernorm

View File

@ -5,6 +5,7 @@ import torch
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_qwen3_next import ALL_DECODER_LAYER_TYPES
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._torch.pyexecutor.model_loader import initialize_dummy_weights
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.mapping import Mapping
@ -28,8 +29,10 @@ class Qwen3NextRunner(RunnerMixin, RunnerBase):
max_seq_len: int,
max_num_tokens: int,
moe_max_num_tokens: int,
use_low_precision_moe_combine: bool,
use_cuda_graph: bool,
):
super().__init__()
self.model_config = ModelConfig.from_pretrained(
pretrained_model_name_or_path,
mapping=mapping,
@ -48,7 +51,7 @@ class Qwen3NextRunner(RunnerMixin, RunnerBase):
attn_backend="TRTLLM",
moe_backend=moe_backend,
moe_disable_finalize_fusion=False,
use_low_precision_moe_combine=False,
use_low_precision_moe_combine=use_low_precision_moe_combine,
skip_create_weights_in_init=True,
)
pretrained_config = self.model_config.pretrained_config
@ -77,12 +80,14 @@ class Qwen3NextRunner(RunnerMixin, RunnerBase):
if callable(getattr(module, "create_weights", None)):
module.create_weights()
layer.cuda()
initialize_dummy_weights(layer)
for module in layer.modules():
if hasattr(module, "post_load_weights") and not getattr(
module, "_weights_removed", False
):
module.post_load_weights()
next_layer_layernorm.cuda()
initialize_dummy_weights(next_layer_layernorm)
for layer, next_layer in zip(layers[:-1], layers[1:]):
layer.next_layer_layernorm = next_layer.input_layernorm
layers[-1].next_layer_layernorm = next_layer_layernorm

View File

@ -316,7 +316,9 @@ class RunnerMixin(ABC):
def create_run_pack(
self,
run_type: str,
*,
batch_size: int,
request_id_begin: int,
seq_len_q: int,
seq_len_kv_cache: int,
kv_cache_manager: KVCacheManager,
@ -330,7 +332,7 @@ class RunnerMixin(ABC):
)
attn_metadata = AttentionCls.Metadata(
seq_lens=torch.tensor([seq_len_q] * batch_size, dtype=torch.int),
request_ids=list(range(batch_size)),
request_ids=list(range(request_id_begin, request_id_begin + batch_size)),
max_num_requests=kv_cache_manager.max_batch_size,
num_contexts={
"CTX": batch_size,
@ -355,8 +357,6 @@ class RunnerMixin(ABC):
)
attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q] * world_size
attn_metadata.prepare()
with model_extra_attrs(self.model_config.extra_attrs):
get_model_extra_attrs()["attention_metadata"] = weakref.ref(attn_metadata)
hidden_size = self.model_config.pretrained_config.hidden_size
position_ids = torch.tensor(
[list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) * batch_size],
@ -377,12 +377,20 @@ class RunnerMixin(ABC):
mamba_metadata.prepare(attn_metadata)
kwargs["mamba_metadata"] = mamba_metadata
def run_pack():
def run_pack(*, check=False):
output = hidden_states, residual
with model_extra_attrs(self.model_config.extra_attrs):
get_model_extra_attrs()["attention_metadata"] = weakref.ref(attn_metadata)
with torch.inference_mode():
for layer in self.layers:
output = layer(position_ids, output[0], attn_metadata, output[1], **kwargs)
if check:
if output[0].isnan().any():
raise ValueError("Has nan, please fix weights initialization")
if output[0].isinf().any():
raise ValueError("Has inf, please fix weights initialization")
if (output[0] == 0).sum() > 0.5 * output[0].numel():
raise ValueError("Too many zeros, please fix weights initialization")
return output
return run_pack

View File

@ -27,6 +27,10 @@ def test_deepseek_r1_ctx_dep(llm_root, world_size):
"PROFILE_DIR": profile_dir,
},
)
check_call(
["python3", "parse.py", "--profile-dir", profile_dir, f"--world-size={world_size}"],
cwd=llm_root / "examples" / "layer_wise_benchmarks",
)
@pytest.mark.parametrize("world_size", [1, 4])
@ -54,6 +58,10 @@ def test_deepseek_r1_ctx_tep(llm_root, world_size):
"TRTLLM_ENABLE_PDL": "1",
},
)
check_call(
["python3", "parse.py", "--profile-dir", profile_dir, f"--world-size={world_size}"],
cwd=llm_root / "examples" / "layer_wise_benchmarks",
)
@pytest.mark.parametrize("world_size", [1, 4])
@ -79,6 +87,10 @@ def test_deepseek_v32_ctx_dep(llm_root, world_size):
"PROFILE_DIR": profile_dir,
},
)
check_call(
["python3", "parse.py", "--profile-dir", profile_dir, f"--world-size={world_size}"],
cwd=llm_root / "examples" / "layer_wise_benchmarks",
)
@pytest.mark.parametrize("world_size", [4])