diff --git a/ATTRIBUTIONS-Python.md b/ATTRIBUTIONS-Python.md index 0cde43ad1d..614efb21de 100644 --- a/ATTRIBUTIONS-Python.md +++ b/ATTRIBUTIONS-Python.md @@ -5261,7 +5261,7 @@ For more information, please refer to - `Tracker`: https://github.com/tox-dev/py-filelock/issues -## flashinfer-python (0.6.1) +## flashinfer-python (0.6.2) ### Licenses License: `Apache-2.0` diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 1ea69df77c..5c08aa4fd0 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -97,6 +97,11 @@ def add_llm_args(parser): default=False, action='store_true') parser.add_argument("--tokens_per_block", type=int, default=32) + parser.add_argument('--mamba_ssm_cache_dtype', + type=str, + default='bfloat16', + choices=['auto', 'float16', 'bfloat16', 'float32'], + help='Data type for Mamba SSM cache.') parser.add_argument('--log_kv_cache_events', default=False, action='store_true') @@ -205,6 +210,7 @@ def setup_llm(args, **kwargs): free_gpu_memory_fraction=args.kv_cache_fraction, dtype=args.kv_cache_dtype, tokens_per_block=args.tokens_per_block, + mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype, event_buffer_max_size=1024 if args.log_kv_cache_events else 0) spec_decode_algo = args.spec_decode_algo.upper( diff --git a/requirements.txt b/requirements.txt index da6c9a2107..ffe5b2d2de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,7 +53,7 @@ ordered-set peft patchelf einops -flashinfer-python~=0.6.1 +flashinfer-python~=0.6.2 opencv-python-headless xgrammar==0.1.25 llguidance==0.7.29 diff --git a/security_scanning/pyproject.toml b/security_scanning/pyproject.toml index a282a47683..82126c88eb 100644 --- a/security_scanning/pyproject.toml +++ b/security_scanning/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "peft (>=0.18.1,<0.19.0)", "patchelf (>=0.17.2.4,<0.18.0.0)", "einops (>=0.8.2,<0.9.0)", - "flashinfer-python (>=0.6.1,<0.7.0)", + "flashinfer-python (>=0.6.2,<0.7.0)", "xgrammar (==0.1.25)", "llguidance (==0.7.29)", "jsonschema (>=4.26.0,<5.0.0)", diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index be21b5716a..ebc3cdf03a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -265,7 +265,7 @@ class Deepseekv3RoutingImpl: ) self.is_fused = False elif (num_experts > 512 or (self.top_k > 8 and self.top_k != 22) - or self.topk_group == 1): + or (self.topk_group == 1 and self.top_k != 22)): # We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3. if self.is_fused: warnings.warn( @@ -273,19 +273,24 @@ class Deepseekv3RoutingImpl: ) self.is_fused = False - if self.n_group == 1 and self.topk_group == 1: - scores, scores_with_bias = self.get_scores(logits, - e_score_correction_bias) - _, topk_indices = torch.topk(scores_with_bias, k=self.top_k, dim=1) - topk_values = torch.gather(scores, dim=1, - index=topk_indices).type_as(scores) + if not self.is_fused: + # Short path for n_group == 1 and topk_group == 1. + if self.n_group == 1 and self.topk_group == 1: + scores, scores_with_bias = self.get_scores( + logits, e_score_correction_bias) + _, topk_indices = torch.topk(scores_with_bias, + k=self.top_k, + dim=1) + topk_values = torch.gather(scores, dim=1, + index=topk_indices).type_as(scores) - # Normalize and scale. - topk_values_sum = torch.sum(topk_values, dim=-1, - keepdim=True) + 1e-20 - topk_values = topk_values / topk_values_sum * self.routed_scaling_factor - return topk_values, topk_indices - elif not self.is_fused: + # Normalize and scale. + topk_values_sum = torch.sum(topk_values, dim=-1, + keepdim=True) + 1e-20 + topk_values = topk_values / topk_values_sum * self.routed_scaling_factor + return topk_values, topk_indices + + # General case with pytorch implementation. scores, scores_with_bias = self.get_scores(logits, e_score_correction_bias) scores_shape = list(scores_with_bias.shape) diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 9c8f1d8cc1..192e304419 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -17,9 +17,11 @@ from typing import Optional import torch from einops import rearrange, repeat +from flashinfer.mamba import selective_state_update as selective_state_update_fi from torch import nn from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from ...attention_backend import AttentionMetadata @@ -30,7 +32,8 @@ from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update from .causal_conv1d_triton import \ causal_conv1d_update as causal_conv1d_update_triton from .layernorm_gated import RMSNorm as RMSNormGated -from .selective_state_update import selective_state_update +from .selective_state_update import \ + selective_state_update as selective_state_update_native from .ssd_combined import mamba_chunk_scan_combined @@ -132,6 +135,24 @@ class Mamba2Mixer(nn.Module): dtype=torch.float32, requires_grad=False)) + # Choose between flashinfer and native implementation. (default to flashinfer) + self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype + supported_head_dim_in_flashinfer = [64, 128] + if head_dim in supported_head_dim_in_flashinfer: + logger.info_once( + "Using flashinfer for selective state update for no MTP", + key="selective_state_update_no_mtp") + self.selective_state_update_func_no_mtp = selective_state_update_fi + else: + logger.info_once( + "Using native for selective state update for no MTP", + key="selective_state_update_no_mtp") + self.selective_state_update_func_no_mtp = selective_state_update_native + # TODO: support MTP selective state update in flashinfer. + logger.info_once("Using native for selective state update for MTP", + key="selective_state_update_mtp") + self.selective_state_update_func_mtp = selective_state_update_native + # D self.D = nn.Parameter( torch.empty(self.tp_nheads, @@ -165,8 +186,6 @@ class Mamba2Mixer(nn.Module): skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy) - self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype - def forward( self, hidden_states: torch.Tensor, @@ -335,7 +354,8 @@ class Mamba2Mixer(nn.Module): ], dim=-1, ) - + # Need to keep the same dtype as self.dt_bias and self.D to avoid garbage outputs. + dt_d = dt_d.to(dtype=torch.float32) x_d = rearrange(x_d, "b (h p) -> b h p", p=self.head_dim) dt_d = repeat(dt_d, "b h -> b h p", p=self.head_dim) B_d = rearrange(B_d, "b (g n) -> b g n", g=self.tp_ngroups) @@ -348,7 +368,7 @@ class Mamba2Mixer(nn.Module): D = repeat(self.D, "h -> h p", p=self.head_dim) if is_target_verify: intermediate_ssm_states = layer_cache.intermediate_ssm - selective_state_update( + self.selective_state_update_func_mtp( ssm_states, x_d.view( num_decodes, @@ -381,10 +401,8 @@ class Mamba2Mixer(nn.Module): cache_steps=draft_token_num, intermediate_state_indices=self.intermediate_state_indices, ) - else: - - selective_state_update( + self.selective_state_update_func_no_mtp( ssm_states, x_d, dt_d,