mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-10062][feat] Enable MTP for Nemotron Super (#10754)
Signed-off-by: qgai <qgai@nvidia.com>
This commit is contained in:
parent
43b8a5561c
commit
ff0dd6076e
@ -675,6 +675,11 @@ private:
|
||||
{
|
||||
continue;
|
||||
}
|
||||
// If tileSizeQ < mNumHeadsQPerKv, this will result in 0, causing division by zero.
|
||||
if (tileSizeQ < params.mNumHeadsQPerKv)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Update the tileSizeQ.
|
||||
selectKernelParamsCopy.mTileSizeQ = tileSizeQ;
|
||||
|
||||
@ -144,8 +144,15 @@ def _triton_cached_ssm(
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
y_prefill = None
|
||||
y_decode = None
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[bs, num_heads, head_dim],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens]
|
||||
preallocated_ssm_out_d = preallocated_ssm_out[num_prefill_tokens:num_total_tokens]
|
||||
|
||||
# Prefill: concatenate tokens at the front and run combined scan
|
||||
if num_prefill > 0:
|
||||
@ -165,7 +172,7 @@ def _triton_cached_ssm(
|
||||
chunk_indices = None
|
||||
chunk_offsets = None
|
||||
|
||||
y_prefill, varlen_states = mamba_chunk_scan_combined(
|
||||
varlen_states = mamba_chunk_scan_combined(
|
||||
hs_prefill,
|
||||
dt_prefill,
|
||||
A,
|
||||
@ -184,11 +191,12 @@ def _triton_cached_ssm(
|
||||
dt_limit=(time_step_limit[0], time_step_limit[1]),
|
||||
return_final_states=False,
|
||||
return_varlen_states=True,
|
||||
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
|
||||
out=preallocated_ssm_out_p.unsqueeze(0),
|
||||
state_dtype=ssm_state_cache.dtype,
|
||||
)
|
||||
|
||||
ssm_state_cache.index_copy_(
|
||||
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
|
||||
0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype)
|
||||
)
|
||||
|
||||
# Decode: batch single-token updates via selective_state_update
|
||||
@ -205,7 +213,7 @@ def _triton_cached_ssm(
|
||||
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
|
||||
D_full = D[..., None].expand(num_heads, head_dim)
|
||||
|
||||
y_decode = selective_state_update(
|
||||
selective_state_update(
|
||||
ssm_state_cache,
|
||||
x_decode,
|
||||
dt_hp,
|
||||
@ -217,19 +225,16 @@ def _triton_cached_ssm(
|
||||
dt_bias=dt_bias_hp,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=slot_idx_decode,
|
||||
) # [nd, H, D]
|
||||
out=preallocated_ssm_out_d,
|
||||
)
|
||||
|
||||
# Dispatch return logic
|
||||
if num_prefill > 0 and num_decode > 0:
|
||||
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
|
||||
y_flat = y.view(bs, *y.shape[2:])
|
||||
y_flat[:num_prefill_tokens].copy_(y_prefill[0])
|
||||
y_flat[num_prefill_tokens:num_total_tokens].copy_(y_decode)
|
||||
return y
|
||||
elif num_prefill > 0:
|
||||
return y_prefill[0].view(b, s, num_heads, head_dim).to(hidden_states.dtype)
|
||||
elif num_decode > 0:
|
||||
return y_decode.view(b, s, num_heads, head_dim).to(hidden_states.dtype)
|
||||
# Return the preallocated output reshaped to original dimensions
|
||||
if num_total_tokens > 0:
|
||||
return (
|
||||
preallocated_ssm_out[:num_total_tokens]
|
||||
.view(b, s, num_heads, head_dim)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
else:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
import tensorrt_llm.logger as logger
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import \
|
||||
HfWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_utils import register_mapper
|
||||
@ -55,6 +56,16 @@ class NemotronHHfWeightMapper(HfWeightMapper):
|
||||
if "embeddings" in key:
|
||||
key = key.replace("embeddings", "embed_tokens")
|
||||
|
||||
# MTP layers are stored as mtp.layers.0.xxx (sublayer 0, Attention) and mtp.layers.1.xxx (sublayer 1, MoE)
|
||||
if "mtp.layers." in key:
|
||||
import re
|
||||
match = re.match(r'mtp\.layers\.(\d+)\.(.*)', key)
|
||||
if match:
|
||||
sublayer_idx, rest = match.groups()
|
||||
key = f"model.layers.{config.num_hidden_layers}.layers.{sublayer_idx}.{rest}"
|
||||
else:
|
||||
logger.error(f"Failed to match MTP pattern for: {name}")
|
||||
|
||||
if "A_log" in key:
|
||||
key = key.replace("A_log", "A")
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -37,9 +37,11 @@ from ..modules.mamba.mamba2_mixer import Mamba2Mixer
|
||||
from ..modules.mlp import MLP
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
from ..utils import AuxStreamType, EventType
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
register_auto_model)
|
||||
from .modeling_deepseekv3 import DeepseekV3MTPHead
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import DecoderModel, register_auto_model
|
||||
|
||||
|
||||
class NemotronHConfig(PretrainedConfig):
|
||||
@ -347,13 +349,17 @@ class NemotronHLayer(DecoderLayer):
|
||||
position_ids: torch.IntTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata, **kwargs)
|
||||
hidden_states = self.mixer(hidden_states,
|
||||
attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
**kwargs)
|
||||
hidden_states = torch.add(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
@ -405,6 +411,7 @@ class NemotronHModel(DecoderModel):
|
||||
layer_type,
|
||||
aux_stream_dict=self.aux_stream_dict))
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
# final norm
|
||||
self.norm_f = RMSNorm(
|
||||
@ -421,6 +428,7 @@ class NemotronHModel(DecoderModel):
|
||||
input_ids: Optional[torch.IntTensor] = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
@ -439,10 +447,11 @@ class NemotronHModel(DecoderModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for layer in self.layers:
|
||||
for layer in self.layers[:self.num_hidden_layers]:
|
||||
hidden_states = layer(position_ids,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
mamba_metadata=self.mamba_metadata)
|
||||
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
@ -451,8 +460,8 @@ class NemotronHModel(DecoderModel):
|
||||
|
||||
|
||||
@register_auto_model("NemotronHForCausalLM")
|
||||
class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel,
|
||||
NemotronHConfig]):
|
||||
class NemotronHForCausalLM(SpecDecOneEngineForCausalLM[NemotronHModel,
|
||||
NemotronHConfig]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -477,15 +486,286 @@ class NemotronHForCausalLM(DecoderModelForCausalLM[NemotronHModel,
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
NemotronHModel(model_config),
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size,
|
||||
model=NemotronHModel(model_config),
|
||||
model_config=model_config,
|
||||
)
|
||||
self.model_nextn = 0
|
||||
if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp_one_model(
|
||||
):
|
||||
model_nextn = model_config.spec_config.num_nextn_predict_layers
|
||||
ckpt_nextn = self.config.num_nextn_predict_layers
|
||||
self.num_hidden_layers = self.config.num_hidden_layers
|
||||
assert ckpt_nextn > 0, "There are not MTP modules in the checkpoint."
|
||||
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
|
||||
pass
|
||||
else:
|
||||
# modify the QuantConfig to support duplicated mtp layers
|
||||
if model_config.quant_config.exclude_modules is not None:
|
||||
extend_exclude_modules = []
|
||||
for model_mtp_idx in range(
|
||||
self.num_hidden_layers,
|
||||
self.num_hidden_layers + model_nextn):
|
||||
ckpt_mtp_idx = (model_mtp_idx - self.num_hidden_layers
|
||||
) % ckpt_nextn + self.num_hidden_layers
|
||||
model_prefix = f"model.layers.{model_mtp_idx}"
|
||||
ckpt_prefix = f"model.layers.{ckpt_mtp_idx}"
|
||||
for exclude_module in model_config.quant_config.exclude_modules:
|
||||
if ckpt_prefix in exclude_module and model_prefix not in exclude_module:
|
||||
extend_exclude_modules.append(
|
||||
exclude_module.replace(
|
||||
ckpt_prefix, model_prefix))
|
||||
self.model_config.quant_config.exclude_modules.extend(
|
||||
extend_exclude_modules)
|
||||
self.model.layers.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.append(self.spec_worker)
|
||||
|
||||
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
|
||||
new_weights = weight_mapper.preprocess_weights(weights)
|
||||
super().load_weights(weights=new_weights, weight_mapper=weight_mapper)
|
||||
|
||||
|
||||
class NemotronHMTPDecoderLayer(NemotronHLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[NemotronHConfig],
|
||||
layer_idx: int,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
has_start_projections: bool,
|
||||
has_end_norm: bool,
|
||||
layer_type: str,
|
||||
) -> None:
|
||||
|
||||
super().__init__(
|
||||
model_config=model_config,
|
||||
layer_idx=layer_idx,
|
||||
layer_type=layer_type,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
)
|
||||
self.model_nextn = 0
|
||||
if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp_one_model(
|
||||
):
|
||||
model_nextn = model_config.spec_config.num_nextn_predict_layers
|
||||
ckpt_nextn = self.config.num_nextn_predict_layers
|
||||
self.num_hidden_layers = self.config.num_hidden_layers
|
||||
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
|
||||
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
|
||||
pass
|
||||
else:
|
||||
# modify the QuantConfig to support duplicated mtp layers
|
||||
if model_config.quant_config.exclude_modules is not None:
|
||||
extend_exclude_modules = []
|
||||
for model_mtp_idx in range(
|
||||
self.num_hidden_layers,
|
||||
self.num_hidden_layers + model_nextn):
|
||||
ckpt_mtp_idx = (model_mtp_idx - self.num_hidden_layers
|
||||
) % ckpt_nextn + self.num_hidden_layers
|
||||
model_prefix = f"model.layers.{model_mtp_idx}"
|
||||
ckpt_prefix = f"model.layers.{ckpt_mtp_idx}"
|
||||
for exclude_module in model_config.quant_config.exclude_modules:
|
||||
if ckpt_prefix in exclude_module and model_prefix not in exclude_module:
|
||||
extend_exclude_modules.append(
|
||||
exclude_module.replace(
|
||||
ckpt_prefix, model_prefix))
|
||||
self.model_config.quant_config.exclude_modules.extend(
|
||||
extend_exclude_modules)
|
||||
self.model.layers.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.append(self.spec_worker)
|
||||
|
||||
config = model_config.pretrained_config
|
||||
self.model_config = model_config
|
||||
self.has_start_projections = has_start_projections
|
||||
self.has_end_norm = has_end_norm
|
||||
|
||||
if has_start_projections:
|
||||
self.enorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
self.hnorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
if model_config.mapping.enable_attention_dp:
|
||||
self.eh_proj = Linear(
|
||||
in_features=config.hidden_size * 2,
|
||||
out_features=config.hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
quant_config=model_config.quant_config,
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
)
|
||||
else:
|
||||
self.eh_proj = Linear(
|
||||
in_features=config.hidden_size * 2,
|
||||
out_features=config.hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
mapping=model_config.mapping,
|
||||
quant_config=model_config.quant_config,
|
||||
reduce_output=True,
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
)
|
||||
|
||||
if has_end_norm:
|
||||
self.final_layernorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
|
||||
if self.has_start_projections:
|
||||
assert inputs_embeds is not None
|
||||
inputs_embeds_normed = self.enorm(inputs_embeds)
|
||||
previous_hidden_states_normed = self.hnorm(hidden_states)
|
||||
|
||||
# Fuse via concatenation and linear projection
|
||||
fused = torch.cat(
|
||||
[inputs_embeds_normed, previous_hidden_states_normed], dim=-1)
|
||||
|
||||
# Split fused hidden_states columnwise based on TP
|
||||
mapping = self.model_config.mapping
|
||||
if mapping.tp_size > 1 and not mapping.enable_attention_dp:
|
||||
fused = torch.chunk(fused, mapping.tp_size,
|
||||
dim=-1)[mapping.tp_rank]
|
||||
|
||||
hidden_states = self.eh_proj(fused)
|
||||
residual = None # Start fresh after fusion
|
||||
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
|
||||
new_weights = weight_mapper.preprocess_weights(weights)
|
||||
super().load_weights(new_weights, weight_mapper)
|
||||
if self.has_end_norm:
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = None
|
||||
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class NemotronHMTP(nn.Module):
|
||||
"""NemotronH MTP Layer - single MTP layer following DeepseekV3MTP pattern."""
|
||||
|
||||
def __init__(self,
|
||||
model_config: ModelConfig[NemotronHConfig],
|
||||
layer_idx: int,
|
||||
aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream],
|
||||
is_separate_draft_engine: bool = False,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Pattern configuration
|
||||
self.pattern_str = config.mtp_hybrid_override_pattern
|
||||
self.pattern_len = len(self.pattern_str)
|
||||
assert self.pattern_len > 0
|
||||
|
||||
# Build pattern-based layers
|
||||
self.layers = nn.ModuleDict()
|
||||
|
||||
for i in range(self.pattern_len):
|
||||
step_rel_idx = i % self.pattern_len
|
||||
|
||||
char = self.pattern_str[step_rel_idx]
|
||||
|
||||
is_start_of_step = step_rel_idx == 0
|
||||
is_end_of_step = step_rel_idx == self.pattern_len - 1
|
||||
|
||||
sublayer_quant_config = self._get_mtp_sublayer_quant_config(
|
||||
model_config, self.layer_idx)
|
||||
|
||||
# Create a temporary model_config with the override quant_config
|
||||
sublayer_model_config = ModelConfig(
|
||||
pretrained_config=model_config.pretrained_config,
|
||||
mapping=model_config.mapping,
|
||||
quant_config=sublayer_quant_config,
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
)
|
||||
|
||||
self.layers[str(i)] = NemotronHMTPDecoderLayer(
|
||||
model_config=sublayer_model_config,
|
||||
layer_idx=self.layer_idx,
|
||||
aux_stream_dict=aux_stream_dict,
|
||||
has_start_projections=is_start_of_step,
|
||||
has_end_norm=is_end_of_step,
|
||||
layer_type=char,
|
||||
)
|
||||
|
||||
# Add shared_head for MTP, following DeepseekV3MTP pattern
|
||||
self.shared_head = DeepseekV3MTPHead(model_config)
|
||||
|
||||
def _get_mtp_sublayer_quant_config(
|
||||
self, model_config: ModelConfig[NemotronHConfig], layer_idx: int):
|
||||
"""
|
||||
Get quantization config for MTP sublayer.
|
||||
The MTP layer in the nvfp4 checkpoint is unquantized. Because the TRTLLM
|
||||
moe_backend only supports fp8/fp4 quantization, we need to override
|
||||
the quant_config for the MTP layer.
|
||||
"""
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
quant_config = model_config.quant_config
|
||||
# MTP layers are always unquantized, force quant_algo=None
|
||||
if quant_config is None:
|
||||
return None
|
||||
return QuantConfig(
|
||||
quant_algo=None,
|
||||
kv_cache_quant_algo=quant_config.kv_cache_quant_algo,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.IntTensor,
|
||||
position_ids: torch.IntTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
embed_tokens: Embedding,
|
||||
attn_metadata: AttentionMetadata,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(self.pattern_len):
|
||||
layer = self.layers[str(i)]
|
||||
hidden_states, residual = layer(
|
||||
inputs_embeds=inputs_embeds,
|
||||
positions=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
AutoConfig.register(NemotronHConfig.model_type, NemotronHConfig)
|
||||
|
||||
@ -706,6 +706,9 @@ class MTPForCausalLM(nn.Module):
|
||||
case "exaone_moe":
|
||||
from .modeling_exaone_moe import ExaoneMoeMTP
|
||||
mtp_layer = ExaoneMoeMTP
|
||||
case "nemotron_h":
|
||||
from .modeling_nemotron_h import NemotronHMTP
|
||||
mtp_layer = NemotronHMTP
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Model type {model_type} not supported for MTP")
|
||||
@ -751,6 +754,12 @@ class MTPDraftModel(nn.Module):
|
||||
from .modeling_exaone_moe import ExaoneMoeMTP
|
||||
mtp_layer = ExaoneMoeMTP(model_config, layer_idx, aux_stream_dict)
|
||||
|
||||
elif model_type == "nemotron_h":
|
||||
from .modeling_nemotron_h import NemotronHMTP
|
||||
mtp_layer = NemotronHMTP(model_config,
|
||||
layer_idx,
|
||||
aux_stream_dict,
|
||||
is_separate_draft_engine=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MTPDraftModel does not support model_type: {model_type}")
|
||||
|
||||
1165
tensorrt_llm/_torch/modules/mamba/causal_conv1d_triton.py
Normal file
1165
tensorrt_llm/_torch/modules/mamba/causal_conv1d_triton.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -24,8 +24,11 @@ from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...attention_backend import AttentionMetadata
|
||||
from ...model_config import ModelConfig
|
||||
from ...speculative import SpecMetadata
|
||||
from ..linear import Linear, TensorParallelMode
|
||||
from .causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||
from .causal_conv1d_triton import \
|
||||
causal_conv1d_update as causal_conv1d_update_triton
|
||||
from .layernorm_gated import RMSNorm as RMSNormGated
|
||||
from .selective_state_update import selective_state_update
|
||||
from .ssd_combined import mamba_chunk_scan_combined
|
||||
@ -82,6 +85,8 @@ class Mamba2Mixer(nn.Module):
|
||||
self.tp_d_inner = d_inner // tp_size
|
||||
self.tp_nheads = nheads // tp_size
|
||||
self.tp_ngroups = n_groups // tp_size
|
||||
self.num_heads = nheads
|
||||
self.tp_size = tp_size
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.d_conv = d_conv
|
||||
@ -167,6 +172,7 @@ class Mamba2Mixer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_metadata: Mamba2Metadata,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@ -175,6 +181,7 @@ class Mamba2Mixer(nn.Module):
|
||||
num_decodes = attn_metadata.seq_lens.shape[0] - num_prefills
|
||||
num_prefill_tokens = attn_metadata.num_ctx_tokens
|
||||
num_decode_tokens = attn_metadata.num_tokens - num_prefill_tokens
|
||||
num_actual_tokens = attn_metadata.num_tokens
|
||||
seqlen_split_size = [num_prefill_tokens, num_decode_tokens]
|
||||
batch_split_size = [num_prefills, num_decodes]
|
||||
|
||||
@ -183,10 +190,10 @@ class Mamba2Mixer(nn.Module):
|
||||
|
||||
state_indices_p, state_indices_d = torch.split(state_indices,
|
||||
batch_split_size)
|
||||
conv_states = attn_metadata.kv_cache_manager.get_conv_states(
|
||||
self.layer_idx)
|
||||
ssm_states = attn_metadata.kv_cache_manager.get_ssm_states(
|
||||
layer_cache = attn_metadata.kv_cache_manager.mamba_layer_cache(
|
||||
self.layer_idx)
|
||||
conv_states = layer_cache.conv
|
||||
ssm_states = layer_cache.temporal
|
||||
|
||||
# in_proj
|
||||
zxbcdt = self.in_proj(hidden_states)
|
||||
@ -199,7 +206,21 @@ class Mamba2Mixer(nn.Module):
|
||||
xbc_p, xbc_d = torch.split(xbc, seqlen_split_size, dim=0)
|
||||
dt_p, dt_d = torch.split(dt, seqlen_split_size, dim=0)
|
||||
|
||||
out = []
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
zxbcdt.shape[0],
|
||||
(self.num_heads * self.head_dim) // self.tp_size,
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_prefill_tokens, num_decode_tokens],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if num_prefills > 0:
|
||||
|
||||
@ -239,7 +260,7 @@ class Mamba2Mixer(nn.Module):
|
||||
has_initial_states[:, None, None, None],
|
||||
ssm_states[state_indices_p], 0)
|
||||
|
||||
y, current_ssm_states = mamba_chunk_scan_combined(
|
||||
current_ssm_states = mamba_chunk_scan_combined(
|
||||
x_p,
|
||||
dt_p,
|
||||
self.A,
|
||||
@ -247,30 +268,63 @@ class Mamba2Mixer(nn.Module):
|
||||
C_p,
|
||||
chunk_size=self.chunk_size,
|
||||
D=self.D,
|
||||
z=z_p,
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
initial_states=initial_states,
|
||||
chunk_indices=mamba_metadata.chunk_indices,
|
||||
chunk_offsets=mamba_metadata.chunk_offsets,
|
||||
dt_softplus=self.delta_softplus,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
mamba_ssm_cache_dtype=self._mamba_ssm_cache_dtype,
|
||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
state_dtype=self._mamba_ssm_cache_dtype,
|
||||
)
|
||||
out.append(rearrange(y, "b l h p -> (b l) (h p)"))
|
||||
|
||||
# copy new ssm state
|
||||
ssm_states[state_indices_p] = current_ssm_states
|
||||
|
||||
if num_decodes > 0:
|
||||
xbc_d = causal_conv1d_update(xbc_d,
|
||||
conv_states,
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
activation="silu",
|
||||
conv_state_indices=state_indices_d)
|
||||
is_target_verify = attn_metadata.kv_cache_manager.is_speculative(
|
||||
) and spec_metadata is not None
|
||||
if is_target_verify:
|
||||
# TODO: support dynamic speculation, will add current_draft_len later [TRTLLM-10319]
|
||||
draft_token_num = spec_metadata.max_draft_len + 1
|
||||
intermediate_conv_states = layer_cache.intermediate_conv_window
|
||||
|
||||
self.intermediate_state_indices = torch.arange(
|
||||
num_decodes,
|
||||
dtype=torch.int32,
|
||||
device=state_indices_d.device)
|
||||
|
||||
# Reshape for batch processing
|
||||
xbc_d_reshaped = xbc_d.view(num_decodes, draft_token_num,
|
||||
-1).transpose(1, 2)
|
||||
# TODO:support tree structure [TRTLLM-10320]
|
||||
xbc_d_processed = causal_conv1d_update_triton(
|
||||
xbc_d_reshaped,
|
||||
conv_states,
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
activation="silu",
|
||||
conv_state_indices=state_indices_d[:num_decodes],
|
||||
intermediate_conv_window=intermediate_conv_states,
|
||||
intermediate_state_indices=self.intermediate_state_indices,
|
||||
)
|
||||
|
||||
xbc_d = xbc_d_processed.transpose(1, 2).view(
|
||||
num_decode_tokens, -1)
|
||||
|
||||
else:
|
||||
xbc_d = causal_conv1d_update(xbc_d,
|
||||
conv_states,
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
activation="silu",
|
||||
conv_state_indices=state_indices_d)
|
||||
|
||||
x_d, B_d, C_d = torch.split(
|
||||
xbc_d,
|
||||
@ -292,29 +346,64 @@ class Mamba2Mixer(nn.Module):
|
||||
n=self.d_state).to(dtype=torch.float32)
|
||||
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.head_dim)
|
||||
D = repeat(self.D, "h -> h p", p=self.head_dim)
|
||||
if is_target_verify:
|
||||
intermediate_ssm_states = layer_cache.intermediate_ssm
|
||||
selective_state_update(
|
||||
ssm_states,
|
||||
x_d.view(
|
||||
num_decodes,
|
||||
draft_token_num,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim,
|
||||
),
|
||||
dt_d.view(
|
||||
num_decodes,
|
||||
draft_token_num,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim,
|
||||
),
|
||||
A,
|
||||
B_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1),
|
||||
C_d.view(num_decodes, draft_token_num, self.tp_ngroups, -1),
|
||||
D,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_d[:num_decodes],
|
||||
out=preallocated_ssm_out_d.view(
|
||||
num_decodes,
|
||||
draft_token_num,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim,
|
||||
),
|
||||
disable_state_update=True,
|
||||
intermediate_states_buffer=intermediate_ssm_states,
|
||||
cache_steps=draft_token_num,
|
||||
intermediate_state_indices=self.intermediate_state_indices,
|
||||
)
|
||||
|
||||
y = selective_state_update(
|
||||
ssm_states,
|
||||
x_d,
|
||||
dt_d,
|
||||
A,
|
||||
B_d,
|
||||
C_d,
|
||||
D,
|
||||
z=z_d,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=self.delta_softplus,
|
||||
state_batch_indices=state_indices_d,
|
||||
)
|
||||
else:
|
||||
|
||||
out.append(rearrange(y, "b h p -> b (h p)"))
|
||||
|
||||
out = torch.cat(out, dim=0)
|
||||
selective_state_update(
|
||||
ssm_states,
|
||||
x_d,
|
||||
dt_d,
|
||||
A,
|
||||
B_d,
|
||||
C_d,
|
||||
D,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=self.delta_softplus,
|
||||
state_batch_indices=state_indices_d,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
|
||||
# norm
|
||||
out = self.norm(out)
|
||||
hidden_states = self.norm(preallocated_ssm_out, z[:num_actual_tokens])
|
||||
|
||||
# out_proj
|
||||
out = self.out_proj(out)
|
||||
out = self.out_proj(hidden_states)
|
||||
|
||||
return out
|
||||
return out[:num_actual_tokens]
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -15,6 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the sglang project
|
||||
#
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@ -35,7 +38,19 @@ from .softplus import softplus
|
||||
})
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
||||
@triton.jit
|
||||
@triton.heuristics({
|
||||
"CACHE_INTERMEDIATE_STATES":
|
||||
lambda args: args["intermediate_states_buffer"] is not None
|
||||
})
|
||||
@triton.heuristics({
|
||||
"HAS_EAGLE_TREE_CUSTOM_ATTN_MASK":
|
||||
lambda args: args["retrieve_parent_token_ptr"] is not None
|
||||
})
|
||||
@triton.heuristics({
|
||||
"HAS_INTERMEDIATE_STATE_INDICES":
|
||||
lambda args: args["intermediate_state_indices_ptr"] is not None
|
||||
})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr,
|
||||
@ -50,8 +65,13 @@ def _selective_scan_update_kernel(
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
intermediate_states_buffer,
|
||||
cache_steps,
|
||||
retrieve_parent_token_ptr,
|
||||
intermediate_state_indices_ptr,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
T,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
@ -62,9 +82,11 @@ def _selective_scan_update_kernel(
|
||||
stride_state_dim,
|
||||
stride_state_dstate,
|
||||
stride_x_batch,
|
||||
stride_x_T,
|
||||
stride_x_head,
|
||||
stride_x_dim,
|
||||
stride_dt_batch,
|
||||
stride_dt_T,
|
||||
stride_dt_head,
|
||||
stride_dt_dim,
|
||||
stride_dt_bias_head,
|
||||
@ -73,19 +95,25 @@ def _selective_scan_update_kernel(
|
||||
stride_A_dim,
|
||||
stride_A_dstate,
|
||||
stride_B_batch,
|
||||
stride_B_T,
|
||||
stride_B_group,
|
||||
stride_B_dstate,
|
||||
stride_C_batch,
|
||||
stride_C_T,
|
||||
stride_C_group,
|
||||
stride_C_dstate,
|
||||
stride_D_head,
|
||||
stride_D_dim,
|
||||
stride_z_batch,
|
||||
stride_z_T,
|
||||
stride_z_head,
|
||||
stride_z_dim,
|
||||
stride_out_batch,
|
||||
stride_out_T,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
stride_retrieve_parent_token_batch,
|
||||
stride_retrieve_parent_token_T,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
TIE_HDIM: tl.constexpr,
|
||||
@ -94,6 +122,10 @@ def _selective_scan_update_kernel(
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
||||
DISABLE_STATE_UPDATE: tl.constexpr,
|
||||
CACHE_INTERMEDIATE_STATES: tl.constexpr,
|
||||
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr,
|
||||
HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
@ -105,7 +137,7 @@ def _selective_scan_update_kernel(
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr)
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
else:
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
@ -127,91 +159,153 @@ def _selective_scan_update_kernel(
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
|
||||
offs_n[None, :] * stride_A_dstate)
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_D:
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0)
|
||||
A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[
|
||||
None, :] * stride_A_dstate
|
||||
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||
A = tl.load(A_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
cache_idx = -1
|
||||
if CACHE_INTERMEDIATE_STATES:
|
||||
if HAS_INTERMEDIATE_STATE_INDICES:
|
||||
intermediate_state_idx = tl.load(intermediate_state_indices_ptr +
|
||||
pid_b).to(tl.int64)
|
||||
cache_idx = intermediate_state_idx
|
||||
elif HAS_STATE_BATCH_INDICES:
|
||||
cache_idx = state_batch_idx
|
||||
else:
|
||||
cache_idx = pid_b
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_Z:
|
||||
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
current_step_idx = 0
|
||||
for _ in range(T):
|
||||
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
|
||||
if current_step_idx != 0 and cache_idx >= 0:
|
||||
parent_ptr = (retrieve_parent_token_ptr +
|
||||
pid_b * stride_retrieve_parent_token_batch +
|
||||
current_step_idx * stride_retrieve_parent_token_T)
|
||||
parent_step_idx = tl.load(parent_ptr).to(tl.int32)
|
||||
|
||||
if not TIE_HDIM:
|
||||
dB = B[None, :] * dt[:, None]
|
||||
else:
|
||||
dB = B * dt # vector of size (dstate,)
|
||||
state = state * dA + dB * x[:, None]
|
||||
if parent_step_idx >= 0 and parent_step_idx < T:
|
||||
step_offset = parent_step_idx * nheads * dim * dstate
|
||||
cache_ptr = (
|
||||
intermediate_states_buffer +
|
||||
cache_idx * cache_steps * nheads * dim * dstate +
|
||||
step_offset + pid_h * dim * dstate +
|
||||
offs_m[:, None] * dstate + offs_n[None, :])
|
||||
state = tl.load(cache_ptr, mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
tl.store(state_ptrs, state, mask=mask)
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
if HAS_Z:
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(
|
||||
A_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_Z:
|
||||
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
|
||||
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
|
||||
state = state * dA + dB * x[:, None]
|
||||
|
||||
if CACHE_INTERMEDIATE_STATES:
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
if state_batch_idx != pad_slot_id:
|
||||
cache_ptr_base = (
|
||||
intermediate_states_buffer +
|
||||
cache_idx * cache_steps * nheads * dim * dstate +
|
||||
current_step_idx * nheads * dim * dstate +
|
||||
pid_h * dim * dstate)
|
||||
cache_ptrs = cache_ptr_base + (offs_m[:, None] * dstate +
|
||||
offs_n[None, :])
|
||||
tl.store(cache_ptrs,
|
||||
state.to(cache_ptrs.dtype.element_ty),
|
||||
mask=mask)
|
||||
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
if HAS_Z:
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
current_step_idx += 1
|
||||
|
||||
x_ptr += stride_x_T
|
||||
dt_ptr += stride_dt_T
|
||||
B_ptr += stride_B_T
|
||||
C_ptr += stride_C_T
|
||||
out_ptr += stride_out_T
|
||||
if HAS_Z:
|
||||
z_ptr += stride_z_T
|
||||
|
||||
if not DISABLE_STATE_UPDATE:
|
||||
tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
def selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID):
|
||||
def selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None,
|
||||
disable_state_update=False,
|
||||
intermediate_states_buffer=None,
|
||||
cache_steps=None,
|
||||
retrieve_parent_token=None,
|
||||
intermediate_state_indices=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
x: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
@ -222,38 +316,58 @@ def selective_state_update(state,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
disable_state_update: If True, don't write back to state (for speculative verify)
|
||||
intermediate_states_buffer: Buffer to cache intermediate states
|
||||
cache_steps: Total number of steps in the buffer
|
||||
retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention
|
||||
intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations.
|
||||
If provided, uses these indices instead of state_batch_indices for the buffer.
|
||||
"""
|
||||
has_heads = state.dim() > 3
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if x.dim() == 3:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if dt.dim() == 3:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if B.dim() == 3:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if C.dim() == 3:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if z is not None:
|
||||
if z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if z.dim() == 3:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
if out.dim() == 2:
|
||||
out = out.unsqueeze(1)
|
||||
if out.dim() == 3:
|
||||
out = out.unsqueeze(1)
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
batch, T, _, _ = x.shape
|
||||
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert x.shape == (batch, T, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
ngroups = B.shape[2]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert B.shape == (batch, T, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
@ -263,10 +377,11 @@ def selective_state_update(state,
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch, )
|
||||
out = torch.empty_like(x)
|
||||
assert out.shape == x.shape
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
||||
z_strides = (z.stride(0), z.stride(1),
|
||||
z.stride(2)) if z is not None else (0, 0, 0)
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2),
|
||||
z.stride(3)) if z is not None else (0, 0, 0, 0))
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
|
||||
@ -275,6 +390,12 @@ def selective_state_update(state,
|
||||
((4, 4) if dstate <= 128 else ((4, 8))))))
|
||||
tie_hdim = (A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0
|
||||
and dt_bias.stride(-1) == 0)
|
||||
|
||||
retrieve_parent_token_strides = ((retrieve_parent_token.stride(0),
|
||||
retrieve_parent_token.stride(1))
|
||||
if retrieve_parent_token is not None else
|
||||
(0, 0))
|
||||
|
||||
with torch.cuda.device(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state,
|
||||
@ -289,7 +410,12 @@ def selective_state_update(state,
|
||||
out,
|
||||
state_batch_indices,
|
||||
pad_slot_id,
|
||||
intermediate_states_buffer,
|
||||
cache_steps if cache_steps is not None else 0,
|
||||
retrieve_parent_token,
|
||||
intermediate_state_indices,
|
||||
batch,
|
||||
T,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
@ -301,9 +427,11 @@ def selective_state_update(state,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
x.stride(3),
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
dt.stride(3),
|
||||
*(dt_bias.stride(0),
|
||||
dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
A.stride(0),
|
||||
@ -312,21 +440,25 @@ def selective_state_update(state,
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
B.stride(3),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
C.stride(3),
|
||||
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
||||
z_strides[0],
|
||||
z_strides[1],
|
||||
z_strides[2],
|
||||
z_strides[3],
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
out.stride(3),
|
||||
retrieve_parent_token_strides[0],
|
||||
retrieve_parent_token_strides[1],
|
||||
dt_softplus,
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
DISABLE_STATE_UPDATE=disable_state_update,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
|
||||
@ -234,6 +234,7 @@ def _chunk_scan_fwd_kernel(
|
||||
|
||||
# M-block offsets and prev states
|
||||
# - logic in next block may override these if there is an active offset
|
||||
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
prev_states_ptr = (states_ptr + pid_b * stride_states_batch +
|
||||
c_idx * stride_states_chunk + pid_h * stride_states_head)
|
||||
prev_states_hdim = stride_states_hdim
|
||||
@ -269,11 +270,12 @@ def _chunk_scan_fwd_kernel(
|
||||
):
|
||||
|
||||
# - replace prev_states_ptr with init_states
|
||||
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
|
||||
prev_states_ptr = (initstates_ptr +
|
||||
seq_idx_m * stride_init_states_batch +
|
||||
pid_h * stride_init_states_head)
|
||||
prev_states_hdim = stride_init_states_hdim # override strides
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
||||
mask=offs_m < chunk_size,
|
||||
@ -289,7 +291,7 @@ def _chunk_scan_fwd_kernel(
|
||||
c_idx_n = tl.load(
|
||||
chunk_indices_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=-1 # to trigger different chunk
|
||||
other=-1, # to trigger different chunk
|
||||
)
|
||||
|
||||
# - there are things to consider
|
||||
@ -304,9 +306,11 @@ def _chunk_scan_fwd_kernel(
|
||||
if (c_idx == c_idx_n) or c_off > 0:
|
||||
|
||||
# get the next offset
|
||||
c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=chunk_size)
|
||||
c_off_n = tl.load(
|
||||
chunk_offsets_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=chunk_size,
|
||||
)
|
||||
|
||||
# in this case, adjust down the chunk_size_limit
|
||||
if c_idx == c_idx_n:
|
||||
@ -319,8 +323,9 @@ def _chunk_scan_fwd_kernel(
|
||||
# i.e. the same for all blocks)
|
||||
dA_cs_m_boundary = tl.load(
|
||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((c_off - 1) > -1) and (c_off < chunk_size)),
|
||||
other=0.0).to(tl.float32)
|
||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
if HAS_SEQ_IDX:
|
||||
# - handle seq idx when HAS_INITSTATES==False
|
||||
@ -416,8 +421,7 @@ def _chunk_scan_fwd_kernel(
|
||||
other=0.0).to(tl.float32)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
# cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
|
||||
cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0))
|
||||
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
|
||||
other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
@ -494,18 +498,21 @@ def _chunk_scan_fwd_kernel(
|
||||
)
|
||||
|
||||
|
||||
def _chunk_scan_fwd(cb,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
D=None,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
initial_states=None):
|
||||
def _chunk_scan_fwd(
|
||||
cb,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
D=None,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
initial_states=None,
|
||||
out=None,
|
||||
):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = C.shape
|
||||
@ -527,34 +534,17 @@ def _chunk_scan_fwd(cb,
|
||||
# with initial states, we need to take care of how
|
||||
# seq_idx crosses the boundaries
|
||||
assert batch == 1, "chunk scan only supports initial states with batch 1"
|
||||
|
||||
if initial_states.shape[0] == 1:
|
||||
# no in this case no point to use initial states
|
||||
initial_states = None
|
||||
else:
|
||||
assert chunk_indices is not None and chunk_offsets is not None, \
|
||||
(
|
||||
"chunk_indices and chunk_offsets should have been set"
|
||||
)
|
||||
assert (chunk_indices is not None and chunk_offsets is not None
|
||||
), "chunk_indices and chunk_offsets should have been set"
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
|
||||
# Allocates output.
|
||||
out = torch.empty(batch,
|
||||
seqlen,
|
||||
nheads,
|
||||
headdim,
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
assert out.shape == x.shape
|
||||
|
||||
if z is not None:
|
||||
out_x = torch.empty(batch,
|
||||
seqlen,
|
||||
nheads,
|
||||
headdim,
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
out_x = torch.empty_like(x)
|
||||
assert out_x.stride() == out.stride()
|
||||
else:
|
||||
out_x = None
|
||||
@ -625,10 +615,12 @@ def _chunk_scan_fwd(cb,
|
||||
states.stride(2),
|
||||
states.stride(3),
|
||||
states.stride(4),
|
||||
*((initial_states.stride(0), initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3)) if initial_states is not None else
|
||||
(0, 0, 0, 0)),
|
||||
*((
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
) if initial_states is not None else (0, 0, 0, 0)),
|
||||
D.stride(0) if D is not None else 0,
|
||||
True,
|
||||
D is not None,
|
||||
@ -639,4 +631,4 @@ def _chunk_scan_fwd(cb,
|
||||
IS_TRITON_22=TRITON_22,
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out, out_x
|
||||
return out_x
|
||||
|
||||
@ -16,8 +16,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
@ -28,6 +26,10 @@ from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
|
||||
from .ssd_state_passing import _state_passing_fwd
|
||||
|
||||
|
||||
def is_int_pow_2(n):
|
||||
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
@ -45,13 +47,14 @@ def _mamba_chunk_scan_combined_fwd(
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
mamba_ssm_cache_dtype=None,
|
||||
state_dtype=None,
|
||||
out=None,
|
||||
):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert x.shape == (batch, seqlen, nheads, headdim)
|
||||
assert dt.shape == (batch, seqlen, nheads)
|
||||
assert A.shape == (nheads, )
|
||||
assert C.shape == B.shape
|
||||
@ -77,8 +80,12 @@ def _mamba_chunk_scan_combined_fwd(
|
||||
if cu_seqlens is None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
else:
|
||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads,
|
||||
headdim, dstate)
|
||||
assert initial_states.shape == (
|
||||
len(cu_seqlens) - 1,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
)
|
||||
|
||||
# This function executes 5 sub-functions for computing mamba
|
||||
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
||||
@ -125,13 +132,12 @@ def _mamba_chunk_scan_combined_fwd(
|
||||
if initial_states is not None else None),
|
||||
seq_idx=seq_idx,
|
||||
chunk_size=chunk_size,
|
||||
out_dtype=mamba_ssm_cache_dtype or C.dtype,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
is_cont_batched=cu_seqlens is not None,
|
||||
chunk_offsets=chunk_offsets)
|
||||
states, final_states = [
|
||||
rearrange(t, "... (p n) -> ... p n", n=dstate)
|
||||
for t in [states, final_states]
|
||||
]
|
||||
chunk_offsets=chunk_offsets,
|
||||
)
|
||||
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
|
||||
for t in [states, final_states])
|
||||
|
||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||
CB = _bmm_chunk_fwd(C,
|
||||
@ -150,20 +156,23 @@ def _mamba_chunk_scan_combined_fwd(
|
||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||
# a seq_idx change, in which case we take states information from
|
||||
# init_states.
|
||||
out, out_x = _chunk_scan_fwd(CB,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
D=D,
|
||||
z=z,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=initial_states)
|
||||
out_x = _chunk_scan_fwd(
|
||||
CB,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
D=D,
|
||||
z=z,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=initial_states,
|
||||
out=out,
|
||||
)
|
||||
if cu_seqlens is None:
|
||||
return out, out_x, dt, dA_cumsum, states, final_states
|
||||
return out_x, dt, dA_cumsum, states, final_states
|
||||
else:
|
||||
assert (
|
||||
batch == 1
|
||||
@ -177,29 +186,31 @@ def _mamba_chunk_scan_combined_fwd(
|
||||
states.squeeze(0),
|
||||
initial_states=initial_states,
|
||||
)
|
||||
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
|
||||
return out_x, dt, dA_cumsum, states, final_states, varlen_states
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
return_final_states=False,
|
||||
return_varlen_states=False,
|
||||
mamba_ssm_cache_dtype: Optional[torch.dtype] = None):
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=None,
|
||||
return_final_states=False,
|
||||
return_varlen_states=False,
|
||||
state_dtype=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
@ -215,38 +226,42 @@ def mamba_chunk_scan_combined(
|
||||
seq_idx: (batch, seqlen)
|
||||
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
mamba_ssm_cache_dtype: torch.dtype, default to None
|
||||
Return:
|
||||
out: (batch, seqlen, nheads, headdim)
|
||||
out: Preallocated output tensor
|
||||
state_dtype: The data type of the ssm state
|
||||
"""
|
||||
|
||||
if not return_varlen_states:
|
||||
cu_seqlens = None
|
||||
else:
|
||||
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
|
||||
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
mamba_ssm_cache_dtype=mamba_ssm_cache_dtype)
|
||||
assert (cu_seqlens is not None
|
||||
), "cu_seqlens must be provided if return_varlen_states is True"
|
||||
out_x, dt_out, dA_cumsum, states, final_states, *rest = (
|
||||
_mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
out=out,
|
||||
state_dtype=state_dtype,
|
||||
))
|
||||
if not return_varlen_states:
|
||||
return out if not return_final_states else (out, final_states)
|
||||
if not return_final_states:
|
||||
return
|
||||
else:
|
||||
return final_states
|
||||
else:
|
||||
varlen_states = rest[0]
|
||||
return (out,
|
||||
varlen_states) if not return_final_states else (out,
|
||||
final_states,
|
||||
varlen_states)
|
||||
return ((varlen_states) if not return_final_states else
|
||||
(final_states, varlen_states))
|
||||
|
||||
@ -13,7 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -22,11 +23,45 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import (
|
||||
BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers)
|
||||
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensorrt_llm._torch.attention_backend.interface import \
|
||||
AttentionMetadata
|
||||
|
||||
GB = 1 << 30
|
||||
|
||||
|
||||
def get_tensor_size_bytes(tensor):
|
||||
"""Calculate tensor size in bytes."""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return tensor.element_size() * tensor.nelement()
|
||||
elif isinstance(tensor, list):
|
||||
return sum(get_tensor_size_bytes(t) for t in tensor)
|
||||
return 0
|
||||
|
||||
|
||||
class MambaCacheManager(BaseResourceManager):
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class State:
|
||||
"""Base state container for Mamba cache."""
|
||||
conv: torch.Tensor
|
||||
temporal: torch.Tensor
|
||||
|
||||
def at_layer_idx(self, layer: int):
|
||||
kwargs = {}
|
||||
for k, v in vars(self).items():
|
||||
kwargs[k] = v[layer]
|
||||
return type(self)(**kwargs)
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class SpeculativeState(State):
|
||||
"""Speculative state with intermediate states for draft tokens."""
|
||||
intermediate_ssm: torch.Tensor
|
||||
intermediate_conv_window: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_state: int,
|
||||
@ -36,13 +71,17 @@ class MambaCacheManager(BaseResourceManager):
|
||||
head_dim: int,
|
||||
num_layers: int,
|
||||
max_batch_size: int,
|
||||
spec_state_size: int,
|
||||
mapping: Mapping,
|
||||
dtype: torch.dtype,
|
||||
ssm_cache_dtype: torch.dtype,
|
||||
layer_mask: Optional[List[bool]] = None,
|
||||
speculative_num_draft_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
|
||||
self.mamba_ssm_cache_dtype = ssm_cache_dtype
|
||||
self.speculative_num_draft_tokens = speculative_num_draft_tokens
|
||||
self.spec_state_size = spec_state_size
|
||||
|
||||
# get tp size
|
||||
tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1
|
||||
@ -74,31 +113,69 @@ class MambaCacheManager(BaseResourceManager):
|
||||
for offset, idx in enumerate(pp_layers)
|
||||
}
|
||||
|
||||
# mamba conv states
|
||||
self.conv_states = torch.empty(
|
||||
size=[
|
||||
num_local_layers,
|
||||
max_batch_size,
|
||||
conv_dim,
|
||||
d_conv - 1,
|
||||
],
|
||||
conv_state_shape = (conv_dim, d_conv - 1)
|
||||
ssm_state_shape = (nheads, head_dim, d_state)
|
||||
|
||||
# create mamba conv and ssm states
|
||||
conv_states = torch.empty(
|
||||
size=(num_local_layers, max_batch_size) + conv_state_shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# mamba ssm states
|
||||
self.ssm_states = torch.empty(
|
||||
size=[
|
||||
num_local_layers,
|
||||
max_batch_size,
|
||||
nheads,
|
||||
head_dim,
|
||||
d_state,
|
||||
],
|
||||
ssm_states = torch.empty(
|
||||
size=(num_local_layers, max_batch_size) + ssm_state_shape,
|
||||
dtype=self.mamba_ssm_cache_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# create state container
|
||||
if speculative_num_draft_tokens is not None:
|
||||
|
||||
# Cache intermediate SSM states per draft token(include new sampled token) during target model verification phase
|
||||
intermediate_ssm_states = torch.zeros(
|
||||
size=(num_local_layers, self.spec_state_size,
|
||||
speculative_num_draft_tokens + 1) + ssm_state_shape,
|
||||
dtype=self.mamba_ssm_cache_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Cache intermediate conv windows per draft token(include new sampled token) during target model verification phase
|
||||
intermediate_conv_window_cache = torch.zeros(
|
||||
size=(num_local_layers, self.spec_state_size,
|
||||
speculative_num_draft_tokens + 1) + conv_state_shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.mamba_cache = self.SpeculativeState(
|
||||
conv=conv_states,
|
||||
temporal=ssm_states,
|
||||
intermediate_ssm=intermediate_ssm_states,
|
||||
intermediate_conv_window=intermediate_conv_window_cache,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"max_mamba_cache_size: {max_batch_size}, "
|
||||
f"conv_state size: {get_tensor_size_bytes(conv_states) / GB:.2f}GB, "
|
||||
f"ssm_state size: {get_tensor_size_bytes(ssm_states) / GB:.2f}GB, "
|
||||
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_states) / GB:.2f}GB, "
|
||||
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB"
|
||||
)
|
||||
else:
|
||||
self.mamba_cache = self.State(
|
||||
conv=conv_states,
|
||||
temporal=ssm_states,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"max_mamba_cache_size: {max_batch_size}, "
|
||||
f"conv_state size: {get_tensor_size_bytes(conv_states) / GB:.2f}GB, "
|
||||
f"ssm_state size: {get_tensor_size_bytes(ssm_states) / GB:.2f}GB"
|
||||
)
|
||||
|
||||
# mamba cache available blocks
|
||||
self.mamba_cache_free_blocks = [i for i in range(max_batch_size)]
|
||||
|
||||
@ -111,6 +188,10 @@ class MambaCacheManager(BaseResourceManager):
|
||||
dtype=torch.int32)
|
||||
# save mamba state indices for requests
|
||||
self.state_indices_list: List[int] = []
|
||||
# save intermediate state indices for requests
|
||||
self.intermediate_state_indices = torch.arange(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
|
||||
self.state_indices_list.clear()
|
||||
@ -184,22 +265,85 @@ class MambaCacheManager(BaseResourceManager):
|
||||
|
||||
def get_conv_states(self, layer_idx: int) -> torch.Tensor:
|
||||
layer_offset = self.mamba_layer_offsets[layer_idx]
|
||||
return self.conv_states[layer_offset]
|
||||
return self.mamba_cache.at_layer_idx(layer_offset).conv
|
||||
|
||||
def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
|
||||
layer_offset = self.mamba_layer_offsets[layer_idx]
|
||||
return self.ssm_states[layer_offset]
|
||||
return self.mamba_cache.at_layer_idx(layer_offset).temporal
|
||||
|
||||
def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
|
||||
return self.mamba_ssm_cache_dtype
|
||||
def get_intermediate_ssm_states(self,
|
||||
layer_idx: int) -> Optional[torch.Tensor]:
|
||||
if not isinstance(self.mamba_cache, self.SpeculativeState):
|
||||
return None
|
||||
layer_offset = self.mamba_layer_offsets[layer_idx]
|
||||
return self.mamba_cache.at_layer_idx(layer_offset).intermediate_ssm
|
||||
|
||||
def get_intermediate_conv_states(self,
|
||||
layer_idx: int) -> Optional[torch.Tensor]:
|
||||
"""Get intermediate conv states for speculative decoding."""
|
||||
if not isinstance(self.mamba_cache, self.SpeculativeState):
|
||||
return None
|
||||
layer_offset = self.mamba_layer_offsets[layer_idx]
|
||||
return self.mamba_cache.at_layer_idx(
|
||||
layer_offset).intermediate_conv_window
|
||||
|
||||
def is_speculative(self) -> bool:
|
||||
return isinstance(self.mamba_cache, self.SpeculativeState)
|
||||
|
||||
def mamba_layer_cache(self,
|
||||
layer_idx: int) -> Union[State, SpeculativeState]:
|
||||
layer_offset = self.mamba_layer_offsets[layer_idx]
|
||||
return self.mamba_cache.at_layer_idx(layer_offset)
|
||||
|
||||
def shutdown(self):
|
||||
# release tensor memory, keeping python references as tensors
|
||||
self.conv_states = torch.tensor([])
|
||||
self.ssm_states = torch.tensor([])
|
||||
"""Release tensor memory."""
|
||||
# Clear state indices
|
||||
self.state_indices = torch.tensor([])
|
||||
|
||||
# Clear mamba cache states
|
||||
if isinstance(self.mamba_cache, self.SpeculativeState):
|
||||
self.mamba_cache = self.SpeculativeState(
|
||||
conv=torch.tensor([]),
|
||||
temporal=torch.tensor([]),
|
||||
intermediate_ssm=torch.tensor([]),
|
||||
intermediate_conv_window=torch.tensor([]),
|
||||
)
|
||||
else:
|
||||
self.mamba_cache = self.State(
|
||||
conv=torch.tensor([]),
|
||||
temporal=torch.tensor([]),
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def update_mamba_states(self, attn_metadata: "AttentionMetadata",
|
||||
num_accepted_tokens: torch.Tensor):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
num_gens = batch_size - num_contexts
|
||||
num_accepted_draft_tokens = num_accepted_tokens[
|
||||
num_contexts:num_contexts + num_gens] - 1
|
||||
state_indices_d = self.state_indices[num_contexts:num_contexts +
|
||||
num_gens]
|
||||
|
||||
conv_states = self.mamba_cache.conv
|
||||
ssm_states = self.mamba_cache.temporal
|
||||
|
||||
intermediate_state_cache = self.mamba_cache.intermediate_ssm
|
||||
intermediate_conv_window_cache = self.mamba_cache.intermediate_conv_window
|
||||
|
||||
src_state_indices = self.intermediate_state_indices[:num_gens]
|
||||
|
||||
accepted_ssm_state = intermediate_state_cache[:, src_state_indices,
|
||||
num_accepted_draft_tokens]
|
||||
ssm_states[:, state_indices_d, :] = accepted_ssm_state
|
||||
|
||||
accepted_conv_state = intermediate_conv_window_cache[:,
|
||||
src_state_indices,
|
||||
num_accepted_draft_tokens]
|
||||
conv_states[:, state_indices_d, :] = accepted_conv_state
|
||||
|
||||
|
||||
class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
|
||||
@ -249,10 +393,13 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
mamba_head_dim,
|
||||
mamba_num_layers,
|
||||
max_batch_size,
|
||||
max_batch_size,
|
||||
mapping,
|
||||
mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype,
|
||||
mamba_layer_mask,
|
||||
speculative_num_draft_tokens=spec_config.max_draft_len
|
||||
if spec_config is not None else None,
|
||||
)
|
||||
|
||||
# initialize kv cache manager
|
||||
@ -285,3 +432,15 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
def shutdown(self):
|
||||
MambaCacheManager.shutdown(self)
|
||||
KVCacheManager.shutdown(self)
|
||||
|
||||
def update_resources(self,
|
||||
scheduled_batch: ScheduledRequests,
|
||||
attn_metadata: "AttentionMetadata" = None,
|
||||
kv_cache_dtype_byte_size: float = None):
|
||||
KVCacheManager.update_resources(self, scheduled_batch, attn_metadata,
|
||||
kv_cache_dtype_byte_size)
|
||||
|
||||
def update_mamba_states(self, attn_metadata: "AttentionMetadata",
|
||||
num_accepted_tokens: torch.Tensor):
|
||||
MambaCacheManager.update_mamba_states(self, attn_metadata,
|
||||
num_accepted_tokens)
|
||||
|
||||
@ -1469,10 +1469,12 @@ class ResourceManager:
|
||||
resource_manager.prepare_resources(scheduled_batch)
|
||||
|
||||
@nvtx_range("update_resources")
|
||||
def update_resources(self,
|
||||
scheduled_batch: ScheduledRequests,
|
||||
attn_metadata: Optional["AttentionMetadata"] = None,
|
||||
kv_cache_dtype_byte_size: Optional[float] = None):
|
||||
def update_resources(
|
||||
self,
|
||||
scheduled_batch: ScheduledRequests,
|
||||
attn_metadata: Optional["AttentionMetadata"] = None,
|
||||
kv_cache_dtype_byte_size: Optional[float] = None,
|
||||
):
|
||||
for _, resource_manager in self.resource_managers.items():
|
||||
if hasattr(resource_manager, "update_resources"):
|
||||
if isinstance(resource_manager, KVCacheManager):
|
||||
|
||||
@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import \
|
||||
MambaHybridCacheManager
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -1132,6 +1134,7 @@ class MTPEagleWorker(MTPWorker):
|
||||
super().__init__(spec_config, model_config)
|
||||
self.model_config = model_config
|
||||
self.mtp_num_modules = spec_config.num_nextn_predict_layers
|
||||
self._is_mamba_hybrid_cache = None
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def update_draft_tokens(self, next_draft_tokens, new_draft_token,
|
||||
@ -1164,6 +1167,14 @@ class MTPEagleWorker(MTPWorker):
|
||||
accepted_tokens, num_accepted_tokens = self.sample_and_accept_draft_tokens(
|
||||
input_ids, logits, spec_metadata, attn_metadata)
|
||||
|
||||
if self._is_mamba_hybrid_cache is None:
|
||||
self._is_mamba_hybrid_cache = isinstance(
|
||||
attn_metadata.kv_cache_manager, MambaHybridCacheManager)
|
||||
if num_gens > 0 and self._is_mamba_hybrid_cache:
|
||||
attn_metadata.kv_cache_manager.update_mamba_states(
|
||||
attn_metadata=attn_metadata,
|
||||
num_accepted_tokens=num_accepted_tokens)
|
||||
|
||||
# Save the old attn_metadata and spec_metadata
|
||||
self._prepare_attn_metadata_for_spec_dec(attn_metadata)
|
||||
|
||||
|
||||
@ -348,6 +348,11 @@ nvidia/Nemotron-Super-V3:
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 80.85
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
mtp_enabled: true
|
||||
num_nextn_predict_layers: 3
|
||||
accuracy: 80.85
|
||||
nvidia/Nemotron-3-Nano:
|
||||
- accuracy: 69.37
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -382,6 +382,11 @@ nvidia/Nemotron-Super-V3:
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 77.56
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
mtp_enabled: true
|
||||
num_nextn_predict_layers: 3
|
||||
accuracy: 77.56
|
||||
nvidia/Nemotron-3-Nano:
|
||||
- accuracy: 73.85
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -5800,6 +5800,42 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
|
||||
|
||||
@pytest.mark.skip(reason="Skip MTP test due to no model path file in CI")
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
def test_nvfp4_8gpus_mtp(self):
|
||||
# Test MTP (Multi-Token Prediction) accuracy with nvfp4-fp8kv model.
|
||||
# This test uses MTP with max_draft_len=3 and one_model mode.
|
||||
mtp_config = MTPDecodingConfig(
|
||||
num_nextn_predict_layers=3,
|
||||
mtp_eagle_one_model=True,
|
||||
)
|
||||
model_path = f"{llm_models_root()}/nemotron-super-sft-repeated-mtp-iter-0010600-nvfp4-fp8kv"
|
||||
with LLM(
|
||||
model_path,
|
||||
kv_cache_config=KvCacheConfig(
|
||||
enable_block_reuse=False,
|
||||
mamba_ssm_cache_dtype="float16",
|
||||
free_gpu_memory_fraction=0.5,
|
||||
),
|
||||
max_batch_size=128,
|
||||
tensor_parallel_size=8,
|
||||
moe_expert_parallel_size=8,
|
||||
pipeline_parallel_size=1,
|
||||
enable_attention_dp=False,
|
||||
cuda_graph_config=CudaGraphConfig(max_batch_size=32,
|
||||
enable_padding=True),
|
||||
disable_overlap_scheduler=False,
|
||||
moe_config=MoeConfig(backend="CUTLASS"),
|
||||
decoding_config=mtp_config,
|
||||
) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
class TestMiniMaxM2(LlmapiAccuracyTestHarness):
|
||||
|
||||
@ -214,7 +214,8 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
|
||||
assert remove_padding
|
||||
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size)
|
||||
out, ssm_state = mamba_chunk_scan_combined(
|
||||
out = torch.empty_like(x)
|
||||
ssm_state = mamba_chunk_scan_combined(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
@ -232,6 +233,7 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
|
||||
dt_softplus=delta_softplus,
|
||||
return_final_states=not remove_padding,
|
||||
return_varlen_states=remove_padding,
|
||||
out=out,
|
||||
)
|
||||
|
||||
if (ssm_state.shape[0] > 1 and ssm_state.dtype == torch.float32
|
||||
@ -257,7 +259,8 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
|
||||
else:
|
||||
state_batch_indices = None
|
||||
|
||||
y = selective_state_update(
|
||||
y = torch.empty_like(x)
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
@ -269,6 +272,7 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=delta_softplus,
|
||||
state_batch_indices=state_batch_indices,
|
||||
out=y,
|
||||
)
|
||||
outputs = (y, state[state_batch_indices]
|
||||
if state_batch_indices is not None else state)
|
||||
@ -432,7 +436,8 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
|
||||
z = torch.randn_like(x)
|
||||
|
||||
## full seqlen computation
|
||||
out_ref, state_ref = mamba_chunk_scan_combined(
|
||||
out_ref = torch.empty_like(x)
|
||||
state_ref = mamba_chunk_scan_combined(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
@ -447,6 +452,7 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
|
||||
dt_softplus=delta_softplus,
|
||||
return_final_states=False,
|
||||
return_varlen_states=True,
|
||||
out=out_ref,
|
||||
)
|
||||
|
||||
## chunked seqlen computation
|
||||
@ -478,7 +484,8 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
|
||||
z_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(z, i)
|
||||
# yapf: enable
|
||||
|
||||
partial_out, partial_state = mamba_chunk_scan_combined(
|
||||
partial_out = torch.empty_like(x_chunked)
|
||||
partial_state = mamba_chunk_scan_combined(
|
||||
x_chunked,
|
||||
dt_chunked,
|
||||
A,
|
||||
@ -493,6 +500,7 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
|
||||
dt_softplus=delta_softplus,
|
||||
return_final_states=False,
|
||||
return_varlen_states=True,
|
||||
out=partial_out,
|
||||
)
|
||||
|
||||
# remaining chunk
|
||||
@ -542,7 +550,8 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
|
||||
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
|
||||
remaining_chunked_cu_seqlens, mamba_chunk_size)
|
||||
|
||||
out_chunked, state_chunked = mamba_chunk_scan_combined(
|
||||
out_chunked = torch.empty_like(remaining_x_chunked)
|
||||
state_chunked = mamba_chunk_scan_combined(
|
||||
remaining_x_chunked,
|
||||
remaining_dt_chunked,
|
||||
A,
|
||||
@ -560,6 +569,7 @@ def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens):
|
||||
dt_softplus=delta_softplus,
|
||||
return_final_states=False,
|
||||
return_varlen_states=True,
|
||||
out=out_chunked,
|
||||
)
|
||||
out = concat_batch_f(partial_out, out_chunked)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user