[TRTLLM-10062][feat] Enable MTP for Nemotron Super (#10754)

Signed-off-by: qgai <qgai@nvidia.com>
This commit is contained in:
sunnyqgg 2026-01-27 00:23:26 +08:00 committed by GitHub
parent 43b8a5561c
commit ff0dd6076e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2241 additions and 310 deletions

View File

@ -675,6 +675,11 @@ private:
{
continue;
}
// If tileSizeQ < mNumHeadsQPerKv, this will result in 0, causing division by zero.
if (tileSizeQ < params.mNumHeadsQPerKv)
{
continue;
}
// Update the tileSizeQ.
selectKernelParamsCopy.mTileSizeQ = tileSizeQ;

View File

@ -144,8 +144,15 @@ def _triton_cached_ssm(
num_seq = num_prefill + num_decode
num_total_tokens = num_prefill_tokens + num_decode
y_prefill = None
y_decode = None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[bs, num_heads, head_dim],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens]
preallocated_ssm_out_d = preallocated_ssm_out[num_prefill_tokens:num_total_tokens]
# Prefill: concatenate tokens at the front and run combined scan
if num_prefill > 0:
@ -165,7 +172,7 @@ def _triton_cached_ssm(
chunk_indices = None
chunk_offsets = None
y_prefill, varlen_states = mamba_chunk_scan_combined(
varlen_states = mamba_chunk_scan_combined(
hs_prefill,
dt_prefill,
A,
@ -184,11 +191,12 @@ def _triton_cached_ssm(
dt_limit=(time_step_limit[0], time_step_limit[1]),
return_final_states=False,
return_varlen_states=True,
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
out=preallocated_ssm_out_p.unsqueeze(0),
state_dtype=ssm_state_cache.dtype,
)
ssm_state_cache.index_copy_(
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype)
)
# Decode: batch single-token updates via selective_state_update
@ -205,7 +213,7 @@ def _triton_cached_ssm(
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
D_full = D[..., None].expand(num_heads, head_dim)
y_decode = selective_state_update(
selective_state_update(
ssm_state_cache,
x_decode,
dt_hp,
@ -217,19 +225,16 @@ def _triton_cached_ssm(
dt_bias=dt_bias_hp,
dt_softplus=True,
state_batch_indices=slot_idx_decode,
) # [nd, H, D]
out=preallocated_ssm_out_d,
)
# Dispatch return logic
if num_prefill > 0 and num_decode > 0:
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
y_flat = y.view(bs, *y.shape[2:])
y_flat[:num_prefill_tokens].copy_(y_prefill[0])
y_flat[num_prefill_tokens:num_total_tokens].copy_(y_decode)
return y
elif num_prefill > 0:
return y_prefill[0].view(b, s, num_heads, head_dim).to(hidden_states.dtype)
elif num_decode > 0:
return y_decode.view(b, s, num_heads, head_dim).to(hidden_states.dtype)
# Return the preallocated output reshaped to original dimensions
if num_total_tokens > 0:
return (
preallocated_ssm_out[:num_total_tokens]
.view(b, s, num_heads, head_dim)
.to(hidden_states.dtype)
)
else:
return torch.empty_like(hidden_states)

View File

@ -1,5 +1,6 @@
import torch
import tensorrt_llm.logger as logger
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import \
HfWeightMapper
from tensorrt_llm._torch.models.modeling_utils import register_mapper
@ -55,6 +56,16 @@ class NemotronHHfWeightMapper(HfWeightMapper):
if "embeddings" in key:
key = key.replace("embeddings", "embed_tokens")
# MTP layers are stored as mtp.layers.0.xxx (sublayer 0, Attention) and mtp.layers.1.xxx (sublayer 1, MoE)
if "mtp.layers." in key:
import re
match = re.match(r'mtp\.layers\.(\d+)\.(.*)', key)
if match:
sublayer_idx, rest = match.groups()
key = f"model.layers.{config.num_hidden_layers}.layers.{sublayer_idx}.{rest}"
else:
logger.error(f"Failed to match MTP pattern for: {name}")
if "A_log" in key:
key = key.replace("A_log", "A")

View File

@ -14,7 +14,7 @@
# limitations under the License.
import re
from typing import Dict, Optional
from typing import Dict, List, Optional
import torch
from torch import nn
@ -37,9 +37,11 @@ from ..modules.mamba.mamba2_mixer import Mamba2Mixer
from ..modules.mlp import MLP
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from ..utils import AuxStreamType, EventType
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
from .modeling_deepseekv3 import DeepseekV3MTPHead
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, register_auto_model
class NemotronHConfig(PretrainedConfig):
@ -347,13 +349,17 @@ class NemotronHLayer(DecoderLayer):
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.mixer(hidden_states, attn_metadata, **kwargs)
hidden_states = self.mixer(hidden_states,
attn_metadata,
spec_metadata=spec_metadata,
**kwargs)
hidden_states = torch.add(hidden_states, residual)
return hidden_states
@ -405,6 +411,7 @@ class NemotronHModel(DecoderModel):
layer_type,
aux_stream_dict=self.aux_stream_dict))
self.layers = nn.ModuleList(layers)
self.num_hidden_layers = config.num_hidden_layers
# final norm
self.norm_f = RMSNorm(
@ -421,6 +428,7 @@ class NemotronHModel(DecoderModel):
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
@ -439,10 +447,11 @@ class NemotronHModel(DecoderModel):
hidden_states = inputs_embeds
for layer in self.layers:
for layer in self.layers[:self.num_hidden_layers]:
hidden_states = layer(position_ids,
hidden_states,
attn_metadata,
spec_metadata=spec_metadata,
mamba_metadata=self.mamba_metadata)
hidden_states = self.norm_f(hidden_states)
@ -451,8 +460,8 @@ class NemotronHModel(DecoderModel):
@register_auto_model("NemotronHForCausalLM")
class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel,
NemotronHConfig]):
class NemotronHForCausalLM(SpecDecOneEngineForCausalLM[NemotronHModel,
NemotronHConfig]):
def __init__(
self,
@ -477,15 +486,286 @@ class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel,
]
super().__init__(
NemotronHModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
model=NemotronHModel(model_config),
model_config=model_config,
)
self.model_nextn = 0
if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp_one_model(
):
model_nextn = model_config.spec_config.num_nextn_predict_layers
ckpt_nextn = self.config.num_nextn_predict_layers
self.num_hidden_layers = self.config.num_hidden_layers
assert ckpt_nextn > 0, "There are not MTP modules in the checkpoint."
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
pass
else:
# modify the QuantConfig to support duplicated mtp layers
if model_config.quant_config.exclude_modules is not None:
extend_exclude_modules = []
for model_mtp_idx in range(
self.num_hidden_layers,
self.num_hidden_layers + model_nextn):
ckpt_mtp_idx = (model_mtp_idx - self.num_hidden_layers
) % ckpt_nextn + self.num_hidden_layers
model_prefix = f"model.layers.{model_mtp_idx}"
ckpt_prefix = f"model.layers.{ckpt_mtp_idx}"
for exclude_module in model_config.quant_config.exclude_modules:
if ckpt_prefix in exclude_module and model_prefix not in exclude_module:
extend_exclude_modules.append(
exclude_module.replace(
ckpt_prefix, model_prefix))
self.model_config.quant_config.exclude_modules.extend(
extend_exclude_modules)
self.model.layers.extend(self.draft_model.mtp_layers)
self.epilogue.extend(self.draft_model.mtp_layers)
self.epilogue.append(self.spec_worker)
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
new_weights = weight_mapper.preprocess_weights(weights)
super().load_weights(weights=new_weights, weight_mapper=weight_mapper)
class NemotronHMTPDecoderLayer(NemotronHLayer):
def __init__(
self,
model_config: ModelConfig[NemotronHConfig],
layer_idx: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
has_start_projections: bool,
has_end_norm: bool,
layer_type: str,
) -> None:
super().__init__(
model_config=model_config,
layer_idx=layer_idx,
layer_type=layer_type,
aux_stream_dict=aux_stream_dict,
)
self.model_nextn = 0
if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp_one_model(
):
model_nextn = model_config.spec_config.num_nextn_predict_layers
ckpt_nextn = self.config.num_nextn_predict_layers
self.num_hidden_layers = self.config.num_hidden_layers
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
pass
else:
# modify the QuantConfig to support duplicated mtp layers
if model_config.quant_config.exclude_modules is not None:
extend_exclude_modules = []
for model_mtp_idx in range(
self.num_hidden_layers,
self.num_hidden_layers + model_nextn):
ckpt_mtp_idx = (model_mtp_idx - self.num_hidden_layers
) % ckpt_nextn + self.num_hidden_layers
model_prefix = f"model.layers.{model_mtp_idx}"
ckpt_prefix = f"model.layers.{ckpt_mtp_idx}"
for exclude_module in model_config.quant_config.exclude_modules:
if ckpt_prefix in exclude_module and model_prefix not in exclude_module:
extend_exclude_modules.append(
exclude_module.replace(
ckpt_prefix, model_prefix))
self.model_config.quant_config.exclude_modules.extend(
extend_exclude_modules)
self.model.layers.extend(self.draft_model.mtp_layers)
self.epilogue.extend(self.draft_model.mtp_layers)
self.epilogue.append(self.spec_worker)
config = model_config.pretrained_config
self.model_config = model_config
self.has_start_projections = has_start_projections
self.has_end_norm = has_end_norm
if has_start_projections:
self.enorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
self.hnorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
if model_config.mapping.enable_attention_dp:
self.eh_proj = Linear(
in_features=config.hidden_size * 2,
out_features=config.hidden_size,
bias=False,
dtype=config.torch_dtype,
quant_config=model_config.quant_config,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
else:
self.eh_proj = Linear(
in_features=config.hidden_size * 2,
out_features=config.hidden_size,
bias=False,
dtype=config.torch_dtype,
tensor_parallel_mode=TensorParallelMode.ROW,
mapping=model_config.mapping,
quant_config=model_config.quant_config,
reduce_output=True,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
if has_end_norm:
self.final_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
def forward(
self,
inputs_embeds: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.has_start_projections:
assert inputs_embeds is not None
inputs_embeds_normed = self.enorm(inputs_embeds)
previous_hidden_states_normed = self.hnorm(hidden_states)
# Fuse via concatenation and linear projection
fused = torch.cat(
[inputs_embeds_normed, previous_hidden_states_normed], dim=-1)
# Split fused hidden_states columnwise based on TP
mapping = self.model_config.mapping
if mapping.tp_size > 1 and not mapping.enable_attention_dp:
fused = torch.chunk(fused, mapping.tp_size,
dim=-1)[mapping.tp_rank]
hidden_states = self.eh_proj(fused)
residual = None # Start fresh after fusion
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
)
def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
new_weights = weight_mapper.preprocess_weights(weights)
super().load_weights(new_weights, weight_mapper)
if self.has_end_norm:
if residual is not None:
hidden_states = hidden_states + residual
residual = None
hidden_states = self.final_layernorm(hidden_states)
return hidden_states, residual
class NemotronHMTP(nn.Module):
"""NemotronH MTP Layer - single MTP layer following DeepseekV3MTP pattern."""
def __init__(self,
model_config: ModelConfig[NemotronHConfig],
layer_idx: int,
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
is_separate_draft_engine: bool = False,
prefix: str = ""):
super().__init__()
config = model_config.pretrained_config
self.model_config = model_config
self.config = config
self.layer_idx = layer_idx
# Pattern configuration
self.pattern_str = config.mtp_hybrid_override_pattern
self.pattern_len = len(self.pattern_str)
assert self.pattern_len > 0
# Build pattern-based layers
self.layers = nn.ModuleDict()
for i in range(self.pattern_len):
step_rel_idx = i % self.pattern_len
char = self.pattern_str[step_rel_idx]
is_start_of_step = step_rel_idx == 0
is_end_of_step = step_rel_idx == self.pattern_len - 1
sublayer_quant_config = self._get_mtp_sublayer_quant_config(
model_config, self.layer_idx)
# Create a temporary model_config with the override quant_config
sublayer_model_config = ModelConfig(
pretrained_config=model_config.pretrained_config,
mapping=model_config.mapping,
quant_config=sublayer_quant_config,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
self.layers[str(i)] = NemotronHMTPDecoderLayer(
model_config=sublayer_model_config,
layer_idx=self.layer_idx,
aux_stream_dict=aux_stream_dict,
has_start_projections=is_start_of_step,
has_end_norm=is_end_of_step,
layer_type=char,
)
# Add shared_head for MTP, following DeepseekV3MTP pattern
self.shared_head = DeepseekV3MTPHead(model_config)
def _get_mtp_sublayer_quant_config(
self, model_config: ModelConfig[NemotronHConfig], layer_idx: int):
"""
Get quantization config for MTP sublayer.
The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM
moe_backend only supports fp8/fp4 quantization, we need to override
the quant_config for the MTP layer.
"""
from tensorrt_llm.models.modeling_utils import QuantConfig
quant_config = model_config.quant_config
# MTP layers are always unquantized, force quant_algo=None
if quant_config is None:
return None
return QuantConfig(
quant_algo=None,
kv_cache_quant_algo=quant_config.kv_cache_quant_algo,
)
def forward(
self,
input_ids: torch.IntTensor,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
embed_tokens: Embedding,
attn_metadata: AttentionMetadata,
all_rank_num_tokens: Optional[List[int]] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
inputs_embeds = embed_tokens(input_ids)
residual = None
for i in range(self.pattern_len):
layer = self.layers[str(i)]
hidden_states, residual = layer(
inputs_embeds=inputs_embeds,
positions=position_ids,
hidden_states=hidden_states,
residual=residual,
attn_metadata=attn_metadata,
)
return hidden_states
AutoConfig.register(NemotronHConfig.model_type, NemotronHConfig)

View File

@ -706,6 +706,9 @@ class MTPForCausalLM(nn.Module):
case "exaone_moe":
from .modeling_exaone_moe import ExaoneMoeMTP
mtp_layer = ExaoneMoeMTP
case "nemotron_h":
from .modeling_nemotron_h import NemotronHMTP
mtp_layer = NemotronHMTP
case _:
raise ValueError(
f"Model type {model_type} not supported for MTP")
@ -751,6 +754,12 @@ class MTPDraftModel(nn.Module):
from .modeling_exaone_moe import ExaoneMoeMTP
mtp_layer = ExaoneMoeMTP(model_config, layer_idx, aux_stream_dict)
elif model_type == "nemotron_h":
from .modeling_nemotron_h import NemotronHMTP
mtp_layer = NemotronHMTP(model_config,
layer_idx,
aux_stream_dict,
is_separate_draft_engine=False)
else:
raise ValueError(
f"MTPDraftModel does not support model_type: {model_type}")

File diff suppressed because it is too large Load Diff

View File

@ -24,8 +24,11 @@ from tensorrt_llm.mapping import Mapping
from ...attention_backend import AttentionMetadata
from ...model_config import ModelConfig
from ...speculative import SpecMetadata
from ..linear import Linear, TensorParallelMode
from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .causal_conv1d_triton import \
causal_conv1d_update as causal_conv1d_update_triton
from .layernorm_gated import RMSNorm as RMSNormGated
from .selective_state_update import selective_state_update
from .ssd_combined import mamba_chunk_scan_combined
@ -82,6 +85,8 @@ class Mamba2Mixer(nn.Module):
self.tp_d_inner = d_inner // tp_size
self.tp_nheads = nheads // tp_size
self.tp_ngroups = n_groups // tp_size
self.num_heads = nheads
self.tp_size = tp_size
self.layer_idx = layer_idx
self.d_conv = d_conv
@ -167,6 +172,7 @@ class Mamba2Mixer(nn.Module):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_metadata: Mamba2Metadata,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
@ -175,6 +181,7 @@ class Mamba2Mixer(nn.Module):
num_decodes = attn_metadata.seq_lens.shape[0] - num_prefills
num_prefill_tokens = attn_metadata.num_ctx_tokens
num_decode_tokens = attn_metadata.num_tokens - num_prefill_tokens
num_actual_tokens = attn_metadata.num_tokens
seqlen_split_size = [num_prefill_tokens, num_decode_tokens]
batch_split_size = [num_prefills, num_decodes]
@ -183,10 +190,10 @@ class Mamba2Mixer(nn.Module):
state_indices_p, state_indices_d = torch.split(state_indices,
batch_split_size)
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
self.layer_idx)
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
layer_cache = attn_metadata.kv_cache_manager.mamba_layer_cache(
self.layer_idx)
conv_states = layer_cache.conv
ssm_states = layer_cache.temporal
# in_proj
zxbcdt = self.in_proj(hidden_states)
@ -199,7 +206,21 @@ class Mamba2Mixer(nn.Module):
xbc_p, xbc_d = torch.split(xbc, seqlen_split_size, dim=0)
dt_p, dt_d = torch.split(dt, seqlen_split_size, dim=0)
out = []
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
zxbcdt.shape[0],
(self.num_heads * self.head_dim) // self.tp_size,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decode_tokens],
dim=0,
)
if num_prefills > 0:
@ -239,7 +260,7 @@ class Mamba2Mixer(nn.Module):
has_initial_states[:, None, None, None],
ssm_states[state_indices_p], 0)
y, current_ssm_states = mamba_chunk_scan_combined(
current_ssm_states = mamba_chunk_scan_combined(
x_p,
dt_p,
self.A,
@ -247,30 +268,63 @@ class Mamba2Mixer(nn.Module):
C_p,
chunk_size=self.chunk_size,
D=self.D,
z=z_p,
z=None,
dt_bias=self.dt_bias,
initial_states=initial_states,
chunk_indices=mamba_metadata.chunk_indices,
chunk_offsets=mamba_metadata.chunk_offsets,
dt_softplus=self.delta_softplus,
dt_limit=(0.0, float("inf")),
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
return_varlen_states=True,
return_final_states=False,
mamba_ssm_cache_dtype=self._mamba_ssm_cache_dtype,
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
state_dtype=self._mamba_ssm_cache_dtype,
)
out.append(rearrange(y, "b l h p -> (b l) (h p)"))
# copy new ssm state
ssm_states[state_indices_p] = current_ssm_states
if num_decodes > 0:
xbc_d = causal_conv1d_update(xbc_d,
conv_states,
self.conv1d.weight,
self.conv1d.bias,
activation="silu",
conv_state_indices=state_indices_d)
is_target_verify = attn_metadata.kv_cache_manager.is_speculative(
) and spec_metadata is not None
if is_target_verify:
# TODO: support dynamic speculation, will add current_draft_len later [TRTLLM-10319]
draft_token_num = spec_metadata.max_draft_len + 1
intermediate_conv_states = layer_cache.intermediate_conv_window
self.intermediate_state_indices = torch.arange(
num_decodes,
dtype=torch.int32,
device=state_indices_d.device)
# Reshape for batch processing
xbc_d_reshaped = xbc_d.view(num_decodes, draft_token_num,
-1).transpose(1, 2)
# TODO:support tree structure [TRTLLM-10320]
xbc_d_processed = causal_conv1d_update_triton(
xbc_d_reshaped,
conv_states,
self.conv1d.weight,
self.conv1d.bias,
activation="silu",
conv_state_indices=state_indices_d[:num_decodes],
intermediate_conv_window=intermediate_conv_states,
intermediate_state_indices=self.intermediate_state_indices,
)
xbc_d = xbc_d_processed.transpose(1, 2).view(
num_decode_tokens, -1)
else:
xbc_d = causal_conv1d_update(xbc_d,
conv_states,
self.conv1d.weight,
self.conv1d.bias,
activation="silu",
conv_state_indices=state_indices_d)
x_d, B_d, C_d = torch.split(
xbc_d,
@ -292,29 +346,64 @@ class Mamba2Mixer(nn.Module):
n=self.d_state).to(dtype=torch.float32)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.head_dim)
D = repeat(self.D, "h -> h p", p=self.head_dim)
if is_target_verify:
intermediate_ssm_states = layer_cache.intermediate_ssm
selective_state_update(
ssm_states,
x_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
dt_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
A,
B_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1),
C_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1),
D,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_d[:num_decodes],
out=preallocated_ssm_out_d.view(
num_decodes,
draft_token_num,
self.num_heads // self.tp_size,
self.head_dim,
),
disable_state_update=True,
intermediate_states_buffer=intermediate_ssm_states,
cache_steps=draft_token_num,
intermediate_state_indices=self.intermediate_state_indices,
)
y = selective_state_update(
ssm_states,
x_d,
dt_d,
A,
B_d,
C_d,
D,
z=z_d,
dt_bias=dt_bias,
dt_softplus=self.delta_softplus,
state_batch_indices=state_indices_d,
)
else:
out.append(rearrange(y, "b h p -> b (h p)"))
out = torch.cat(out, dim=0)
selective_state_update(
ssm_states,
x_d,
dt_d,
A,
B_d,
C_d,
D,
z=None,
dt_bias=dt_bias,
dt_softplus=self.delta_softplus,
state_batch_indices=state_indices_d,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)
# norm
out = self.norm(out)
hidden_states = self.norm(preallocated_ssm_out, z[:num_actual_tokens])
# out_proj
out = self.out_proj(out)
out = self.out_proj(hidden_states)
return out
return out[:num_actual_tokens]

View File

@ -1,7 +1,4 @@
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
# Copyright (c) 2024, Tri Dao, Albert Gu.
#
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -15,6 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Adapted from: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py
# SPDX-FileCopyrightText: Copyright contributors to the sglang project
#
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
import torch
import triton
@ -35,7 +38,19 @@ from .softplus import softplus
})
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
@triton.jit
@triton.heuristics({
"CACHE_INTERMEDIATE_STATES":
lambda args: args["intermediate_states_buffer"] is not None
})
@triton.heuristics({
"HAS_EAGLE_TREE_CUSTOM_ATTN_MASK":
lambda args: args["retrieve_parent_token_ptr"] is not None
})
@triton.heuristics({
"HAS_INTERMEDIATE_STATE_INDICES":
lambda args: args["intermediate_state_indices_ptr"] is not None
})
@triton.jit(do_not_specialize=["T"])
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr,
@ -50,8 +65,13 @@ def _selective_scan_update_kernel(
out_ptr,
state_batch_indices_ptr,
pad_slot_id,
intermediate_states_buffer,
cache_steps,
retrieve_parent_token_ptr,
intermediate_state_indices_ptr,
# Matrix dimensions
batch,
T,
nheads,
dim,
dstate,
@ -62,9 +82,11 @@ def _selective_scan_update_kernel(
stride_state_dim,
stride_state_dstate,
stride_x_batch,
stride_x_T,
stride_x_head,
stride_x_dim,
stride_dt_batch,
stride_dt_T,
stride_dt_head,
stride_dt_dim,
stride_dt_bias_head,
@ -73,19 +95,25 @@ def _selective_scan_update_kernel(
stride_A_dim,
stride_A_dstate,
stride_B_batch,
stride_B_T,
stride_B_group,
stride_B_dstate,
stride_C_batch,
stride_C_T,
stride_C_group,
stride_C_dstate,
stride_D_head,
stride_D_dim,
stride_z_batch,
stride_z_T,
stride_z_head,
stride_z_dim,
stride_out_batch,
stride_out_T,
stride_out_head,
stride_out_dim,
stride_retrieve_parent_token_batch,
stride_retrieve_parent_token_T,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
TIE_HDIM: tl.constexpr,
@ -94,6 +122,10 @@ def _selective_scan_update_kernel(
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_STATE_BATCH_INDICES: tl.constexpr,
DISABLE_STATE_UPDATE: tl.constexpr,
CACHE_INTERMEDIATE_STATES: tl.constexpr,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr,
HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
@ -105,7 +137,7 @@ def _selective_scan_update_kernel(
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr)
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
else:
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
@ -127,91 +159,153 @@ def _selective_scan_update_kernel(
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
offs_n[None, :] * stride_state_dstate)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
if HAS_D:
D_ptr += pid_h * stride_D_head
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
offs_n[None, :] * stride_A_dstate)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
state = tl.load(state_ptrs, mask=mask, other=0.0)
A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[
None, :] * stride_A_dstate
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, softplus(dt), dt)
A = tl.load(A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, softplus(dt), dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
cache_idx = -1
if CACHE_INTERMEDIATE_STATES:
if HAS_INTERMEDIATE_STATE_INDICES:
intermediate_state_idx = tl.load(intermediate_state_indices_ptr +
pid_b).to(tl.int64)
cache_idx = intermediate_state_idx
elif HAS_STATE_BATCH_INDICES:
cache_idx = state_batch_idx
else:
cache_idx = pid_b
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
current_step_idx = 0
for _ in range(T):
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
if current_step_idx != 0 and cache_idx >= 0:
parent_ptr = (retrieve_parent_token_ptr +
pid_b * stride_retrieve_parent_token_batch +
current_step_idx * stride_retrieve_parent_token_T)
parent_step_idx = tl.load(parent_ptr).to(tl.int32)
if not TIE_HDIM:
dB = B[None, :] * dt[:, None]
else:
dB = B * dt # vector of size (dstate,)
state = state * dA + dB * x[:, None]
if parent_step_idx >= 0 and parent_step_idx < T:
step_offset = parent_step_idx * nheads * dim * dstate
cache_ptr = (
intermediate_states_buffer +
cache_idx * cache_steps * nheads * dim * dstate +
step_offset + pid_h * dim * dstate +
offs_m[:, None] * dstate + offs_n[None, :])
state = tl.load(cache_ptr, mask=mask,
other=0.0).to(tl.float32)
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
tl.store(state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
if CACHE_INTERMEDIATE_STATES:
if HAS_STATE_BATCH_INDICES:
if state_batch_idx != pad_slot_id:
cache_ptr_base = (
intermediate_states_buffer +
cache_idx * cache_steps * nheads * dim * dstate +
current_step_idx * nheads * dim * dstate +
pid_h * dim * dstate)
cache_ptrs = cache_ptr_base + (offs_m[:, None] * dstate +
offs_n[None, :])
tl.store(cache_ptrs,
state.to(cache_ptrs.dtype.element_ty),
mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
current_step_idx += 1
x_ptr += stride_x_T
dt_ptr += stride_dt_T
B_ptr += stride_B_T
C_ptr += stride_C_T
out_ptr += stride_out_T
if HAS_Z:
z_ptr += stride_z_T
if not DISABLE_STATE_UPDATE:
tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask)
def selective_state_update(state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID):
def selective_state_update(
state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None,
disable_state_update=False,
intermediate_states_buffer=None,
cache_steps=None,
retrieve_parent_token=None,
intermediate_state_indices=None,
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
x: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
B: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
@ -222,38 +316,58 @@ def selective_state_update(state,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
Return:
out: (batch, dim) or (batch, nheads, dim)
out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated.
disable_state_update: If True, don't write back to state (for speculative verify)
intermediate_states_buffer: Buffer to cache intermediate states
cache_steps: Total number of steps in the buffer
retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention
intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations.
If provided, uses these indices instead of state_batch_indices for the buffer.
"""
has_heads = state.dim() > 3
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if x.dim() == 3:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if dt.dim() == 3:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if B.dim() == 3:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if C.dim() == 3:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if z is not None:
if z.dim() == 2:
z = z.unsqueeze(1)
if z.dim() == 3:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2:
out = out.unsqueeze(1)
if out.dim() == 3:
out = out.unsqueeze(1)
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
batch, T, _, _ = x.shape
assert x.shape == (batch, nheads, dim)
assert x.shape == (batch, T, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
ngroups = B.shape[2]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert B.shape == (batch, T, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
@ -263,10 +377,11 @@ def selective_state_update(state,
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch, )
out = torch.empty_like(x)
assert out.shape == x.shape
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
z_strides = (z.stride(0), z.stride(1),
z.stride(2)) if z is not None else (0, 0, 0)
z_strides = ((z.stride(0), z.stride(1), z.stride(2),
z.stride(3)) if z is not None else (0, 0, 0, 0))
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
@ -275,6 +390,12 @@ def selective_state_update(state,
((4, 4) if dstate <= 128 else ((4, 8))))))
tie_hdim = (A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0
and dt_bias.stride(-1) == 0)
retrieve_parent_token_strides = ((retrieve_parent_token.stride(0),
retrieve_parent_token.stride(1))
if retrieve_parent_token is not None else
(0, 0))
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state,
@ -289,7 +410,12 @@ def selective_state_update(state,
out,
state_batch_indices,
pad_slot_id,
intermediate_states_buffer,
cache_steps if cache_steps is not None else 0,
retrieve_parent_token,
intermediate_state_indices,
batch,
T,
nheads,
dim,
dstate,
@ -301,9 +427,11 @@ def selective_state_update(state,
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
dt.stride(0),
dt.stride(1),
dt.stride(2),
dt.stride(3),
*(dt_bias.stride(0),
dt_bias.stride(1)) if dt_bias is not None else 0,
A.stride(0),
@ -312,21 +440,25 @@ def selective_state_update(state,
B.stride(0),
B.stride(1),
B.stride(2),
B.stride(3),
C.stride(0),
C.stride(1),
C.stride(2),
C.stride(3),
*(D.stride(0), D.stride(1)) if D is not None else 0,
z_strides[0],
z_strides[1],
z_strides[2],
z_strides[3],
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
retrieve_parent_token_strides[0],
retrieve_parent_token_strides[1],
dt_softplus,
tie_hdim,
BLOCK_SIZE_M,
DISABLE_STATE_UPDATE=disable_state_update,
num_warps=num_warps,
)
if not has_heads:
out = out.squeeze(1)
return out

View File

@ -234,6 +234,7 @@ def _chunk_scan_fwd_kernel(
# M-block offsets and prev states
# - logic in next block may override these if there is an active offset
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
prev_states_ptr = (states_ptr + pid_b * stride_states_batch +
c_idx * stride_states_chunk + pid_h * stride_states_head)
prev_states_hdim = stride_states_hdim
@ -269,11 +270,12 @@ def _chunk_scan_fwd_kernel(
):
# - replace prev_states_ptr with init_states
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
prev_states_ptr = (initstates_ptr +
seq_idx_m * stride_init_states_batch +
pid_h * stride_init_states_head)
prev_states_hdim = stride_init_states_hdim # override strides
prev_states_dstate = stride_init_states_dstate
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
mask=offs_m < chunk_size,
@ -289,7 +291,7 @@ def _chunk_scan_fwd_kernel(
c_idx_n = tl.load(
chunk_indices_ptr + (pid_c + 1),
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
other=-1 # to trigger different chunk
other=-1, # to trigger different chunk
)
# - there are things to consider
@ -304,9 +306,11 @@ def _chunk_scan_fwd_kernel(
if (c_idx == c_idx_n) or c_off > 0:
# get the next offset
c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1),
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
other=chunk_size)
c_off_n = tl.load(
chunk_offsets_ptr + (pid_c + 1),
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
other=chunk_size,
)
# in this case, adjust down the chunk_size_limit
if c_idx == c_idx_n:
@ -319,8 +323,9 @@ def _chunk_scan_fwd_kernel(
# i.e. the same for all blocks)
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
mask=(((c_off - 1) > -1) and (c_off < chunk_size)),
other=0.0).to(tl.float32)
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
other=0.0,
).to(tl.float32)
if HAS_SEQ_IDX:
# - handle seq idx when HAS_INITSTATES==False
@ -416,8 +421,7 @@ def _chunk_scan_fwd_kernel(
other=0.0).to(tl.float32)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
# cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0))
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
other=0.0).to(tl.float32)
cb *= dt_k
@ -494,18 +498,21 @@ def _chunk_scan_fwd_kernel(
)
def _chunk_scan_fwd(cb,
x,
dt,
dA_cumsum,
C,
states,
D=None,
z=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
initial_states=None):
def _chunk_scan_fwd(
cb,
x,
dt,
dA_cumsum,
C,
states,
D=None,
z=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
initial_states=None,
out=None,
):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = C.shape
@ -527,34 +534,17 @@ def _chunk_scan_fwd(cb,
# with initial states, we need to take care of how
# seq_idx crosses the boundaries
assert batch == 1, "chunk scan only supports initial states with batch 1"
if initial_states.shape[0] == 1:
# no in this case no point to use initial states
initial_states = None
else:
assert chunk_indices is not None and chunk_offsets is not None, \
(
"chunk_indices and chunk_offsets should have been set"
)
assert (chunk_indices is not None and chunk_offsets is not None
), "chunk_indices and chunk_offsets should have been set"
else:
chunk_indices, chunk_offsets = None, None
else:
chunk_indices, chunk_offsets = None, None
# Allocates output.
out = torch.empty(batch,
seqlen,
nheads,
headdim,
device=x.device,
dtype=x.dtype)
assert out.shape == x.shape
if z is not None:
out_x = torch.empty(batch,
seqlen,
nheads,
headdim,
device=x.device,
dtype=x.dtype)
out_x = torch.empty_like(x)
assert out_x.stride() == out.stride()
else:
out_x = None
@ -625,10 +615,12 @@ def _chunk_scan_fwd(cb,
states.stride(2),
states.stride(3),
states.stride(4),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3)) if initial_states is not None else
(0, 0, 0, 0)),
*((
initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3),
) if initial_states is not None else (0, 0, 0, 0)),
D.stride(0) if D is not None else 0,
True,
D is not None,
@ -639,4 +631,4 @@ def _chunk_scan_fwd(cb,
IS_TRITON_22=TRITON_22,
HAS_INITSTATES=initial_states is not None,
)
return out, out_x
return out_x

View File

@ -16,8 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from einops import rearrange
@ -28,6 +26,10 @@ from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
from .ssd_state_passing import _state_passing_fwd
def is_int_pow_2(n):
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
def _mamba_chunk_scan_combined_fwd(
x,
dt,
@ -45,13 +47,14 @@ def _mamba_chunk_scan_combined_fwd(
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
mamba_ssm_cache_dtype=None,
state_dtype=None,
out=None,
):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
assert x.shape == (batch, seqlen, nheads, headdim)
assert dt.shape == (batch, seqlen, nheads)
assert A.shape == (nheads, )
assert C.shape == B.shape
@ -77,8 +80,12 @@ def _mamba_chunk_scan_combined_fwd(
if cu_seqlens is None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
else:
assert initial_states.shape == (len(cu_seqlens) - 1, nheads,
headdim, dstate)
assert initial_states.shape == (
len(cu_seqlens) - 1,
nheads,
headdim,
dstate,
)
# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
@ -125,13 +132,12 @@ def _mamba_chunk_scan_combined_fwd(
if initial_states is not None else None),
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=mamba_ssm_cache_dtype or C.dtype,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None,
chunk_offsets=chunk_offsets)
states, final_states = [
rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states]
]
chunk_offsets=chunk_offsets,
)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB = _bmm_chunk_fwd(C,
@ -150,20 +156,23 @@ def _mamba_chunk_scan_combined_fwd(
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
# a seq_idx change, in which case we take states information from
# init_states.
out, out_x = _chunk_scan_fwd(CB,
x,
dt,
dA_cumsum,
C,
states,
D=D,
z=z,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=initial_states)
out_x = _chunk_scan_fwd(
CB,
x,
dt,
dA_cumsum,
C,
states,
D=D,
z=z,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
initial_states=initial_states,
out=out,
)
if cu_seqlens is None:
return out, out_x, dt, dA_cumsum, states, final_states
return out_x, dt, dA_cumsum, states, final_states
else:
assert (
batch == 1
@ -177,29 +186,31 @@ def _mamba_chunk_scan_combined_fwd(
states.squeeze(0),
initial_states=initial_states,
)
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
return out_x, dt, dA_cumsum, states, final_states, varlen_states
def mamba_chunk_scan_combined(
x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
return_final_states=False,
return_varlen_states=False,
mamba_ssm_cache_dtype: Optional[torch.dtype] = None):
x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
out=None,
return_final_states=False,
return_varlen_states=False,
state_dtype=None,
):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
@ -215,38 +226,42 @@ def mamba_chunk_scan_combined(
seq_idx: (batch, seqlen)
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
mamba_ssm_cache_dtype: torch.dtype, default to None
Return:
out: (batch, seqlen, nheads, headdim)
out: Preallocated output tensor
state_dtype: The data type of the ssm state
"""
if not return_varlen_states:
cu_seqlens = None
else:
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
D=D,
z=z,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
mamba_ssm_cache_dtype=mamba_ssm_cache_dtype)
assert (cu_seqlens is not None
), "cu_seqlens must be provided if return_varlen_states is True"
out_x, dt_out, dA_cumsum, states, final_states, *rest = (
_mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
D=D,
z=z,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
out=out,
state_dtype=state_dtype,
))
if not return_varlen_states:
return out if not return_final_states else (out, final_states)
if not return_final_states:
return
else:
return final_states
else:
varlen_states = rest[0]
return (out,
varlen_states) if not return_final_states else (out,
final_states,
varlen_states)
return ((varlen_states) if not return_final_states else
(final_states, varlen_states))

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import torch
@ -22,11 +23,45 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import (
BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers)
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
if TYPE_CHECKING:
from tensorrt_llm._torch.attention_backend.interface import \
AttentionMetadata
GB = 1 << 30
def get_tensor_size_bytes(tensor):
"""Calculate tensor size in bytes."""
if isinstance(tensor, torch.Tensor):
return tensor.element_size() * tensor.nelement()
elif isinstance(tensor, list):
return sum(get_tensor_size_bytes(t) for t in tensor)
return 0
class MambaCacheManager(BaseResourceManager):
@dataclass(frozen=True, kw_only=True)
class State:
"""Base state container for Mamba cache."""
conv: torch.Tensor
temporal: torch.Tensor
def at_layer_idx(self, layer: int):
kwargs = {}
for k, v in vars(self).items():
kwargs[k] = v[layer]
return type(self)(**kwargs)
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
"""Speculative state with intermediate states for draft tokens."""
intermediate_ssm: torch.Tensor
intermediate_conv_window: torch.Tensor
def __init__(
self,
d_state: int,
@ -36,13 +71,17 @@ class MambaCacheManager(BaseResourceManager):
head_dim: int,
num_layers: int,
max_batch_size: int,
spec_state_size: int,
mapping: Mapping,
dtype: torch.dtype,
ssm_cache_dtype: torch.dtype,
layer_mask: Optional[List[bool]] = None,
speculative_num_draft_tokens: Optional[int] = None,
) -> None:
self.mamba_ssm_cache_dtype = ssm_cache_dtype
self.speculative_num_draft_tokens = speculative_num_draft_tokens
self.spec_state_size = spec_state_size
# get tp size
tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1
@ -74,31 +113,69 @@ class MambaCacheManager(BaseResourceManager):
for offset, idx in enumerate(pp_layers)
}
# mamba conv states
self.conv_states = torch.empty(
size=[
num_local_layers,
max_batch_size,
conv_dim,
d_conv - 1,
],
conv_state_shape = (conv_dim, d_conv - 1)
ssm_state_shape = (nheads, head_dim, d_state)
# create mamba conv and ssm states
conv_states = torch.empty(
size=(num_local_layers, max_batch_size) + conv_state_shape,
dtype=dtype,
device=device,
)
# mamba ssm states
self.ssm_states = torch.empty(
size=[
num_local_layers,
max_batch_size,
nheads,
head_dim,
d_state,
],
ssm_states = torch.empty(
size=(num_local_layers, max_batch_size) + ssm_state_shape,
dtype=self.mamba_ssm_cache_dtype,
device=device,
)
# create state container
if speculative_num_draft_tokens is not None:
# Cache intermediate SSM states per draft token(include new sampled token) during target model verification phase
intermediate_ssm_states = torch.zeros(
size=(num_local_layers, self.spec_state_size,
speculative_num_draft_tokens + 1) + ssm_state_shape,
dtype=self.mamba_ssm_cache_dtype,
device=device,
)
# Cache intermediate conv windows per draft token(include new sampled token) during target model verification phase
intermediate_conv_window_cache = torch.zeros(
size=(num_local_layers, self.spec_state_size,
speculative_num_draft_tokens + 1) + conv_state_shape,
dtype=dtype,
device=device,
)
self.mamba_cache = self.SpeculativeState(
conv=conv_states,
temporal=ssm_states,
intermediate_ssm=intermediate_ssm_states,
intermediate_conv_window=intermediate_conv_window_cache,
)
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {max_batch_size}, "
f"conv_state size: {get_tensor_size_bytes(conv_states) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(ssm_states) / GB:.2f}GB, "
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_states) / GB:.2f}GB, "
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB"
)
else:
self.mamba_cache = self.State(
conv=conv_states,
temporal=ssm_states,
)
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {max_batch_size}, "
f"conv_state size: {get_tensor_size_bytes(conv_states) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(ssm_states) / GB:.2f}GB"
)
# mamba cache available blocks
self.mamba_cache_free_blocks = [i for i in range(max_batch_size)]
@ -111,6 +188,10 @@ class MambaCacheManager(BaseResourceManager):
dtype=torch.int32)
# save mamba state indices for requests
self.state_indices_list: List[int] = []
# save intermediate state indices for requests
self.intermediate_state_indices = torch.arange(max_batch_size,
dtype=torch.int32,
device=device)
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
self.state_indices_list.clear()
@ -184,22 +265,85 @@ class MambaCacheManager(BaseResourceManager):
def get_conv_states(self, layer_idx: int) -> torch.Tensor:
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.conv_states[layer_offset]
return self.mamba_cache.at_layer_idx(layer_offset).conv
def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.ssm_states[layer_offset]
return self.mamba_cache.at_layer_idx(layer_offset).temporal
def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
return self.mamba_ssm_cache_dtype
def get_intermediate_ssm_states(self,
layer_idx: int) -> Optional[torch.Tensor]:
if not isinstance(self.mamba_cache, self.SpeculativeState):
return None
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.mamba_cache.at_layer_idx(layer_offset).intermediate_ssm
def get_intermediate_conv_states(self,
layer_idx: int) -> Optional[torch.Tensor]:
"""Get intermediate conv states for speculative decoding."""
if not isinstance(self.mamba_cache, self.SpeculativeState):
return None
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.mamba_cache.at_layer_idx(
layer_offset).intermediate_conv_window
def is_speculative(self) -> bool:
return isinstance(self.mamba_cache, self.SpeculativeState)
def mamba_layer_cache(self,
layer_idx: int) -> Union[State, SpeculativeState]:
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.mamba_cache.at_layer_idx(layer_offset)
def shutdown(self):
# release tensor memory, keeping python references as tensors
self.conv_states = torch.tensor([])
self.ssm_states = torch.tensor([])
"""Release tensor memory."""
# Clear state indices
self.state_indices = torch.tensor([])
# Clear mamba cache states
if isinstance(self.mamba_cache, self.SpeculativeState):
self.mamba_cache = self.SpeculativeState(
conv=torch.tensor([]),
temporal=torch.tensor([]),
intermediate_ssm=torch.tensor([]),
intermediate_conv_window=torch.tensor([]),
)
else:
self.mamba_cache = self.State(
conv=torch.tensor([]),
temporal=torch.tensor([]),
)
torch.cuda.empty_cache()
@torch.compile(options={"max-autotune": True})
def update_mamba_states(self, attn_metadata: "AttentionMetadata",
num_accepted_tokens: torch.Tensor):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
num_accepted_draft_tokens = num_accepted_tokens[
num_contexts:num_contexts + num_gens] - 1
state_indices_d = self.state_indices[num_contexts:num_contexts +
num_gens]
conv_states = self.mamba_cache.conv
ssm_states = self.mamba_cache.temporal
intermediate_state_cache = self.mamba_cache.intermediate_ssm
intermediate_conv_window_cache = self.mamba_cache.intermediate_conv_window
src_state_indices = self.intermediate_state_indices[:num_gens]
accepted_ssm_state = intermediate_state_cache[:, src_state_indices,
num_accepted_draft_tokens]
ssm_states[:, state_indices_d, :] = accepted_ssm_state
accepted_conv_state = intermediate_conv_window_cache[:,
src_state_indices,
num_accepted_draft_tokens]
conv_states[:, state_indices_d, :] = accepted_conv_state
class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
@ -249,10 +393,13 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
mamba_head_dim,
mamba_num_layers,
max_batch_size,
max_batch_size,
mapping,
mamba_cache_dtype,
mamba_ssm_cache_dtype,
mamba_layer_mask,
speculative_num_draft_tokens=spec_config.max_draft_len
if spec_config is not None else None,
)
# initialize kv cache manager
@ -285,3 +432,15 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
def shutdown(self):
MambaCacheManager.shutdown(self)
KVCacheManager.shutdown(self)
def update_resources(self,
scheduled_batch: ScheduledRequests,
attn_metadata: "AttentionMetadata" = None,
kv_cache_dtype_byte_size: float = None):
KVCacheManager.update_resources(self, scheduled_batch, attn_metadata,
kv_cache_dtype_byte_size)
def update_mamba_states(self, attn_metadata: "AttentionMetadata",
num_accepted_tokens: torch.Tensor):
MambaCacheManager.update_mamba_states(self, attn_metadata,
num_accepted_tokens)

View File

@ -1469,10 +1469,12 @@ class ResourceManager:
resource_manager.prepare_resources(scheduled_batch)
@nvtx_range("update_resources")
def update_resources(self,
scheduled_batch: ScheduledRequests,
attn_metadata: Optional["AttentionMetadata"] = None,
kv_cache_dtype_byte_size: Optional[float] = None):
def update_resources(
self,
scheduled_batch: ScheduledRequests,
attn_metadata: Optional["AttentionMetadata"] = None,
kv_cache_dtype_byte_size: Optional[float] = None,
):
for _, resource_manager in self.resource_managers.items():
if hasattr(resource_manager, "update_resources"):
if isinstance(resource_manager, KVCacheManager):

View File

@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn.functional as F
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
MambaHybridCacheManager
from tensorrt_llm.mapping import Mapping
from ..attention_backend import AttentionMetadata
@ -1132,6 +1134,7 @@ class MTPEagleWorker(MTPWorker):
super().__init__(spec_config, model_config)
self.model_config = model_config
self.mtp_num_modules = spec_config.num_nextn_predict_layers
self._is_mamba_hybrid_cache = None
@torch.compile(options={"max-autotune": True})
def update_draft_tokens(self, next_draft_tokens, new_draft_token,
@ -1164,6 +1167,14 @@ class MTPEagleWorker(MTPWorker):
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
input_ids, logits, spec_metadata, attn_metadata)
if self._is_mamba_hybrid_cache is None:
self._is_mamba_hybrid_cache = isinstance(
attn_metadata.kv_cache_manager, MambaHybridCacheManager)
if num_gens > 0 and self._is_mamba_hybrid_cache:
attn_metadata.kv_cache_manager.update_mamba_states(
attn_metadata=attn_metadata,
num_accepted_tokens=num_accepted_tokens)
# Save the old attn_metadata and spec_metadata
self._prepare_attn_metadata_for_spec_dec(attn_metadata)

View File

@ -348,6 +348,11 @@ nvidia/Nemotron-Super-V3:
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 80.85
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
mtp_enabled: true
num_nextn_predict_layers: 3
accuracy: 80.85
nvidia/Nemotron-3-Nano:
- accuracy: 69.37
- quant_algo: FP8

View File

@ -382,6 +382,11 @@ nvidia/Nemotron-Super-V3:
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 77.56
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
mtp_enabled: true
num_nextn_predict_layers: 3
accuracy: 77.56
nvidia/Nemotron-3-Nano:
- accuracy: 73.85
- quant_algo: FP8

View File

@ -5800,6 +5800,42 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
@pytest.mark.skip(reason="Skip MTP test due to no model path file in CI")
@skip_pre_blackwell
@pytest.mark.skip_less_mpi_world_size(8)
def test_nvfp4_8gpus_mtp(self):
# Test MTP (Multi-Token Prediction) accuracy with nvfp4-fp8kv model.
# This test uses MTP with max_draft_len=3 and one_model mode.
mtp_config = MTPDecodingConfig(
num_nextn_predict_layers=3,
mtp_eagle_one_model=True,
)
model_path = f"{llm_models_root()}/nemotron-super-sft-repeated-mtp-iter-0010600-nvfp4-fp8kv"
with LLM(
model_path,
kv_cache_config=KvCacheConfig(
enable_block_reuse=False,
mamba_ssm_cache_dtype="float16",
free_gpu_memory_fraction=0.5,
),
max_batch_size=128,
tensor_parallel_size=8,
moe_expert_parallel_size=8,
pipeline_parallel_size=1,
enable_attention_dp=False,
cuda_graph_config=CudaGraphConfig(max_batch_size=32,
enable_padding=True),
disable_overlap_scheduler=False,
moe_config=MoeConfig(backend="CUTLASS"),
decoding_config=mtp_config,
) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
@skip_pre_hopper
class TestMiniMaxM2(LlmapiAccuracyTestHarness):

View File

@ -214,7 +214,8 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
assert remove_padding
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
cu_seqlens, chunk_size)
out, ssm_state = mamba_chunk_scan_combined(
out = torch.empty_like(x)
ssm_state = mamba_chunk_scan_combined(
x,
dt,
A,
@ -232,6 +233,7 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
dt_softplus=delta_softplus,
return_final_states=not remove_padding,
return_varlen_states=remove_padding,
out=out,
)
if (ssm_state.shape[0] > 1 and ssm_state.dtype == torch.float32
@ -257,7 +259,8 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
else:
state_batch_indices = None
y = selective_state_update(
y = torch.empty_like(x)
selective_state_update(
state,
x,
dt,
@ -269,6 +272,7 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
dt_bias=dt_bias,
dt_softplus=delta_softplus,
state_batch_indices=state_batch_indices,
out=y,
)
outputs = (y, state[state_batch_indices]
if state_batch_indices is not None else state)
@ -432,7 +436,8 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
z = torch.randn_like(x)
## full seqlen computation
out_ref, state_ref = mamba_chunk_scan_combined(
out_ref = torch.empty_like(x)
state_ref = mamba_chunk_scan_combined(
x,
dt,
A,
@ -447,6 +452,7 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
dt_softplus=delta_softplus,
return_final_states=False,
return_varlen_states=True,
out=out_ref,
)
## chunked seqlen computation
@ -478,7 +484,8 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
z_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(z, i)
# yapf: enable
partial_out, partial_state = mamba_chunk_scan_combined(
partial_out = torch.empty_like(x_chunked)
partial_state = mamba_chunk_scan_combined(
x_chunked,
dt_chunked,
A,
@ -493,6 +500,7 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
dt_softplus=delta_softplus,
return_final_states=False,
return_varlen_states=True,
out=partial_out,
)
# remaining chunk
@ -542,7 +550,8 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
remaining_chunked_cu_seqlens, mamba_chunk_size)
out_chunked, state_chunked = mamba_chunk_scan_combined(
out_chunked = torch.empty_like(remaining_x_chunked)
state_chunked = mamba_chunk_scan_combined(
remaining_x_chunked,
remaining_dt_chunked,
A,
@ -560,6 +569,7 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
dt_softplus=delta_softplus,
return_final_states=False,
return_varlen_states=True,
out=out_chunked,
)
out = concat_batch_f(partial_out, out_chunked)