mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-23 04:03:22 +08:00
863 lines
34 KiB
Python
863 lines
34 KiB
Python
import contextlib
|
|
import functools
|
|
import itertools
|
|
import unittest.mock
|
|
import weakref
|
|
from enum import IntEnum
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
import tensorrt_llm._torch.model_config
|
|
import tensorrt_llm.bindings
|
|
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
|
from tensorrt_llm._torch.metadata import KVCacheParams
|
|
from tensorrt_llm._torch.model_config import ModelConfig
|
|
from tensorrt_llm._torch.models.modeling_utils import PostInitCaller, skip_forward
|
|
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
|
|
from tensorrt_llm._torch.modules.fused_moe.fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
|
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import WideEPMoE
|
|
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
|
|
from tensorrt_llm._torch.pyexecutor._util import get_kv_cache_manager_cls
|
|
from tensorrt_llm._torch.pyexecutor.config_utils import (
|
|
is_mla,
|
|
is_nemotron_hybrid,
|
|
is_qwen3_next,
|
|
load_pretrained_config,
|
|
)
|
|
from tensorrt_llm._torch.pyexecutor.model_loader import (
|
|
ModelLoader,
|
|
_construct_checkpoint_loader,
|
|
validate_and_set_kv_cache_quant,
|
|
validate_and_set_mamba_ssm_cache_dtype,
|
|
)
|
|
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
|
from tensorrt_llm._torch.utils import get_model_extra_attrs, model_extra_attrs
|
|
from tensorrt_llm._utils import local_mpi_size, mpi_rank, mpi_world_size, torch_dtype_to_binding
|
|
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, MoeConfig, TorchLlmArgs
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.mapping import Mapping
|
|
|
|
|
|
class BalanceMethod(IntEnum):
|
|
NotModified = 1
|
|
Balanced = 2
|
|
ImbalancedRanks = 3
|
|
ImbalancedExperts = 4
|
|
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
|
|
def round_up(a, b):
|
|
return ceil_div(a, b) * b
|
|
|
|
|
|
def get_balanced_selection_no_cache(
|
|
num_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
|
):
|
|
token_id = torch.arange(dp_rank * num_tokens * top_k, (dp_rank + 1) * num_tokens * top_k).view(
|
|
num_tokens, top_k
|
|
)
|
|
experts_per_rank = num_experts // ep_size
|
|
token_selected_experts = (token_id % ep_size) * experts_per_rank + (
|
|
token_id // ep_size
|
|
) % experts_per_rank
|
|
token_selected_experts = token_selected_experts.sort(dim=-1).values
|
|
return token_selected_experts.contiguous().to(dtype=dtype, device=device)
|
|
|
|
|
|
get_balanced_selection = functools.cache(get_balanced_selection_no_cache)
|
|
|
|
|
|
def test_get_balanced_selection():
|
|
dtype = torch.long
|
|
for num_tokens, num_experts, enable_attention_dp in itertools.product(
|
|
range(1, 35), range(1, 35), [False, True]
|
|
):
|
|
print(f"{num_tokens=} {num_experts=} {enable_attention_dp=}")
|
|
for top_k in range(1, min(10, num_experts) + 1):
|
|
for world_size in range(1, 35):
|
|
dp_size = world_size if enable_attention_dp else 1
|
|
ep_size = world_size
|
|
if num_experts % ep_size == 0:
|
|
tokens_per_expert = torch.zeros(num_experts)
|
|
for dp_rank in range(dp_size):
|
|
token_selected_experts = get_balanced_selection_no_cache(
|
|
num_tokens, top_k, num_experts, dtype, "cpu", dp_size, dp_rank, ep_size
|
|
)
|
|
sorted_selection = token_selected_experts.sort(dim=-1).values
|
|
if (sorted_selection[:, :-1] == sorted_selection[:, 1:]).any():
|
|
raise ValueError(f"duplicated experts on rank {dp_rank}")
|
|
experts_per_rank = num_experts // ep_size
|
|
tokens_per_rank = (
|
|
(token_selected_experts // experts_per_rank)
|
|
.view(-1)
|
|
.bincount(minlength=ep_size)
|
|
)
|
|
if tokens_per_rank.max() - tokens_per_rank.min() > 1:
|
|
raise ValueError(f"tokens sent from rank {dp_rank} is not balanced")
|
|
unique_tokens_per_rank = (
|
|
(
|
|
torch.arange(ep_size).view(ep_size, 1, 1)
|
|
== token_selected_experts // experts_per_rank
|
|
)
|
|
.any(dim=2)
|
|
.sum(dim=1)
|
|
)
|
|
if unique_tokens_per_rank.max() - unique_tokens_per_rank.min() > 1:
|
|
raise ValueError(
|
|
f"tokens sent from rank {dp_rank} is not balanced after removing duplicates"
|
|
)
|
|
tokens_per_expert += token_selected_experts.view(-1).bincount(
|
|
minlength=num_experts
|
|
)
|
|
if tokens_per_expert.max() - tokens_per_expert.min() > 1:
|
|
raise ValueError("tokens per expert is not balanced")
|
|
|
|
|
|
def get_num_balanced_tokens(num_tokens, top_k, num_experts, dp_size, balance_ratio):
|
|
if balance_ratio == 0.0:
|
|
return 0
|
|
else:
|
|
# Activate all experts
|
|
min_num_balanced_tokens = min(num_tokens, ceil_div(num_experts, dp_size * top_k))
|
|
return min_num_balanced_tokens + round(
|
|
(num_tokens - min_num_balanced_tokens) * balance_ratio
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def get_all_to_one_selection(
|
|
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
|
|
):
|
|
num_balanced_tokens = get_num_balanced_tokens(
|
|
num_tokens, top_k, num_experts, dp_size, balance_ratio
|
|
)
|
|
balanced_experts = get_balanced_selection_no_cache(
|
|
num_balanced_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
|
)
|
|
num_imbalanced_tokens = num_tokens - num_balanced_tokens
|
|
experts_per_rank = num_experts // ep_size
|
|
if top_k > experts_per_rank:
|
|
raise ValueError(
|
|
"Cannot send all tokens to a single rank because `top_k > experts_per_rank`"
|
|
)
|
|
imbalanced_experts = (
|
|
torch.arange(
|
|
dp_rank * num_imbalanced_tokens * top_k,
|
|
(dp_rank + 1) * num_imbalanced_tokens * top_k,
|
|
dtype=dtype,
|
|
device=device,
|
|
).view(num_imbalanced_tokens, top_k)
|
|
% experts_per_rank
|
|
)
|
|
mixed_experts = torch.cat([balanced_experts, imbalanced_experts])
|
|
return mixed_experts.sort(dim=-1).values
|
|
|
|
|
|
@functools.cache
|
|
def get_balanced_rank_imbalanced_expert_selection(
|
|
num_tokens, top_k, num_experts, balance_ratio, dtype, device, dp_size, dp_rank, ep_size
|
|
):
|
|
num_balanced_tokens = get_num_balanced_tokens(
|
|
num_tokens, top_k, num_experts, dp_size, balance_ratio
|
|
)
|
|
balanced_experts = get_balanced_selection_no_cache(
|
|
num_balanced_tokens, top_k, num_experts, dtype, device, dp_size, dp_rank, ep_size
|
|
)
|
|
num_imbalanced_tokens = num_tokens - num_balanced_tokens
|
|
experts_per_rank = num_experts // ep_size
|
|
active_experts_per_rank = ceil_div(top_k, ep_size)
|
|
# Select expert from [0, active_experts_per_rank * ep_size),
|
|
# then scale to [0, experts_per_rank * ep_size)
|
|
narrow_experts = get_balanced_selection_no_cache(
|
|
num_imbalanced_tokens,
|
|
top_k,
|
|
active_experts_per_rank * ep_size,
|
|
dtype,
|
|
device,
|
|
dp_size,
|
|
dp_rank,
|
|
ep_size,
|
|
)
|
|
imbalanced_experts = (
|
|
narrow_experts // active_experts_per_rank * experts_per_rank
|
|
+ narrow_experts % active_experts_per_rank
|
|
)
|
|
mixed_experts = torch.cat([balanced_experts, imbalanced_experts])
|
|
return mixed_experts.sort(dim=-1).values
|
|
|
|
|
|
def make_balanced_routing_method(
|
|
moe_module,
|
|
apply_method_orig,
|
|
num_experts,
|
|
balance_method,
|
|
balance_ratio,
|
|
dp_size,
|
|
dp_rank,
|
|
ep_size,
|
|
):
|
|
def balanced_routing_method(router_logits):
|
|
token_selected_experts, token_final_scales = apply_method_orig(router_logits)
|
|
assert moe_module._routing_results_replaced_at in [None, "make_balanced_routing_method"]
|
|
if balance_method == BalanceMethod.Balanced:
|
|
token_selected_experts = get_balanced_selection(
|
|
token_selected_experts.shape[0],
|
|
token_selected_experts.shape[1],
|
|
num_experts,
|
|
token_selected_experts.dtype,
|
|
token_selected_experts.device,
|
|
dp_size,
|
|
dp_rank,
|
|
ep_size,
|
|
)
|
|
elif balance_method == BalanceMethod.ImbalancedRanks:
|
|
token_selected_experts = get_all_to_one_selection(
|
|
token_selected_experts.shape[0],
|
|
token_selected_experts.shape[1],
|
|
num_experts,
|
|
balance_ratio,
|
|
token_selected_experts.dtype,
|
|
token_selected_experts.device,
|
|
dp_size,
|
|
dp_rank,
|
|
ep_size,
|
|
)
|
|
elif balance_method == BalanceMethod.ImbalancedExperts:
|
|
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
|
|
token_selected_experts.shape[0],
|
|
token_selected_experts.shape[1],
|
|
num_experts,
|
|
balance_ratio,
|
|
token_selected_experts.dtype,
|
|
token_selected_experts.device,
|
|
dp_size,
|
|
dp_rank,
|
|
ep_size,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Not support balance_method {balance_method}")
|
|
moe_module._routing_results_replaced_at = "make_balanced_routing_method"
|
|
return token_selected_experts, token_final_scales
|
|
|
|
return balanced_routing_method
|
|
|
|
|
|
@functools.cache
|
|
def get_token_final_scales(shape, device):
|
|
return torch.full(shape, 1.0 / shape[-1], dtype=torch.bfloat16, device=device)
|
|
|
|
|
|
def make_balanced_run_moe(
|
|
moe_module,
|
|
run_moe_orig,
|
|
top_k,
|
|
num_experts,
|
|
balance_method,
|
|
balance_ratio,
|
|
dp_size,
|
|
dp_rank,
|
|
ep_size,
|
|
):
|
|
def balanced_run_moe(
|
|
x, token_selected_experts, token_final_scales, x_sf, router_logits, do_finalize, moe_output
|
|
):
|
|
if moe_module._routing_results_replaced_at is not None:
|
|
return run_moe_orig(
|
|
x,
|
|
token_selected_experts,
|
|
token_final_scales,
|
|
x_sf,
|
|
router_logits,
|
|
do_finalize,
|
|
moe_output,
|
|
)
|
|
logger.warning_once(
|
|
'Layer-wise benchmarks: Specifying routing results of "TRTLLM" MoE backend in TEP cases leads to different'
|
|
" execution path around the topk kernel",
|
|
key="replace_routing_method_ctx_trtllm_tp",
|
|
)
|
|
if balance_method == BalanceMethod.Balanced:
|
|
token_selected_experts = get_balanced_selection(
|
|
x.shape[0],
|
|
top_k,
|
|
num_experts,
|
|
torch.int32,
|
|
x.device,
|
|
dp_size,
|
|
dp_rank,
|
|
ep_size,
|
|
)
|
|
elif balance_method == BalanceMethod.ImbalancedRanks:
|
|
token_selected_experts = get_all_to_one_selection(
|
|
x.shape[0],
|
|
top_k,
|
|
num_experts,
|
|
balance_ratio,
|
|
torch.int32,
|
|
x.device,
|
|
dp_size,
|
|
dp_rank,
|
|
ep_size,
|
|
)
|
|
elif balance_method == BalanceMethod.ImbalancedExperts:
|
|
token_selected_experts = get_balanced_rank_imbalanced_expert_selection(
|
|
x.shape[0],
|
|
top_k,
|
|
num_experts,
|
|
balance_ratio,
|
|
torch.int32,
|
|
x.device,
|
|
dp_size,
|
|
dp_rank,
|
|
ep_size,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Not support balance_method {balance_method}")
|
|
token_final_scales = get_token_final_scales(
|
|
token_selected_experts.shape, token_selected_experts.device
|
|
)
|
|
router_logits = None
|
|
final_hidden_states = run_moe_orig(
|
|
x,
|
|
token_selected_experts,
|
|
token_final_scales,
|
|
x_sf,
|
|
router_logits,
|
|
do_finalize,
|
|
moe_output,
|
|
)
|
|
if not do_finalize:
|
|
final_hidden_states = (
|
|
final_hidden_states[0],
|
|
token_final_scales, # WAR for TRTLLMGenFusedMoE bug that it returns wrong `token_final_scales`
|
|
final_hidden_states[2],
|
|
)
|
|
moe_module._routing_results_replaced_at = "make_balanced_run_moe"
|
|
return final_hidden_states
|
|
|
|
return balanced_run_moe
|
|
|
|
|
|
def make_forward_impl_check(moe_module, forward_impl_orig):
|
|
def forward_impl(*args, **kwargs):
|
|
moe_module._routing_results_replaced_at = None
|
|
res = forward_impl_orig(*args, **kwargs)
|
|
assert moe_module._routing_results_replaced_at is not None, (
|
|
"Routing results are not replaced"
|
|
)
|
|
del moe_module._routing_results_replaced_at
|
|
return res
|
|
|
|
return forward_impl
|
|
|
|
|
|
class Runner:
|
|
def __init__(
|
|
self,
|
|
pretrained_model_name_or_path: str,
|
|
mapping: Mapping,
|
|
*,
|
|
load_format: str,
|
|
moe_backend: str,
|
|
layer_indices: list[int],
|
|
scaled_from: Optional[int],
|
|
max_seq_len: int,
|
|
max_num_tokens: int,
|
|
moe_max_num_tokens: int,
|
|
kv_cache_dtype,
|
|
mamba_ssm_cache_dtype: str,
|
|
use_low_precision_moe_combine: bool,
|
|
use_cuda_graph: bool,
|
|
):
|
|
super().__init__()
|
|
|
|
checkpoint_loader = _construct_checkpoint_loader("pytorch", None, "HF")
|
|
# Please refer to `tensorrt_llm/_torch/pyexecutor/model_loader.py` for effective args
|
|
llm_args = TorchLlmArgs(
|
|
model=pretrained_model_name_or_path,
|
|
load_format=load_format,
|
|
**{} if use_cuda_graph else {"cuda_graph_config": None},
|
|
moe_config=MoeConfig(
|
|
backend=moe_backend,
|
|
max_num_tokens=moe_max_num_tokens,
|
|
disable_finalize_fusion=False,
|
|
use_low_precision_moe_combine=use_low_precision_moe_combine,
|
|
),
|
|
attn_backend="TRTLLM",
|
|
kv_cache_config=KvCacheConfig(
|
|
dtype=kv_cache_dtype, mamba_ssm_cache_dtype=mamba_ssm_cache_dtype
|
|
),
|
|
)
|
|
model_loader = ModelLoader(
|
|
llm_args=llm_args,
|
|
mapping=mapping,
|
|
spec_config=None,
|
|
sparse_attention_config=None,
|
|
max_num_tokens=max_num_tokens,
|
|
max_seq_len=max_seq_len,
|
|
)
|
|
|
|
with self.scaled_from_ctx(scaled_from, mapping), self.skip_unused_layers_ctx(layer_indices):
|
|
model, _ = model_loader.load(
|
|
checkpoint_dir=pretrained_model_name_or_path, checkpoint_loader=checkpoint_loader
|
|
)
|
|
|
|
self.layers = [model.model.layers[i] for i in layer_indices]
|
|
self.model_config = model.model_config
|
|
|
|
@staticmethod
|
|
@contextlib.contextmanager
|
|
def scaled_from_ctx(scaled_from, mapping):
|
|
if scaled_from is None:
|
|
yield
|
|
return
|
|
|
|
def make_load_pretrained_config(mapping, load_pretrained_config_orig):
|
|
# To run the problem size of $B$ GPUs on $A$ GPUs, we need:
|
|
# (1) Attention: If TP, reduce the number of attention heads; If DP, nothing to change.
|
|
# (2) MoE: If EP, reduce the number of experts; If TP, reduce head size.
|
|
# Maintain the result of AllToAll method selection because it is affected by EP size.
|
|
def load_pretrained_config(*args, **kwargs):
|
|
pretrained_config = load_pretrained_config_orig(*args, **kwargs)
|
|
if not mapping.enable_attention_dp:
|
|
if hasattr(pretrained_config, "index_n_heads"):
|
|
raise NotImplementedError("Not support Indexer TP for weak scaling")
|
|
pretrained_config.num_attention_heads = (
|
|
pretrained_config.num_attention_heads // scaled_from * mapping.tp_size
|
|
)
|
|
pretrained_config.num_key_value_heads = (
|
|
pretrained_config.num_key_value_heads // scaled_from * mapping.tp_size
|
|
)
|
|
if mapping.moe_ep_size != mapping.tp_size:
|
|
raise NotImplementedError("Not support MoE TP for weak scaling")
|
|
pretrained_config.n_routed_experts = (
|
|
pretrained_config.n_routed_experts // scaled_from * mapping.moe_ep_size
|
|
)
|
|
return pretrained_config
|
|
|
|
return load_pretrained_config
|
|
|
|
def make_select_alltoall_method_type(select_alltoall_method_type_orig):
|
|
def select_alltoall_method_type(
|
|
cls: type, mapping: Mapping, top_k: int, *args, **kwargs
|
|
):
|
|
# Replace the condition `mapping.moe_ep_size <= top_k` with `scaled_from <= top_k`
|
|
# by replacing `top_k` with `fake_top_k`
|
|
if scaled_from <= top_k:
|
|
fake_top_k = mapping.moe_ep_size + 1
|
|
else:
|
|
fake_top_k = mapping.moe_ep_size - 1
|
|
assert (mapping.moe_ep_size <= fake_top_k) == (scaled_from <= top_k)
|
|
return select_alltoall_method_type_orig(mapping, fake_top_k, *args, **kwargs)
|
|
|
|
return select_alltoall_method_type
|
|
|
|
def make_select_alltoall_method_type_2(select_alltoall_method_type_orig):
|
|
def select_alltoall_method_type(self):
|
|
# Replace the condition `mapping.moe_ep_size <= top_k` with `scaled_from <= top_k`
|
|
# by replacing `top_k` with `fake_top_k`
|
|
top_k = self.routing_method.experts_per_token
|
|
if scaled_from <= top_k:
|
|
fake_top_k = mapping.moe_ep_size + 1
|
|
else:
|
|
fake_top_k = mapping.moe_ep_size - 1
|
|
assert (mapping.moe_ep_size <= fake_top_k) == (scaled_from <= top_k)
|
|
with unittest.mock.patch.object(
|
|
self.routing_method.__class__,
|
|
"experts_per_token",
|
|
new_callable=unittest.mock.PropertyMock,
|
|
) as mock_top_k:
|
|
mock_top_k.return_value = fake_top_k
|
|
return select_alltoall_method_type_orig(self)
|
|
|
|
return select_alltoall_method_type
|
|
|
|
select_alltoall_method_type_cutlass = CutlassFusedMoE.select_alltoall_method_type
|
|
select_alltoall_method_type_trtllm_gen = TRTLLMGenFusedMoE.select_alltoall_method_type
|
|
select_alltoall_method_type_wide_ep = WideEPMoE.select_alltoall_method_type
|
|
tensorrt_llm._torch.model_config.load_pretrained_config = make_load_pretrained_config(
|
|
mapping, load_pretrained_config
|
|
)
|
|
CutlassFusedMoE.select_alltoall_method_type = make_select_alltoall_method_type_2(
|
|
select_alltoall_method_type_cutlass
|
|
)
|
|
TRTLLMGenFusedMoE.select_alltoall_method_type = make_select_alltoall_method_type_2(
|
|
select_alltoall_method_type_trtllm_gen
|
|
)
|
|
WideEPMoE.select_alltoall_method_type = make_select_alltoall_method_type(
|
|
select_alltoall_method_type_wide_ep
|
|
)
|
|
try:
|
|
yield
|
|
finally:
|
|
tensorrt_llm._torch.model_config.load_pretrained_config = load_pretrained_config
|
|
CutlassFusedMoE.select_alltoall_method_type = select_alltoall_method_type_cutlass
|
|
TRTLLMGenFusedMoE.select_alltoall_method_type = select_alltoall_method_type_trtllm_gen
|
|
WideEPMoE.select_alltoall_method_type = select_alltoall_method_type_wide_ep
|
|
|
|
@staticmethod
|
|
@contextlib.contextmanager
|
|
def skip_unused_layers_ctx(layer_indices):
|
|
call_orig = PostInitCaller.__call__
|
|
|
|
def call_new(cls, *args, **kwargs):
|
|
model = call_orig(cls, *args, **kwargs)
|
|
for module in (
|
|
model.prologue + model.model.prologue + model.model.epilogue + model.epilogue
|
|
):
|
|
skip_forward(module)
|
|
num_hidden_layers = model.model_config.pretrained_config.num_hidden_layers
|
|
if hasattr(model.model, "embed_tokens"):
|
|
skip_forward(model.model.embed_tokens)
|
|
for layer_idx in range(num_hidden_layers):
|
|
layer = model.model.layers[layer_idx]
|
|
if layer_idx not in layer_indices:
|
|
# keep next layer's input_layernorm's weights for fusion
|
|
skip_forward(
|
|
layer,
|
|
ignore_modules=[layer.input_layernorm]
|
|
if layer_idx - 1 in layer_indices
|
|
and hasattr(model.model.layers[layer_idx - 1], "next_layer_layernorm")
|
|
else None,
|
|
)
|
|
if hasattr(model.model, "norm"):
|
|
skip_forward(
|
|
model.model.norm,
|
|
ignore_modules=[model.model.norm]
|
|
if num_hidden_layers - 1 in layer_indices
|
|
else None,
|
|
)
|
|
return model
|
|
|
|
PostInitCaller.__call__ = call_new
|
|
try:
|
|
yield
|
|
finally:
|
|
PostInitCaller.__call__ = call_orig
|
|
|
|
def create_run_pack(
|
|
self,
|
|
run_type: str,
|
|
*,
|
|
batch_size: int,
|
|
request_id_begin: int,
|
|
seq_len_q: int,
|
|
seq_len_kv_cache: int,
|
|
kv_cache_manager: KVCacheManager,
|
|
attn_workspace: Optional[torch.Tensor] = None,
|
|
):
|
|
world_size = mpi_world_size()
|
|
pretrained_config = self.model_config.pretrained_config
|
|
AttentionCls = get_attention_backend(
|
|
self.model_config.attn_backend, self.model_config.sparse_attention_config
|
|
)
|
|
attn_metadata = AttentionCls.Metadata(
|
|
seq_lens=torch.tensor([seq_len_q] * batch_size, dtype=torch.int),
|
|
request_ids=list(range(request_id_begin, request_id_begin + batch_size)),
|
|
max_num_requests=kv_cache_manager.max_batch_size,
|
|
num_contexts={
|
|
"CTX": batch_size,
|
|
"GEN": 0,
|
|
}[run_type],
|
|
prompt_lens=[
|
|
{
|
|
"CTX": seq_len_q,
|
|
"GEN": seq_len_kv_cache,
|
|
}[run_type]
|
|
]
|
|
* batch_size,
|
|
max_num_tokens=batch_size * seq_len_q,
|
|
kv_cache_manager=kv_cache_manager,
|
|
kv_cache_params=KVCacheParams(
|
|
use_cache=True,
|
|
num_cached_tokens_per_seq=[seq_len_kv_cache] * batch_size,
|
|
),
|
|
workspace=attn_workspace,
|
|
mapping=self.model_config.mapping,
|
|
sparse_attention_config=self.model_config.sparse_attention_config,
|
|
)
|
|
attn_metadata.all_rank_num_tokens = [batch_size * seq_len_q] * world_size
|
|
attn_metadata.prepare()
|
|
hidden_size = pretrained_config.hidden_size
|
|
position_ids = torch.tensor(
|
|
[list(range(seq_len_kv_cache, seq_len_kv_cache + seq_len_q)) * batch_size],
|
|
dtype=torch.int32,
|
|
device="cuda",
|
|
)
|
|
hidden_states = torch.rand(
|
|
(batch_size * seq_len_q, hidden_size), dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
residual = torch.rand(
|
|
(batch_size * seq_len_q, hidden_size), dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
kwargs = {}
|
|
|
|
if is_nemotron_hybrid(pretrained_config) or is_qwen3_next(pretrained_config):
|
|
# Please refer to `tensorrt_llm/_torch/models/modeling_qwen3_next.py` for the magic number chunk_size=128
|
|
mamba_metadata = Mamba2Metadata(
|
|
attn_metadata.max_num_requests,
|
|
chunk_size=128
|
|
if is_qwen3_next(pretrained_config)
|
|
else pretrained_config.chunk_size,
|
|
)
|
|
mamba_metadata.prepare(attn_metadata)
|
|
kwargs["mamba_metadata"] = mamba_metadata
|
|
|
|
def run_pack(*, check=False):
|
|
output = hidden_states, residual
|
|
with model_extra_attrs(self.model_config.extra_attrs):
|
|
get_model_extra_attrs()["attention_metadata"] = weakref.ref(attn_metadata)
|
|
with torch.inference_mode():
|
|
# TODO: to be more general, we should call DecoderModel.forward
|
|
for layer in self.layers:
|
|
residual_fusion = hasattr(layer, "next_layer_layernorm")
|
|
if residual_fusion:
|
|
output = layer(
|
|
position_ids, output[0], attn_metadata, output[1], **kwargs
|
|
)
|
|
else:
|
|
output = layer(position_ids, output[0], attn_metadata, **kwargs), None
|
|
if check:
|
|
if output[0].isnan().any():
|
|
raise ValueError("Has nan, please fix weights initialization")
|
|
if output[0].isinf().any():
|
|
raise ValueError("Has inf, please fix weights initialization")
|
|
if (output[0] == 0).sum() > 0.5 * output[0].numel():
|
|
raise ValueError("Too many zeros, please fix weights initialization")
|
|
return output
|
|
|
|
return run_pack
|
|
|
|
@contextlib.contextmanager
|
|
def replace_routing_method_ctx(self, balance_method: BalanceMethod, balance_ratio: float):
|
|
if balance_method == BalanceMethod.NotModified:
|
|
yield
|
|
return
|
|
if self.model_config.moe_backend not in [
|
|
"CUTEDSL",
|
|
"CUTLASS",
|
|
"DEEPGEMM",
|
|
"TRTLLM",
|
|
"WIDEEP",
|
|
]:
|
|
raise NotImplementedError(
|
|
f'Not support replace routing method for moe_backend "{self.model_config.moe_backend}",'
|
|
f' please set balance_method to "NotModified"'
|
|
)
|
|
original_methods = []
|
|
dp_rank = (
|
|
self.model_config.mapping.tp_rank
|
|
if self.model_config.mapping.enable_attention_dp
|
|
else 0
|
|
)
|
|
moe_modules = []
|
|
for layer in self.layers:
|
|
if layer.__class__.__name__ == "NemotronHLayer":
|
|
if layer.layer_type == "E":
|
|
moe_modules.append(layer.mixer.experts)
|
|
else:
|
|
moe_modules.append(layer.mlp.experts)
|
|
|
|
for moe_module in moe_modules:
|
|
# Replace `routing_method.apply` for normal cases
|
|
apply_method_orig = moe_module.routing_method.apply
|
|
moe_module.routing_method.apply = make_balanced_routing_method(
|
|
moe_module,
|
|
apply_method_orig,
|
|
moe_module.num_experts,
|
|
balance_method,
|
|
balance_ratio,
|
|
self.model_config.mapping.dp_size,
|
|
dp_rank,
|
|
self.model_config.mapping.moe_ep_size,
|
|
)
|
|
|
|
# Replace `run_moe` for TRTLLMGenFusedMoE TEP because it does not call `routing_method.apply`
|
|
if isinstance(moe_module, TRTLLMGenFusedMoE):
|
|
run_moe_orig = moe_module.run_moe
|
|
moe_module.run_moe = make_balanced_run_moe(
|
|
moe_module,
|
|
run_moe_orig,
|
|
moe_module.routing_method.top_k,
|
|
moe_module.num_experts,
|
|
balance_method,
|
|
balance_ratio,
|
|
self.model_config.mapping.dp_size,
|
|
dp_rank,
|
|
self.model_config.mapping.moe_ep_size,
|
|
)
|
|
else:
|
|
run_moe_orig = None
|
|
|
|
# Replace `forward_impl` to ensure that routing results are replaced
|
|
forward_impl_orig = moe_module.forward_impl
|
|
moe_module.forward_impl = make_forward_impl_check(moe_module, forward_impl_orig)
|
|
|
|
original_methods.append((apply_method_orig, run_moe_orig, forward_impl_orig))
|
|
try:
|
|
yield
|
|
finally:
|
|
for moe_module, (apply_method_orig, run_moe_orig, forward_impl_orig) in zip(
|
|
moe_modules, original_methods
|
|
):
|
|
moe_module.routing_method.apply = apply_method_orig
|
|
if isinstance(moe_module, TRTLLMGenFusedMoE):
|
|
moe_module.run_moe = run_moe_orig
|
|
moe_module.forward_impl = forward_impl_orig
|
|
|
|
@staticmethod
|
|
def create_kv_cache_manager(
|
|
pretrained_model_name_or_path,
|
|
mapping,
|
|
tokens_per_block,
|
|
max_batch_size,
|
|
max_seq_len,
|
|
kv_cache_dtype,
|
|
mamba_ssm_cache_dtype,
|
|
layer_indices,
|
|
):
|
|
# Please refer to `tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` for `tokens_per_block`
|
|
model_config = ModelConfig.from_pretrained(pretrained_model_name_or_path)
|
|
validate_and_set_kv_cache_quant(model_config, kv_cache_dtype)
|
|
validate_and_set_mamba_ssm_cache_dtype(model_config, mamba_ssm_cache_dtype)
|
|
if model_config.enable_flash_mla:
|
|
assert tokens_per_block == 64
|
|
|
|
# Please refer to `tensorrt_llm/_torch/pyexecutor/_util.py` for `kv_cache_manager`
|
|
kv_cache_manager_cls = get_kv_cache_manager_cls(model_config)
|
|
config = model_config.pretrained_config
|
|
kv_cache_config = KvCacheConfig(
|
|
max_tokens=max_batch_size * round_up(max_seq_len, tokens_per_block),
|
|
enable_block_reuse=False,
|
|
)
|
|
kv_cache_dtype = {
|
|
"FP8": tensorrt_llm.bindings.DataType.FP8,
|
|
"NVFP4": tensorrt_llm.bindings.DataType.NVFP4,
|
|
None: torch_dtype_to_binding(config.torch_dtype),
|
|
}[model_config.quant_config.kv_cache_quant_algo]
|
|
if is_mla(config):
|
|
layer_mask = [i in layer_indices for i in range(config.num_hidden_layers)]
|
|
num_layers = sum(layer_mask)
|
|
kv_cache_manager = kv_cache_manager_cls(
|
|
kv_cache_config,
|
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY,
|
|
num_layers=num_layers,
|
|
num_kv_heads=1,
|
|
head_dim=model_config.pretrained_config.kv_lora_rank
|
|
+ model_config.pretrained_config.qk_rope_head_dim,
|
|
tokens_per_block=tokens_per_block,
|
|
max_seq_len=max_seq_len,
|
|
max_batch_size=max_batch_size,
|
|
mapping=mapping,
|
|
dtype=kv_cache_dtype,
|
|
spec_config=None,
|
|
layer_mask=layer_mask,
|
|
sparse_attn_config=model_config.sparse_attention_config,
|
|
)
|
|
elif is_nemotron_hybrid(config):
|
|
mamba_layer_mask = [
|
|
i in layer_indices and char == "M"
|
|
for i, char in enumerate(config.hybrid_override_pattern)
|
|
]
|
|
layer_mask = [
|
|
i in layer_indices and char == "*"
|
|
for i, char in enumerate(config.hybrid_override_pattern)
|
|
]
|
|
num_mamba_layers = sum(mamba_layer_mask)
|
|
num_layers = sum(layer_mask)
|
|
kv_cache_manager = kv_cache_manager_cls(
|
|
# mamba cache parameters
|
|
config.ssm_state_size,
|
|
config.conv_kernel,
|
|
config.mamba_num_heads,
|
|
config.n_groups,
|
|
config.mamba_head_dim,
|
|
num_mamba_layers,
|
|
mamba_layer_mask,
|
|
config.torch_dtype,
|
|
model_config.quant_config.mamba_ssm_cache_dtype,
|
|
# kv cache parameters
|
|
kv_cache_config,
|
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
|
num_layers=num_layers,
|
|
layer_mask=layer_mask,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
head_dim=config.head_dim,
|
|
tokens_per_block=tokens_per_block,
|
|
max_seq_len=max_seq_len,
|
|
max_batch_size=max_batch_size,
|
|
mapping=mapping,
|
|
dtype=kv_cache_dtype,
|
|
spec_config=None,
|
|
)
|
|
elif is_qwen3_next(config):
|
|
mamba_layer_mask = [
|
|
i in layer_indices
|
|
if i % config.full_attention_interval != config.full_attention_interval - 1
|
|
else False
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
layer_mask = [
|
|
False
|
|
if i % config.full_attention_interval != config.full_attention_interval - 1
|
|
else i in layer_indices
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
num_mamba_layers = sum(mamba_layer_mask)
|
|
num_layers = sum(layer_mask)
|
|
kv_cache_manager = kv_cache_manager_cls(
|
|
# mamba cache parameters
|
|
config.linear_key_head_dim,
|
|
config.linear_conv_kernel_dim,
|
|
config.linear_num_value_heads,
|
|
config.linear_num_key_heads,
|
|
config.linear_value_head_dim,
|
|
num_mamba_layers,
|
|
mamba_layer_mask,
|
|
config.torch_dtype,
|
|
model_config.quant_config.mamba_ssm_cache_dtype,
|
|
# kv cache parameters
|
|
kv_cache_config,
|
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
|
num_layers=num_layers,
|
|
layer_mask=layer_mask,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
head_dim=config.head_dim,
|
|
tokens_per_block=tokens_per_block,
|
|
max_seq_len=max_seq_len,
|
|
max_batch_size=max_batch_size,
|
|
mapping=mapping,
|
|
dtype=kv_cache_dtype,
|
|
spec_config=None,
|
|
)
|
|
else:
|
|
raise NotImplementedError("Unsupported config")
|
|
kv_cache_manager.add_dummy_requests(
|
|
list(range(max_batch_size)), [max_seq_len] * max_batch_size
|
|
)
|
|
return kv_cache_manager
|
|
|
|
@staticmethod
|
|
def create_mapping(enable_attention_dp: bool):
|
|
world_size = mpi_world_size()
|
|
rank = mpi_rank()
|
|
mapping = Mapping(
|
|
world_size=world_size,
|
|
rank=rank,
|
|
gpus_per_node=local_mpi_size(),
|
|
cp_size=1,
|
|
tp_size=world_size,
|
|
pp_size=1,
|
|
moe_cluster_size=1,
|
|
moe_tp_size=1,
|
|
moe_ep_size=world_size,
|
|
attn_tp_size=world_size,
|
|
attn_cp_size=1,
|
|
enable_attention_dp=enable_attention_dp,
|
|
)
|
|
return mapping
|