[TRTLLM-10453][feat] Update mamba decode kernel to flashinfer (#10757)

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2026-01-27 13:04:40 +08:00 committed by GitHub
parent da43a28b01
commit 4a206351bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 53 additions and 24 deletions

View File

@ -5261,7 +5261,7 @@ For more information, please refer to <http://unlicense.org>
- `Tracker`: https://github.com/tox-dev/py-filelock/issues
## flashinfer-python (0.6.1)
## flashinfer-python (0.6.2)
### Licenses
License: `Apache-2.0`

View File

@ -97,6 +97,11 @@ def add_llm_args(parser):
default=False,
action='store_true')
parser.add_argument("--tokens_per_block", type=int, default=32)
parser.add_argument('--mamba_ssm_cache_dtype',
type=str,
default='bfloat16',
choices=['auto', 'float16', 'bfloat16', 'float32'],
help='Data type for Mamba SSM cache.')
parser.add_argument('--log_kv_cache_events',
default=False,
action='store_true')
@ -205,6 +210,7 @@ def setup_llm(args, **kwargs):
free_gpu_memory_fraction=args.kv_cache_fraction,
dtype=args.kv_cache_dtype,
tokens_per_block=args.tokens_per_block,
mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype,
event_buffer_max_size=1024 if args.log_kv_cache_events else 0)
spec_decode_algo = args.spec_decode_algo.upper(

View File

@ -53,7 +53,7 @@ ordered-set
peft
patchelf
einops
flashinfer-python~=0.6.1
flashinfer-python~=0.6.2
opencv-python-headless
xgrammar==0.1.25
llguidance==0.7.29

View File

@ -55,7 +55,7 @@ dependencies = [
"peft (>=0.18.1,<0.19.0)",
"patchelf (>=0.17.2.4,<0.18.0.0)",
"einops (>=0.8.2,<0.9.0)",
"flashinfer-python (>=0.6.1,<0.7.0)",
"flashinfer-python (>=0.6.2,<0.7.0)",
"xgrammar (==0.1.25)",
"llguidance (==0.7.29)",
"jsonschema (>=4.26.0,<5.0.0)",

View File

@ -265,7 +265,7 @@ class Deepseekv3RoutingImpl:
)
self.is_fused = False
elif (num_experts > 512 or (self.top_k > 8 and self.top_k != 22)
or self.topk_group == 1):
or (self.topk_group == 1 and self.top_k != 22)):
# We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3.
if self.is_fused:
warnings.warn(
@ -273,19 +273,24 @@ class Deepseekv3RoutingImpl:
)
self.is_fused = False
if self.n_group == 1 and self.topk_group == 1:
scores, scores_with_bias = self.get_scores(logits,
e_score_correction_bias)
_, topk_indices = torch.topk(scores_with_bias, k=self.top_k, dim=1)
topk_values = torch.gather(scores, dim=1,
index=topk_indices).type_as(scores)
if not self.is_fused:
# Short path for n_group == 1 and topk_group == 1.
if self.n_group == 1 and self.topk_group == 1:
scores, scores_with_bias = self.get_scores(
logits, e_score_correction_bias)
_, topk_indices = torch.topk(scores_with_bias,
k=self.top_k,
dim=1)
topk_values = torch.gather(scores, dim=1,
index=topk_indices).type_as(scores)
# Normalize and scale.
topk_values_sum = torch.sum(topk_values, dim=-1,
keepdim=True) + 1e-20
topk_values = topk_values / topk_values_sum * self.routed_scaling_factor
return topk_values, topk_indices
elif not self.is_fused:
# Normalize and scale.
topk_values_sum = torch.sum(topk_values, dim=-1,
keepdim=True) + 1e-20
topk_values = topk_values / topk_values_sum * self.routed_scaling_factor
return topk_values, topk_indices
# General case with pytorch implementation.
scores, scores_with_bias = self.get_scores(logits,
e_score_correction_bias)
scores_shape = list(scores_with_bias.shape)

View File

@ -17,9 +17,11 @@ from typing import Optional
import torch
from einops import rearrange, repeat
from flashinfer.mamba import selective_state_update as selective_state_update_fi
from torch import nn
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ...attention_backend import AttentionMetadata
@ -30,7 +32,8 @@ from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .causal_conv1d_triton import \
causal_conv1d_update as causal_conv1d_update_triton
from .layernorm_gated import RMSNorm as RMSNormGated
from .selective_state_update import selective_state_update
from .selective_state_update import \
selective_state_update as selective_state_update_native
from .ssd_combined import mamba_chunk_scan_combined
@ -132,6 +135,24 @@ class Mamba2Mixer(nn.Module):
dtype=torch.float32,
requires_grad=False))
# Choose between flashinfer and native implementation. (default to flashinfer)
self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype
supported_head_dim_in_flashinfer = [64, 128]
if head_dim in supported_head_dim_in_flashinfer:
logger.info_once(
"Using flashinfer for selective state update for no MTP",
key="selective_state_update_no_mtp")
self.selective_state_update_func_no_mtp = selective_state_update_fi
else:
logger.info_once(
"Using native for selective state update for no MTP",
key="selective_state_update_no_mtp")
self.selective_state_update_func_no_mtp = selective_state_update_native
# TODO: support MTP selective state update in flashinfer.
logger.info_once("Using native for selective state update for MTP",
key="selective_state_update_mtp")
self.selective_state_update_func_mtp = selective_state_update_native
# D
self.D = nn.Parameter(
torch.empty(self.tp_nheads,
@ -165,8 +186,6 @@ class Mamba2Mixer(nn.Module):
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy)
self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype
def forward(
self,
hidden_states: torch.Tensor,
@ -335,7 +354,8 @@ class Mamba2Mixer(nn.Module):
],
dim=-1,
)
# Need to keep the same dtype as self.dt_bias and self.D to avoid garbage outputs.
dt_d = dt_d.to(dtype=torch.float32)
x_d = rearrange(x_d, "b (h p) -> b h p", p=self.head_dim)
dt_d = repeat(dt_d, "b h -> b h p", p=self.head_dim)
B_d = rearrange(B_d, "b (g n) -> b g n", g=self.tp_ngroups)
@ -348,7 +368,7 @@ class Mamba2Mixer(nn.Module):
D = repeat(self.D, "h -> h p", p=self.head_dim)
if is_target_verify:
intermediate_ssm_states = layer_cache.intermediate_ssm
selective_state_update(
self.selective_state_update_func_mtp(
ssm_states,
x_d.view(
num_decodes,
@ -381,10 +401,8 @@ class Mamba2Mixer(nn.Module):
cache_steps=draft_token_num,
intermediate_state_indices=self.intermediate_state_indices,
)
else:
selective_state_update(
self.selective_state_update_func_no_mtp(
ssm_states,
x_d,
dt_d,