mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-10453][feat] Update mamba decode kernel to flashinfer (#10757)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
da43a28b01
commit
4a206351bb
@ -5261,7 +5261,7 @@ For more information, please refer to <http://unlicense.org>
|
||||
- `Tracker`: https://github.com/tox-dev/py-filelock/issues
|
||||
|
||||
|
||||
## flashinfer-python (0.6.1)
|
||||
## flashinfer-python (0.6.2)
|
||||
|
||||
### Licenses
|
||||
License: `Apache-2.0`
|
||||
|
||||
@ -97,6 +97,11 @@ def add_llm_args(parser):
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument("--tokens_per_block", type=int, default=32)
|
||||
parser.add_argument('--mamba_ssm_cache_dtype',
|
||||
type=str,
|
||||
default='bfloat16',
|
||||
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||||
help='Data type for Mamba SSM cache.')
|
||||
parser.add_argument('--log_kv_cache_events',
|
||||
default=False,
|
||||
action='store_true')
|
||||
@ -205,6 +210,7 @@ def setup_llm(args, **kwargs):
|
||||
free_gpu_memory_fraction=args.kv_cache_fraction,
|
||||
dtype=args.kv_cache_dtype,
|
||||
tokens_per_block=args.tokens_per_block,
|
||||
mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype,
|
||||
event_buffer_max_size=1024 if args.log_kv_cache_events else 0)
|
||||
|
||||
spec_decode_algo = args.spec_decode_algo.upper(
|
||||
|
||||
@ -53,7 +53,7 @@ ordered-set
|
||||
peft
|
||||
patchelf
|
||||
einops
|
||||
flashinfer-python~=0.6.1
|
||||
flashinfer-python~=0.6.2
|
||||
opencv-python-headless
|
||||
xgrammar==0.1.25
|
||||
llguidance==0.7.29
|
||||
|
||||
@ -55,7 +55,7 @@ dependencies = [
|
||||
"peft (>=0.18.1,<0.19.0)",
|
||||
"patchelf (>=0.17.2.4,<0.18.0.0)",
|
||||
"einops (>=0.8.2,<0.9.0)",
|
||||
"flashinfer-python (>=0.6.1,<0.7.0)",
|
||||
"flashinfer-python (>=0.6.2,<0.7.0)",
|
||||
"xgrammar (==0.1.25)",
|
||||
"llguidance (==0.7.29)",
|
||||
"jsonschema (>=4.26.0,<5.0.0)",
|
||||
|
||||
@ -265,7 +265,7 @@ class Deepseekv3RoutingImpl:
|
||||
)
|
||||
self.is_fused = False
|
||||
elif (num_experts > 512 or (self.top_k > 8 and self.top_k != 22)
|
||||
or self.topk_group == 1):
|
||||
or (self.topk_group == 1 and self.top_k != 22)):
|
||||
# We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3.
|
||||
if self.is_fused:
|
||||
warnings.warn(
|
||||
@ -273,19 +273,24 @@ class Deepseekv3RoutingImpl:
|
||||
)
|
||||
self.is_fused = False
|
||||
|
||||
if self.n_group == 1 and self.topk_group == 1:
|
||||
scores, scores_with_bias = self.get_scores(logits,
|
||||
e_score_correction_bias)
|
||||
_, topk_indices = torch.topk(scores_with_bias, k=self.top_k, dim=1)
|
||||
topk_values = torch.gather(scores, dim=1,
|
||||
index=topk_indices).type_as(scores)
|
||||
if not self.is_fused:
|
||||
# Short path for n_group == 1 and topk_group == 1.
|
||||
if self.n_group == 1 and self.topk_group == 1:
|
||||
scores, scores_with_bias = self.get_scores(
|
||||
logits, e_score_correction_bias)
|
||||
_, topk_indices = torch.topk(scores_with_bias,
|
||||
k=self.top_k,
|
||||
dim=1)
|
||||
topk_values = torch.gather(scores, dim=1,
|
||||
index=topk_indices).type_as(scores)
|
||||
|
||||
# Normalize and scale.
|
||||
topk_values_sum = torch.sum(topk_values, dim=-1,
|
||||
keepdim=True) + 1e-20
|
||||
topk_values = topk_values / topk_values_sum * self.routed_scaling_factor
|
||||
return topk_values, topk_indices
|
||||
elif not self.is_fused:
|
||||
# Normalize and scale.
|
||||
topk_values_sum = torch.sum(topk_values, dim=-1,
|
||||
keepdim=True) + 1e-20
|
||||
topk_values = topk_values / topk_values_sum * self.routed_scaling_factor
|
||||
return topk_values, topk_indices
|
||||
|
||||
# General case with pytorch implementation.
|
||||
scores, scores_with_bias = self.get_scores(logits,
|
||||
e_score_correction_bias)
|
||||
scores_shape = list(scores_with_bias.shape)
|
||||
|
||||
@ -17,9 +17,11 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from flashinfer.mamba import selective_state_update as selective_state_update_fi
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...attention_backend import AttentionMetadata
|
||||
@ -30,7 +32,8 @@ from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
from .causal_conv1d_triton import \
|
||||
causal_conv1d_update as causal_conv1d_update_triton
|
||||
from .layernorm_gated import RMSNorm as RMSNormGated
|
||||
from .selective_state_update import selective_state_update
|
||||
from .selective_state_update import \
|
||||
selective_state_update as selective_state_update_native
|
||||
from .ssd_combined import mamba_chunk_scan_combined
|
||||
|
||||
|
||||
@ -132,6 +135,24 @@ class Mamba2Mixer(nn.Module):
|
||||
dtype=torch.float32,
|
||||
requires_grad=False))
|
||||
|
||||
# Choose between flashinfer and native implementation. (default to flashinfer)
|
||||
self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype
|
||||
supported_head_dim_in_flashinfer = [64, 128]
|
||||
if head_dim in supported_head_dim_in_flashinfer:
|
||||
logger.info_once(
|
||||
"Using flashinfer for selective state update for no MTP",
|
||||
key="selective_state_update_no_mtp")
|
||||
self.selective_state_update_func_no_mtp = selective_state_update_fi
|
||||
else:
|
||||
logger.info_once(
|
||||
"Using native for selective state update for no MTP",
|
||||
key="selective_state_update_no_mtp")
|
||||
self.selective_state_update_func_no_mtp = selective_state_update_native
|
||||
# TODO: support MTP selective state update in flashinfer.
|
||||
logger.info_once("Using native for selective state update for MTP",
|
||||
key="selective_state_update_mtp")
|
||||
self.selective_state_update_func_mtp = selective_state_update_native
|
||||
|
||||
# D
|
||||
self.D = nn.Parameter(
|
||||
torch.empty(self.tp_nheads,
|
||||
@ -165,8 +186,6 @@ class Mamba2Mixer(nn.Module):
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy)
|
||||
|
||||
self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -335,7 +354,8 @@ class Mamba2Mixer(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Need to keep the same dtype as self.dt_bias and self.D to avoid garbage outputs.
|
||||
dt_d = dt_d.to(dtype=torch.float32)
|
||||
x_d = rearrange(x_d, "b (h p) -> b h p", p=self.head_dim)
|
||||
dt_d = repeat(dt_d, "b h -> b h p", p=self.head_dim)
|
||||
B_d = rearrange(B_d, "b (g n) -> b g n", g=self.tp_ngroups)
|
||||
@ -348,7 +368,7 @@ class Mamba2Mixer(nn.Module):
|
||||
D = repeat(self.D, "h -> h p", p=self.head_dim)
|
||||
if is_target_verify:
|
||||
intermediate_ssm_states = layer_cache.intermediate_ssm
|
||||
selective_state_update(
|
||||
self.selective_state_update_func_mtp(
|
||||
ssm_states,
|
||||
x_d.view(
|
||||
num_decodes,
|
||||
@ -381,10 +401,8 @@ class Mamba2Mixer(nn.Module):
|
||||
cache_steps=draft_token_num,
|
||||
intermediate_state_indices=self.intermediate_state_indices,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
selective_state_update(
|
||||
self.selective_state_update_func_no_mtp(
|
||||
ssm_states,
|
||||
x_d,
|
||||
dt_d,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user