mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Autodeploy add triton configs and optimize mamba prefill (#9083)
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
This commit is contained in:
parent
3c950910a0
commit
d12cb9436d
4
LICENSE
4
LICENSE
@ -1,3 +1,7 @@
|
||||
Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
|
||||
Portions of this project are under the following copyright:
|
||||
- Copyright contributors to the vLLM project
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
|
||||
1
setup.py
1
setup.py
@ -134,6 +134,7 @@ package_data += [
|
||||
"_torch/auto_deploy/config/*.yaml",
|
||||
# Include CUDA source for fused MoE align extension so runtime JIT can find it in wheels
|
||||
'_torch/auto_deploy/custom_ops/fused_moe/moe_align_kernel.cu',
|
||||
'_torch/auto_deploy/custom_ops/fused_moe/triton_fused_moe_configs/*'
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -175,7 +175,7 @@ class CapturedGraph(nn.Module):
|
||||
|
||||
# retrieve output from buffer, cut to batch size, and unflatten
|
||||
bs = args_batched[0].shape[0]
|
||||
out_flat = [o_b[:bs].detach().clone() for o_b in self._out_buffer_flat]
|
||||
out_flat = [o_b[:bs] for o_b in self._out_buffer_flat]
|
||||
return self._out_spec.unflatten(out_flat)
|
||||
|
||||
|
||||
|
||||
@ -116,6 +116,7 @@ class SequenceInfo:
|
||||
page_size: int = 0,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
vocab_size_padded: Optional[int] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the SequenceInfo object.
|
||||
|
||||
@ -142,7 +143,10 @@ class SequenceInfo:
|
||||
self.max_batch_size = max_batch_size
|
||||
self.page_size = page_size if page_size > 0 else max_seq_len
|
||||
self.vocab_size_padded = vocab_size_padded
|
||||
|
||||
self.chunk_size = chunk_size
|
||||
# Chunk size is an input to a custom op, so we need to set a default value if it is not provided.
|
||||
if self.chunk_size is None:
|
||||
self.chunk_size = 128
|
||||
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
|
||||
# (max_batch_size, max_seq_len) input in trtllm runtime.
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
|
||||
@ -193,7 +197,7 @@ class SequenceInfo:
|
||||
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
|
||||
"cache_loc": torch.empty(max_num_cache_loc_assignments, dtype=torch.int),
|
||||
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
|
||||
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.int),
|
||||
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.long),
|
||||
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
|
||||
"_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
|
||||
}
|
||||
@ -203,7 +207,9 @@ class SequenceInfo:
|
||||
# NOTE: order of keys is relevant here!
|
||||
self._uncached_arg_names = ("input_ids", "position_ids")
|
||||
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq", "slot_idx")
|
||||
self._cached_constants = ("page_size",)
|
||||
# page_size is the size of attentionkv-cache pages.
|
||||
# chunk_size is used in mamba prefill kernels to split the context into chunks.
|
||||
self._cached_constants = ("page_size", "chunk_size")
|
||||
############################################################################################
|
||||
|
||||
# EXTRA TENSOR FIELDS ######################################################################
|
||||
|
||||
@ -162,6 +162,7 @@ def prepare_flashinfer_metadata(
|
||||
pages_per_seq: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for flashinfer attention.
|
||||
|
||||
@ -213,7 +214,7 @@ def prepare_flashinfer_metadata(
|
||||
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
|
||||
@prepare_flashinfer_metadata.register_fake
|
||||
def prepare_flashinfer_metadata_fake(
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
|
||||
):
|
||||
seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
|
||||
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
|
||||
|
||||
@ -0,0 +1,147 @@
|
||||
{
|
||||
"triton_version": "3.5.0",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,147 @@
|
||||
{
|
||||
"triton_version": "3.5.0",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@ -4,13 +4,18 @@ Triton implementation of the Fused MOE ops. Inspired by vLLM's triton MOE implem
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _write_zeros_to_output(
|
||||
@ -283,7 +288,90 @@ def fused_mlp_moe_kernel_w8a8(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def _default_kernel_config(M: int, E: int, N: int, K: int, top_k: int) -> dict:
|
||||
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
|
||||
def get_config_file_name(
|
||||
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
|
||||
) -> str:
|
||||
device_name = torch.cuda.get_device_name().replace(" ", "_")
|
||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||
block_shape_selector = (
|
||||
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
|
||||
).replace(" ", "")
|
||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
|
||||
|
||||
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/b4fda58a2d0e458e0186e4caa4354b3d07153c70/vllm/model_executor/layers/fused_moe/fused_moe.py#L828
|
||||
@functools.lru_cache
|
||||
def get_moe_configs(
|
||||
E: int,
|
||||
N: int,
|
||||
dtype: str | None,
|
||||
block_n: int | None = None,
|
||||
block_k: int | None = None,
|
||||
) -> dict[int, Any] | None:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
|
||||
The return value will be a dictionary that maps an irregular grid of
|
||||
batch sizes to configurations of the fused_moe kernel. To evaluate the
|
||||
kernel on a given batch size bs, the closest batch size in the grid should
|
||||
be picked and the associated configuration chosen to invoke the kernel.
|
||||
"""
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
# directory
|
||||
block_shape = [block_n, block_k] if block_n and block_k else None
|
||||
json_file_name = get_config_file_name(E, N, dtype, block_shape)
|
||||
|
||||
config_file_paths = []
|
||||
|
||||
# note that we prioritize user defined config
|
||||
# user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||
user_defined_config_folder = "."
|
||||
if user_defined_config_folder is not None:
|
||||
user_defined_config_file_path = os.path.join(user_defined_config_folder, json_file_name)
|
||||
config_file_paths.append(user_defined_config_file_path)
|
||||
|
||||
ad_folder = "triton_fused_moe_configs"
|
||||
default_config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), ad_folder, json_file_name
|
||||
)
|
||||
config_file_paths.append(default_config_file_path)
|
||||
|
||||
for config_file_path in config_file_paths:
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
ad_logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
||||
# If a configuration has been found, return it
|
||||
tuned_config = json.load(f)
|
||||
# Delete triton_version from tuned_config
|
||||
tuned_config.pop("triton_version", None)
|
||||
return {int(key): val for key, val in tuned_config.items()}
|
||||
|
||||
# If no optimized configuration is available, we will use the default
|
||||
# configuration
|
||||
ad_logger.warning(
|
||||
("Using default MoE config. Performance might be sub-optimal! Config file not found at %s"),
|
||||
config_file_paths,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_kernel_config(
|
||||
M: int, E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
|
||||
) -> dict:
|
||||
configs = get_moe_configs(E, N, dtype=None)
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config (closest batch size)
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Else use the default config
|
||||
config = _default_kernel_config(M, E)
|
||||
return config
|
||||
|
||||
|
||||
def _default_kernel_config(M: int, E: int) -> dict:
|
||||
if M <= E:
|
||||
return {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
@ -461,7 +549,7 @@ def _fused_moe_mlp_relu2(
|
||||
E, inter_size, _ = w_up.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
config = _default_kernel_config(M, E, inter_size, H, top_k)
|
||||
config = _get_kernel_config(M, E, inter_size, H, top_k)
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens(
|
||||
topk_ids, M, E, top_k, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
@ -598,7 +686,7 @@ def triton_quant_fp8_moe(
|
||||
M, H = x2d.shape
|
||||
E, inter_size, _ = w1_q.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
config = _default_kernel_config(M, E, inter_size, H, top_k)
|
||||
config = _get_kernel_config(M, E, inter_size, H, top_k)
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens(
|
||||
topk_ids, M, E, top_k, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
|
||||
@ -61,6 +61,7 @@ def cuda_causal_conv_prepare_metadata(
|
||||
pages_per_seq: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for cached causal conv (CUDA backend).
|
||||
|
||||
@ -75,13 +76,13 @@ def cuda_causal_conv_prepare_metadata(
|
||||
|
||||
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
|
||||
# This is only used during prefill to determine if we should use the initial states from the cache.
|
||||
use_initial_states = input_pos > 0
|
||||
use_initial_states = input_pos[:num_seq] > 0
|
||||
return (seq_len_sanitized, seq_start, slot_idx_sanitized, use_initial_states)
|
||||
|
||||
|
||||
@cuda_causal_conv_prepare_metadata.register_fake
|
||||
def cuda_causal_conv_prepare_metadata_fake(
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
|
||||
):
|
||||
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
|
||||
num_seq = len(seq_len_sanitized)
|
||||
@ -182,7 +183,7 @@ def _cuda_cached_causal_conv1d(
|
||||
|
||||
# Scatter outputs back to y
|
||||
y_prefill = y_varlen.transpose(0, 1) # [total_prefill_tokens, C_out]
|
||||
y_flat[:total_prefill_tokens].copy_(y_prefill.to(y_flat.dtype))
|
||||
y_flat[:total_prefill_tokens].copy_(y_prefill)
|
||||
|
||||
# DECODE: batch update for single-token sequences
|
||||
if num_decode > 0:
|
||||
@ -208,7 +209,7 @@ def _cuda_cached_causal_conv1d(
|
||||
)
|
||||
|
||||
# Custom op must not return an alias of any input; return a fresh tensor
|
||||
return y.contiguous().clone()
|
||||
return y
|
||||
|
||||
|
||||
@_cuda_cached_causal_conv1d.register_fake
|
||||
|
||||
@ -147,6 +147,7 @@ def torch_causal_conv_prepare_metadata(
|
||||
pages_per_seq: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for cached causal conv.
|
||||
|
||||
|
||||
@ -120,6 +120,7 @@ def _torch_ssm_prepare_metadata(
|
||||
pages_per_seq: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for cached SSM transform.
|
||||
|
||||
@ -143,7 +144,7 @@ def _torch_ssm_prepare_metadata(
|
||||
|
||||
@_torch_ssm_prepare_metadata.register_fake
|
||||
def _torch_ssm_prepare_metadata_fake(
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
|
||||
):
|
||||
# Use the same sanitization logic to determine sizes in fake mode
|
||||
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
|
||||
|
||||
@ -24,6 +24,135 @@ from ..attention_interface import (
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_ssm_prepare_metadata", mutates_args=())
|
||||
def _triton_ssm_prepare_metadata(
|
||||
position_ids: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
pages_per_seq: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for cached SSM transform.
|
||||
|
||||
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
|
||||
"""
|
||||
# Determine number of active sequences and compute seq_start boundaries
|
||||
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
|
||||
num_seq = len(seq_len_sanitized)
|
||||
|
||||
seq_start = torch.zeros_like(seq_len_sanitized)
|
||||
if num_seq > 1:
|
||||
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)
|
||||
|
||||
# Truncate slot indices to match active sequences
|
||||
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
|
||||
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
|
||||
# reference implementation to support chunked prefill.
|
||||
use_initial_states = input_pos[:num_seq] > 0
|
||||
|
||||
device = position_ids.device
|
||||
|
||||
chunk_indices = torch.zeros(num_seq, dtype=torch.int32, device=device)
|
||||
chunk_offsets = torch.zeros(num_seq, dtype=torch.int32, device=device)
|
||||
cu_seqlens = torch.zeros(num_seq + 1, dtype=torch.int32, device=device)
|
||||
_, s = position_ids.shape[:2]
|
||||
if s > 1:
|
||||
# only compute chunk indices and offsets for prefill.
|
||||
prefill_mask = seq_len_sanitized > 1
|
||||
num_prefill = int(prefill_mask.sum().item())
|
||||
num_prefill_tokens = int(seq_len_sanitized[:num_prefill].sum().item())
|
||||
num_decode = num_seq - num_prefill
|
||||
cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32, device=device),
|
||||
torch.cumsum(seq_len_sanitized[:num_prefill].to(torch.int32), dim=0),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
|
||||
seq_idx_prefill = torch.repeat_interleave(
|
||||
torch.arange(num_prefill, device=device, dtype=torch.int32),
|
||||
seq_len_sanitized[:num_prefill],
|
||||
).view(1, -1)
|
||||
else:
|
||||
num_prefill = 0
|
||||
num_prefill_tokens = 0
|
||||
num_decode = num_seq
|
||||
seq_idx_prefill = torch.empty(1, 0, dtype=torch.int32, device=device)
|
||||
batch_info_tensor = torch.tensor(
|
||||
[num_prefill, num_prefill_tokens, num_decode], dtype=torch.int32
|
||||
) # host tensor
|
||||
|
||||
return (
|
||||
seq_len_sanitized,
|
||||
seq_start,
|
||||
slot_idx_sanitized,
|
||||
use_initial_states,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
chunk_offsets,
|
||||
seq_idx_prefill,
|
||||
batch_info_tensor,
|
||||
)
|
||||
|
||||
|
||||
@_triton_ssm_prepare_metadata.register_fake
|
||||
def _triton_ssm_prepare_metadata_fake(
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
|
||||
):
|
||||
# Use the same sanitization logic to determine sizes in fake mode
|
||||
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
|
||||
num_seq = len(seq_len_sanitized)
|
||||
device = slot_idx.device
|
||||
# Always-correct shapes
|
||||
seq_len_fake = torch.empty_like(seq_len_sanitized)
|
||||
seq_start_fake = torch.empty_like(seq_len_sanitized)
|
||||
slot_idx_fake = torch.empty(num_seq, dtype=torch.long, device=device)
|
||||
use_initial_states_fake = torch.empty(num_seq, dtype=torch.bool, device=device)
|
||||
cu_seqlens_fake = torch.empty(num_seq + 1, dtype=torch.int32, device=device)
|
||||
|
||||
# Token-dependent shapes (prefill vs decode)
|
||||
_, s = position_ids.shape[:2]
|
||||
if s > 1:
|
||||
prefill_mask = seq_len_sanitized > 1
|
||||
num_prefill = int(prefill_mask.sum().item())
|
||||
num_prefill_tokens = int(seq_len_sanitized[:num_prefill].sum().item())
|
||||
cu_seqlens_runtime = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32, device=device),
|
||||
torch.cumsum(seq_len_sanitized[:num_prefill].to(torch.int32), dim=0),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
chunk_indices_rt, chunk_offsets_rt = cu_seqlens_to_chunk_indices_offsets(
|
||||
cu_seqlens_runtime, chunk_size
|
||||
)
|
||||
chunk_indices_fake = torch.empty_like(chunk_indices_rt)
|
||||
chunk_offsets_fake = torch.empty_like(chunk_offsets_rt)
|
||||
seq_idx_prefill_fake = torch.empty(1, num_prefill_tokens, dtype=torch.int32, device=device)
|
||||
else:
|
||||
chunk_indices_fake = torch.empty(0, dtype=torch.int32, device=device)
|
||||
chunk_offsets_fake = torch.empty(0, dtype=torch.int32, device=device)
|
||||
seq_idx_prefill_fake = torch.empty(1, 0, dtype=torch.int32, device=device)
|
||||
|
||||
batch_info_tensor_fake = torch.empty(3, dtype=torch.int32)
|
||||
|
||||
return (
|
||||
seq_len_fake,
|
||||
seq_start_fake,
|
||||
slot_idx_fake,
|
||||
use_initial_states_fake,
|
||||
cu_seqlens_fake,
|
||||
chunk_indices_fake,
|
||||
chunk_offsets_fake,
|
||||
seq_idx_prefill_fake,
|
||||
batch_info_tensor_fake,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={})
|
||||
def _triton_cached_ssm(
|
||||
# INPUTS (dense but may be flattened across sequences)
|
||||
@ -39,6 +168,11 @@ def _triton_cached_ssm(
|
||||
seq_start: torch.Tensor, # [num_seq]
|
||||
slot_idx: torch.Tensor, # [num_seq]
|
||||
use_initial_states: torch.Tensor, # [num_seq]
|
||||
cu_seqlens: torch.Tensor, # [num_seq + 1]
|
||||
chunk_indices: torch.Tensor, # [num_seq + 1]
|
||||
chunk_offsets: torch.Tensor, # [num_seq + 1]
|
||||
seq_idx_prefill: torch.Tensor, # [1, num_prefill]
|
||||
batch_info_tensor: torch.Tensor, # [3]
|
||||
# CACHES
|
||||
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
|
||||
# CONSTANTS
|
||||
@ -51,11 +185,9 @@ def _triton_cached_ssm(
|
||||
- Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot.
|
||||
- Decode: batch single-token updates with selective_state_update and update states per slot.
|
||||
"""
|
||||
b, s = hidden_states.shape[:2]
|
||||
num_seq = seq_len.shape[0]
|
||||
b, s, num_heads, head_dim = hidden_states.shape
|
||||
# Flatten tokens for indexing/scatter
|
||||
bs = b * s
|
||||
device = hidden_states.device
|
||||
hs_flat = hidden_states.reshape(bs, *hidden_states.shape[2:]) # [bs, H, D]
|
||||
B_flat = B.reshape(bs, *B.shape[2:]) # [bs, G, N]
|
||||
C_flat = C.reshape(bs, *C.shape[2:]) # [bs, G, N]
|
||||
@ -64,48 +196,28 @@ def _triton_cached_ssm(
|
||||
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
|
||||
y_flat = y.view(bs, *y.shape[2:])
|
||||
|
||||
num_heads = hidden_states.shape[2]
|
||||
head_dim = hidden_states.shape[3]
|
||||
ssm_state_size = B.shape[3]
|
||||
|
||||
if s == 1:
|
||||
num_prefill = 0
|
||||
num_decode = num_seq
|
||||
else:
|
||||
prefill_mask = seq_len > 1
|
||||
num_prefill = int(prefill_mask.sum().item())
|
||||
num_decode = num_seq - num_prefill
|
||||
num_prefill, num_prefill_tokens, num_decode = batch_info_tensor.tolist()
|
||||
|
||||
# Prefill: concatenate tokens at the front and run combined scan
|
||||
if num_prefill > 0:
|
||||
seq_len_prefill = seq_len[:num_prefill].to(torch.int32)
|
||||
total_prefill_tokens = int(seq_len_prefill.sum().item())
|
||||
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
|
||||
B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
|
||||
C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
|
||||
dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H]
|
||||
|
||||
hs_prefill = hs_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
|
||||
B_prefill = B_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
|
||||
C_prefill = C_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
|
||||
dt_prefill = dt_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H]
|
||||
|
||||
cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32, device=device),
|
||||
torch.cumsum(seq_len_prefill, dim=0),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
seq_ids = torch.arange(num_prefill, device=device, dtype=torch.int32)
|
||||
seq_idx_prefill = torch.repeat_interleave(seq_ids, seq_len_prefill).view(1, -1)
|
||||
|
||||
initial_states = chunk_indices = chunk_offsets = None
|
||||
initial_states = None
|
||||
if torch.any(use_initial_states[:num_prefill]):
|
||||
initial_states = torch.where(
|
||||
use_initial_states[:num_prefill, None, None, None],
|
||||
ssm_state_cache[slot_idx[:num_prefill]],
|
||||
0,
|
||||
)
|
||||
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size
|
||||
)
|
||||
else:
|
||||
chunk_indices = None
|
||||
chunk_offsets = None
|
||||
|
||||
y_prefill, varlen_states = mamba_chunk_scan_combined(
|
||||
hs_prefill,
|
||||
dt_prefill,
|
||||
@ -128,20 +240,19 @@ def _triton_cached_ssm(
|
||||
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
|
||||
)
|
||||
|
||||
y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
|
||||
y_flat[:num_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
|
||||
ssm_state_cache.index_copy_(
|
||||
0, slot_idx[:num_prefill].to(torch.long), varlen_states.to(ssm_state_cache.dtype)
|
||||
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
|
||||
)
|
||||
|
||||
# Decode: batch single-token updates via selective_state_update
|
||||
if num_decode > 0:
|
||||
total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item())
|
||||
slot_idx_decode = slot_idx[num_prefill:].to(torch.long)
|
||||
slot_idx_decode = slot_idx[num_prefill:]
|
||||
|
||||
x_decode = hs_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H, D]
|
||||
B_decode = B_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
|
||||
C_decode = C_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
|
||||
dt_decode = dt_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H]
|
||||
x_decode = hs_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H, D]
|
||||
B_decode = B_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N]
|
||||
C_decode = C_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N]
|
||||
dt_decode = dt_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H]
|
||||
|
||||
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
|
||||
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
|
||||
@ -162,9 +273,7 @@ def _triton_cached_ssm(
|
||||
state_batch_indices=slot_idx_decode,
|
||||
) # [nd, H, D]
|
||||
|
||||
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
|
||||
y_dec.to(y_flat.dtype)
|
||||
)
|
||||
y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec.to(y_flat.dtype))
|
||||
|
||||
return y
|
||||
|
||||
@ -184,6 +293,11 @@ def _triton_cached_ssm_fake(
|
||||
seq_start: torch.Tensor, # [num_seq]
|
||||
slot_idx: torch.Tensor, # [num_seq]
|
||||
use_initial_states: torch.Tensor, # [num_seq]
|
||||
cu_seqlens: torch.Tensor, # [num_seq + 1]
|
||||
chunk_indices: torch.Tensor, # [num_seq + 1]
|
||||
chunk_offsets: torch.Tensor, # [num_seq + 1]
|
||||
seq_idx_prefill: torch.Tensor, # [1, num_prefill]
|
||||
batch_info_tensor: torch.Tensor, # [3]
|
||||
# CACHES
|
||||
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
|
||||
# CONSTANTS
|
||||
@ -226,8 +340,9 @@ class TritonBackendSSM(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
|
||||
# Returns (seq_len, seq_start, slot_idx, use_initial_states)
|
||||
return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4
|
||||
# Returns: seq_len, seq_start, slot_idx, use_initial_states,
|
||||
# cu_seqlens, chunk_indices, chunk_offsets, seq_idx_prefill, batch_info_tensor
|
||||
return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 9
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -182,6 +182,7 @@ def prepare_fused_mla_metadata(
|
||||
pages_per_seq: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
|
||||
seq_start = torch.zeros_like(seq_len[:num_seq])
|
||||
@ -196,7 +197,7 @@ def prepare_fused_mla_metadata(
|
||||
|
||||
@prepare_fused_mla_metadata.register_fake
|
||||
def prepare_fused_mla_metadata_fake(
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
|
||||
):
|
||||
return (
|
||||
torch.empty_like(seq_len),
|
||||
|
||||
@ -363,6 +363,7 @@ def torch_backend_prepare_metadata(
|
||||
pages_per_seq: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Prepare metadata for torch backend attention (similar to triton backend)."""
|
||||
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
|
||||
|
||||
@ -291,6 +291,7 @@ def prepare_fused_mha_metadata(
|
||||
pages_per_seq: torch.Tensor,
|
||||
slot_idx: torch.Tensor,
|
||||
page_size: int,
|
||||
chunk_size: int,
|
||||
) -> List[torch.Tensor]:
|
||||
# TODO: maybe use slot_idx instead of pages_per_seq??
|
||||
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
|
||||
@ -308,7 +309,7 @@ def prepare_fused_mha_metadata(
|
||||
# SequenceInfo._get_sanitized_num_sequences could break in fake mode
|
||||
@prepare_fused_mha_metadata.register_fake
|
||||
def prepare_fused_mha_metadata_fake(
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
|
||||
):
|
||||
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
|
||||
return (
|
||||
|
||||
@ -194,6 +194,11 @@ class ModelFactory(ABC):
|
||||
"""Returns the sharding config for this model."""
|
||||
return self._sharding_config
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> Optional[int]:
|
||||
"""Returns the chunk size for this model."""
|
||||
return None
|
||||
|
||||
def get_cache_config(self) -> CacheConfig:
|
||||
"""Return the cache configuration for the model.
|
||||
|
||||
|
||||
@ -137,6 +137,13 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
|
||||
model_config, _ = self._get_model_config()
|
||||
return getattr(model_config, "vocab_size", None)
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> Optional[int]:
|
||||
"""Returns the chunk size for this model."""
|
||||
model_config, _ = self._get_model_config()
|
||||
# chunk_size is an input to a custom op, so it can not be none. We set it to a default value of 128.
|
||||
return getattr(model_config, "chunk_size", 128)
|
||||
|
||||
def _recursive_update_config(
|
||||
self, config: PretrainedConfig, update_dict: Dict[str, Any]
|
||||
) -> Tuple[PretrainedConfig, Dict[str, Any]]:
|
||||
|
||||
@ -121,8 +121,8 @@ class ADEngine(ModelEngine):
|
||||
page_size=attn_page_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
vocab_size_padded=factory.vocab_size_padded,
|
||||
chunk_size=factory.chunk_size,
|
||||
)
|
||||
|
||||
# TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
|
||||
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
|
||||
|
||||
@ -167,7 +167,6 @@ class ADEngine(ModelEngine):
|
||||
|
||||
# build model
|
||||
self.model = get_inference_model(self.cache_seq_interface)
|
||||
|
||||
# start fresh with fixed seed
|
||||
torch.manual_seed(42)
|
||||
|
||||
@ -324,7 +323,6 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
torch.cuda.set_device(rank)
|
||||
port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port
|
||||
dist.initialize_or_skip(rank, world_size, port)
|
||||
|
||||
# some config
|
||||
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
|
||||
|
||||
|
||||
@ -186,15 +186,9 @@ def test_prepare_metadata_cuda(conv_env):
|
||||
pages_per_seq = torch.ones(b, device=device, dtype=torch.int32)
|
||||
slot_idx = torch.tensor([2, 0, 1, 3], device=device, dtype=torch.int32)
|
||||
page_size = 128
|
||||
|
||||
chunk_size = 128
|
||||
out = torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata(
|
||||
position_ids,
|
||||
seq_len,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
pages_per_seq,
|
||||
slot_idx,
|
||||
page_size,
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
|
||||
)
|
||||
assert len(out) == 4
|
||||
seq_len_s, seq_start, slot_s, use_initial_states = out
|
||||
|
||||
@ -479,7 +479,7 @@ class TestTorchBackendAttention:
|
||||
|
||||
# Test metadata preparation
|
||||
result = torch.ops.auto_deploy.torch_cached_attention_prepare_metadata(
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, 128
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, 128, 128
|
||||
)
|
||||
|
||||
# Verify result structure
|
||||
|
||||
@ -177,15 +177,10 @@ def test_prepare_metadata(conv_env):
|
||||
pages_per_seq = torch.ones(b, device=device, dtype=torch.int32)
|
||||
slot_idx = torch.tensor([2, 0, 1, 3], device=device, dtype=torch.int32)
|
||||
page_size = 128
|
||||
chunk_size = 128
|
||||
|
||||
out = torch.ops.auto_deploy.torch_causal_conv_prepare_metadata(
|
||||
position_ids,
|
||||
seq_len,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
pages_per_seq,
|
||||
slot_idx,
|
||||
page_size,
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
|
||||
)
|
||||
assert len(out) == 4
|
||||
seq_len_s, seq_start, slot_s, use_initial_states = out
|
||||
|
||||
@ -191,15 +191,9 @@ def test_prepare_metadata(mamba_env):
|
||||
pages_per_seq = torch.ones(b, device=device, dtype=torch.int32)
|
||||
slot_idx = torch.tensor([2, 0, 1, 3], device=device, dtype=torch.int32)
|
||||
page_size = 128
|
||||
|
||||
chunk_size = 128
|
||||
out = torch.ops.auto_deploy.torch_ssm_prepare_metadata(
|
||||
position_ids,
|
||||
seq_len,
|
||||
input_pos,
|
||||
cache_loc,
|
||||
pages_per_seq,
|
||||
slot_idx,
|
||||
page_size,
|
||||
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
|
||||
)
|
||||
# Returns a list of tensors from custom op API
|
||||
assert len(out) == 4
|
||||
|
||||
@ -114,7 +114,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
|
||||
)
|
||||
|
||||
max_batch_size = 2
|
||||
slot_idx = torch.tensor([1], device=device, dtype=torch.int32)
|
||||
slot_idx = torch.tensor([1], device=device, dtype=torch.long)
|
||||
ssm_state_cache_torch = torch.randn(
|
||||
max_batch_size, num_heads, head_dim, ssm_state_size, device=device, dtype=dtype
|
||||
)
|
||||
@ -123,6 +123,18 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
|
||||
seq_len = torch.tensor(lens, device=device, dtype=torch.int32)
|
||||
seq_start = torch.tensor([0, lens[0]], device=device, dtype=torch.int32)
|
||||
use_initial_states = torch.tensor([0] * batch, device=device).to(torch.bool)
|
||||
cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32, device=device),
|
||||
torch.cumsum(seq_len.to(torch.int32), dim=0),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
seq_idx_prefill = torch.repeat_interleave(
|
||||
torch.arange(len(lens), device=device, dtype=torch.int32),
|
||||
seq_len,
|
||||
).view(1, -1)
|
||||
batch_info_tensor = torch.tensor([len(lens), sum(lens), 0], dtype=torch.int32)
|
||||
# Torch reference
|
||||
y_torch = torch.ops.auto_deploy.torch_cached_ssm(
|
||||
hidden_states,
|
||||
@ -154,6 +166,11 @@ def test_triton_context_flattened_and_state_writeback(mamba_env):
|
||||
seq_start,
|
||||
slot_idx,
|
||||
use_initial_states,
|
||||
cu_seqlens,
|
||||
None, # chunk indices
|
||||
None, # chunk offsets
|
||||
seq_idx_prefill,
|
||||
batch_info_tensor,
|
||||
ssm_state_cache_triton,
|
||||
time_step_limit,
|
||||
chunk_size,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user