[None][feat] Autodeploy add triton configs and optimize mamba prefill (#9083)

Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
Suyog Gupta 2025-11-13 19:15:43 -08:00 committed by GitHub
parent 3c950910a0
commit d12cb9436d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 615 additions and 90 deletions

View File

@ -1,3 +1,7 @@
Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Portions of this project are under the following copyright:
- Copyright contributors to the vLLM project
Apache License
Version 2.0, January 2004

View File

@ -134,6 +134,7 @@ package_data += [
"_torch/auto_deploy/config/*.yaml",
# Include CUDA source for fused MoE align extension so runtime JIT can find it in wheels
'_torch/auto_deploy/custom_ops/fused_moe/moe_align_kernel.cu',
'_torch/auto_deploy/custom_ops/fused_moe/triton_fused_moe_configs/*'
]

View File

@ -175,7 +175,7 @@ class CapturedGraph(nn.Module):
# retrieve output from buffer, cut to batch size, and unflatten
bs = args_batched[0].shape[0]
out_flat = [o_b[:bs].detach().clone() for o_b in self._out_buffer_flat]
out_flat = [o_b[:bs] for o_b in self._out_buffer_flat]
return self._out_spec.unflatten(out_flat)

View File

@ -116,6 +116,7 @@ class SequenceInfo:
page_size: int = 0,
max_num_tokens: Optional[int] = None,
vocab_size_padded: Optional[int] = None,
chunk_size: Optional[int] = None,
):
"""Initialize the SequenceInfo object.
@ -142,7 +143,10 @@ class SequenceInfo:
self.max_batch_size = max_batch_size
self.page_size = page_size if page_size > 0 else max_seq_len
self.vocab_size_padded = vocab_size_padded
self.chunk_size = chunk_size
# Chunk size is an input to a custom op, so we need to set a default value if it is not provided.
if self.chunk_size is None:
self.chunk_size = 128
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
# (max_batch_size, max_seq_len) input in trtllm runtime.
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
@ -193,7 +197,7 @@ class SequenceInfo:
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
"cache_loc": torch.empty(max_num_cache_loc_assignments, dtype=torch.int),
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.int),
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.long),
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
"_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
}
@ -203,7 +207,9 @@ class SequenceInfo:
# NOTE: order of keys is relevant here!
self._uncached_arg_names = ("input_ids", "position_ids")
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq", "slot_idx")
self._cached_constants = ("page_size",)
# page_size is the size of attentionkv-cache pages.
# chunk_size is used in mamba prefill kernels to split the context into chunks.
self._cached_constants = ("page_size", "chunk_size")
############################################################################################
# EXTRA TENSOR FIELDS ######################################################################

View File

@ -162,6 +162,7 @@ def prepare_flashinfer_metadata(
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
chunk_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for flashinfer attention.
@ -213,7 +214,7 @@ def prepare_flashinfer_metadata(
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
@prepare_flashinfer_metadata.register_fake
def prepare_flashinfer_metadata_fake(
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
):
seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)

View File

@ -0,0 +1,147 @@
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@ -0,0 +1,147 @@
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 5
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
}
}

View File

@ -4,13 +4,18 @@ Triton implementation of the Fused MOE ops. Inspired by vLLM's triton MOE implem
from __future__ import annotations
from typing import Tuple
import functools
import json
import os
from typing import Any, Tuple
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from ...utils.logger import ad_logger
@triton.jit
def _write_zeros_to_output(
@ -283,7 +288,90 @@ def fused_mlp_moe_kernel_w8a8(
tl.store(c_ptrs, accumulator, mask=c_mask)
def _default_kernel_config(M: int, E: int, N: int, K: int, top_k: int) -> dict:
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name(
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = (
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
).replace(" ", "")
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
# Adapted from: https://github.com/vllm-project/vllm/blob/b4fda58a2d0e458e0186e4caa4354b3d07153c70/vllm/model_executor/layers/fused_moe/fused_moe.py#L828
@functools.lru_cache
def get_moe_configs(
E: int,
N: int,
dtype: str | None,
block_n: int | None = None,
block_k: int | None = None,
) -> dict[int, Any] | None:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
block_shape = [block_n, block_k] if block_n and block_k else None
json_file_name = get_config_file_name(E, N, dtype, block_shape)
config_file_paths = []
# note that we prioritize user defined config
# user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
user_defined_config_folder = "."
if user_defined_config_folder is not None:
user_defined_config_file_path = os.path.join(user_defined_config_folder, json_file_name)
config_file_paths.append(user_defined_config_file_path)
ad_folder = "triton_fused_moe_configs"
default_config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), ad_folder, json_file_name
)
config_file_paths.append(default_config_file_path)
for config_file_path in config_file_paths:
if os.path.exists(config_file_path):
with open(config_file_path) as f:
ad_logger.info("Using configuration from %s for MoE layer.", config_file_path)
# If a configuration has been found, return it
tuned_config = json.load(f)
# Delete triton_version from tuned_config
tuned_config.pop("triton_version", None)
return {int(key): val for key, val in tuned_config.items()}
# If no optimized configuration is available, we will use the default
# configuration
ad_logger.warning(
("Using default MoE config. Performance might be sub-optimal! Config file not found at %s"),
config_file_paths,
)
return None
def _get_kernel_config(
M: int, E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
) -> dict:
configs = get_moe_configs(E, N, dtype=None)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config (closest batch size)
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = _default_kernel_config(M, E)
return config
def _default_kernel_config(M: int, E: int) -> dict:
if M <= E:
return {
"BLOCK_SIZE_M": 16,
@ -461,7 +549,7 @@ def _fused_moe_mlp_relu2(
E, inter_size, _ = w_up.shape
top_k = topk_ids.shape[1]
config = _default_kernel_config(M, E, inter_size, H, top_k)
config = _get_kernel_config(M, E, inter_size, H, top_k)
sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens(
topk_ids, M, E, top_k, config["BLOCK_SIZE_M"]
)
@ -598,7 +686,7 @@ def triton_quant_fp8_moe(
M, H = x2d.shape
E, inter_size, _ = w1_q.shape
top_k = topk_ids.shape[1]
config = _default_kernel_config(M, E, inter_size, H, top_k)
config = _get_kernel_config(M, E, inter_size, H, top_k)
sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens(
topk_ids, M, E, top_k, config["BLOCK_SIZE_M"]
)

View File

@ -61,6 +61,7 @@ def cuda_causal_conv_prepare_metadata(
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
chunk_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for cached causal conv (CUDA backend).
@ -75,13 +76,13 @@ def cuda_causal_conv_prepare_metadata(
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
# This is only used during prefill to determine if we should use the initial states from the cache.
use_initial_states = input_pos > 0
use_initial_states = input_pos[:num_seq] > 0
return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)
@cuda_causal_conv_prepare_metadata.register_fake
def cuda_causal_conv_prepare_metadata_fake(
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
):
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)
@ -182,7 +183,7 @@ def _cuda_cached_causal_conv1d(
# Scatter outputs back to y
y_prefill = y_varlen.transpose(0, 1) # [total_prefill_tokens, C_out]
y_flat[:total_prefill_tokens].copy_(y_prefill.to(y_flat.dtype))
y_flat[:total_prefill_tokens].copy_(y_prefill)
# DECODE: batch update for single-token sequences
if num_decode > 0:
@ -208,7 +209,7 @@ def _cuda_cached_causal_conv1d(
)
# Custom op must not return an alias of any input; return a fresh tensor
return y.contiguous().clone()
return y
@_cuda_cached_causal_conv1d.register_fake

View File

@ -147,6 +147,7 @@ def torch_causal_conv_prepare_metadata(
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
chunk_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for cached causal conv.

View File

@ -120,6 +120,7 @@ def _torch_ssm_prepare_metadata(
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
chunk_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for cached SSM transform.
@ -143,7 +144,7 @@ def _torch_ssm_prepare_metadata(
@_torch_ssm_prepare_metadata.register_fake
def _torch_ssm_prepare_metadata_fake(
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
):
# Use the same sanitization logic to determine sizes in fake mode
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)

View File

@ -24,6 +24,135 @@ from ..attention_interface import (
)
@torch.library.custom_op("auto_deploy::triton_ssm_prepare_metadata", mutates_args=())
def _triton_ssm_prepare_metadata(
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
chunk_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for cached SSM transform.
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
"""
# Determine number of active sequences and compute seq_start boundaries
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)
seq_start = torch.zeros_like(seq_len_sanitized)
if num_seq > 1:
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)
# Truncate slot indices to match active sequences
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
# reference implementation to support chunked prefill.
use_initial_states = input_pos[:num_seq] > 0
device = position_ids.device
chunk_indices = torch.zeros(num_seq, dtype=torch.int32, device=device)
chunk_offsets = torch.zeros(num_seq, dtype=torch.int32, device=device)
cu_seqlens = torch.zeros(num_seq + 1, dtype=torch.int32, device=device)
_, s = position_ids.shape[:2]
if s > 1:
# only compute chunk indices and offsets for prefill.
prefill_mask = seq_len_sanitized > 1
num_prefill = int(prefill_mask.sum().item())
num_prefill_tokens = int(seq_len_sanitized[:num_prefill].sum().item())
num_decode = num_seq - num_prefill
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=device),
torch.cumsum(seq_len_sanitized[:num_prefill].to(torch.int32), dim=0),
],
dim=0,
)
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
seq_idx_prefill = torch.repeat_interleave(
torch.arange(num_prefill, device=device, dtype=torch.int32),
seq_len_sanitized[:num_prefill],
).view(1, -1)
else:
num_prefill = 0
num_prefill_tokens = 0
num_decode = num_seq
seq_idx_prefill = torch.empty(1, 0, dtype=torch.int32, device=device)
batch_info_tensor = torch.tensor(
[num_prefill, num_prefill_tokens, num_decode], dtype=torch.int32
) # host tensor
return (
seq_len_sanitized,
seq_start,
slot_idx_sanitized,
use_initial_states,
cu_seqlens,
chunk_indices,
chunk_offsets,
seq_idx_prefill,
batch_info_tensor,
)
@_triton_ssm_prepare_metadata.register_fake
def _triton_ssm_prepare_metadata_fake(
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
):
# Use the same sanitization logic to determine sizes in fake mode
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)
device = slot_idx.device
# Always-correct shapes
seq_len_fake = torch.empty_like(seq_len_sanitized)
seq_start_fake = torch.empty_like(seq_len_sanitized)
slot_idx_fake = torch.empty(num_seq, dtype=torch.long, device=device)
use_initial_states_fake = torch.empty(num_seq, dtype=torch.bool, device=device)
cu_seqlens_fake = torch.empty(num_seq + 1, dtype=torch.int32, device=device)
# Token-dependent shapes (prefill vs decode)
_, s = position_ids.shape[:2]
if s > 1:
prefill_mask = seq_len_sanitized > 1
num_prefill = int(prefill_mask.sum().item())
num_prefill_tokens = int(seq_len_sanitized[:num_prefill].sum().item())
cu_seqlens_runtime = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=device),
torch.cumsum(seq_len_sanitized[:num_prefill].to(torch.int32), dim=0),
],
dim=0,
)
chunk_indices_rt, chunk_offsets_rt = cu_seqlens_to_chunk_indices_offsets(
cu_seqlens_runtime, chunk_size
)
chunk_indices_fake = torch.empty_like(chunk_indices_rt)
chunk_offsets_fake = torch.empty_like(chunk_offsets_rt)
seq_idx_prefill_fake = torch.empty(1, num_prefill_tokens, dtype=torch.int32, device=device)
else:
chunk_indices_fake = torch.empty(0, dtype=torch.int32, device=device)
chunk_offsets_fake = torch.empty(0, dtype=torch.int32, device=device)
seq_idx_prefill_fake = torch.empty(1, 0, dtype=torch.int32, device=device)
batch_info_tensor_fake = torch.empty(3, dtype=torch.int32)
return (
seq_len_fake,
seq_start_fake,
slot_idx_fake,
use_initial_states_fake,
cu_seqlens_fake,
chunk_indices_fake,
chunk_offsets_fake,
seq_idx_prefill_fake,
batch_info_tensor_fake,
)
@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={})
def _triton_cached_ssm(
# INPUTS (dense but may be flattened across sequences)
@ -39,6 +168,11 @@ def _triton_cached_ssm(
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
cu_seqlens: torch.Tensor, # [num_seq + 1]
chunk_indices: torch.Tensor, # [num_seq + 1]
chunk_offsets: torch.Tensor, # [num_seq + 1]
seq_idx_prefill: torch.Tensor, # [1, num_prefill]
batch_info_tensor: torch.Tensor, # [3]
# CACHES
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
# CONSTANTS
@ -51,11 +185,9 @@ def _triton_cached_ssm(
- Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot.
- Decode: batch single-token updates with selective_state_update and update states per slot.
"""
b, s = hidden_states.shape[:2]
num_seq = seq_len.shape[0]
b, s, num_heads, head_dim = hidden_states.shape
# Flatten tokens for indexing/scatter
bs = b * s
device = hidden_states.device
hs_flat = hidden_states.reshape(bs, *hidden_states.shape[2:]) # [bs, H, D]
B_flat = B.reshape(bs, *B.shape[2:]) # [bs, G, N]
C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N]
@ -64,48 +196,28 @@ def _triton_cached_ssm(
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
y_flat = y.view(bs, *y.shape[2:])
num_heads = hidden_states.shape[2]
head_dim = hidden_states.shape[3]
ssm_state_size = B.shape[3]
if s == 1:
num_prefill = 0
num_decode = num_seq
else:
prefill_mask = seq_len > 1
num_prefill = int(prefill_mask.sum().item())
num_decode = num_seq - num_prefill
num_prefill, num_prefill_tokens, num_decode = batch_info_tensor.tolist()
# Prefill: concatenate tokens at the front and run combined scan
if num_prefill > 0:
seq_len_prefill = seq_len[:num_prefill].to(torch.int32)
total_prefill_tokens = int(seq_len_prefill.sum().item())
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H]
hs_prefill = hs_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
B_prefill = B_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
C_prefill = C_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
dt_prefill = dt_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H]
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=device),
torch.cumsum(seq_len_prefill, dim=0),
],
dim=0,
)
seq_ids = torch.arange(num_prefill, device=device, dtype=torch.int32)
seq_idx_prefill = torch.repeat_interleave(seq_ids, seq_len_prefill).view(1, -1)
initial_states = chunk_indices = chunk_offsets = None
initial_states = None
if torch.any(use_initial_states[:num_prefill]):
initial_states = torch.where(
use_initial_states[:num_prefill, None, None, None],
ssm_state_cache[slot_idx[:num_prefill]],
0,
)
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
cu_seqlens, chunk_size
)
else:
chunk_indices = None
chunk_offsets = None
y_prefill, varlen_states = mamba_chunk_scan_combined(
hs_prefill,
dt_prefill,
@ -128,20 +240,19 @@ def _triton_cached_ssm(
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
)
y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
y_flat[:num_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
ssm_state_cache.index_copy_(
0, slot_idx[:num_prefill].to(torch.long), varlen_states.to(ssm_state_cache.dtype)
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
)
# Decode: batch single-token updates via selective_state_update
if num_decode > 0:
total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item())
slot_idx_decode = slot_idx[num_prefill:].to(torch.long)
slot_idx_decode = slot_idx[num_prefill:]
x_decode = hs_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H, D]
B_decode = B_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
C_decode = C_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
dt_decode = dt_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H]
x_decode = hs_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H, D]
B_decode = B_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N]
C_decode = C_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N]
dt_decode = dt_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H]
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
@ -162,9 +273,7 @@ def _triton_cached_ssm(
state_batch_indices=slot_idx_decode,
) # [nd, H, D]
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
y_dec.to(y_flat.dtype)
)
y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec.to(y_flat.dtype))
return y
@ -184,6 +293,11 @@ def _triton_cached_ssm_fake(
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
cu_seqlens: torch.Tensor, # [num_seq + 1]
chunk_indices: torch.Tensor, # [num_seq + 1]
chunk_offsets: torch.Tensor, # [num_seq + 1]
seq_idx_prefill: torch.Tensor, # [1, num_prefill]
batch_info_tensor: torch.Tensor, # [3]
# CACHES
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
# CONSTANTS
@ -226,8 +340,9 @@ class TritonBackendSSM(AttentionDescriptor):
@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4
# Returns: seq_len, seq_start, slot_idx, use_initial_states,
# cu_seqlens, chunk_indices, chunk_offsets, seq_idx_prefill, batch_info_tensor
return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 9
@classmethod
def get_cache_initializers(

View File

@ -182,6 +182,7 @@ def prepare_fused_mla_metadata(
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
chunk_size: int,
) -> List[torch.Tensor]:
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
seq_start = torch.zeros_like(seq_len[:num_seq])
@ -196,7 +197,7 @@ def prepare_fused_mla_metadata(
@prepare_fused_mla_metadata.register_fake
def prepare_fused_mla_metadata_fake(
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
):
return (
torch.empty_like(seq_len),

View File

@ -363,6 +363,7 @@ def torch_backend_prepare_metadata(
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
chunk_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for torch backend attention (similar to triton backend)."""
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)

View File

@ -291,6 +291,7 @@ def prepare_fused_mha_metadata(
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
chunk_size: int,
) -> List[torch.Tensor]:
# TODO: maybe use slot_idx instead of pages_per_seq??
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
@ -308,7 +309,7 @@ def prepare_fused_mha_metadata(
# SequenceInfo._get_sanitized_num_sequences could break in fake mode
@prepare_fused_mha_metadata.register_fake
def prepare_fused_mha_metadata_fake(
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
):
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
return (

View File

@ -194,6 +194,11 @@ class ModelFactory(ABC):
"""Returns the sharding config for this model."""
return self._sharding_config
@property
def chunk_size(self) -> Optional[int]:
"""Returns the chunk size for this model."""
return None
def get_cache_config(self) -> CacheConfig:
"""Return the cache configuration for the model.

View File

@ -137,6 +137,13 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
model_config, _ = self._get_model_config()
return getattr(model_config, "vocab_size", None)
@property
def chunk_size(self) -> Optional[int]:
"""Returns the chunk size for this model."""
model_config, _ = self._get_model_config()
# chunk_size is an input to a custom op, so it can not be none. We set it to a default value of 128.
return getattr(model_config, "chunk_size", 128)
def _recursive_update_config(
self, config: PretrainedConfig, update_dict: Dict[str, Any]
) -> Tuple[PretrainedConfig, Dict[str, Any]]:

View File

@ -121,8 +121,8 @@ class ADEngine(ModelEngine):
page_size=attn_page_size,
max_num_tokens=max_num_tokens,
vocab_size_padded=factory.vocab_size_padded,
chunk_size=factory.chunk_size,
)
# TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
@ -167,7 +167,6 @@ class ADEngine(ModelEngine):
# build model
self.model = get_inference_model(self.cache_seq_interface)
# start fresh with fixed seed
torch.manual_seed(42)
@ -324,7 +323,6 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
torch.cuda.set_device(rank)
port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port
dist.initialize_or_skip(rank, world_size, port)
# some config
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"

View File

@ -186,15 +186,9 @@ def test_prepare_metadata_cuda(conv_env):
pages_per_seq = torch.ones(b, device=device, dtype=torch.int32)
slot_idx = torch.tensor([2, 0, 1, 3], device=device, dtype=torch.int32)
page_size = 128
chunk_size = 128
out = torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata(
position_ids,
seq_len,
input_pos,
cache_loc,
pages_per_seq,
slot_idx,
page_size,
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
)
assert len(out) == 4
seq_len_s, seq_start, slot_s, use_initial_states = out

View File

@ -479,7 +479,7 @@ class TestTorchBackendAttention:
# Test metadata preparation
result = torch.ops.auto_deploy.torch_cached_attention_prepare_metadata(
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, 128
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, 128, 128
)
# Verify result structure

View File

@ -177,15 +177,10 @@ def test_prepare_metadata(conv_env):
pages_per_seq = torch.ones(b, device=device, dtype=torch.int32)
slot_idx = torch.tensor([2, 0, 1, 3], device=device, dtype=torch.int32)
page_size = 128
chunk_size = 128
out = torch.ops.auto_deploy.torch_causal_conv_prepare_metadata(
position_ids,
seq_len,
input_pos,
cache_loc,
pages_per_seq,
slot_idx,
page_size,
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
)
assert len(out) == 4
seq_len_s, seq_start, slot_s, use_initial_states = out

View File

@ -191,15 +191,9 @@ def test_prepare_metadata(mamba_env):
pages_per_seq = torch.ones(b, device=device, dtype=torch.int32)
slot_idx = torch.tensor([2, 0, 1, 3], device=device, dtype=torch.int32)
page_size = 128
chunk_size = 128
out = torch.ops.auto_deploy.torch_ssm_prepare_metadata(
position_ids,
seq_len,
input_pos,
cache_loc,
pages_per_seq,
slot_idx,
page_size,
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
)
# Returns a list of tensors from custom op API
assert len(out) == 4

View File

@ -114,7 +114,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
)
max_batch_size = 2
slot_idx = torch.tensor([1], device=device, dtype=torch.int32)
slot_idx = torch.tensor([1], device=device, dtype=torch.long)
ssm_state_cache_torch = torch.randn(
max_batch_size, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype
)
@ -123,6 +123,18 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
seq_len = torch.tensor(lens, device=device, dtype=torch.int32)
seq_start = torch.tensor([0, lens[0]], device=device, dtype=torch.int32)
use_initial_states = torch.tensor([0] * batch, device=device).to(torch.bool)
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=device),
torch.cumsum(seq_len.to(torch.int32), dim=0),
],
dim=0,
)
seq_idx_prefill = torch.repeat_interleave(
torch.arange(len(lens), device=device, dtype=torch.int32),
seq_len,
).view(1, -1)
batch_info_tensor = torch.tensor([len(lens), sum(lens), 0], dtype=torch.int32)
# Torch reference
y_torch = torch.ops.auto_deploy.torch_cached_ssm(
hidden_states,
@ -154,6 +166,11 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
seq_start,
slot_idx,
use_initial_states,
cu_seqlens,
None, # chunk indices
None, # chunk offsets
seq_idx_prefill,
batch_info_tensor,
ssm_state_cache_triton,
time_step_limit,
chunk_size,