diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h index b8b316ab3a..681e9e0685 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h @@ -675,6 +675,11 @@ private: { continue; } + // If tileSizeQ < mNumHeadsQPerKv, this will result in 0, causing division by zero. + if (tileSizeQ < params.mNumHeadsQPerKv) + { + continue; + } // Update the tileSizeQ. selectKernelParamsCopy.mTileSizeQ = tileSizeQ; diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index d119ddeece..6df74b3c2a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -144,8 +144,15 @@ def _triton_cached_ssm( num_seq = num_prefill + num_decode num_total_tokens = num_prefill_tokens + num_decode - y_prefill = None - y_decode = None + # Preallocate output tensor to avoid memcpy cost for merging prefill + # and decode outputs + preallocated_ssm_out = torch.empty( + [bs, num_heads, head_dim], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens] + preallocated_ssm_out_d = preallocated_ssm_out[num_prefill_tokens:num_total_tokens] # Prefill: concatenate tokens at the front and run combined scan if num_prefill > 0: @@ -165,7 +172,7 @@ def _triton_cached_ssm( chunk_indices = None chunk_offsets = None - y_prefill, varlen_states = mamba_chunk_scan_combined( + varlen_states = mamba_chunk_scan_combined( hs_prefill, dt_prefill, A, @@ -184,11 +191,12 @@ def _triton_cached_ssm( dt_limit=(time_step_limit[0], time_step_limit[1]), return_final_states=False, return_varlen_states=True, - mamba_ssm_cache_dtype=ssm_state_cache.dtype, + out=preallocated_ssm_out_p.unsqueeze(0), + state_dtype=ssm_state_cache.dtype, ) ssm_state_cache.index_copy_( - 0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype) + 0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype) ) # Decode: batch single-token updates via selective_state_update @@ -205,7 +213,7 @@ def _triton_cached_ssm( A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size) D_full = D[..., None].expand(num_heads, head_dim) - y_decode = selective_state_update( + selective_state_update( ssm_state_cache, x_decode, dt_hp, @@ -217,19 +225,16 @@ def _triton_cached_ssm( dt_bias=dt_bias_hp, dt_softplus=True, state_batch_indices=slot_idx_decode, - ) # [nd, H, D] + out=preallocated_ssm_out_d, + ) - # Dispatch return logic - if num_prefill > 0 and num_decode > 0: - y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format) - y_flat = y.view(bs, *y.shape[2:]) - y_flat[:num_prefill_tokens].copy_(y_prefill[0]) - y_flat[num_prefill_tokens:num_total_tokens].copy_(y_decode) - return y - elif num_prefill > 0: - return y_prefill[0].view(b, s, num_heads, head_dim).to(hidden_states.dtype) - elif num_decode > 0: - return y_decode.view(b, s, num_heads, head_dim).to(hidden_states.dtype) + # Return the preallocated output reshaped to original dimensions + if num_total_tokens > 0: + return ( + preallocated_ssm_out[:num_total_tokens] + .view(b, s, num_heads, head_dim) + .to(hidden_states.dtype) + ) else: return torch.empty_like(hidden_states) diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py index dfa84ceaee..25c5764b94 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py @@ -1,5 +1,6 @@ import torch +import tensorrt_llm.logger as logger from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import \ HfWeightMapper from tensorrt_llm._torch.models.modeling_utils import register_mapper @@ -55,6 +56,16 @@ class NemotronHHfWeightMapper(HfWeightMapper): if "embeddings" in key: key = key.replace("embeddings", "embed_tokens") + # MTP layers are stored as mtp.layers.0.xxx (sublayer 0, Attention) and mtp.layers.1.xxx (sublayer 1, MoE) + if "mtp.layers." in key: + import re + match = re.match(r'mtp\.layers\.(\d+)\.(.*)', key) + if match: + sublayer_idx, rest = match.groups() + key = f"model.layers.{config.num_hidden_layers}.layers.{sublayer_idx}.{rest}" + else: + logger.error(f"Failed to match MTP pattern for: {name}") + if "A_log" in key: key = key.replace("A_log", "A") diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index 6df300cc69..249cab1547 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -14,7 +14,7 @@ # limitations under the License. import re -from typing import Dict, Optional +from typing import Dict, List, Optional import torch from torch import nn @@ -37,9 +37,11 @@ from ..modules.mamba.mamba2_mixer import Mamba2Mixer from ..modules.mlp import MLP from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm +from ..speculative import SpecMetadata from ..utils import AuxStreamType, EventType -from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, - register_auto_model) +from .modeling_deepseekv3 import DeepseekV3MTPHead +from .modeling_speculative import SpecDecOneEngineForCausalLM +from .modeling_utils import DecoderModel, register_auto_model class NemotronHConfig(PretrainedConfig): @@ -347,13 +349,17 @@ class NemotronHLayer(DecoderLayer): position_ids: torch.IntTensor, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> torch.Tensor: residual = hidden_states hidden_states = self.norm(hidden_states) - hidden_states = self.mixer(hidden_states, attn_metadata, **kwargs) + hidden_states = self.mixer(hidden_states, + attn_metadata, + spec_metadata=spec_metadata, + **kwargs) hidden_states = torch.add(hidden_states, residual) return hidden_states @@ -405,6 +411,7 @@ class NemotronHModel(DecoderModel): layer_type, aux_stream_dict=self.aux_stream_dict)) self.layers = nn.ModuleList(layers) + self.num_hidden_layers = config.num_hidden_layers # final norm self.norm_f = RMSNorm( @@ -421,6 +428,7 @@ class NemotronHModel(DecoderModel): input_ids: Optional[torch.IntTensor] = None, position_ids: Optional[torch.IntTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): @@ -439,10 +447,11 @@ class NemotronHModel(DecoderModel): hidden_states = inputs_embeds - for layer in self.layers: + for layer in self.layers[:self.num_hidden_layers]: hidden_states = layer(position_ids, hidden_states, attn_metadata, + spec_metadata=spec_metadata, mamba_metadata=self.mamba_metadata) hidden_states = self.norm_f(hidden_states) @@ -451,8 +460,8 @@ class NemotronHModel(DecoderModel): @register_auto_model("NemotronHForCausalLM") -class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel, - NemotronHConfig]): +class NemotronHForCausalLM(SpecDecOneEngineForCausalLM[NemotronHModel, + NemotronHConfig]): def __init__( self, @@ -477,15 +486,286 @@ class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel, ] super().__init__( - NemotronHModel(model_config), - config=model_config, - hidden_size=model_config.pretrained_config.hidden_size, - vocab_size=model_config.pretrained_config.vocab_size, + model=NemotronHModel(model_config), + model_config=model_config, + ) + self.model_nextn = 0 + if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp_one_model( + ): + model_nextn = model_config.spec_config.num_nextn_predict_layers + ckpt_nextn = self.config.num_nextn_predict_layers + self.num_hidden_layers = self.config.num_hidden_layers + assert ckpt_nextn > 0, "There are not MTP modules in the checkpoint." + if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla: + pass + else: + # modify the QuantConfig to support duplicated mtp layers + if model_config.quant_config.exclude_modules is not None: + extend_exclude_modules = [] + for model_mtp_idx in range( + self.num_hidden_layers, + self.num_hidden_layers + model_nextn): + ckpt_mtp_idx = (model_mtp_idx - self.num_hidden_layers + ) % ckpt_nextn + self.num_hidden_layers + model_prefix = f"model.layers.{model_mtp_idx}" + ckpt_prefix = f"model.layers.{ckpt_mtp_idx}" + for exclude_module in model_config.quant_config.exclude_modules: + if ckpt_prefix in exclude_module and model_prefix not in exclude_module: + extend_exclude_modules.append( + exclude_module.replace( + ckpt_prefix, model_prefix)) + self.model_config.quant_config.exclude_modules.extend( + extend_exclude_modules) + self.model.layers.extend(self.draft_model.mtp_layers) + self.epilogue.extend(self.draft_model.mtp_layers) + self.epilogue.append(self.spec_worker) + + def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper): + new_weights = weight_mapper.preprocess_weights(weights) + super().load_weights(weights=new_weights, weight_mapper=weight_mapper) + + +class NemotronHMTPDecoderLayer(NemotronHLayer): + + def __init__( + self, + model_config: ModelConfig[NemotronHConfig], + layer_idx: int, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + has_start_projections: bool, + has_end_norm: bool, + layer_type: str, + ) -> None: + + super().__init__( + model_config=model_config, + layer_idx=layer_idx, + layer_type=layer_type, + aux_stream_dict=aux_stream_dict, + ) + self.model_nextn = 0 + if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp_one_model( + ): + model_nextn = model_config.spec_config.num_nextn_predict_layers + ckpt_nextn = self.config.num_nextn_predict_layers + self.num_hidden_layers = self.config.num_hidden_layers + assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint." + if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla: + pass + else: + # modify the QuantConfig to support duplicated mtp layers + if model_config.quant_config.exclude_modules is not None: + extend_exclude_modules = [] + for model_mtp_idx in range( + self.num_hidden_layers, + self.num_hidden_layers + model_nextn): + ckpt_mtp_idx = (model_mtp_idx - self.num_hidden_layers + ) % ckpt_nextn + self.num_hidden_layers + model_prefix = f"model.layers.{model_mtp_idx}" + ckpt_prefix = f"model.layers.{ckpt_mtp_idx}" + for exclude_module in model_config.quant_config.exclude_modules: + if ckpt_prefix in exclude_module and model_prefix not in exclude_module: + extend_exclude_modules.append( + exclude_module.replace( + ckpt_prefix, model_prefix)) + self.model_config.quant_config.exclude_modules.extend( + extend_exclude_modules) + self.model.layers.extend(self.draft_model.mtp_layers) + self.epilogue.extend(self.draft_model.mtp_layers) + self.epilogue.append(self.spec_worker) + + config = model_config.pretrained_config + self.model_config = model_config + self.has_start_projections = has_start_projections + self.has_end_norm = has_end_norm + + if has_start_projections: + self.enorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) + self.hnorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) + + if model_config.mapping.enable_attention_dp: + self.eh_proj = Linear( + in_features=config.hidden_size * 2, + out_features=config.hidden_size, + bias=False, + dtype=config.torch_dtype, + quant_config=model_config.quant_config, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + ) + else: + self.eh_proj = Linear( + in_features=config.hidden_size * 2, + out_features=config.hidden_size, + bias=False, + dtype=config.torch_dtype, + tensor_parallel_mode=TensorParallelMode.ROW, + mapping=model_config.mapping, + quant_config=model_config.quant_config, + reduce_output=True, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + ) + + if has_end_norm: + self.final_layernorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + + if self.has_start_projections: + assert inputs_embeds is not None + inputs_embeds_normed = self.enorm(inputs_embeds) + previous_hidden_states_normed = self.hnorm(hidden_states) + + # Fuse via concatenation and linear projection + fused = torch.cat( + [inputs_embeds_normed, previous_hidden_states_normed], dim=-1) + + # Split fused hidden_states columnwise based on TP + mapping = self.model_config.mapping + if mapping.tp_size > 1 and not mapping.enable_attention_dp: + fused = torch.chunk(fused, mapping.tp_size, + dim=-1)[mapping.tp_rank] + + hidden_states = self.eh_proj(fused) + residual = None # Start fresh after fusion + + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer( + hidden_states=hidden_states, + attn_metadata=attn_metadata, ) - def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): - new_weights = weight_mapper.preprocess_weights(weights) - super().load_weights(new_weights, weight_mapper) + if self.has_end_norm: + if residual is not None: + hidden_states = hidden_states + residual + residual = None + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, residual + + +class NemotronHMTP(nn.Module): + """NemotronH MTP Layer - single MTP layer following DeepseekV3MTP pattern.""" + + def __init__(self, + model_config: ModelConfig[NemotronHConfig], + layer_idx: int, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], + is_separate_draft_engine: bool = False, + prefix: str = ""): + super().__init__() + config = model_config.pretrained_config + self.model_config = model_config + self.config = config + self.layer_idx = layer_idx + + # Pattern configuration + self.pattern_str = config.mtp_hybrid_override_pattern + self.pattern_len = len(self.pattern_str) + assert self.pattern_len > 0 + + # Build pattern-based layers + self.layers = nn.ModuleDict() + + for i in range(self.pattern_len): + step_rel_idx = i % self.pattern_len + + char = self.pattern_str[step_rel_idx] + + is_start_of_step = step_rel_idx == 0 + is_end_of_step = step_rel_idx == self.pattern_len - 1 + + sublayer_quant_config = self._get_mtp_sublayer_quant_config( + model_config, self.layer_idx) + + # Create a temporary model_config with the override quant_config + sublayer_model_config = ModelConfig( + pretrained_config=model_config.pretrained_config, + mapping=model_config.mapping, + quant_config=sublayer_quant_config, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + ) + + self.layers[str(i)] = NemotronHMTPDecoderLayer( + model_config=sublayer_model_config, + layer_idx=self.layer_idx, + aux_stream_dict=aux_stream_dict, + has_start_projections=is_start_of_step, + has_end_norm=is_end_of_step, + layer_type=char, + ) + + # Add shared_head for MTP, following DeepseekV3MTP pattern + self.shared_head = DeepseekV3MTPHead(model_config) + + def _get_mtp_sublayer_quant_config( + self, model_config: ModelConfig[NemotronHConfig], layer_idx: int): + """ + Get quantization config for MTP sublayer. + The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM + moe_backend only supports fp8/fp4 quantization, we need to override + the quant_config for the MTP layer. + """ + from tensorrt_llm.models.modeling_utils import QuantConfig + quant_config = model_config.quant_config + # MTP layers are always unquantized, force quant_algo=None + if quant_config is None: + return None + return QuantConfig( + quant_algo=None, + kv_cache_quant_algo=quant_config.kv_cache_quant_algo, + ) + + def forward( + self, + input_ids: torch.IntTensor, + position_ids: torch.IntTensor, + hidden_states: torch.Tensor, + embed_tokens: Embedding, + attn_metadata: AttentionMetadata, + all_rank_num_tokens: Optional[List[int]] = None, + spec_metadata: Optional[SpecMetadata] = None, + **kwargs, + ) -> torch.Tensor: + inputs_embeds = embed_tokens(input_ids) + residual = None + for i in range(self.pattern_len): + layer = self.layers[str(i)] + hidden_states, residual = layer( + inputs_embeds=inputs_embeds, + positions=position_ids, + hidden_states=hidden_states, + residual=residual, + attn_metadata=attn_metadata, + ) + return hidden_states AutoConfig.register(NemotronHConfig.model_type, NemotronHConfig) diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 6792d06393..1849f1acf2 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -706,6 +706,9 @@ class MTPForCausalLM(nn.Module): case "exaone_moe": from .modeling_exaone_moe import ExaoneMoeMTP mtp_layer = ExaoneMoeMTP + case "nemotron_h": + from .modeling_nemotron_h import NemotronHMTP + mtp_layer = NemotronHMTP case _: raise ValueError( f"Model type {model_type} not supported for MTP") @@ -751,6 +754,12 @@ class MTPDraftModel(nn.Module): from .modeling_exaone_moe import ExaoneMoeMTP mtp_layer = ExaoneMoeMTP(model_config, layer_idx, aux_stream_dict) + elif model_type == "nemotron_h": + from .modeling_nemotron_h import NemotronHMTP + mtp_layer = NemotronHMTP(model_config, + layer_idx, + aux_stream_dict, + is_separate_draft_engine=False) else: raise ValueError( f"MTPDraftModel does not support model_type: {model_type}") diff --git a/tensorrt_llm/_torch/modules/mamba/causal_conv1d_triton.py b/tensorrt_llm/_torch/modules/mamba/causal_conv1d_triton.py new file mode 100644 index 0000000000..5a3924425e --- /dev/null +++ b/tensorrt_llm/_torch/modules/mamba/causal_conv1d_triton.py @@ -0,0 +1,1165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +# and https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +# -*- coding: utf-8 -*- + +from typing import List, Optional, Union + +import torch +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + + +@triton.jit() +def _causal_conv1d_fwd_kernel( # continuous batching + # Pointers to matrices + x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences + w_ptr, # (dim, width) + bias_ptr, + initial_states_ptr, # conv_states_ptr + cache_indices_ptr, # conv_state_indices_ptr + has_initial_states_ptr, + query_start_loc_ptr, + o_ptr, # (dim, seqlen) - actually pointing to x_ptr + # Matrix dimensions + dim: tl.constexpr, + seqlen: tl.int32, # cu_seqlen + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, # stride to get to next sequence, + stride_x_dim: tl.constexpr, # stride to get to next feature-value, + stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_w_dim: tl.constexpr, # stride to get to next dim-axis value + stride_w_width: tl.constexpr, # stride to get to next width-axis value + stride_istate_seq: tl.constexpr, + stride_istate_dim: tl.constexpr, + stride_istate_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, + HAS_CACHE: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + NP2_STATELEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + conv_states_ptr = initial_states_ptr + conv_state_indices_ptr = cache_indices_ptr + stride_conv_state_seq = stride_istate_seq + stride_conv_state_dim = stride_istate_dim + stride_conv_state_tok = stride_istate_token + state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value + + # one program handles one chunk in a single sequence + # rather than mixing sequences - to make updating initial_states across sequences efficiently + + # single-sequence id + idx_seq = tl.program_id(0) + chunk_offset = tl.program_id(1) + + # BLOCK_N elements along the feature-dimension (channel) + idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N) + + if idx_seq == pad_slot_id: + return + + sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) + sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) + # find the actual sequence length + seqlen = sequence_end_index - sequence_start_index + + token_offset = BLOCK_M * chunk_offset + segment_len = min(BLOCK_M, seqlen - token_offset) + + if segment_len <= 0: + return + + # base of the sequence + x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] + + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64) + else: + # cache_idx + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + conv_states_base = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + + # Does 2 things: + # 1. READ prior-block init-state data - [done by every Triton programs] + # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] + if chunk_offset == 0: + # read from conv_states + load_init_state = False + if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) + if load_init_state: + # load from conv_states + prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + else: + # prior-tokens are zeros + if KERNEL_WIDTH >= 2: # STRATEGY1 + # first chunk and does not have prior-token, so just set to 0 + col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 3: # STRATEGY1 + col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 4: # STRATEGY1 + col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 5: # STRATEGY1 + col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + + # STEP 2: + # here prepare data for updating conv_state + if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + # just read from 'x' + # copy 'x' data to conv_state + # load only 'x' data (and set 0 before 'x' if seqlen < state_len) + idx_tokens_last = (seqlen - state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = ( + x_ptr + + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + mask_x = ( + (idx_tokens_last >= 0)[:, None] + & (idx_tokens_last < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + conv_states_ptrs_target = ( + conv_states_base[None, :] + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + # tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: + if load_init_state: + # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_states_ptrs_source = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + + x_ptrs = ( + x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + + # tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load + new_conv_state = tl.where( + mask, conv_state, loaded_x + ) # BUG in 'tl.where' which requires a barrier before this + conv_states_ptrs_target = ( + conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + else: # load_init_state == False + # update conv_state by shifting left, BUT + # set cols prior to 'x' as zeros + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + VAL = state_len - seqlen + + x_ptrs = ( + x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + + conv_states_ptrs_target = ( + conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: # chunk_offset > 0 + # read prior-token data from `x` + load_init_state = True + prior_tokens = x_base + (token_offset - 1) * stride_x_token + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 5: + # ruff: noqa: F841 + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + x_base_1d = x_base + token_offset * stride_x_token # starting of chunk + + # PRE-LOAD WEIGHTS + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + mask_x_1d = idx_feats < dim + for idx_token in range(segment_len): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < segment_len) & (idx_feats < dim) # token-index # feature-index + o_ptrs = ( + o_ptr + + (sequence_start_index + token_offset + idx_token) * stride_o_token + + (idx_feats * stride_o_dim) + ) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Union[torch.Tensor, None], + conv_states: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_cpu: List[int], + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, + validate_data=False, + **kwargs, +): + """support varlen + continuous batching when x is 2D tensor + + x: (dim,cu_seq_len) + cu_seq_len = total tokens of all seqs in that batch + sequences are concatenated from left to right for varlen + weight: (dim, width) + conv_states: (...,dim,width - 1) itype + updated inplace if provided + [it use `cache_indices` to get the index to the cache of conv_state for that sequence + + conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True + and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' + ] + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + if + x = [5, 1, 1, 1] <- continuous batching (batch=4) + then + query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is + the ending index of the last sequence + [length(query_start_loc)-1 == batch] + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + seq_lens_cpu: (batch) int32 + The sequence lengths of the sequences in the batch + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + [single boolean for each sequence in the batch: True or False] + bias: (dim,) + activation: either None or "silu" or "swish" or True + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + out: same shape as `x` + """ + if isinstance(activation, bool) and activation: + activation = "silu" + + out = torch.empty_like(x) + + is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) + dim, cu_seqlen = x.shape + _, width = weight.shape + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + stride_x_seq = 0 + stride_x_dim = x.stride(0) + stride_x_token = x.stride(1) + stride_w_dim = weight.stride(0) + stride_w_width = weight.stride(1) + stride_istate_seq = 0 + stride_istate_dim = 0 + stride_istate_token = 0 + num_cache_lines = 0 + if conv_states is not None: + # extensions to support vLLM: + # 1. conv_states is used to replaced initial_states + # 2. conv_states serve as a cache with num cache lines can be larger than batch size + # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] + # 4. computation can be skipped if cache_indices[idx] == pad_slot_id + num_cache_lines = conv_states.size(0) + assert ( + num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2] + ) + stride_istate_seq = conv_states.stride(0) + stride_istate_dim = conv_states.stride(1) + stride_istate_token = conv_states.stride(2) + # assert stride_istate_dim == 1 + if out.dim() == 2: + stride_o_seq = 0 + stride_o_dim = out.stride(0) + stride_o_token = out.stride(1) + else: + stride_o_seq = out.stride(0) + stride_o_dim = out.stride(1) + stride_o_token = out.stride(2) + + if validate_data: + assert x.dim() == 2 + assert query_start_loc is not None + assert query_start_loc.dim() == 1 + assert x.stride(0) == 1 or x.stride(1) == 1 + padded_batch = query_start_loc.size(0) - 1 + if bias is not None: + assert bias.dim() == 1 + assert dim == bias.size(0) + if cache_indices is not None: + assert cache_indices.dim() == 1 + assert padded_batch == cache_indices.size(0) + if has_initial_state is not None: + assert has_initial_state.size() == (padded_batch,) + assert conv_states is not None, ( + "ERROR: `has_initial_state` is used, which needs also `conv_states`" + ) + assert weight.stride(1) == 1 + assert (dim, width) == weight.shape + assert is_channel_last, "Need to run in channel-last layout" + + def grid(META): + max_seq_len = max(seq_lens_cpu) + return ( + len(seq_lens_cpu), # batch_size + (max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"], + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_fwd_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_states, + cache_indices, + has_initial_state, + query_start_loc, + out, + # Matrix dimensions + dim, + cu_seqlen, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + HAS_INITIAL_STATES=has_initial_state is not None, + HAS_CACHE=conv_states is not None, + IS_CONTINUOUS_BATCHING=cache_indices is not None, + USE_PAD_SLOT=pad_slot_id is not None, + NP2_STATELEN=np2_statelen, + # launch_cooperative_grid=True + BLOCK_M=8, + BLOCK_N=256, + num_stages=2, + ) + return out + + +# HAS_EAGLE_TREE_CUSTOM_ATTN_MASK is added to support eagle tree attention mask +# retrieve_next_token_ptr: [N, NP2_T], retrieve_next_sibling_ptr: [N, NP2_T] +# e.g. for a sequence of length 4, the eagle tree attention structure is: +# retrieve_next_token=[1, 3, -1, -1] -> retrieve_next_token[i]: the 1st child token of token i +# retrieve_next_sibling=[-1, 2, -1, -1] -> retrieve_next_sibling[i]: the 1st tree sibling token of token i +# retrieve_parent_token=[n/a, 0, 0, 1] -> retrieve_parent_token[i]: the parent token of token i +# Tree: +# 0 +# / \ +# 1 2 +# / +# 3 +# When calculating token 3's convolution, it should conv to token 1 (parent) and token 0 (grand-parent) +# When calculating token 2's convolution, it should conv to token 0 (parent) +# This kernel is a fused kernel which will also produce retrieve_parent_token based on retrieve_next_token & retrieve_next_sibling +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + num_accepted_tokens_ptr, + intermediate_conv_window_ptr, + intermediate_state_indices_ptr, + retrieve_next_token_ptr, + retrieve_next_sibling_ptr, + retrieve_parent_token_ptr, + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_inter_seq: tl.constexpr, + stride_inter_step: tl.constexpr, + stride_inter_dim: tl.constexpr, + stride_inter_win: tl.constexpr, + stride_intermediate_state_indices: tl.constexpr, + stride_retrieve_next_token_seq: tl.constexpr, + stride_retrieve_next_token_token: tl.constexpr, + stride_retrieve_next_sibling_seq: tl.constexpr, + stride_retrieve_next_sibling_token: tl.constexpr, + stride_retrieve_parent_token_seq: tl.constexpr, + stride_retrieve_parent_token_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + NP2_SEQLEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, + SAVE_INTERMEDIATE: tl.constexpr, + HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + ).to(tl.int64) + if SAVE_INTERMEDIATE: + intermediate_state_batch_coord = tl.load( + intermediate_state_indices_ptr + idx_seq * stride_intermediate_state_indices + ).to(tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # The conv_state updates works in a sliding window manner, + # at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + # tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + conv_state_ptrs_target = ( + conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: + idx_tokens = tl.arange(0, NP2_SEQLEN) # [BLOCK_M] + # Update parent mapping for all tokens at once using vectorized operations + mask_retrieve = idx_tokens < seqlen + retrieve_next_token_base = ( + retrieve_next_token_ptr + + (idx_seq * stride_retrieve_next_token_seq) + + idx_tokens * stride_retrieve_next_token_token + ) + retrieve_next_tokens = tl.load(retrieve_next_token_base, mask_retrieve) + retrieve_next_sibling_base = ( + retrieve_next_sibling_ptr + + (idx_seq * stride_retrieve_next_sibling_seq) + + idx_tokens * stride_retrieve_next_sibling_token + ) + retrieve_next_siblings = tl.load(retrieve_next_sibling_base, mask_retrieve) + parent_idx_tokens = tl.zeros((NP2_SEQLEN,), dtype=tl.int32) + + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.static_range(seqlen): + acc = acc_preload + + if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: + # set the parent index of the next token in the eagle tree + # next token's parent is the current token + retrieve_next_token_idx = tl.sum( + tl.where(idx_tokens == idx_token, retrieve_next_tokens, 0) + ) + if retrieve_next_token_idx != -1: # pad slot id + parent_idx_tokens = tl.where( + idx_tokens == retrieve_next_token_idx, + idx_token, + parent_idx_tokens, + ) + # next token's parent is the parent of the current token + retrieve_sibling_token_idx = tl.sum( + tl.where(idx_tokens == idx_token, retrieve_next_siblings, 0) + ) + if retrieve_sibling_token_idx != -1: # pad slot id + parent_idx_token = tl.sum(tl.where(idx_tokens == idx_token, parent_idx_tokens, 0)) + parent_idx_tokens = tl.where( + idx_tokens == retrieve_sibling_token_idx, + parent_idx_token, + parent_idx_tokens, + ) + # tl.device_print("am", parent_idx_tokens) + + _idx_token = idx_token + x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + # convolution operation: itself * wcol[-1] + parent * wcol[-2] + grand-parent * wcol[-3] + ... + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 0: + matrix_w = w_col1 + else: + matrix_w = w_col0 + elif KERNEL_WIDTH == 3: + if j == 0: + matrix_w = w_col2 + elif j == 1: + matrix_w = w_col1 + else: + matrix_w = w_col0 + elif KERNEL_WIDTH == 4: + if j == 0: + matrix_w = w_col3 + elif j == 1: + matrix_w = w_col2 + elif j == 2: + matrix_w = w_col1 + else: + matrix_w = w_col0 + + if SAVE_INTERMEDIATE: + # Save the window state after consuming this token + # Layout: [seq(cache line), step, dim, win(K-1)] + base_ptr = ( + intermediate_conv_window_ptr + + intermediate_state_batch_coord * stride_inter_seq + + idx_token * stride_inter_step + + idx_feats * stride_inter_dim + ) + + # store itself in KERNEL_WIDTH-2 slot, parent in KERNEL_WIDTH-3 slot, grand-parent in KERNEL_WIDTH-4 slot, ... + if KERNEL_WIDTH - j - 2 >= 0: + tl.store( + base_ptr + (KERNEL_WIDTH - j - 2) * stride_inter_win, + matrix_x, + mask=mask_w, + ) + + acc += matrix_x * matrix_w + + # move to parent for next iteration + if _idx_token > 0: + _idx_token = tl.sum(tl.where(idx_tokens == _idx_token, parent_idx_tokens, 0)) + x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + else: + # no parent within the current chunk, load from prev conv state: col[-1] (idx 0's parent), col[-2] (idx 0's grand parent), ... + if KERNEL_WIDTH == 2: + if _idx_token == 0: + matrix_x = col0 + elif KERNEL_WIDTH == 3: + if _idx_token == 0: + matrix_x = col1 + else: + matrix_x = col0 + elif KERNEL_WIDTH == 4: + if _idx_token == 0: + matrix_x = col2 + elif _idx_token == -1: + matrix_x = col1 + else: + matrix_x = col0 + _idx_token = _idx_token - 1 + else: + matrix_w = w_col0 + matrix_x = col0 + + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SAVE_INTERMEDIATE: + # Save the window state after consuming this token + # Layout: [seq(cache line), step, dim, win(K-1)] + base_ptr = ( + intermediate_conv_window_ptr + + intermediate_state_batch_coord * stride_inter_seq + + idx_token * stride_inter_step + + idx_feats * stride_inter_dim + ) + if KERNEL_WIDTH >= 2: + tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) + if KERNEL_WIDTH >= 3: + tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) + if KERNEL_WIDTH >= 4: + tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & (idx_feats < dim) # token-index # feature-index + o_ptrs = ( + o_ptr + + (idx_seq) * stride_o_seq + + idx_token * stride_o_token + + (idx_feats * stride_o_dim) + ) + + tl.store(o_ptrs, acc, mask=mask_1d) + + # fuse: store calculated retrieve_parent_token to tensor + if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: + tl.store( + retrieve_parent_token_ptr + + idx_seq * stride_retrieve_parent_token_seq + + idx_tokens * stride_retrieve_parent_token_token, + parent_idx_tokens, + mask=mask_retrieve, + ) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + intermediate_conv_window: Optional[torch.Tensor] = None, + intermediate_state_indices: Optional[torch.Tensor] = None, + retrieve_next_token: Optional[torch.Tensor] = None, + retrieve_next_sibling: Optional[torch.Tensor] = None, + retrieve_parent_token: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + [shape=2: single token prediction] + [shape=3: single or multiple tokens prediction] + conv_state: (..., dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert conv_state.stride(-2) == 1, ( + f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + ) + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch,) == conv_state_indices.shape + assert intermediate_state_indices is not None + assert (batch,) == intermediate_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = torch.empty_like(x) + stride_w_dim, stride_w_width = weight.stride() + + stride_x_seq, stride_x_dim, stride_x_token = x.stride() # X (batch, dim, seqlen) + + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0 + stride_intermediate_state_indices = ( + intermediate_state_indices.stride(0) if intermediate_state_indices is not None else 0 + ) + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + np2_seqlen = triton.next_power_of_2(seqlen) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + # prepare intermediate buffer strides if provided + if intermediate_conv_window is not None: + stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( + intermediate_conv_window.stride(0), + intermediate_conv_window.stride(1), + intermediate_conv_window.stride(2), + intermediate_conv_window.stride(3), + ) + else: + stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 + + # prepare retrieve next token buffer strides if provided + if retrieve_next_token is not None: + stride_retrieve_next_token_seq, stride_retrieve_next_token_token = ( + retrieve_next_token.stride(0), + retrieve_next_token.stride(1), + ) + else: + stride_retrieve_next_token_seq = stride_retrieve_next_token_token = 0 + + # prepare retrieve next sibling buffer strides if provided + if retrieve_next_sibling is not None: + stride_retrieve_next_sibling_seq, stride_retrieve_next_sibling_token = ( + retrieve_next_sibling.stride(0), + retrieve_next_sibling.stride(1), + ) + else: + stride_retrieve_next_sibling_seq = stride_retrieve_next_sibling_token = 0 + + # prepare retrieve parent token buffer strides if provided + if retrieve_parent_token is not None: + stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token = ( + retrieve_parent_token.stride(0), + retrieve_parent_token.stride(1), + ) + else: + stride_retrieve_parent_token_seq = stride_retrieve_parent_token_token = 0 + + # come here + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + num_accepted_tokens, + intermediate_conv_window if intermediate_conv_window is not None else x, + intermediate_state_indices, + retrieve_next_token, + retrieve_next_sibling, + retrieve_parent_token, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_inter_seq, + stride_inter_step, + stride_inter_dim, + stride_inter_win, + stride_intermediate_state_indices, + stride_retrieve_next_token_seq, + stride_retrieve_next_token_token, + stride_retrieve_next_sibling_seq, + stride_retrieve_next_sibling_token, + stride_retrieve_parent_token_seq, + stride_retrieve_parent_token_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + NP2_SEQLEN=np2_seqlen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + SAVE_INTERMEDIATE=intermediate_conv_window is not None, + HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_next_token is not None, + ) + if unsqueeze: + out = out.squeeze(-1) + return out diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 44ab0e51a4..9c8f1d8cc1 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -24,8 +24,11 @@ from tensorrt_llm.mapping import Mapping from ...attention_backend import AttentionMetadata from ...model_config import ModelConfig +from ...speculative import SpecMetadata from ..linear import Linear, TensorParallelMode from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_triton import \ + causal_conv1d_update as causal_conv1d_update_triton from .layernorm_gated import RMSNorm as RMSNormGated from .selective_state_update import selective_state_update from .ssd_combined import mamba_chunk_scan_combined @@ -82,6 +85,8 @@ class Mamba2Mixer(nn.Module): self.tp_d_inner = d_inner // tp_size self.tp_nheads = nheads // tp_size self.tp_ngroups = n_groups // tp_size + self.num_heads = nheads + self.tp_size = tp_size self.layer_idx = layer_idx self.d_conv = d_conv @@ -167,6 +172,7 @@ class Mamba2Mixer(nn.Module): hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, mamba_metadata: Mamba2Metadata, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> torch.Tensor: @@ -175,6 +181,7 @@ class Mamba2Mixer(nn.Module): num_decodes = attn_metadata.seq_lens.shape[0] - num_prefills num_prefill_tokens = attn_metadata.num_ctx_tokens num_decode_tokens = attn_metadata.num_tokens - num_prefill_tokens + num_actual_tokens = attn_metadata.num_tokens seqlen_split_size = [num_prefill_tokens, num_decode_tokens] batch_split_size = [num_prefills, num_decodes] @@ -183,10 +190,10 @@ class Mamba2Mixer(nn.Module): state_indices_p, state_indices_d = torch.split(state_indices, batch_split_size) - conv_states = attn_metadata.kv_cache_manager.get_conv_states( - self.layer_idx) - ssm_states = attn_metadata.kv_cache_manager.get_ssm_states( + layer_cache = attn_metadata.kv_cache_manager.mamba_layer_cache( self.layer_idx) + conv_states = layer_cache.conv + ssm_states = layer_cache.temporal # in_proj zxbcdt = self.in_proj(hidden_states) @@ -199,7 +206,21 @@ class Mamba2Mixer(nn.Module): xbc_p, xbc_d = torch.split(xbc, seqlen_split_size, dim=0) dt_p, dt_d = torch.split(dt, seqlen_split_size, dim=0) - out = [] + # Preallocate output tensor to avoid memcpy cost for merging prefill + # and decode outputs + preallocated_ssm_out = torch.empty( + [ + zxbcdt.shape[0], + (self.num_heads * self.head_dim) // self.tp_size, + ], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( + preallocated_ssm_out, + [num_prefill_tokens, num_decode_tokens], + dim=0, + ) if num_prefills > 0: @@ -239,7 +260,7 @@ class Mamba2Mixer(nn.Module): has_initial_states[:, None, None, None], ssm_states[state_indices_p], 0) - y, current_ssm_states = mamba_chunk_scan_combined( + current_ssm_states = mamba_chunk_scan_combined( x_p, dt_p, self.A, @@ -247,30 +268,63 @@ class Mamba2Mixer(nn.Module): C_p, chunk_size=self.chunk_size, D=self.D, - z=z_p, + z=None, dt_bias=self.dt_bias, initial_states=initial_states, chunk_indices=mamba_metadata.chunk_indices, chunk_offsets=mamba_metadata.chunk_offsets, dt_softplus=self.delta_softplus, + dt_limit=(0.0, float("inf")), cu_seqlens=cu_seqlens, seq_idx=seq_idx, return_varlen_states=True, return_final_states=False, - mamba_ssm_cache_dtype=self._mamba_ssm_cache_dtype, + out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, + self.head_dim), + state_dtype=self._mamba_ssm_cache_dtype, ) - out.append(rearrange(y, "b l h p -> (b l) (h p)")) # copy new ssm state ssm_states[state_indices_p] = current_ssm_states if num_decodes > 0: - xbc_d = causal_conv1d_update(xbc_d, - conv_states, - self.conv1d.weight, - self.conv1d.bias, - activation="silu", - conv_state_indices=state_indices_d) + is_target_verify = attn_metadata.kv_cache_manager.is_speculative( + ) and spec_metadata is not None + if is_target_verify: + # TODO: support dynamic speculation, will add current_draft_len later [TRTLLM-10319] + draft_token_num = spec_metadata.max_draft_len + 1 + intermediate_conv_states = layer_cache.intermediate_conv_window + + self.intermediate_state_indices = torch.arange( + num_decodes, + dtype=torch.int32, + device=state_indices_d.device) + + # Reshape for batch processing + xbc_d_reshaped = xbc_d.view(num_decodes, draft_token_num, + -1).transpose(1, 2) + # TODO:support tree structure [TRTLLM-10320] + xbc_d_processed = causal_conv1d_update_triton( + xbc_d_reshaped, + conv_states, + self.conv1d.weight, + self.conv1d.bias, + activation="silu", + conv_state_indices=state_indices_d[:num_decodes], + intermediate_conv_window=intermediate_conv_states, + intermediate_state_indices=self.intermediate_state_indices, + ) + + xbc_d = xbc_d_processed.transpose(1, 2).view( + num_decode_tokens, -1) + + else: + xbc_d = causal_conv1d_update(xbc_d, + conv_states, + self.conv1d.weight, + self.conv1d.bias, + activation="silu", + conv_state_indices=state_indices_d) x_d, B_d, C_d = torch.split( xbc_d, @@ -292,29 +346,64 @@ class Mamba2Mixer(nn.Module): n=self.d_state).to(dtype=torch.float32) dt_bias = repeat(self.dt_bias, "h -> h p", p=self.head_dim) D = repeat(self.D, "h -> h p", p=self.head_dim) + if is_target_verify: + intermediate_ssm_states = layer_cache.intermediate_ssm + selective_state_update( + ssm_states, + x_d.view( + num_decodes, + draft_token_num, + self.num_heads // self.tp_size, + self.head_dim, + ), + dt_d.view( + num_decodes, + draft_token_num, + self.num_heads // self.tp_size, + self.head_dim, + ), + A, + B_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1), + C_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1), + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices_d[:num_decodes], + out=preallocated_ssm_out_d.view( + num_decodes, + draft_token_num, + self.num_heads // self.tp_size, + self.head_dim, + ), + disable_state_update=True, + intermediate_states_buffer=intermediate_ssm_states, + cache_steps=draft_token_num, + intermediate_state_indices=self.intermediate_state_indices, + ) - y = selective_state_update( - ssm_states, - x_d, - dt_d, - A, - B_d, - C_d, - D, - z=z_d, - dt_bias=dt_bias, - dt_softplus=self.delta_softplus, - state_batch_indices=state_indices_d, - ) + else: - out.append(rearrange(y, "b h p -> b (h p)")) - - out = torch.cat(out, dim=0) + selective_state_update( + ssm_states, + x_d, + dt_d, + A, + B_d, + C_d, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=self.delta_softplus, + state_batch_indices=state_indices_d, + out=preallocated_ssm_out_d.view(num_decodes, -1, + self.head_dim), + ) # norm - out = self.norm(out) + hidden_states = self.norm(preallocated_ssm_out, z[:num_actual_tokens]) # out_proj - out = self.out_proj(out) + out = self.out_proj(hidden_states) - return out + return out[:num_actual_tokens] diff --git a/tensorrt_llm/_torch/modules/mamba/selective_state_update.py b/tensorrt_llm/_torch/modules/mamba/selective_state_update.py index 56a8586ce1..a1ba2dfa0a 100644 --- a/tensorrt_llm/_torch/modules/mamba/selective_state_update.py +++ b/tensorrt_llm/_torch/modules/mamba/selective_state_update.py @@ -1,7 +1,4 @@ -# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py -# Copyright (c) 2024, Tri Dao, Albert Gu. -# -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# Adapted from: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +# SPDX-FileCopyrightText: Copyright contributors to the sglang project +# +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py import torch import triton @@ -35,7 +38,19 @@ from .softplus import softplus }) @triton.heuristics( {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) -@triton.jit +@triton.heuristics({ + "CACHE_INTERMEDIATE_STATES": + lambda args: args["intermediate_states_buffer"] is not None +}) +@triton.heuristics({ + "HAS_EAGLE_TREE_CUSTOM_ATTN_MASK": + lambda args: args["retrieve_parent_token_ptr"] is not None +}) +@triton.heuristics({ + "HAS_INTERMEDIATE_STATE_INDICES": + lambda args: args["intermediate_state_indices_ptr"] is not None +}) +@triton.jit(do_not_specialize=["T"]) def _selective_scan_update_kernel( # Pointers to matrices state_ptr, @@ -50,8 +65,13 @@ def _selective_scan_update_kernel( out_ptr, state_batch_indices_ptr, pad_slot_id, + intermediate_states_buffer, + cache_steps, + retrieve_parent_token_ptr, + intermediate_state_indices_ptr, # Matrix dimensions batch, + T, nheads, dim, dstate, @@ -62,9 +82,11 @@ def _selective_scan_update_kernel( stride_state_dim, stride_state_dstate, stride_x_batch, + stride_x_T, stride_x_head, stride_x_dim, stride_dt_batch, + stride_dt_T, stride_dt_head, stride_dt_dim, stride_dt_bias_head, @@ -73,19 +95,25 @@ def _selective_scan_update_kernel( stride_A_dim, stride_A_dstate, stride_B_batch, + stride_B_T, stride_B_group, stride_B_dstate, stride_C_batch, + stride_C_T, stride_C_group, stride_C_dstate, stride_D_head, stride_D_dim, stride_z_batch, + stride_z_T, stride_z_head, stride_z_dim, stride_out_batch, + stride_out_T, stride_out_head, stride_out_dim, + stride_retrieve_parent_token_batch, + stride_retrieve_parent_token_T, # Meta-parameters DT_SOFTPLUS: tl.constexpr, TIE_HDIM: tl.constexpr, @@ -94,6 +122,10 @@ def _selective_scan_update_kernel( HAS_D: tl.constexpr, HAS_Z: tl.constexpr, HAS_STATE_BATCH_INDICES: tl.constexpr, + DISABLE_STATE_UPDATE: tl.constexpr, + CACHE_INTERMEDIATE_STATES: tl.constexpr, + HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr, + HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) @@ -105,7 +137,7 @@ def _selective_scan_update_kernel( # is the same as the batch id. if HAS_STATE_BATCH_INDICES: state_batch_indices_ptr += pid_b - state_batch_idx = tl.load(state_batch_indices_ptr) + state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head else: state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head @@ -127,91 +159,153 @@ def _selective_scan_update_kernel( offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) - x_ptrs = x_ptr + offs_m * stride_x_dim - dt_ptrs = dt_ptr + offs_m * stride_dt_dim + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= state_batch_idx != pad_slot_id + state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) + if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: D_ptr += pid_h * stride_D_head - A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + - offs_n[None, :] * stride_A_dstate) - B_ptrs = B_ptr + offs_n * stride_B_dstate - C_ptrs = C_ptr + offs_n * stride_C_dstate - if HAS_D: D_ptrs = D_ptr + offs_m * stride_D_dim - if HAS_Z: - z_ptrs = z_ptr + offs_m * stride_z_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= (state_batch_idx != pad_slot_id) - state = tl.load(state_ptrs, mask=mask, other=0.0) + A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[ + None, :] * stride_A_dstate - x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if not TIE_HDIM: - dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, - other=0.0).to(tl.float32) - if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, softplus(dt), dt) - A = tl.load(A_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), - other=0.0).to(tl.float32) - dA = tl.exp(A * dt[:, None]) - else: - dt = tl.load(dt_ptr).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptr).to(tl.float32) - if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, softplus(dt), dt) - A = tl.load(A_ptr).to(tl.float32) - dA = tl.exp(A * dt) # scalar, not a matrix + cache_idx = -1 + if CACHE_INTERMEDIATE_STATES: + if HAS_INTERMEDIATE_STATE_INDICES: + intermediate_state_idx = tl.load(intermediate_state_indices_ptr + + pid_b).to(tl.int64) + cache_idx = intermediate_state_idx + elif HAS_STATE_BATCH_INDICES: + cache_idx = state_batch_idx + else: + cache_idx = pid_b - B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - if HAS_D: - D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_Z: - z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + current_step_idx = 0 + for _ in range(T): + if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: + if current_step_idx != 0 and cache_idx >= 0: + parent_ptr = (retrieve_parent_token_ptr + + pid_b * stride_retrieve_parent_token_batch + + current_step_idx * stride_retrieve_parent_token_T) + parent_step_idx = tl.load(parent_ptr).to(tl.int32) - if not TIE_HDIM: - dB = B[None, :] * dt[:, None] - else: - dB = B * dt # vector of size (dstate,) - state = state * dA + dB * x[:, None] + if parent_step_idx >= 0 and parent_step_idx < T: + step_offset = parent_step_idx * nheads * dim * dstate + cache_ptr = ( + intermediate_states_buffer + + cache_idx * cache_steps * nheads * dim * dstate + + step_offset + pid_h * dim * dstate + + offs_m[:, None] * dstate + offs_n[None, :]) + state = tl.load(cache_ptr, mask=mask, + other=0.0).to(tl.float32) - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= (state_batch_idx != pad_slot_id) - tl.store(state_ptrs, state, mask=mask) - out = tl.sum(state * C[None, :], axis=1) - if HAS_D: - out += x * D - if HAS_Z: - out *= z * tl.sigmoid(z) - tl.store(out_ptrs, out, mask=offs_m < dim) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load( + A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + + if CACHE_INTERMEDIATE_STATES: + if HAS_STATE_BATCH_INDICES: + if state_batch_idx != pad_slot_id: + cache_ptr_base = ( + intermediate_states_buffer + + cache_idx * cache_steps * nheads * dim * dstate + + current_step_idx * nheads * dim * dstate + + pid_h * dim * dstate) + cache_ptrs = cache_ptr_base + (offs_m[:, None] * dstate + + offs_n[None, :]) + tl.store(cache_ptrs, + state.to(cache_ptrs.dtype.element_ty), + mask=mask) + + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + current_step_idx += 1 + + x_ptr += stride_x_T + dt_ptr += stride_dt_T + B_ptr += stride_B_T + C_ptr += stride_C_T + out_ptr += stride_out_T + if HAS_Z: + z_ptr += stride_z_T + + if not DISABLE_STATE_UPDATE: + tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask) -def selective_state_update(state, - x, - dt, - A, - B, - C, - D=None, - z=None, - dt_bias=None, - dt_softplus=False, - state_batch_indices=None, - pad_slot_id=PAD_SLOT_ID): +def selective_state_update( + state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID, + out=None, + disable_state_update=False, + intermediate_states_buffer=None, + cache_steps=None, + retrieve_parent_token=None, + intermediate_state_indices=None, +): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) + x: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token dt: (batch, dim) or (batch, nheads, dim) A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token C: (batch, dstate) or (batch, ngroups, dstate) D: (dim,) or (nheads, dim) z: (batch, dim) or (batch, nheads, dim) @@ -222,38 +316,58 @@ def selective_state_update(state, for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - Return: - out: (batch, dim) or (batch, nheads, dim) + out: Preallocated ssm output tensor. Assume same shape as x. + In-place updated. + disable_state_update: If True, don't write back to state (for speculative verify) + intermediate_states_buffer: Buffer to cache intermediate states + cache_steps: Total number of steps in the buffer + retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention + intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations. + If provided, uses these indices instead of state_batch_indices for the buffer. """ - has_heads = state.dim() > 3 if state.dim() == 3: state = state.unsqueeze(1) if x.dim() == 2: x = x.unsqueeze(1) + if x.dim() == 3: + x = x.unsqueeze(1) if dt.dim() == 2: dt = dt.unsqueeze(1) + if dt.dim() == 3: + dt = dt.unsqueeze(1) if A.dim() == 2: A = A.unsqueeze(0) if B.dim() == 2: B = B.unsqueeze(1) + if B.dim() == 3: + B = B.unsqueeze(1) if C.dim() == 2: C = C.unsqueeze(1) + if C.dim() == 3: + C = C.unsqueeze(1) if D is not None and D.dim() == 1: D = D.unsqueeze(0) - if z is not None and z.dim() == 2: - z = z.unsqueeze(1) + if z is not None: + if z.dim() == 2: + z = z.unsqueeze(1) + if z.dim() == 3: + z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) + if out.dim() == 2: + out = out.unsqueeze(1) + if out.dim() == 3: + out = out.unsqueeze(1) _, nheads, dim, dstate = state.shape - batch = x.shape[0] + batch, T, _, _ = x.shape - assert x.shape == (batch, nheads, dim) + assert x.shape == (batch, T, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) - ngroups = B.shape[1] + ngroups = B.shape[2] assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert B.shape == (batch, ngroups, dstate) + assert B.shape == (batch, T, ngroups, dstate) assert C.shape == B.shape if D is not None: assert D.shape == (nheads, dim) @@ -263,10 +377,11 @@ def selective_state_update(state, assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: assert state_batch_indices.shape == (batch, ) - out = torch.empty_like(x) + assert out.shape == x.shape + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) - z_strides = (z.stride(0), z.stride(1), - z.stride(2)) if z is not None else (0, 0, 0) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), + z.stride(3)) if z is not None else (0, 0, 0, 0)) # We don't want autotune since it will overwrite the state # We instead tune by hand. BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else @@ -275,6 +390,12 @@ def selective_state_update(state, ((4, 4) if dstate <= 128 else ((4, 8)))))) tie_hdim = (A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0) + + retrieve_parent_token_strides = ((retrieve_parent_token.stride(0), + retrieve_parent_token.stride(1)) + if retrieve_parent_token is not None else + (0, 0)) + with torch.cuda.device(x.device.index): _selective_scan_update_kernel[grid]( state, @@ -289,7 +410,12 @@ def selective_state_update(state, out, state_batch_indices, pad_slot_id, + intermediate_states_buffer, + cache_steps if cache_steps is not None else 0, + retrieve_parent_token, + intermediate_state_indices, batch, + T, nheads, dim, dstate, @@ -301,9 +427,11 @@ def selective_state_update(state, x.stride(0), x.stride(1), x.stride(2), + x.stride(3), dt.stride(0), dt.stride(1), dt.stride(2), + dt.stride(3), *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, A.stride(0), @@ -312,21 +440,25 @@ def selective_state_update(state, B.stride(0), B.stride(1), B.stride(2), + B.stride(3), C.stride(0), C.stride(1), C.stride(2), + C.stride(3), *(D.stride(0), D.stride(1)) if D is not None else 0, z_strides[0], z_strides[1], z_strides[2], + z_strides[3], out.stride(0), out.stride(1), out.stride(2), + out.stride(3), + retrieve_parent_token_strides[0], + retrieve_parent_token_strides[1], dt_softplus, tie_hdim, BLOCK_SIZE_M, + DISABLE_STATE_UPDATE=disable_state_update, num_warps=num_warps, ) - if not has_heads: - out = out.squeeze(1) - return out diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py index 23b55d8811..c41e6b47f0 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py @@ -234,6 +234,7 @@ def _chunk_scan_fwd_kernel( # M-block offsets and prev states # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) prev_states_ptr = (states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head) prev_states_hdim = stride_states_hdim @@ -269,11 +270,12 @@ def _chunk_scan_fwd_kernel( ): # - replace prev_states_ptr with init_states - prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_ptr = (initstates_ptr + + seq_idx_m * stride_init_states_batch + + pid_h * stride_init_states_head) prev_states_hdim = stride_init_states_hdim # override strides prev_states_dstate = stride_init_states_dstate - offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, @@ -289,7 +291,7 @@ def _chunk_scan_fwd_kernel( c_idx_n = tl.load( chunk_indices_ptr + (pid_c + 1), mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=-1 # to trigger different chunk + other=-1, # to trigger different chunk ) # - there are things to consider @@ -304,9 +306,11 @@ def _chunk_scan_fwd_kernel( if (c_idx == c_idx_n) or c_off > 0: # get the next offset - c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=chunk_size) + c_off_n = tl.load( + chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size, + ) # in this case, adjust down the chunk_size_limit if c_idx == c_idx_n: @@ -319,8 +323,9 @@ def _chunk_scan_fwd_kernel( # i.e. the same for all blocks) dA_cs_m_boundary = tl.load( dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, - mask=(((c_off - 1) > -1) and (c_off < chunk_size)), - other=0.0).to(tl.float32) + mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), + other=0.0, + ).to(tl.float32) if HAS_SEQ_IDX: # - handle seq idx when HAS_INITSTATES==False @@ -416,8 +421,7 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. - # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) - cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0)) + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k @@ -494,18 +498,21 @@ def _chunk_scan_fwd_kernel( ) -def _chunk_scan_fwd(cb, - x, - dt, - dA_cumsum, - C, - states, - D=None, - z=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - initial_states=None): +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + initial_states=None, + out=None, +): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = C.shape @@ -527,34 +534,17 @@ def _chunk_scan_fwd(cb, # with initial states, we need to take care of how # seq_idx crosses the boundaries assert batch == 1, "chunk scan only supports initial states with batch 1" - - if initial_states.shape[0] == 1: - # no in this case no point to use initial states - initial_states = None - else: - assert chunk_indices is not None and chunk_offsets is not None, \ - ( - "chunk_indices and chunk_offsets should have been set" - ) + assert (chunk_indices is not None and chunk_offsets is not None + ), "chunk_indices and chunk_offsets should have been set" else: chunk_indices, chunk_offsets = None, None else: chunk_indices, chunk_offsets = None, None - # Allocates output. - out = torch.empty(batch, - seqlen, - nheads, - headdim, - device=x.device, - dtype=x.dtype) + assert out.shape == x.shape + if z is not None: - out_x = torch.empty(batch, - seqlen, - nheads, - headdim, - device=x.device, - dtype=x.dtype) + out_x = torch.empty_like(x) assert out_x.stride() == out.stride() else: out_x = None @@ -625,10 +615,12 @@ def _chunk_scan_fwd(cb, states.stride(2), states.stride(3), states.stride(4), - *((initial_states.stride(0), initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3)) if initial_states is not None else - (0, 0, 0, 0)), + *(( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) if initial_states is not None else (0, 0, 0, 0)), D.stride(0) if D is not None else 0, True, D is not None, @@ -639,4 +631,4 @@ def _chunk_scan_fwd(cb, IS_TRITON_22=TRITON_22, HAS_INITSTATES=initial_states is not None, ) - return out, out_x + return out_x diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py index 8edbe902bd..a5916657c3 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py @@ -16,8 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - import torch from einops import rearrange @@ -28,6 +26,10 @@ from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, from .ssd_state_passing import _state_passing_fwd +def is_int_pow_2(n): + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + def _mamba_chunk_scan_combined_fwd( x, dt, @@ -45,13 +47,14 @@ def _mamba_chunk_scan_combined_fwd( cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), - mamba_ssm_cache_dtype=None, + state_dtype=None, + out=None, ): + assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) - assert x.shape == (batch, seqlen, nheads, headdim) assert dt.shape == (batch, seqlen, nheads) assert A.shape == (nheads, ) assert C.shape == B.shape @@ -77,8 +80,12 @@ def _mamba_chunk_scan_combined_fwd( if cu_seqlens is None: assert initial_states.shape == (batch, nheads, headdim, dstate) else: - assert initial_states.shape == (len(cu_seqlens) - 1, nheads, - headdim, dstate) + assert initial_states.shape == ( + len(cu_seqlens) - 1, + nheads, + headdim, + dstate, + ) # This function executes 5 sub-functions for computing mamba # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ @@ -125,13 +132,12 @@ def _mamba_chunk_scan_combined_fwd( if initial_states is not None else None), seq_idx=seq_idx, chunk_size=chunk_size, - out_dtype=mamba_ssm_cache_dtype or C.dtype, + out_dtype=state_dtype if state_dtype is not None else C.dtype, is_cont_batched=cu_seqlens is not None, - chunk_offsets=chunk_offsets) - states, final_states = [ - rearrange(t, "... (p n) -> ... p n", n=dstate) - for t in [states, final_states] - ] + chunk_offsets=chunk_offsets, + ) + states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states]) # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, @@ -150,20 +156,23 @@ def _mamba_chunk_scan_combined_fwd( # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had # a seq_idx change, in which case we take states information from # init_states. - out, out_x = _chunk_scan_fwd(CB, - x, - dt, - dA_cumsum, - C, - states, - D=D, - z=z, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, - initial_states=initial_states) + out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + initial_states=initial_states, + out=out, + ) if cu_seqlens is None: - return out, out_x, dt, dA_cumsum, states, final_states + return out_x, dt, dA_cumsum, states, final_states else: assert ( batch == 1 @@ -177,29 +186,31 @@ def _mamba_chunk_scan_combined_fwd( states.squeeze(0), initial_states=initial_states, ) - return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + return out_x, dt, dA_cumsum, states, final_states, varlen_states def mamba_chunk_scan_combined( - x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - return_final_states=False, - return_varlen_states=False, - mamba_ssm_cache_dtype: Optional[torch.dtype] = None): + x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + out=None, + return_final_states=False, + return_varlen_states=False, + state_dtype=None, +): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -215,38 +226,42 @@ def mamba_chunk_scan_combined( seq_idx: (batch, seqlen) cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt - mamba_ssm_cache_dtype: torch.dtype, default to None - Return: - out: (batch, seqlen, nheads, headdim) + out: Preallocated output tensor + state_dtype: The data type of the ssm state """ if not return_varlen_states: cu_seqlens = None else: - assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" - out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( - x, - dt, - A, - B, - C, - chunk_size, - D=D, - z=z, - dt_bias=dt_bias, - initial_states=initial_states, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, - cu_seqlens=cu_seqlens, - dt_softplus=dt_softplus, - dt_limit=dt_limit, - mamba_ssm_cache_dtype=mamba_ssm_cache_dtype) + assert (cu_seqlens is not None + ), "cu_seqlens must be provided if return_varlen_states is True" + out_x, dt_out, dA_cumsum, states, final_states, *rest = ( + _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + out=out, + state_dtype=state_dtype, + )) if not return_varlen_states: - return out if not return_final_states else (out, final_states) + if not return_final_states: + return + else: + return final_states else: varlen_states = rest[0] - return (out, - varlen_states) if not return_final_states else (out, - final_states, - varlen_states) + return ((varlen_states) if not return_final_states else + (final_states, varlen_states)) diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index 55b20937f9..ff0f0330e7 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Union import torch @@ -22,11 +23,45 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import ( BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers) from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests from tensorrt_llm.llmapi.llm_args import KvCacheConfig +from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping +if TYPE_CHECKING: + from tensorrt_llm._torch.attention_backend.interface import \ + AttentionMetadata + +GB = 1 << 30 + + +def get_tensor_size_bytes(tensor): + """Calculate tensor size in bytes.""" + if isinstance(tensor, torch.Tensor): + return tensor.element_size() * tensor.nelement() + elif isinstance(tensor, list): + return sum(get_tensor_size_bytes(t) for t in tensor) + return 0 + class MambaCacheManager(BaseResourceManager): + @dataclass(frozen=True, kw_only=True) + class State: + """Base state container for Mamba cache.""" + conv: torch.Tensor + temporal: torch.Tensor + + def at_layer_idx(self, layer: int): + kwargs = {} + for k, v in vars(self).items(): + kwargs[k] = v[layer] + return type(self)(**kwargs) + + @dataclass(frozen=True, kw_only=True) + class SpeculativeState(State): + """Speculative state with intermediate states for draft tokens.""" + intermediate_ssm: torch.Tensor + intermediate_conv_window: torch.Tensor + def __init__( self, d_state: int, @@ -36,13 +71,17 @@ class MambaCacheManager(BaseResourceManager): head_dim: int, num_layers: int, max_batch_size: int, + spec_state_size: int, mapping: Mapping, dtype: torch.dtype, ssm_cache_dtype: torch.dtype, layer_mask: Optional[List[bool]] = None, + speculative_num_draft_tokens: Optional[int] = None, ) -> None: self.mamba_ssm_cache_dtype = ssm_cache_dtype + self.speculative_num_draft_tokens = speculative_num_draft_tokens + self.spec_state_size = spec_state_size # get tp size tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1 @@ -74,31 +113,69 @@ class MambaCacheManager(BaseResourceManager): for offset, idx in enumerate(pp_layers) } - # mamba conv states - self.conv_states = torch.empty( - size=[ - num_local_layers, - max_batch_size, - conv_dim, - d_conv - 1, - ], + conv_state_shape = (conv_dim, d_conv - 1) + ssm_state_shape = (nheads, head_dim, d_state) + + # create mamba conv and ssm states + conv_states = torch.empty( + size=(num_local_layers, max_batch_size) + conv_state_shape, dtype=dtype, device=device, ) - # mamba ssm states - self.ssm_states = torch.empty( - size=[ - num_local_layers, - max_batch_size, - nheads, - head_dim, - d_state, - ], + ssm_states = torch.empty( + size=(num_local_layers, max_batch_size) + ssm_state_shape, dtype=self.mamba_ssm_cache_dtype, device=device, ) + # create state container + if speculative_num_draft_tokens is not None: + + # Cache intermediate SSM states per draft token(include new sampled token) during target model verification phase + intermediate_ssm_states = torch.zeros( + size=(num_local_layers, self.spec_state_size, + speculative_num_draft_tokens + 1) + ssm_state_shape, + dtype=self.mamba_ssm_cache_dtype, + device=device, + ) + + # Cache intermediate conv windows per draft token(include new sampled token) during target model verification phase + intermediate_conv_window_cache = torch.zeros( + size=(num_local_layers, self.spec_state_size, + speculative_num_draft_tokens + 1) + conv_state_shape, + dtype=dtype, + device=device, + ) + + self.mamba_cache = self.SpeculativeState( + conv=conv_states, + temporal=ssm_states, + intermediate_ssm=intermediate_ssm_states, + intermediate_conv_window=intermediate_conv_window_cache, + ) + + logger.info( + f"Mamba Cache is allocated. " + f"max_mamba_cache_size: {max_batch_size}, " + f"conv_state size: {get_tensor_size_bytes(conv_states) / GB:.2f}GB, " + f"ssm_state size: {get_tensor_size_bytes(ssm_states) / GB:.2f}GB, " + f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_states) / GB:.2f}GB, " + f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB" + ) + else: + self.mamba_cache = self.State( + conv=conv_states, + temporal=ssm_states, + ) + + logger.info( + f"Mamba Cache is allocated. " + f"max_mamba_cache_size: {max_batch_size}, " + f"conv_state size: {get_tensor_size_bytes(conv_states) / GB:.2f}GB, " + f"ssm_state size: {get_tensor_size_bytes(ssm_states) / GB:.2f}GB" + ) + # mamba cache available blocks self.mamba_cache_free_blocks = [i for i in range(max_batch_size)] @@ -111,6 +188,10 @@ class MambaCacheManager(BaseResourceManager): dtype=torch.int32) # save mamba state indices for requests self.state_indices_list: List[int] = [] + # save intermediate state indices for requests + self.intermediate_state_indices = torch.arange(max_batch_size, + dtype=torch.int32, + device=device) def _prepare_mamba_cache_blocks(self, request_ids: List[int]): self.state_indices_list.clear() @@ -184,22 +265,85 @@ class MambaCacheManager(BaseResourceManager): def get_conv_states(self, layer_idx: int) -> torch.Tensor: layer_offset = self.mamba_layer_offsets[layer_idx] - return self.conv_states[layer_offset] + return self.mamba_cache.at_layer_idx(layer_offset).conv def get_ssm_states(self, layer_idx: int) -> torch.Tensor: layer_offset = self.mamba_layer_offsets[layer_idx] - return self.ssm_states[layer_offset] + return self.mamba_cache.at_layer_idx(layer_offset).temporal - def get_mamba_ssm_cache_dtype(self) -> torch.dtype: - return self.mamba_ssm_cache_dtype + def get_intermediate_ssm_states(self, + layer_idx: int) -> Optional[torch.Tensor]: + if not isinstance(self.mamba_cache, self.SpeculativeState): + return None + layer_offset = self.mamba_layer_offsets[layer_idx] + return self.mamba_cache.at_layer_idx(layer_offset).intermediate_ssm + + def get_intermediate_conv_states(self, + layer_idx: int) -> Optional[torch.Tensor]: + """Get intermediate conv states for speculative decoding.""" + if not isinstance(self.mamba_cache, self.SpeculativeState): + return None + layer_offset = self.mamba_layer_offsets[layer_idx] + return self.mamba_cache.at_layer_idx( + layer_offset).intermediate_conv_window + + def is_speculative(self) -> bool: + return isinstance(self.mamba_cache, self.SpeculativeState) + + def mamba_layer_cache(self, + layer_idx: int) -> Union[State, SpeculativeState]: + layer_offset = self.mamba_layer_offsets[layer_idx] + return self.mamba_cache.at_layer_idx(layer_offset) def shutdown(self): - # release tensor memory, keeping python references as tensors - self.conv_states = torch.tensor([]) - self.ssm_states = torch.tensor([]) + """Release tensor memory.""" + # Clear state indices self.state_indices = torch.tensor([]) + + # Clear mamba cache states + if isinstance(self.mamba_cache, self.SpeculativeState): + self.mamba_cache = self.SpeculativeState( + conv=torch.tensor([]), + temporal=torch.tensor([]), + intermediate_ssm=torch.tensor([]), + intermediate_conv_window=torch.tensor([]), + ) + else: + self.mamba_cache = self.State( + conv=torch.tensor([]), + temporal=torch.tensor([]), + ) + torch.cuda.empty_cache() + @torch.compile(options={"max-autotune": True}) + def update_mamba_states(self, attn_metadata: "AttentionMetadata", + num_accepted_tokens: torch.Tensor): + batch_size = attn_metadata.num_seqs + num_contexts = attn_metadata.num_contexts + num_gens = batch_size - num_contexts + num_accepted_draft_tokens = num_accepted_tokens[ + num_contexts:num_contexts + num_gens] - 1 + state_indices_d = self.state_indices[num_contexts:num_contexts + + num_gens] + + conv_states = self.mamba_cache.conv + ssm_states = self.mamba_cache.temporal + + intermediate_state_cache = self.mamba_cache.intermediate_ssm + intermediate_conv_window_cache = self.mamba_cache.intermediate_conv_window + + src_state_indices = self.intermediate_state_indices[:num_gens] + + accepted_ssm_state = intermediate_state_cache[:, src_state_indices, + num_accepted_draft_tokens] + ssm_states[:, state_indices_d, :] = accepted_ssm_state + + accepted_conv_state = intermediate_conv_window_cache[:, + src_state_indices, + num_accepted_draft_tokens] + conv_states[:, state_indices_d, :] = accepted_conv_state + class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): @@ -249,10 +393,13 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): mamba_head_dim, mamba_num_layers, max_batch_size, + max_batch_size, mapping, mamba_cache_dtype, mamba_ssm_cache_dtype, mamba_layer_mask, + speculative_num_draft_tokens=spec_config.max_draft_len + if spec_config is not None else None, ) # initialize kv cache manager @@ -285,3 +432,15 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager): def shutdown(self): MambaCacheManager.shutdown(self) KVCacheManager.shutdown(self) + + def update_resources(self, + scheduled_batch: ScheduledRequests, + attn_metadata: "AttentionMetadata" = None, + kv_cache_dtype_byte_size: float = None): + KVCacheManager.update_resources(self, scheduled_batch, attn_metadata, + kv_cache_dtype_byte_size) + + def update_mamba_states(self, attn_metadata: "AttentionMetadata", + num_accepted_tokens: torch.Tensor): + MambaCacheManager.update_mamba_states(self, attn_metadata, + num_accepted_tokens) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index bfdfb39af7..a55e7d3402 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1469,10 +1469,12 @@ class ResourceManager: resource_manager.prepare_resources(scheduled_batch) @nvtx_range("update_resources") - def update_resources(self, - scheduled_batch: ScheduledRequests, - attn_metadata: Optional["AttentionMetadata"] = None, - kv_cache_dtype_byte_size: Optional[float] = None): + def update_resources( + self, + scheduled_batch: ScheduledRequests, + attn_metadata: Optional["AttentionMetadata"] = None, + kv_cache_dtype_byte_size: Optional[float] = None, + ): for _, resource_manager in self.resource_managers.items(): if hasattr(resource_manager, "update_resources"): if isinstance(resource_manager, KVCacheManager): diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index f2388b9851..845455d229 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, List, Optional import torch import torch.nn.functional as F +from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \ + MambaHybridCacheManager from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata @@ -1132,6 +1134,7 @@ class MTPEagleWorker(MTPWorker): super().__init__(spec_config, model_config) self.model_config = model_config self.mtp_num_modules = spec_config.num_nextn_predict_layers + self._is_mamba_hybrid_cache = None @torch.compile(options={"max-autotune": True}) def update_draft_tokens(self, next_draft_tokens, new_draft_token, @@ -1164,6 +1167,14 @@ class MTPEagleWorker(MTPWorker): accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens( input_ids, logits, spec_metadata, attn_metadata) + if self._is_mamba_hybrid_cache is None: + self._is_mamba_hybrid_cache = isinstance( + attn_metadata.kv_cache_manager, MambaHybridCacheManager) + if num_gens > 0 and self._is_mamba_hybrid_cache: + attn_metadata.kv_cache_manager.update_mamba_states( + attn_metadata=attn_metadata, + num_accepted_tokens=num_accepted_tokens) + # Save the old attn_metadata and spec_metadata self._prepare_attn_metadata_for_spec_dec(attn_metadata) diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 70c3f18757..7cc38a742c 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -348,6 +348,11 @@ nvidia/Nemotron-Super-V3: - quant_algo: NVFP4 kv_cache_quant_algo: FP8 accuracy: 80.85 + - quant_algo: NVFP4 + kv_cache_quant_algo: FP8 + mtp_enabled: true + num_nextn_predict_layers: 3 + accuracy: 80.85 nvidia/Nemotron-3-Nano: - accuracy: 69.37 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 479ecca029..51198b62fe 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -382,6 +382,11 @@ nvidia/Nemotron-Super-V3: - quant_algo: NVFP4 kv_cache_quant_algo: FP8 accuracy: 77.56 + - quant_algo: NVFP4 + kv_cache_quant_algo: FP8 + mtp_enabled: true + num_nextn_predict_layers: 3 + accuracy: 77.56 nvidia/Nemotron-3-Nano: - accuracy: 73.85 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index f6d0f42304..1d2f80ac09 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -5800,6 +5800,42 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness): task.evaluate(llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + @pytest.mark.skip(reason="Skip MTP test due to no model path file in CI") + @skip_pre_blackwell + @pytest.mark.skip_less_mpi_world_size(8) + def test_nvfp4_8gpus_mtp(self): + # Test MTP (Multi-Token Prediction) accuracy with nvfp4-fp8kv model. + # This test uses MTP with max_draft_len=3 and one_model mode. + mtp_config = MTPDecodingConfig( + num_nextn_predict_layers=3, + mtp_eagle_one_model=True, + ) + model_path = f"{llm_models_root()}/nemotron-super-sft-repeated-mtp-iter-0010600-nvfp4-fp8kv" + with LLM( + model_path, + kv_cache_config=KvCacheConfig( + enable_block_reuse=False, + mamba_ssm_cache_dtype="float16", + free_gpu_memory_fraction=0.5, + ), + max_batch_size=128, + tensor_parallel_size=8, + moe_expert_parallel_size=8, + pipeline_parallel_size=1, + enable_attention_dp=False, + cuda_graph_config=CudaGraphConfig(max_batch_size=32, + enable_padding=True), + disable_overlap_scheduler=False, + moe_config=MoeConfig(backend="CUTLASS"), + decoding_config=mtp_config, + ) as llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, + extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm, + extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS) + @skip_pre_hopper class TestMiniMaxM2(LlmapiAccuracyTestHarness): diff --git a/tests/unittest/_torch/thop/parallel/test_mamba2_chunk_ss_update.py b/tests/unittest/_torch/thop/parallel/test_mamba2_chunk_ss_update.py index 17c28f75fa..a018c80889 100644 --- a/tests/unittest/_torch/thop/parallel/test_mamba2_chunk_ss_update.py +++ b/tests/unittest/_torch/thop/parallel/test_mamba2_chunk_ss_update.py @@ -214,7 +214,8 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, assert remove_padding chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( cu_seqlens, chunk_size) - out, ssm_state = mamba_chunk_scan_combined( + out = torch.empty_like(x) + ssm_state = mamba_chunk_scan_combined( x, dt, A, @@ -232,6 +233,7 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, dt_softplus=delta_softplus, return_final_states=not remove_padding, return_varlen_states=remove_padding, + out=out, ) if (ssm_state.shape[0] > 1 and ssm_state.dtype == torch.float32 @@ -257,7 +259,8 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, else: state_batch_indices = None - y = selective_state_update( + y = torch.empty_like(x) + selective_state_update( state, x, dt, @@ -269,6 +272,7 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, dt_bias=dt_bias, dt_softplus=delta_softplus, state_batch_indices=state_batch_indices, + out=y, ) outputs = (y, state[state_batch_indices] if state_batch_indices is not None else state) @@ -432,7 +436,8 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens): z = torch.randn_like(x) ## full seqlen computation - out_ref, state_ref = mamba_chunk_scan_combined( + out_ref = torch.empty_like(x) + state_ref = mamba_chunk_scan_combined( x, dt, A, @@ -447,6 +452,7 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens): dt_softplus=delta_softplus, return_final_states=False, return_varlen_states=True, + out=out_ref, ) ## chunked seqlen computation @@ -478,7 +484,8 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens): z_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(z, i) # yapf: enable - partial_out, partial_state = mamba_chunk_scan_combined( + partial_out = torch.empty_like(x_chunked) + partial_state = mamba_chunk_scan_combined( x_chunked, dt_chunked, A, @@ -493,6 +500,7 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens): dt_softplus=delta_softplus, return_final_states=False, return_varlen_states=True, + out=partial_out, ) # remaining chunk @@ -542,7 +550,8 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens): chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( remaining_chunked_cu_seqlens, mamba_chunk_size) - out_chunked, state_chunked = mamba_chunk_scan_combined( + out_chunked = torch.empty_like(remaining_x_chunked) + state_chunked = mamba_chunk_scan_combined( remaining_x_chunked, remaining_dt_chunked, A, @@ -560,6 +569,7 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens): dt_softplus=delta_softplus, return_final_states=False, return_varlen_states=True, + out=out_chunked, ) out = concat_batch_f(partial_out, out_chunked)