mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][feat] AutoDeploy: Flashinfer kernels bringup (#10867)
Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>
This commit is contained in:
parent
0ad87895f5
commit
e033929221
@ -113,7 +113,7 @@ class _FlashInferPlanner:
|
||||
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
self.kv_layout,
|
||||
backend="fa2",
|
||||
backend="fa2" if torch.cuda.get_device_capability(0) == (9, 0) else "auto",
|
||||
)
|
||||
self.decode_wrapper = self._init_decode_wrapper()
|
||||
|
||||
|
||||
@ -0,0 +1,246 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ..attention_interface import (
|
||||
AttentionRegistry,
|
||||
MHACallable,
|
||||
ResourceHandler,
|
||||
ResourceHandlerDict,
|
||||
SequenceInfo,
|
||||
)
|
||||
from .mamba_backend_common import (
|
||||
BaseBackendSSM,
|
||||
_flatten_ssm_inputs,
|
||||
_prepare_ssm_decode_inputs,
|
||||
_run_ssm_prefill,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_cached_ssm", mutates_args={})
|
||||
def _flashinfer_cached_ssm(
|
||||
# INPUTS (dense but may be flattened across sequences)
|
||||
hidden_states: torch.Tensor, # [b, s, num_heads, head_dim]
|
||||
A: torch.Tensor, # [num_heads]
|
||||
B: torch.Tensor, # [b, s, n_groups, ssm_state_size]
|
||||
C: torch.Tensor, # [b, s, n_groups, ssm_state_size]
|
||||
D: torch.Tensor, # [num_heads]
|
||||
dt: torch.Tensor, # [b, s, num_heads]
|
||||
dt_bias: torch.Tensor, # [num_heads]
|
||||
# STANDARD METADATA
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
chunk_indices: torch.Tensor, # [num_logical_chunks]
|
||||
chunk_offsets: torch.Tensor, # [num_logical_chunks]
|
||||
seq_idx_prefill: torch.Tensor, # [1, num_prefill_tokens]
|
||||
# CACHES
|
||||
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
|
||||
# CONSTANTS
|
||||
time_step_limit: List[float],
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
b, s, num_heads, head_dim, bs, hs_flat, B_flat, C_flat, dt_flat = _flatten_ssm_inputs(
|
||||
hidden_states, B, C, dt
|
||||
)
|
||||
ssm_state_size = B.shape[3]
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[bs, num_heads, head_dim],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens]
|
||||
|
||||
num_prefill, num_prefill_tokens, num_total_tokens, num_seq = _run_ssm_prefill(
|
||||
hs_flat,
|
||||
B_flat,
|
||||
C_flat,
|
||||
dt_flat,
|
||||
A,
|
||||
D,
|
||||
dt_bias,
|
||||
batch_info_host,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
chunk_indices,
|
||||
chunk_offsets,
|
||||
seq_idx_prefill,
|
||||
ssm_state_cache,
|
||||
time_step_limit,
|
||||
chunk_size,
|
||||
preallocated_ssm_out_p.unsqueeze(0),
|
||||
)
|
||||
|
||||
num_decode = num_total_tokens - num_prefill_tokens
|
||||
decode_inputs = _prepare_ssm_decode_inputs(
|
||||
hs_flat,
|
||||
B_flat,
|
||||
C_flat,
|
||||
dt_flat,
|
||||
A,
|
||||
D,
|
||||
dt_bias,
|
||||
slot_idx,
|
||||
num_prefill,
|
||||
num_prefill_tokens,
|
||||
num_seq,
|
||||
num_total_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
ssm_state_size,
|
||||
)
|
||||
|
||||
y_decode = None
|
||||
if decode_inputs is not None:
|
||||
(
|
||||
slot_idx_decode,
|
||||
x_decode,
|
||||
B_decode,
|
||||
C_decode,
|
||||
dt_hp,
|
||||
dt_bias_hp,
|
||||
A_full,
|
||||
D_full,
|
||||
) = decode_inputs
|
||||
|
||||
import flashinfer
|
||||
|
||||
slot_idx_decode_i32 = slot_idx_decode.to(torch.int32)
|
||||
y_decode = flashinfer.mamba.selective_state_update(
|
||||
ssm_state_cache,
|
||||
x_decode,
|
||||
dt_hp,
|
||||
A_full,
|
||||
B_decode,
|
||||
C_decode,
|
||||
D=D_full,
|
||||
z=None,
|
||||
dt_bias=dt_bias_hp,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=slot_idx_decode_i32,
|
||||
)
|
||||
preallocated_ssm_out[num_prefill_tokens:num_total_tokens].copy_(y_decode)
|
||||
if num_total_tokens > 0:
|
||||
return (
|
||||
preallocated_ssm_out[:num_total_tokens]
|
||||
.view(b, s, num_heads, head_dim)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
else:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@_flashinfer_cached_ssm.register_fake
|
||||
def _flashinfer_cached_ssm_fake(
|
||||
# INPUTS (dense but may be flattened across sequences)
|
||||
hidden_states: torch.Tensor, # [b, s, num_heads, head_dim]
|
||||
A: torch.Tensor, # [num_heads]
|
||||
B: torch.Tensor, # [b, s, n_groups, ssm_state_size]
|
||||
C: torch.Tensor, # [b, s, n_groups, ssm_state_size]
|
||||
D: torch.Tensor, # [num_heads]
|
||||
dt: torch.Tensor, # [b, s, num_heads]
|
||||
dt_bias: torch.Tensor, # [num_heads]
|
||||
# STANDARD METADATA
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
# EXTRA METADATA
|
||||
chunk_indices: torch.Tensor, # [num_logical_chunks]
|
||||
chunk_offsets: torch.Tensor, # [num_logical_chunks]
|
||||
seq_idx_prefill: torch.Tensor, # [1, num_prefill_tokens]
|
||||
# CACHES
|
||||
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
|
||||
# CONSTANTS
|
||||
time_step_limit: List[float],
|
||||
chunk_size: int,
|
||||
):
|
||||
# Return a correctly-shaped tensor for tracing with fake tensors
|
||||
return torch.empty_like(
|
||||
hidden_states,
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
|
||||
# Flashinfer's selective_state_update kernel only supports these head dimensions
|
||||
FLASHINFER_SUPPORTED_HEAD_DIMS = [64, 128]
|
||||
|
||||
|
||||
class FlashInferStateResourceHandler(ResourceHandler):
|
||||
"""Handler for flashinfer SSM state resources.
|
||||
|
||||
Unlike the default StateResourceHandler which uses byte-level pooling (resulting
|
||||
in non-contiguous strided views), this handler allocates a separate contiguous
|
||||
buffer. This is required because flashinfer's selective_state_update kernel
|
||||
requires the entire state tensor to be contiguous.
|
||||
"""
|
||||
|
||||
def __init__(self, *state_shape: int, dtype: torch.dtype) -> None:
|
||||
self.state_shape = state_shape
|
||||
self.dtype = dtype
|
||||
|
||||
def allocate(self, sequence_info: SequenceInfo) -> torch.Tensor:
|
||||
"""Allocate a contiguous state buffer for flashinfer."""
|
||||
return torch.empty(
|
||||
sequence_info.max_num_state_slots,
|
||||
*self.state_shape,
|
||||
device=sequence_info.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
@AttentionRegistry.register("flashinfer_ssm")
|
||||
class FlashinferBackendSSM(BaseBackendSSM):
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return torch.ops.auto_deploy.flashinfer_cached_ssm.default
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
"""Get cache initializers using FlashInferStateResourceHandler.
|
||||
|
||||
We use a custom handler that allocates contiguous buffers directly,
|
||||
instead of the default StateResourceHandler which creates non-contiguous
|
||||
views from a shared byte buffer. This is required because flashinfer's
|
||||
selective_state_update kernel requires contiguous state tensors.
|
||||
"""
|
||||
# Shapes from fake tensors
|
||||
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
|
||||
|
||||
num_heads = hs_fake.shape[-2]
|
||||
head_dim = hs_fake.shape[-1]
|
||||
|
||||
# Validate head_dim is supported by flashinfer
|
||||
if head_dim not in FLASHINFER_SUPPORTED_HEAD_DIMS:
|
||||
raise ValueError(
|
||||
f"Flashinfer SSM backend only supports head_dim in {FLASHINFER_SUPPORTED_HEAD_DIMS}, "
|
||||
f"but got head_dim={head_dim}. Consider using 'triton_ssm' backend instead."
|
||||
)
|
||||
|
||||
if B_fake.ndim >= 4:
|
||||
ssm_state_size = B_fake.shape[-1]
|
||||
else:
|
||||
ssm_state_size = max(1, B_fake.shape[-1])
|
||||
|
||||
# Extract ssm_state_dtype from cache_config or hs_fake
|
||||
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
|
||||
|
||||
return {
|
||||
"ssm_state_cache": FlashInferStateResourceHandler(
|
||||
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
|
||||
)
|
||||
}
|
||||
@ -0,0 +1,291 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets
|
||||
from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
Constant,
|
||||
PrepareMetadataCallable,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::mamba_ssm_prepare_metadata", mutates_args=())
|
||||
def _mamba_ssm_prepare_metadata(
|
||||
# INPUTS
|
||||
position_ids: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for cached SSM transform.
|
||||
|
||||
Returns a tuple of (chunk_indices, chunk_offsets, seq_idx_prefill).
|
||||
"""
|
||||
device = cu_seqlen.device
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
|
||||
if num_prefill > 0:
|
||||
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
|
||||
cu_seqlen[: num_prefill + 1], chunk_size
|
||||
)
|
||||
seq_idx_prefill = torch.repeat_interleave(
|
||||
torch.arange(num_prefill, device=device, dtype=torch.int32), seq_len[:num_prefill]
|
||||
).view(1, -1)
|
||||
else:
|
||||
chunk_indices = torch.empty(0, dtype=torch.int32, device=device)
|
||||
chunk_offsets = torch.empty(0, dtype=torch.int32, device=device)
|
||||
seq_idx_prefill = torch.empty(1, 0, dtype=torch.int32, device=device)
|
||||
|
||||
return (chunk_indices, chunk_offsets, seq_idx_prefill)
|
||||
|
||||
|
||||
@_mamba_ssm_prepare_metadata.register_fake
|
||||
def _mamba_ssm_prepare_metadata_fake(
|
||||
# INPUTS
|
||||
position_ids: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
|
||||
chunk_size: int,
|
||||
):
|
||||
b, s = position_ids.shape[:2]
|
||||
num_tokens = b * s
|
||||
device = cu_seqlen.device
|
||||
dtype = torch.int32
|
||||
if s > 1:
|
||||
# NOTE: this is only an upper bound for the shape in this case...
|
||||
return (
|
||||
torch.empty(num_tokens, dtype=dtype, device=device), # chunk_indices
|
||||
torch.empty(num_tokens, dtype=dtype, device=device), # chunk_offsets
|
||||
torch.empty(1, num_tokens, dtype=dtype, device=device), # seq_idx_prefill
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch.empty(0, dtype=dtype, device=device), # chunk_indices
|
||||
torch.empty(0, dtype=dtype, device=device), # chunk_offsets
|
||||
torch.empty(1, 0, dtype=dtype, device=device), # seq_idx_prefill
|
||||
)
|
||||
|
||||
|
||||
def _flatten_ssm_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
) -> Tuple[int, int, int, int, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
b, s, num_heads, head_dim = hidden_states.shape
|
||||
bs = b * s
|
||||
hs_flat = hidden_states.reshape(bs, *hidden_states.shape[2:]) # [bs, H, D]
|
||||
B_flat = B.reshape(bs, *B.shape[2:]) # [bs, G, N]
|
||||
C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N]
|
||||
dt_flat = dt.reshape(bs, dt.shape[2]) # [bs, H]
|
||||
return b, s, num_heads, head_dim, bs, hs_flat, B_flat, C_flat, dt_flat
|
||||
|
||||
|
||||
def _run_ssm_prefill(
|
||||
hs_flat: torch.Tensor,
|
||||
B_flat: torch.Tensor,
|
||||
C_flat: torch.Tensor,
|
||||
dt_flat: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
D: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
use_initial_states: torch.Tensor,
|
||||
chunk_indices: torch.Tensor,
|
||||
chunk_offsets: torch.Tensor,
|
||||
seq_idx_prefill: torch.Tensor,
|
||||
ssm_state_cache: torch.Tensor,
|
||||
time_step_limit: List[float],
|
||||
chunk_size: int,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[Optional[torch.Tensor], int, int, int, int]:
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
if num_prefill <= 0:
|
||||
return num_prefill, num_prefill_tokens, num_total_tokens, num_seq
|
||||
|
||||
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
|
||||
B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
|
||||
C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
|
||||
dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H]
|
||||
|
||||
initial_states = None
|
||||
if torch.any(use_initial_states[:num_prefill]):
|
||||
initial_states = torch.where(
|
||||
use_initial_states[:num_prefill, None, None, None],
|
||||
ssm_state_cache[slot_idx[:num_prefill]],
|
||||
0,
|
||||
)
|
||||
else:
|
||||
chunk_indices = None
|
||||
chunk_offsets = None
|
||||
|
||||
varlen_states = mamba_chunk_scan_combined(
|
||||
hs_prefill,
|
||||
dt_prefill,
|
||||
A,
|
||||
B_prefill,
|
||||
C_prefill,
|
||||
chunk_size=chunk_size,
|
||||
D=D,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
seq_idx=seq_idx_prefill,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
cu_seqlens=cu_seqlen[: num_prefill + 1],
|
||||
dt_softplus=True,
|
||||
dt_limit=(time_step_limit[0], time_step_limit[1]),
|
||||
return_final_states=False,
|
||||
return_varlen_states=True,
|
||||
state_dtype=ssm_state_cache.dtype,
|
||||
out=out,
|
||||
)
|
||||
|
||||
ssm_state_cache.index_copy_(
|
||||
0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype)
|
||||
)
|
||||
return num_prefill, num_prefill_tokens, num_total_tokens, num_seq
|
||||
|
||||
|
||||
def _prepare_ssm_decode_inputs(
|
||||
hs_flat: torch.Tensor,
|
||||
B_flat: torch.Tensor,
|
||||
C_flat: torch.Tensor,
|
||||
dt_flat: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
D: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
num_prefill: int,
|
||||
num_prefill_tokens: int,
|
||||
num_seq: int,
|
||||
num_total_tokens: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
ssm_state_size: int,
|
||||
) -> Optional[
|
||||
Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]
|
||||
]:
|
||||
num_decode = num_total_tokens - num_prefill_tokens
|
||||
if num_decode <= 0:
|
||||
return None
|
||||
|
||||
slot_idx_decode = slot_idx[num_prefill:num_seq]
|
||||
x_decode = hs_flat[num_prefill_tokens:num_total_tokens] # [nd, H, D]
|
||||
B_decode = B_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N]
|
||||
C_decode = C_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N]
|
||||
dt_decode = dt_flat[num_prefill_tokens:num_total_tokens] # [nd, H]
|
||||
|
||||
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
|
||||
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
|
||||
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
|
||||
D_full = D[..., None].expand(num_heads, head_dim)
|
||||
|
||||
return slot_idx_decode, x_decode, B_decode, C_decode, dt_hp, dt_bias_hp, A_full, D_full
|
||||
|
||||
|
||||
class BaseBackendSSM(AttentionDescriptor):
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
# Hidden states follow [b, s, n, d]
|
||||
return "bsnd"
|
||||
|
||||
@classmethod
|
||||
def get_num_qkv_args(cls) -> int:
|
||||
# torch_ssm_transform signature has 7 node/state arguments
|
||||
return 7
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
# Keep source op unchanged (used for uncached pre-export)
|
||||
return torch.ops.auto_deploy.torch_ssm
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_prepare_extra_metadata_info(
|
||||
cls, any_source_attn_node: Node
|
||||
) -> Tuple[PrepareMetadataCallable, int, List[Constant]]:
|
||||
return (
|
||||
torch.ops.auto_deploy.mamba_ssm_prepare_metadata.default,
|
||||
3, # chunk_indices, chunk_offsets, seq_idx_prefill
|
||||
extract_op_args(any_source_attn_node, "chunk_size"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
# Shapes from fake tensors
|
||||
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
|
||||
|
||||
num_heads = hs_fake.shape[-2]
|
||||
head_dim = hs_fake.shape[-1]
|
||||
|
||||
if B_fake.ndim >= 4:
|
||||
ssm_state_size = B_fake.shape[-1]
|
||||
else:
|
||||
ssm_state_size = max(1, B_fake.shape[-1])
|
||||
|
||||
# extract ssm_state_dtype from cache_config or hs_fake
|
||||
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
|
||||
|
||||
return {
|
||||
"ssm_state_cache": StateResourceHandler(
|
||||
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
time_step_limit, chunk_size = extract_op_args(
|
||||
source_attn_node, "time_step_limit", "chunk_size"
|
||||
)
|
||||
return [time_step_limit, chunk_size]
|
||||
@ -13,92 +13,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverloadPacket
|
||||
from torch.fx import Node
|
||||
|
||||
# Triton kernels
|
||||
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import cu_seqlens_to_chunk_indices_offsets
|
||||
from tensorrt_llm._torch.modules.mamba.selective_state_update import selective_state_update
|
||||
from tensorrt_llm._torch.modules.mamba.ssd_combined import mamba_chunk_scan_combined
|
||||
|
||||
from .....llmapi.llm_args import KvCacheConfig
|
||||
from ...utils.node_utils import extract_op_args
|
||||
from ..attention_interface import (
|
||||
AttentionDescriptor,
|
||||
AttentionLayout,
|
||||
AttentionRegistry,
|
||||
Constant,
|
||||
MHACallable,
|
||||
PrepareMetadataCallable,
|
||||
ResourceHandlerDict,
|
||||
StateResourceHandler,
|
||||
from ..attention_interface import AttentionRegistry, MHACallable
|
||||
from .mamba_backend_common import (
|
||||
BaseBackendSSM,
|
||||
_flatten_ssm_inputs,
|
||||
_prepare_ssm_decode_inputs,
|
||||
_run_ssm_prefill,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_ssm_prepare_metadata", mutates_args=())
|
||||
def _triton_ssm_prepare_metadata(
|
||||
# INPUTS
|
||||
position_ids: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for cached SSM transform.
|
||||
|
||||
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
|
||||
"""
|
||||
device = cu_seqlen.device
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
|
||||
if num_prefill > 0:
|
||||
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
|
||||
cu_seqlen[: num_prefill + 1], chunk_size
|
||||
)
|
||||
seq_idx_prefill = torch.repeat_interleave(
|
||||
torch.arange(num_prefill, device=device, dtype=torch.int32), seq_len[:num_prefill]
|
||||
).view(1, -1)
|
||||
else:
|
||||
chunk_indices = torch.empty(0, dtype=torch.int32, device=device)
|
||||
chunk_offsets = torch.empty(0, dtype=torch.int32, device=device)
|
||||
seq_idx_prefill = torch.empty(1, 0, dtype=torch.int32, device=device)
|
||||
|
||||
return (chunk_indices, chunk_offsets, seq_idx_prefill)
|
||||
|
||||
|
||||
@_triton_ssm_prepare_metadata.register_fake
|
||||
def _triton_ssm_prepare_metadata_fake(
|
||||
# INPUTS
|
||||
position_ids: torch.Tensor,
|
||||
batch_info_host: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
# EXTRA METADATA PROVIDED BY THE DESCRIPTOR
|
||||
chunk_size: int,
|
||||
):
|
||||
b, s = position_ids.shape[:2]
|
||||
num_tokens = b * s
|
||||
device = cu_seqlen.device
|
||||
dtype = torch.int32
|
||||
if s > 1:
|
||||
# NOTE: this is only an upper bound for the shape in this case...
|
||||
return (
|
||||
torch.empty(num_tokens, dtype=dtype, device=device), # chunk_indices
|
||||
torch.empty(num_tokens, dtype=dtype, device=device), # chunk_offsets
|
||||
torch.empty(1, num_tokens, dtype=dtype, device=device), # seq_idx_prefill
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch.empty(0, dtype=dtype, device=device), # chunk_indices
|
||||
torch.empty(0, dtype=dtype, device=device), # chunk_offsets
|
||||
torch.empty(1, 0, dtype=dtype, device=device), # seq_idx_prefill
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={})
|
||||
def _triton_cached_ssm(
|
||||
# INPUTS (dense but may be flattened across sequences)
|
||||
@ -124,26 +53,10 @@ def _triton_cached_ssm(
|
||||
time_step_limit: List[float],
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""Flattened cached SSM transform op that respects slot-indexed state caches.
|
||||
|
||||
Split mixed batches into prefill (seq_len>1) and decode (seq_len==1):
|
||||
- Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot.
|
||||
- Decode: batch single-token updates with selective_state_update and update states per slot.
|
||||
"""
|
||||
b, s, num_heads, head_dim = hidden_states.shape
|
||||
# Flatten tokens for indexing/scatter
|
||||
bs = b * s
|
||||
hs_flat = hidden_states.reshape(bs, *hidden_states.shape[2:]) # [bs, H, D]
|
||||
B_flat = B.reshape(bs, *B.shape[2:]) # [bs, G, N]
|
||||
C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N]
|
||||
dt_flat = dt.reshape(bs, dt.shape[2]) # [bs, H]
|
||||
|
||||
b, s, num_heads, head_dim, bs, hs_flat, B_flat, C_flat, dt_flat = _flatten_ssm_inputs(
|
||||
hidden_states, B, C, dt
|
||||
)
|
||||
ssm_state_size = B.shape[3]
|
||||
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
@ -151,68 +64,63 @@ def _triton_cached_ssm(
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
|
||||
num_seq = num_prefill + num_decode
|
||||
num_total_tokens = num_prefill_tokens + num_decode
|
||||
preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens]
|
||||
preallocated_ssm_out_d = preallocated_ssm_out[num_prefill_tokens:num_total_tokens]
|
||||
|
||||
# Prefill: concatenate tokens at the front and run combined scan
|
||||
if num_prefill > 0:
|
||||
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
|
||||
B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
|
||||
C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
|
||||
dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H]
|
||||
num_prefill, num_prefill_tokens, num_total_tokens, num_seq = _run_ssm_prefill(
|
||||
hs_flat,
|
||||
B_flat,
|
||||
C_flat,
|
||||
dt_flat,
|
||||
A,
|
||||
D,
|
||||
dt_bias,
|
||||
batch_info_host,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
chunk_indices,
|
||||
chunk_offsets,
|
||||
seq_idx_prefill,
|
||||
ssm_state_cache,
|
||||
time_step_limit,
|
||||
chunk_size,
|
||||
preallocated_ssm_out_p.unsqueeze(0),
|
||||
)
|
||||
|
||||
initial_states = None
|
||||
if torch.any(use_initial_states[:num_prefill]):
|
||||
initial_states = torch.where(
|
||||
use_initial_states[:num_prefill, None, None, None],
|
||||
ssm_state_cache[slot_idx[:num_prefill]],
|
||||
0,
|
||||
)
|
||||
else:
|
||||
chunk_indices = None
|
||||
chunk_offsets = None
|
||||
|
||||
varlen_states = mamba_chunk_scan_combined(
|
||||
hs_prefill,
|
||||
dt_prefill,
|
||||
A,
|
||||
B_prefill,
|
||||
C_prefill,
|
||||
chunk_size=chunk_size,
|
||||
D=D,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
seq_idx=seq_idx_prefill,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
cu_seqlens=cu_seqlen[: num_prefill + 1],
|
||||
dt_softplus=True,
|
||||
dt_limit=(time_step_limit[0], time_step_limit[1]),
|
||||
return_final_states=False,
|
||||
return_varlen_states=True,
|
||||
out=preallocated_ssm_out_p.unsqueeze(0),
|
||||
state_dtype=ssm_state_cache.dtype,
|
||||
)
|
||||
|
||||
ssm_state_cache.index_copy_(
|
||||
0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype)
|
||||
)
|
||||
|
||||
# Decode: batch single-token updates via selective_state_update
|
||||
if num_decode > 0:
|
||||
slot_idx_decode = slot_idx[num_prefill:num_seq]
|
||||
|
||||
x_decode = hs_flat[num_prefill_tokens:num_total_tokens] # [nd, H, D]
|
||||
B_decode = B_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N]
|
||||
C_decode = C_flat[num_prefill_tokens:num_total_tokens] # [nd, G, N]
|
||||
dt_decode = dt_flat[num_prefill_tokens:num_total_tokens] # [nd, H]
|
||||
|
||||
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
|
||||
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
|
||||
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
|
||||
D_full = D[..., None].expand(num_heads, head_dim)
|
||||
num_decode = num_total_tokens - num_prefill_tokens
|
||||
decode_inputs = _prepare_ssm_decode_inputs(
|
||||
hs_flat,
|
||||
B_flat,
|
||||
C_flat,
|
||||
dt_flat,
|
||||
A,
|
||||
D,
|
||||
dt_bias,
|
||||
slot_idx,
|
||||
num_prefill,
|
||||
num_prefill_tokens,
|
||||
num_seq,
|
||||
num_total_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
ssm_state_size,
|
||||
)
|
||||
|
||||
if decode_inputs is not None:
|
||||
(
|
||||
slot_idx_decode,
|
||||
x_decode,
|
||||
B_decode,
|
||||
C_decode,
|
||||
dt_hp,
|
||||
dt_bias_hp,
|
||||
A_full,
|
||||
D_full,
|
||||
) = decode_inputs
|
||||
selective_state_update(
|
||||
ssm_state_cache,
|
||||
x_decode,
|
||||
@ -228,7 +136,6 @@ def _triton_cached_ssm(
|
||||
out=preallocated_ssm_out_d,
|
||||
)
|
||||
|
||||
# Return the preallocated output reshaped to original dimensions
|
||||
if num_total_tokens > 0:
|
||||
return (
|
||||
preallocated_ssm_out[:num_total_tokens]
|
||||
@ -273,68 +180,7 @@ def _triton_cached_ssm_fake(
|
||||
|
||||
|
||||
@AttentionRegistry.register("triton_ssm")
|
||||
class TritonBackendSSM(AttentionDescriptor):
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> AttentionLayout:
|
||||
# Hidden states follow [b, s, n, d]
|
||||
return "bsnd"
|
||||
|
||||
@classmethod
|
||||
def get_num_qkv_args(cls) -> int:
|
||||
# torch_ssm_transform signature has 7 node/state arguments
|
||||
return 7
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
# Keep source op unchanged (used for uncached pre-export)
|
||||
return torch.ops.auto_deploy.torch_ssm
|
||||
|
||||
class TritonBackendSSM(BaseBackendSSM):
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return torch.ops.auto_deploy.triton_cached_ssm.default
|
||||
|
||||
@classmethod
|
||||
def get_standard_metadata_args(cls) -> List[str]:
|
||||
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
|
||||
|
||||
@classmethod
|
||||
def get_prepare_extra_metadata_info(
|
||||
cls, any_source_attn_node: Node
|
||||
) -> Tuple[PrepareMetadataCallable, int, List[Constant]]:
|
||||
return (
|
||||
torch.ops.auto_deploy.triton_ssm_prepare_metadata.default,
|
||||
3, # chunk_indices, chunk_offsets, seq_idx_prefill
|
||||
extract_op_args(any_source_attn_node, "chunk_size"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
cls, source_attn_node: Node, cache_config: KvCacheConfig
|
||||
) -> ResourceHandlerDict:
|
||||
# Shapes from fake tensors
|
||||
hs_fake: torch.Tensor = source_attn_node.args[0].meta["val"]
|
||||
B_fake: torch.Tensor = source_attn_node.args[2].meta["val"]
|
||||
|
||||
num_heads = hs_fake.shape[-2]
|
||||
head_dim = hs_fake.shape[-1]
|
||||
|
||||
if B_fake.ndim >= 4:
|
||||
ssm_state_size = B_fake.shape[-1]
|
||||
else:
|
||||
ssm_state_size = max(1, B_fake.shape[-1])
|
||||
|
||||
# extract ssm_state_dtype from cache_config or hs_fake
|
||||
ssm_state_dtype = cls.resolve_cache_dtype(cache_config.mamba_ssm_cache_dtype, hs_fake.dtype)
|
||||
|
||||
return {
|
||||
"ssm_state_cache": StateResourceHandler(
|
||||
num_heads, head_dim, ssm_state_size, dtype=ssm_state_dtype
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
time_step_limit, chunk_size = extract_op_args(
|
||||
source_attn_node, "time_step_limit", "chunk_size"
|
||||
)
|
||||
return [time_step_limit, chunk_size]
|
||||
|
||||
@ -143,8 +143,14 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
|
||||
def test_auto_dtype(self, enable_chunked_prefill):
|
||||
@pytest.mark.parametrize("ssm_backend", ["triton_ssm", "flashinfer_ssm"])
|
||||
def test_auto_dtype(self, enable_chunked_prefill, ssm_backend):
|
||||
kwargs = self.get_default_kwargs(enable_chunked_prefill)
|
||||
kwargs.setdefault("transforms", {})
|
||||
insert_ssm_cfg = {"backend": ssm_backend}
|
||||
if ssm_backend == "flashinfer_ssm":
|
||||
insert_ssm_cfg["cache_config"] = {"mamba_dtype": "bfloat16"}
|
||||
kwargs["transforms"]["insert_cached_ssm_attention"] = insert_ssm_cfg
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH,
|
||||
tokenizer=self.MODEL_PATH,
|
||||
|
||||
@ -434,8 +434,9 @@ l0_h100:
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[True-1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[False]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[True]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-False]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[flashinfer_ssm-False]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronH::test_auto_dtype[triton_ssm-True]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_fp8[1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16[1]
|
||||
- examples/test_ad_speculative_decoding.py::test_autodeploy_spec_dec_output[draft_target]
|
||||
|
||||
@ -0,0 +1,103 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
from tests.unittest._torch.auto_deploy.unit.singlegpu.custom_ops.test_triton_mamba_cached_op import (
|
||||
_random_params,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mamba_env():
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
atol = 1e-3
|
||||
rtol = 1e-3
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.empty_cache()
|
||||
return {"device": device, "dtype": dtype, "atol": atol, "rtol": rtol}
|
||||
|
||||
|
||||
def test_flashinfer_decode_matches_triton(mamba_env):
|
||||
device = mamba_env["device"]
|
||||
dtype = mamba_env["dtype"]
|
||||
atol = mamba_env["atol"]
|
||||
rtol = mamba_env["rtol"]
|
||||
|
||||
batch, seq = 2, 1
|
||||
num_heads, head_dim = 2, 64
|
||||
n_groups, ssm_state_size = 2, 64
|
||||
(hidden_states, A, B, C, D, dt, dt_bias, time_step_limit, chunk_size) = _random_params(
|
||||
device, dtype, batch, seq, num_heads, head_dim, n_groups, ssm_state_size
|
||||
)
|
||||
|
||||
max_batch_size = 4
|
||||
slot_idx = torch.tensor([0, 2], device=device, dtype=torch.int32)
|
||||
ssm_state_cache_triton = torch.randn(
|
||||
max_batch_size, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype
|
||||
)
|
||||
ssm_state_cache_flashinfer = ssm_state_cache_triton.clone()
|
||||
|
||||
# batch_info_host: [num_prefill, num_prefill_tokens, num_decode]
|
||||
batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32)
|
||||
cu_seqlen = torch.zeros(batch + 1, device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool)
|
||||
|
||||
y_triton = torch.ops.auto_deploy.triton_cached_ssm(
|
||||
hidden_states,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
dt,
|
||||
dt_bias,
|
||||
# STANDARD METADATA
|
||||
batch_info_host,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
# EXTRA METADATA
|
||||
None, # chunk indices
|
||||
None, # chunk offsets
|
||||
None, # seq_idx_prefill
|
||||
# CACHES
|
||||
ssm_state_cache_triton,
|
||||
# CONSTANTS
|
||||
time_step_limit,
|
||||
chunk_size,
|
||||
)
|
||||
|
||||
y_flashinfer = torch.ops.auto_deploy.flashinfer_cached_ssm(
|
||||
hidden_states,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
dt,
|
||||
dt_bias,
|
||||
# STANDARD METADATA
|
||||
batch_info_host,
|
||||
cu_seqlen,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
# EXTRA METADATA
|
||||
None, # chunk indices
|
||||
None, # chunk offsets
|
||||
None, # seq_idx_prefill
|
||||
# CACHES
|
||||
ssm_state_cache_flashinfer,
|
||||
# CONSTANTS
|
||||
time_step_limit,
|
||||
chunk_size,
|
||||
)
|
||||
|
||||
assert y_triton.shape == hidden_states.shape
|
||||
assert y_flashinfer.shape == hidden_states.shape
|
||||
assert torch.isfinite(y_flashinfer).all()
|
||||
assert torch.allclose(y_flashinfer, y_triton.to(y_flashinfer.dtype), atol=atol, rtol=rtol)
|
||||
|
||||
after_triton = ssm_state_cache_triton.index_select(0, slot_idx)
|
||||
after_flashinfer = ssm_state_cache_flashinfer.index_select(0, slot_idx)
|
||||
assert torch.allclose(
|
||||
after_flashinfer.to(after_triton.dtype), after_triton, atol=atol, rtol=rtol
|
||||
)
|
||||
@ -991,14 +991,5 @@ def test_stack_nvfp4_moe_weights_transform_relu2(hidden_size, intermediate_size)
|
||||
# Run the transformed graph
|
||||
transformed_output = gm(x, selected_experts, routing_weights)
|
||||
|
||||
# Get the registered parameters after transform
|
||||
fc1_act_scale = getattr(gm, "nvfp4_moe_w3_w1_input_scale_stacked_0", None)
|
||||
fc1_alpha = getattr(gm, "nvfp4_moe_w1_alpha_stacked_0", None)
|
||||
if fc1_act_scale is not None:
|
||||
print(f"fc1_act_scale (after transform): {fc1_act_scale}, shape: {fc1_act_scale.shape}")
|
||||
if fc1_alpha is not None:
|
||||
print(f"fc1_alpha (after transform): {fc1_alpha}, shape: {fc1_alpha.shape}")
|
||||
|
||||
# Should be close for FP4 quantization (gated MLP may have slightly larger diff due to alpha handling)
|
||||
tol = 1e-3
|
||||
torch.testing.assert_close(ref_output, transformed_output, rtol=tol, atol=tol)
|
||||
|
||||
@ -184,6 +184,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
"transforms": {
|
||||
"insert_cached_attention": {"backend": "flashinfer"},
|
||||
"compile_model": {"backend": "torch-simple"},
|
||||
"insert_cached_ssm_attention": {"backend": "triton_ssm"},
|
||||
},
|
||||
},
|
||||
),
|
||||
@ -192,6 +193,7 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
{
|
||||
"transforms": {
|
||||
"multi_stream_moe": {"stage": "compile", "enabled": True},
|
||||
"insert_cached_ssm_attention": {"backend": "triton_ssm"},
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/9878
|
||||
"compile_model": {"backend": "torch-cudagraph"},
|
||||
},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user