[AutoDeploy] Merge feat/ad_2025_06_13 feature branch (#5454)

Signed-off-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com>
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Co-authored-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com>
Co-authored-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
Co-authored-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Lucas Liebenwein 2025-06-25 04:30:13 +03:00 committed by GitHub
parent 73ba4fc320
commit 5cffb7e0ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 527 additions and 282 deletions

View File

@ -0,0 +1,42 @@
## AutoDeploy Custom Operators
All AutoDeploy custom operators follow the following naming convention:
`torch.ops.auto_deploy.<kernel_backend>_<op_category>_<op_name>`
The table below lists the operators ordered by their backend.
### Available Custom Operators
| Operator Name | Description |
|--------------|-------------|
| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support |
| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation |
| `torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa` | Grouped SDPA (Scaled Dot Product Attention) with BSND format |
| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) |
| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation |
| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation |
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation |
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation |
| `torch.ops.auto_deploy.torch_quant_fp4_linear` | FP4 quantized linear layer |
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine |
| `torch.ops.auto_deploy.torch_rope_with_qk_interleaving` | RoPE with QK interleaving |
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache` | Triton fused flattened MHA with cache |
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion` | Triton fused flattened MHA with cache and RoPE fusion |
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_cache` | Triton fused MHA with cache |
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache` | Triton fused MHA with paged cache |
| `torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache` | Triton flattened MHA with cache |
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache` | Triton fused flattened Multi-head Latent Attention with cache support |
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT-LLM fused MoE implementation |
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT-LLM fused linear layer followed by all-reduce operation |

View File

@ -4,11 +4,12 @@ from ._triton_attention_internal import *
from .dist import *
from .flashinfer_attention import *
from .flashinfer_rope import *
from .fused_moe import *
from .linear import *
from .mla import *
from .quant import *
from .rope import *
from .torch_attention import *
from .torch_moe import *
from .torch_rope import *
from .triton_attention import *
from .triton_rope import *
from .trtllm_moe import *

View File

@ -168,7 +168,9 @@ def _paged_context_mha(
)
@torch.library.custom_op("attention::fused_mha_with_paged_cache", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_mha_with_paged_cache", mutates_args=()
)
def fused_mha_with_paged_cache(
q: torch.Tensor,
k: torch.Tensor,
@ -210,10 +212,10 @@ def fused_mha_with_paged_cache(
if freqs_cis is not None:
if s == 1:
rope_args = (freqs_cis, input_pos, "bsnd")
fn_rope = torch.ops.rope.apply_rope_with_input_pos
fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos
else:
rope_args = (freqs_cis, input_pos, seq_len, seq_start)
fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs
q = fn_rope(q, *rope_args)
k = fn_rope(k, *rope_args)
@ -416,7 +418,9 @@ def _flattened_context_mha_rope_fusion(
)
@torch.library.custom_op("attention::fused_flattened_mha_with_cache_rope_fusion", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_flattened_mha_with_cache_rope_fusion", mutates_args=()
)
def fused_flattened_mha_with_cache_rope_fusion(
q: torch.Tensor,
k: torch.Tensor,
@ -541,7 +545,7 @@ def _context_mha(
)
@torch.library.custom_op("attention::fused_mha_with_cache", mutates_args=())
@torch.library.custom_op("auto_deploy::triton_attention_fused_mha_with_cache", mutates_args=())
def fused_mha_with_cache(
q: torch.Tensor,
k: torch.Tensor,
@ -563,8 +567,8 @@ def fused_mha_with_cache(
# rope embedding
if freqs_cis is not None:
q = torch.ops.rope.apply_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd")
k = torch.ops.rope.apply_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd")
q = torch.ops.auto_deploy.triton_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd")
k = torch.ops.auto_deploy.triton_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd")
# attention (assumed layout is bsnd)
y = torch.empty_like(q)
@ -593,7 +597,9 @@ def fused_mha_fake(
return torch.empty_like(q.contiguous())
@torch.library.custom_op("attention::fused_flattened_mha_with_cache", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_flattened_mha_with_cache", mutates_args=()
)
def fused_flattened_mha_with_cache(
# Q, K, V
q: torch.Tensor,
@ -638,10 +644,10 @@ def fused_flattened_mha_with_cache(
if freqs_cis.numel() > 0:
if s == 1:
rope_args = (freqs_cis, input_pos, "bsnd")
fn_rope = torch.ops.rope.apply_rope_with_input_pos
fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos
else:
rope_args = (freqs_cis, input_pos, seq_len, seq_start)
fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs
q = fn_rope(q, *rope_args)
k = fn_rope(k, *rope_args)

View File

@ -8,7 +8,7 @@ from ..distributed import common as dist
from ..distributed import trtllm as trtllm_dist
@torch.library.custom_op("dist::all_gather", mutates_args=(), device_types="cuda")
@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda")
def all_gather(
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
) -> torch.Tensor:
@ -25,7 +25,7 @@ def all_gather_fake(tensor, dim=0):
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)
@torch.library.custom_op("dist::all_reduce", mutates_args=(), device_types="cuda")
@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
def all_reduce(t: torch.Tensor) -> torch.Tensor:
"""All_reduce across the ranks. Reduction op is SUM.

View File

@ -153,7 +153,7 @@ class _FlashInferPlanner:
_GlobalFlashInferPlanner = _FlashInferPlanner()
@torch.library.custom_op("attention::prepare_flashinfer_metadata", mutates_args=())
@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=())
def prepare_flashinfer_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
@ -228,7 +228,7 @@ def prepare_flashinfer_metadata_fake(
)
@torch.library.custom_op("attention::flashinfer_mha_with_cache", mutates_args=())
@torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=())
def flashinfer_mha_with_cache(
# Q, K, V
q: torch.Tensor,
@ -355,15 +355,15 @@ class FlashInferAttention(AttentionDescriptor):
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
"""Get the source attention op that we target for replacement."""
return torch.ops.attention.bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.attention.flashinfer_mha_with_cache
return torch.ops.auto_deploy.flashinfer_attention_mha_with_cache
@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
return torch.ops.attention.prepare_flashinfer_metadata, 6
return torch.ops.auto_deploy.flashinfer_attention_prepare_metadata, 6
@classmethod
def get_cache_initializers(

View File

@ -4,7 +4,7 @@ import flashinfer
import torch
@torch.library.custom_op("rope::flashinfer", mutates_args=())
@torch.library.custom_op("auto_deploy::flashinfer_rope", mutates_args=())
def apply_rope_with_input_pos_flashinfer(
q: torch.Tensor,
k: torch.Tensor,

View File

@ -8,7 +8,7 @@ from ..distributed import common as dist
from ..distributed import trtllm as trtllm_dist
@torch.library.custom_op("linear::simple", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_linear_simple", mutates_args=())
def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
"""A wrapper for the linear functional to control how it is exposed.
@ -30,7 +30,9 @@ def simple_fake(input, weight, bias):
return torch.ops.aten.linear(input, weight, bias)
@torch.library.custom_op("linear::fused_linear_all_reduce", mutates_args=(), device_types="cuda")
@torch.library.custom_op(
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
)
def fused_linear_all_reduce(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:

View File

@ -22,7 +22,9 @@ from .triton_attention import _flattened_context_mha, _generate_mha
Constant = Union[int, float, str, None]
@torch.library.custom_op("attention::fused_flattened_mla_with_cache", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=()
)
def fused_flattened_mla_with_cache(
# Q, K, V
q_nope: torch.Tensor,
@ -94,7 +96,7 @@ def fused_flattened_mla_with_cache(
q_slice = q_pe[start : start + length]
k_slice = k_pe[start : start + length]
q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q_slice,
k_slice,
cos,
@ -169,7 +171,9 @@ def fused_flattened_mla_with_cache_fake(
return torch.empty_like(kv[..., -v_head_dim:])
@torch.library.custom_op("attention::prepare_fused_mla_metadata", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_prepare_fused_mla_metadata", mutates_args=()
)
def prepare_fused_mla_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
@ -221,15 +225,15 @@ class MultiHeadLatentAttention(AttentionDescriptor):
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.deepseek.fused_mla
return torch.ops.auto_deploy.torch_attention_deepseek_fused_mla
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.attention.fused_flattened_mla_with_cache
return torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache
@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
return torch.ops.attention.prepare_fused_mla_metadata, 4
return torch.ops.auto_deploy.triton_attention_prepare_fused_mla_metadata, 4
@classmethod
def get_cache_initializers(

View File

@ -16,7 +16,7 @@ TRTLLM_FP4_OP_AVAILABLE = True
TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16
@torch.library.custom_op("quant::quant_fn", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=())
def quant_fn(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
scaled_x = x / scale
rounded_x = torch.round(scaled_x)
@ -37,7 +37,7 @@ class QuantModule(nn.Module):
self.register_buffer("scale", torch.tensor(scale))
def forward(self, x: torch.Tensor):
return torch.ops.quant.quant_fn(x, self.scale)
return torch.ops.auto_deploy.torch_quant_fn(x, self.scale)
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
@ -50,7 +50,7 @@ def _to_fp8(x, scale):
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
@torch.library.custom_op("quant::fp8_linear", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_quant_fp8_linear", mutates_args=())
@torch.compile(dynamic=True)
def fp8_linear(
input: torch.Tensor,
@ -105,7 +105,7 @@ def fp8_linear_fake(
return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias)
@torch.library.custom_op("quant::fused_fp8_linear_all_reduce", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_quant_fused_fp8_linear_all_reduce", mutates_args=())
@torch.compile(dynamic=True)
def fused_fp8_linear_all_reduce(
input: torch.Tensor,
@ -114,7 +114,9 @@ def fused_fp8_linear_all_reduce(
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.ops.quant.fp8_linear(input, weight_fp8, bias, input_scale, weight_scale)
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
dist.all_reduce(out, op=dist.ReduceOp.SUM)
@ -129,7 +131,9 @@ def fused_fp8_linear_all_reduce_fake(
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.ops.quant.fp8_linear(input, weight_fp8, bias, input_scale, weight_scale)
return torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)
class FP8Linear(nn.Linear):
@ -146,12 +150,12 @@ class FP8Linear(nn.Linear):
self.bias = nn.Parameter(self.bias.to(torch.half))
def forward(self, x):
return torch.ops.quant.fp8_linear(
return torch.ops.auto_deploy.torch_quant_fp8_linear(
x, self.weight, self.bias, self.input_scale, self.weight_scale
)
@torch.library.custom_op("quant::fp4_linear", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_quant_fp4_linear", mutates_args=())
@torch.compile(dynamic=True)
def fp4_linear(
input: torch.Tensor,
@ -218,4 +222,7 @@ def fp4_linear_fake(
return torch.ops.aten.linear(input, weight_fp4.repeat(1, 2).to(input.dtype), bias)
QUANT_OPS = [torch.ops.quant.fp8_linear, torch.ops.quant.fp4_linear]
QUANT_OPS = [
torch.ops.auto_deploy.torch_quant_fp8_linear,
torch.ops.auto_deploy.torch_quant_fp4_linear,
]

View File

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
@torch.library.custom_op("attention::repeat_kv", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=())
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@ -31,7 +31,7 @@ def repeat_kv_fake(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return torch.empty(replicated_shape, device=hidden_states.device, dtype=hidden_states.dtype)
@torch.library.custom_op("attention::scaled_dot_product_attention", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_attention_sdpa", mutates_args=())
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
@ -66,7 +66,7 @@ def scaled_dot_product_attention_fake(
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
@torch.library.custom_op("attention::grouped_sdpa", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_attention_grouped_sdpa", mutates_args=())
def grouped_sdpa(
query: torch.Tensor,
key: torch.Tensor,
@ -104,7 +104,7 @@ def grouped_sdpa_fake(
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
@torch.library.custom_op("attention::bsnd_grouped_sdpa", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
def bsnd_grouped_sdpa(
query: torch.Tensor, # layout: [b, n, s_q, d]
key: torch.Tensor, # layout: [b, n, s_k, d]
@ -162,7 +162,7 @@ def update_kv_cache(
)
@torch.library.custom_op("attention::fused_mla_ref", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_attention_fused_mla_ref", mutates_args=())
def fused_mla_ref(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
@ -215,7 +215,7 @@ def fused_mla_ref(
q_slice = q_pe[start : start + length]
k_slice = k_pe[start : start + length]
q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q_slice,
k_slice,
cos,
@ -315,7 +315,7 @@ def fused_mla_ref_fake(
return torch.empty_like(kv[..., -v_head_dim:])
@torch.library.custom_op("deepseek::fused_mla", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_attention_deepseek_fused_mla", mutates_args=())
def fused_mla(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
@ -340,7 +340,7 @@ def fused_mla(
cos = cos[position_ids]
sin = sin[position_ids]
q_pe, k_pe = torch.ops.rope.torch_apply_rope_with_qk_interleaving(q_pe, k_pe, cos, sin)
q_pe, k_pe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(q_pe, k_pe, cos, sin)
query_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim)
query_states[:, :, :, :qk_nope_head_dim] = q_nope
@ -399,7 +399,7 @@ def fused_mla(
return torch.empty_like(kv[..., -v_head_dim:])
@torch.library.custom_op("deepseek::mla", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_attention_deepseek_mla", mutates_args=())
def mla(
q_nope: torch.Tensor, # Down projected q_nope
q_pe: torch.Tensor, # q_pe after applying rope

View File

@ -3,10 +3,8 @@ from typing import List
import torch
import torch.nn.functional as F
from ...modules.fused_moe import MoE # noqa: F401
@torch.library.custom_op("moe::torch_moe", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_moe", mutates_args=())
def torch_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
@ -80,7 +78,7 @@ def torch_moe(
return torch.empty_like(x)
@torch.library.custom_op("moe::torch_fused_moe", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_moe_fused", mutates_args=())
def torch_fused_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
@ -90,7 +88,6 @@ def torch_fused_moe(
) -> torch.Tensor:
"""
A reference implementation of a fused MoE layer computation.
Parameters:
x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
S is the sequence length, and H is the hidden size.
@ -102,7 +99,6 @@ def torch_fused_moe(
containing the fused weights for w3 and w1 for each expert.
w2_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
containing the weights for w2 for each expert.
Returns:
torch.Tensor: Output tensor with the same shape as the input x.
"""
@ -145,45 +141,3 @@ def torch_fused_moe(
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.custom_op("moe::trtllm_fused_moe", mutates_args=())
def trtllm_fused_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
x_shape = x.shape
x = x.view(-1, x_shape[-1])
routing_weights = routing_weights.to(torch.float32)
selected_experts = selected_experts.to(torch.int32)
quant_scales = []
return torch.ops.trtllm.fused_moe(
x,
selected_experts,
routing_weights,
w3_w1_stacked_weight,
w2_stacked_weight,
x.dtype,
quant_scales,
tp_size=1,
tp_rank=0,
ep_size=1,
ep_rank=0,
enable_alltoall=False,
)[0].view(x_shape)
@trtllm_fused_moe.register_fake
def trtllm_fused_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -11,7 +11,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
@torch.library.custom_op("rope::torch_apply_rope_with_explicit_cos_sin", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_rope_with_explicit_cos_sin", mutates_args=())
def torch_apply_rope_with_explicit_cos_sin(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -38,7 +38,7 @@ def torch_apply_rope_with_explicit_cos_sin_fake(
return torch.empty_like(q), torch.empty_like(k)
@torch.library.custom_op("rope::torch_apply_rope_with_complex_freqs", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_rope_with_complex_freqs", mutates_args=())
def torch_apply_rope_with_complex_freqs(
xq: torch.Tensor,
xk: torch.Tensor,
@ -69,7 +69,7 @@ def torch_apply_rope_with_complex_freqs_fake(
return torch.empty_like(xq), torch.empty_like(xk)
@torch.library.custom_op("rope::torch_apply_rope_with_qk_interleaving", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_rope_with_qk_interleaving", mutates_args=())
def torch_apply_rope_with_qk_interleaving(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:

View File

@ -169,7 +169,7 @@ def _flattened_context_mha(
)
@torch.library.custom_op("attention::flattened_mha_with_cache", mutates_args=())
@torch.library.custom_op("auto_deploy::triton_attention_flattened_mha_with_cache", mutates_args=())
def flattened_mha_with_cache(
# Q, K, V
q: torch.Tensor,
@ -259,7 +259,9 @@ def flattened_mha_fake(
return q.new_empty(*q.shape[:-1], v.shape[-1]).contiguous()
@torch.library.custom_op("attention::prepare_fused_mha_metadata", mutates_args=())
@torch.library.custom_op(
"auto_deploy::triton_attention_prepare_fused_mha_metadata", mutates_args=()
)
def prepare_fused_mha_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
@ -314,15 +316,15 @@ class TritonWithFlattenedInputs(AttentionDescriptor):
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.attention.bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
return torch.ops.attention.flattened_mha_with_cache
return torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache
@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
return torch.ops.attention.prepare_fused_mha_metadata, 4
return torch.ops.auto_deploy.triton_attention_prepare_fused_mha_metadata, 4
@classmethod
def get_cache_initializers(

View File

@ -4,7 +4,7 @@ import triton
from .triton_kernels.rope import rope_fwd_flattened_kernel, rope_fwd_kernel
@torch.library.custom_op("rope::apply_rope_with_input_pos", mutates_args=())
@torch.library.custom_op("auto_deploy::triton_rope_with_input_pos", mutates_args=())
def apply_rope_with_input_pos(
x: torch.Tensor, freqs_cis: torch.Tensor, input_pos: torch.Tensor, layout: str
) -> torch.Tensor:
@ -77,7 +77,7 @@ def apply_rope_with_input_pos_fake(x, freqs_cis, input_pos, layout):
return torch.empty_like(x)
@torch.library.custom_op("rope::apply_rope_on_flattened_inputs", mutates_args=())
@torch.library.custom_op("auto_deploy::triton_rope_on_flattened_inputs", mutates_args=())
def apply_rope_on_flattened_inputs(
x: torch.Tensor,
freqs_cis: torch.Tensor,

View File

@ -0,0 +1,43 @@
import torch
@torch.library.custom_op("auto_deploy::trtllm_moe_fused", mutates_args=())
def trtllm_fused_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
x_shape = x.shape
x = x.view(-1, x_shape[-1])
routing_weights = routing_weights.to(torch.float32)
selected_experts = selected_experts.to(torch.int32)
quant_scales = []
return torch.ops.trtllm.fused_moe(
x,
selected_experts,
routing_weights,
w3_w1_stacked_weight,
w2_stacked_weight,
x.dtype,
quant_scales,
tp_size=1,
tp_rank=0,
ep_size=1,
ep_rank=0,
enable_alltoall=False,
)[0].view(x_shape)
@trtllm_fused_moe.register_fake
def trtllm_fused_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -53,7 +53,7 @@ def deepseek_v3_attention(
# Use custom op to capture mla. This does not handle KV cache
# as passing transformers Cache into a custom op is throwing an error.
# Would not be an issue, cause we intend to replace mla op with our implementation further along the pipeline
attn_output = torch.ops.deepseek.fused_mla(
attn_output = torch.ops.auto_deploy.torch_attention_deepseek_fused_mla(
q_nope,
q_pe,
kv,
@ -131,7 +131,7 @@ def deepseek_v3_moe(self, hidden_states):
"""DeepSeekV3MoE forward function rewritten in Mixtral style to enable torch export."""
selected_experts, routing_weights, *_ = self.gate(hidden_states)
final_hidden_states = torch.ops.moe.torch_moe(
final_hidden_states = torch.ops.auto_deploy.torch_moe(
hidden_states,
selected_experts,
routing_weights,

View File

@ -17,7 +17,7 @@ def _forward_moe(self: Qwen3MoeSparseMoeBlock, hidden_states: torch.Tensor):
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.ops.moe.torch_moe(
final_hidden_states = torch.ops.auto_deploy.torch_moe(
hidden_states,
selected_experts,
routing_weights,

View File

@ -13,7 +13,6 @@ from ....bindings.internal.batch_manager import CacheType
from ....llmapi.llm_args import _AutoDeployLlmArgs
from ....mapping import Mapping
from ...distributed import MPIDist
from ...pyexecutor._util import create_torch_sampler_args
from ...pyexecutor.config import PyTorchConfig
from ...pyexecutor.model_engine import ModelEngine
from ...pyexecutor.py_executor import PyExecutor
@ -266,9 +265,14 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
ad_config: _AutoDeployLlmArgs = executor_config.pytorch_backend_config
max_batch_size = ad_config.max_batch_size
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
max_seq_len = ad_config.max_seq_len
attn_page_size = ad_config.attn_page_size
max_num_tokens = ad_config.max_num_tokens
max_draft_tokens = (
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens
)
ad_logger.info(f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}")
# initialize model engine
@ -309,22 +313,30 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)
# search sampler with speculative decoding
sampler_args = create_torch_sampler_args(
executor_config, dist_mapping, mixed_sampler=False, max_seq_len=max_seq_len
# TODO (lucaslie, fridah-nv): some models require mixed_sampler=True to have good outputs, see
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
# correctly for models as needed.
sampler_args = TorchSampler.Args(
max_seq_len=max_seq_len,
max_draft_tokens=max_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
mixed_sampler=ad_config.mixed_sampler,
)
sampler = TorchSampler(sampler_args)
# creating the executor object
py_executor = PyExecutor(
resource_manager,
scheduler,
model_engine=engine,
sampler=sampler,
dist=mpi_dist,
max_num_sequences=ad_config.max_batch_size * dist_mapping.pp_size,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
max_input_len=ad_config.max_input_len,
max_batch_size=ad_config.max_batch_size,
max_draft_tokens=ad_config.speculative_config.max_draft_tokens
if ad_config.speculative_config is not None
else 0,
max_batch_size=max_batch_size,
max_draft_tokens=max_draft_tokens,
)
return py_executor

View File

@ -215,11 +215,12 @@ class DemoEngine(ADEngine):
def _sample(
cls, logits: torch.Tensor, sampling_params: SamplingParams
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = cls._logits_to_probs(
logits, sampling_params.temperature, sampling_params.top_k
) # [*logits.shape]
# idx_next shape is [*logits.shape[:-1]]
idx_next = cls._multinomial_sample_one_no_sync(probs)
from tensorrt_llm._torch.pyexecutor.sampler import top_k_sampling_batch
logits_shape = logits.shape
logits = logits.view(-1, logits_shape[-1]) # top_k_sampling_batch expects 2D logits
idx_next, probs = top_k_sampling_batch(logits, sampling_params.top_k)
idx_next = idx_next.view(logits_shape[:-1])
return idx_next, probs
def _decode_tokens(

View File

@ -213,7 +213,7 @@ _torch_where_patch.where_original = torch.where
def _torch_linear_patch(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.ops.linear.simple(input, weight, bias)
return torch.ops.auto_deploy.torch_linear_simple(input, weight, bias)
# TODO: remove once https://github.com/pytorch/pytorch/issues/142439 is resolved
@ -336,7 +336,7 @@ def torch_export_to_gm(
# there is no guarantee how it is represented and we need to make sure it is easily identifiable
# in the graph.
sdpa_original = F.scaled_dot_product_attention
F.scaled_dot_product_attention = torch.ops.attention.scaled_dot_product_attention
F.scaled_dot_product_attention = torch.ops.auto_deploy.torch_attention_sdpa
# We overwrite the linear functional as well. This basically avoids exporting the view ops
# that are used to flatten/unflatten multiple batch dimensions of the input tensor.

View File

@ -18,7 +18,7 @@ def match_repeat_kv(gm: GraphModule) -> GraphModule:
The pattern is:
unsqueeze -> expand -> reshape -> [optional] contiguous
This is replaced with torch.ops.attention.repeat_kv.
This is replaced with torch.ops.auto_deploy.torch_attention_repeat_kv.
"""
graph = gm.graph
@ -49,7 +49,7 @@ def match_eager_attention(gm: GraphModule) -> GraphModule:
The pattern is:
transpose -> matmul -> mul -> (optional) add -> softmax -> to -> dropout -> matmul
This is replaced with torch.ops.attention.scaled_dot_product_attention.
This is replaced with torch.ops.auto_deploy.torch_attention_sdpa.
"""
graph = gm.graph
@ -82,7 +82,7 @@ def match_grouped_attention(gm: GraphModule) -> GraphModule:
repeat_kv(v, n_rep) ->
sdpa(q, repeated_k, repeated_v)
This is replaced with torch.ops.attention.grouped_sdpa.
This is replaced with torch.ops.auto_deploy.torch_attention_grouped_sdpa.
"""
graph = gm.graph
@ -92,7 +92,7 @@ def match_grouped_attention(gm: GraphModule) -> GraphModule:
# Iterate through nodes in the graph
for node in list(graph.nodes):
# Look for SDPA nodes that could be part of our pattern
if is_op(node, torch.ops.attention.scaled_dot_product_attention):
if is_op(node, torch.ops.auto_deploy.torch_attention_sdpa):
match_info = _match_grouped_attention_pattern(node)
if match_info:
ad_logger.debug(f"Found grouped attention pattern at {node}")
@ -126,8 +126,8 @@ def match_causal_attn_mask(gm: GraphModule) -> GraphModule:
for node in list(graph.nodes):
# Look for SDPA nodes or grouped SDPA nodes
if not (
is_op(node, torch.ops.attention.scaled_dot_product_attention)
or is_op(node, torch.ops.attention.grouped_sdpa)
is_op(node, torch.ops.auto_deploy.torch_attention_sdpa)
or is_op(node, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
):
continue
@ -437,7 +437,7 @@ def _match_grouped_attention_pattern(sdpa_node: Node) -> Optional[Dict[str, Node
Returns a dictionary with information about the match or None if no match.
"""
# Check that sdpa_node is an SDPA operation
if not is_op(sdpa_node, torch.ops.attention.scaled_dot_product_attention):
if not is_op(sdpa_node, torch.ops.auto_deploy.torch_attention_sdpa):
return None
# SDPA should have query, key, value as its first three arguments
@ -447,8 +447,8 @@ def _match_grouped_attention_pattern(sdpa_node: Node) -> Optional[Dict[str, Node
query, key_repeated, value_repeated = sdpa_node.args[0:3]
# Key and value should come from repeat_kv operations
if not is_op(key_repeated, torch.ops.attention.repeat_kv) or not is_op(
value_repeated, torch.ops.attention.repeat_kv
if not is_op(key_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv) or not is_op(
value_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv
):
return None
@ -487,7 +487,7 @@ def _replace_with_repeat_kv(graph, match_info: Dict[str, Node]) -> None:
with graph.inserting_before(node_to_replace):
repeat_kv_node = graph.call_function(
torch.ops.attention.repeat_kv, args=(input_tensor, n_rep)
torch.ops.auto_deploy.torch_attention_repeat_kv, args=(input_tensor, n_rep)
)
# Preserve metadata from the original node
@ -502,7 +502,7 @@ def _replace_with_sdpa(graph, match_info: Dict[str, Node]) -> None:
Replace the matched eager attention pattern with scaled_dot_product_attention.
"""
# retrieve the default op for scaled_dot_product_attention
sdpa_op = torch.ops.attention.scaled_dot_product_attention.default
sdpa_op = torch.ops.auto_deploy.torch_attention_sdpa.default
# construct the args for the ops based on the match_info and the op's schema
args = []
@ -530,7 +530,7 @@ def _replace_with_sdpa(graph, match_info: Dict[str, Node]) -> None:
def _replace_with_grouped_sdpa(graph, match_info: Dict[str, Node]) -> None:
"""
Replace the matched grouped attention pattern with torch.ops.attention.grouped_sdpa.
Replace the matched grouped attention pattern with torch.ops.auto_deploy.torch_attention_grouped_sdpa.
"""
sdpa_node = match_info["sdpa_node"]
query = match_info["query"]
@ -543,7 +543,7 @@ def _replace_with_grouped_sdpa(graph, match_info: Dict[str, Node]) -> None:
with graph.inserting_before(sdpa_node):
grouped_sdpa_node = graph.call_function(
torch.ops.attention.grouped_sdpa.default, args=args, kwargs=kwargs
torch.ops.auto_deploy.torch_attention_grouped_sdpa.default, args=args, kwargs=kwargs
)
# Preserve metadata from the original node
@ -763,8 +763,8 @@ def match_attention_layout(gm: GraphModule, attention_op: Type[AttentionDescript
# List of SDPA operations to look for
sdpa_ops = {
torch.ops.attention.scaled_dot_product_attention,
torch.ops.attention.grouped_sdpa,
torch.ops.auto_deploy.torch_attention_sdpa,
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
}
graph = gm.graph

View File

@ -22,14 +22,14 @@ def fuse_collectives(gm: GraphModule) -> GraphModule:
# lookup for fused ops
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
lookup = {
torch.ops.linear.simple: torch.ops.linear.fused_linear_all_reduce,
torch.ops.aten.linear: torch.ops.linear.fused_linear_all_reduce,
torch.ops.quant.fp8_linear: torch.ops.quant.fused_fp8_linear_all_reduce,
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
}
# go through all nodes and find all_reduce nodes
for node in gm.graph.nodes:
if not is_op(node, torch.ops.dist.all_reduce):
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
continue
# check if args are as expected
@ -162,7 +162,7 @@ def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> GraphModule:
# Traverse all nodes
for node in gm.graph.nodes:
if is_op(node, torch.ops.dist.all_reduce):
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
trace_and_fuse(allreduce_node=node, graph=gm.graph)
gm = canonicalize_graph(gm)

View File

@ -38,7 +38,7 @@ def ep_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
assert isinstance(gm, GraphModule), "Expecting GraphModule"
num_moe_patterns = 0
for node in list(gm.graph.nodes):
if not is_op(node, torch.ops.moe.torch_moe):
if not is_op(node, torch.ops.auto_deploy.torch_moe):
continue
_insert_sharded_moe(gm, node, rank, world_size)
num_moe_patterns += 1
@ -123,6 +123,8 @@ def _insert_sharded_moe(
# -- add an all_reduce node --
with gm.graph.inserting_after(node):
dist_node = gm.graph.call_function(torch.ops.dist.all_reduce, args=(node,))
dist_node = gm.graph.call_function(
torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,)
)
node.replace_all_uses_with(dist_node)
dist_node.replace_input_with(dist_node, node)

View File

@ -69,7 +69,7 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule:
w3_list = expert_weights["w3"]
fused_moe_node = graph.call_function(
torch.ops.moe.torch_moe,
torch.ops.auto_deploy.torch_moe,
args=(
hidden_states,
selected_experts,
@ -99,7 +99,7 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule:
def fuse_moe(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Scan the FX graph and replace all calls to torch.ops.moe.torch_moe with
torch.ops.moe.trtllm_fused_moe.
torch.ops.auto_deploy.trtllm_moe_fused.
"""
ad_logger.debug("Before MoE fusion: " + str(gm))
@ -118,7 +118,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
graph = gm.graph
for node in list(graph.nodes):
if not is_op(node, torch.ops.moe.torch_moe):
if not is_op(node, torch.ops.auto_deploy.torch_moe):
continue
ad_logger.debug(f"Found MoE op to fuse: {node} with args: {node.args}")
@ -146,7 +146,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
with graph.inserting_before(node):
new_node = graph.call_function(
torch.ops.moe.trtllm_fused_moe,
torch.ops.auto_deploy.trtllm_moe_fused,
args=(
hidden_states,
selected_experts,

View File

@ -1,6 +1,6 @@
"""
This transformation defines two main RoPE (Rotary Positional Embedding) pattern matchers used
to identify and replace RoPE subgraphs with a custom op (`torch.ops.rope.flashinfer`).
to identify and replace RoPE subgraphs with a custom op (`torch.ops.auto_deploy.flashinfer_rope`).
Supported RoPE variants:
@ -73,7 +73,7 @@ def _explicit_rope_pattern(q, k, cos, sin, unsqueeze_dim=1):
def _explicit_rope_repl(q, k, cos, sin, unsqueeze_dim):
return torch.ops.rope.torch_apply_rope_with_explicit_cos_sin.default(
return torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin.default(
q, k, cos, sin, unsqueeze_dim
)
@ -91,7 +91,7 @@ def _interleaved_rope_pattern(q, k, cos, sin, unsqueeze_dim=1):
def _interleaved_rope_repl(q, k, cos, sin, unsqueeze_dim):
return torch.ops.rope.torch_apply_rope_with_qk_interleaving.default(
return torch.ops.auto_deploy.torch_rope_with_qk_interleaving.default(
q, k, cos, sin, unsqueeze_dim
)
@ -109,7 +109,7 @@ def _complex_rope_pattern(xq, xk, freqs_cis, unsqueeze_dim=1):
def _complex_rope_repl(q, k, freqs_cis, unsqueeze_dim):
return torch.ops.rope.torch_apply_rope_with_complex_freqs.default(
return torch.ops.auto_deploy.torch_rope_with_complex_freqs.default(
q, k, freqs_cis, unsqueeze_dim
)
@ -195,9 +195,9 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo
graph = gm.graph
rope_ops = {
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
torch.ops.rope.torch_apply_rope_with_qk_interleaving,
torch.ops.rope.torch_apply_rope_with_complex_freqs,
torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin,
torch.ops.auto_deploy.torch_rope_with_qk_interleaving,
torch.ops.auto_deploy.torch_rope_with_complex_freqs,
}
need_transpose = False
@ -206,7 +206,7 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo
if not is_op(node, rope_ops):
continue
if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs):
q_node, k_node, freqs_node, unsq = extract_op_args(
node,
"xq", # argument name in schema
@ -257,7 +257,7 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo
q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2)
k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2)
if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs):
new_args = (
q_for_op_contig,
k_for_op_contig,
@ -309,9 +309,9 @@ def optimize_rope(gm: GraphModule) -> GraphModule:
num_rope_optimizations = 0
for node in list(graph.nodes):
if is_op(node, torch.ops.rope.torch_apply_rope_with_explicit_cos_sin):
if is_op(node, torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin):
_optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache)
elif is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
elif is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs):
_optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache)
else:
continue
@ -398,7 +398,7 @@ def _optimize_explicit(
rope_position_ids_cache=pos_cache,
)
flash_node = graph.call_function(
torch.ops.rope.flashinfer,
torch.ops.auto_deploy.flashinfer_rope,
args=(q_node, k_node, position_ids, fused_cos_sin_to, True),
)
@ -478,7 +478,7 @@ def _optimize_complex(
graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=pos_cache
)
flash_node = graph.call_function(
torch.ops.rope.flashinfer,
torch.ops.auto_deploy.flashinfer_rope,
args=(q_node, k_node, position_ids, cos_sin_flash, False),
)

View File

@ -16,6 +16,7 @@ Our sharding algorithm for tensor parallelism (TP) is based on the following ste
happens automatically via the checkpoint loading hook added in step 2c.
"""
import math
import operator
from collections import defaultdict
from functools import partial
@ -71,7 +72,13 @@ def _load_hook_remove(
def _insert_sharded_matmul(
gm: GraphModule, node: Node, dim: int, rank: int, world_size: int, add_dist: bool = False
gm: GraphModule,
node: Node,
dim: int,
rank: int,
world_size: int,
add_dist: bool = False,
min_local_shape: int = 1,
):
"""Replaces the matmul node with a new matmul node that accepts sharded weights.
@ -83,8 +90,21 @@ def _insert_sharded_matmul(
quantization_impl = QuantizationImpl.create(node)
def split_tensor(
t: torch.Tensor, d: int = dim, r: int = rank, ws: int = world_size
t: torch.Tensor,
d: int = dim,
r: int = rank,
ws: int = world_size,
min_d_shape: int = min_local_shape,
) -> torch.Tensor:
# The local tensor shape has to be divisible by min_d_shape
max_split_size = t.shape[d] // min_d_shape
if ws > max_split_size:
num_groups = math.ceil(ws / max_split_size)
ad_logger.debug(
f"World size {ws} is greater than the max split size {max_split_size}. "
+ f"Splitting tensor to {num_groups} chunks"
)
return torch.tensor_split(t, max_split_size, dim=d)[r // num_groups]
return torch.tensor_split(t, ws, dim=d)[r]
num_users = num_users_of_weight_node(node)
@ -168,8 +188,8 @@ def _insert_sharded_matmul(
# figure out the right dist op
dist_lookup = {
0: (torch.ops.dist.all_gather, -1),
1: (torch.ops.dist.all_reduce,),
0: (torch.ops.auto_deploy.torch_dist_all_gather, -1),
1: (torch.ops.auto_deploy.torch_dist_all_reduce,),
}
fn_dist, *dist_args = dist_lookup[dim]
@ -191,7 +211,10 @@ def _simple_shard(
def column_row_shard(
gm: GraphModule, rank: int, world_size: int, simple_shard_only: bool = False
gm: GraphModule,
rank: int,
world_size: int,
simple_shard_only: bool = False,
) -> GraphModule:
"""A transformation to apply sharding to the model following tensor parallelism.
@ -205,6 +228,9 @@ def column_row_shard(
**all** nodes in the subgraph. The subgraph here is defined as the region between the first
linear node to the last linear node of an identified sharding region.
# 5. Shard the GEMM nodes or skip accordingly.
min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism
splitting, e.g., the individual heads into smaller shards.
"""
ad_logger.debug("Before sharding graph: " + str(gm))
@ -232,9 +258,9 @@ def column_row_shard(
# acceptable attention nodes between sharded GEMMs
shardable_attention_nodes = {
torch.ops.attention.scaled_dot_product_attention,
torch.ops.attention.grouped_sdpa,
torch.ops.attention.bsnd_grouped_sdpa,
torch.ops.auto_deploy.torch_attention_sdpa,
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa,
}
# This is a heuristic. Basically, we assume those are okay to shard if we also encounter an
@ -244,7 +270,7 @@ def column_row_shard(
shardable_nodes_with_attention = {
torch.ops.aten.view,
torch.ops.aten.reshape,
torch.ops.rope.flashinfer,
torch.ops.auto_deploy.flashinfer_rope,
operator.getitem,
}
@ -327,9 +353,25 @@ def column_row_shard(
# If we can account for all sharded nodes, we can do a two-way shard
# --> row_split (dim 0) + col_split (dim 1) + all_reduce
# check if we are sharding the attention block
if attention_nodes:
if len(attention_nodes) > 1:
# Column-row shard boundary region detection is probably wrong - there should be
# only one attention operation. Fall back to simple shard.
ad_logger.debug(f"More than one attention node: {unaccounted_nodes}")
_simple_shard(gm, nodes_linear, rank, world_size)
continue
# Extract head dimension. We cannot shard below the head_dim size.
# Assume that head_dim is the last (innermost) dimension of the tensor
min_local_shape = attention_nodes.pop().meta["val"].shape[-1]
else:
min_local_shape = 1
for i, group in enumerate(nodes_linear.values()):
for n in group:
_insert_sharded_matmul(gm, n, i, rank, world_size, add_dist=i > 0)
_insert_sharded_matmul(
gm, n, i, rank, world_size, add_dist=i > 0, min_local_shape=min_local_shape
)
# canonicalize and return
if num_shards:
@ -424,7 +466,7 @@ def dp_bmm_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
base_size = bmm_batch_size // world_size
remainder = bmm_batch_size % world_size
# NOTE: our torch.ops.dist.all_gather doesn't support uneven splits at the moment.
# NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment.
if remainder:
ad_logger.warning(
f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. "
@ -451,7 +493,7 @@ def dp_bmm_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
# Add all_gather node after BMM to collect results
with gm.graph.inserting_after(node):
gather_node = gm.graph.call_function(
torch.ops.dist.all_gather,
torch.ops.auto_deploy.torch_dist_all_gather,
args=(node, 0), # Gather along batch dimension (0)
)
node.replace_all_uses_with(gather_node)

View File

@ -68,11 +68,11 @@ PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
# TODO(yudong): make custom_ops configurable
CUSTOM_OPS = (
torch.ops.dist.all_reduce.default,
torch.ops.auto_deploy.torch_dist_all_reduce.default,
torch.ops.aten.slice.Tensor,
torch.ops.attention.fused_mha_with_cache.default,
torch.ops.linear.fused_linear_all_reduce.default,
torch.ops.linear.simple.default,
torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default,
torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce.default,
torch.ops.auto_deploy.torch_linear_simple.default,
torch.ops.aten.split_with_sizes.default,
)

View File

@ -222,7 +222,7 @@ def is_linear_op(node: Node, include_quantization: bool = False) -> bool:
"""
lin_ops = {
torch.ops.aten.linear,
torch.ops.linear.simple,
torch.ops.auto_deploy.torch_linear_simple,
}
if include_quantization:
@ -233,8 +233,8 @@ def is_linear_op(node: Node, include_quantization: bool = False) -> bool:
def is_dist_op(node: Node) -> bool:
"""Check if the node is a distributed op."""
dist_ops = {
torch.ops.dist.all_gather,
torch.ops.dist.all_reduce,
torch.ops.auto_deploy.torch_dist_all_gather,
torch.ops.auto_deploy.torch_dist_all_reduce,
}
return is_op(node, dist_ops)

View File

@ -136,7 +136,7 @@ class QuantizationImpl:
class FP8QuantizationImpl(QuantizationImpl):
@staticmethod
def target_op():
return torch.ops.quant.fp8_linear
return torch.ops.auto_deploy.torch_quant_fp8_linear
@staticmethod
def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor:
@ -207,7 +207,7 @@ def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank,
class FP4QuantizationImpl(QuantizationImpl):
@staticmethod
def target_op():
return torch.ops.quant.fp4_linear
return torch.ops.auto_deploy.torch_quant_fp4_linear
@staticmethod
def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor:

View File

@ -184,7 +184,7 @@ class MoEOpModel(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: Tensor of shape (batch, hidden_size)
Computes router logits via a gate, and then calls the MoE op via torch.moe.torch_moe.
Computes router logits via a gate, and then calls the MoE op via torch.ops.auto_deploy.torch_moe.
"""
router_logits = self.gate(x)
@ -197,7 +197,7 @@ class MoEOpModel(nn.Module):
w2_list = [expert.w2 for expert in self.experts]
w3_list = [expert.w3 for expert in self.experts]
out = torch.ops.moe.torch_moe(
out = torch.ops.auto_deploy.torch_moe(
x, selected_experts, routing_weights, w1_list, w2_list, w3_list
)
return out

View File

@ -9,14 +9,14 @@ from tensorrt_llm._torch.auto_deploy.distributed.common import spawn_multiproces
def _run_all_reduce_test(rank, world_size):
x = torch.ones(10, 10).to("cuda")
y = torch.ops.dist.all_reduce(x)
y = torch.ops.auto_deploy.torch_dist_all_reduce(x)
assert torch.equal(x * world_size, y)
def _run_all_gather_test(rank, world_size):
x = torch.ones(10, 10).to("cuda")
y = torch.ops.dist.all_gather(x)
y = torch.ops.auto_deploy.torch_dist_all_gather(x)
assert torch.sum(y) == world_size * torch.sum(x)
assert y.shape == (world_size * x.shape[0], *x.shape[1:])

View File

@ -74,7 +74,7 @@ def _run_moe_ep_test(num_experts: int, topk: int, rank: int, world_size: int):
final_scales_local = final_scales * rank_mask
output_trt = torch.ops.moe.trtllm_fused_moe(
output_trt = torch.ops.auto_deploy.trtllm_moe_fused(
x,
selected_experts_local,
final_scales_local,

View File

@ -36,7 +36,7 @@ class AllreduceResidualNorm(torch.nn.Module):
self.norm = RMSNorm(hidden_size, 1e-5, dtype)
def forward(self, x, residual):
x = torch.ops.dist.all_reduce(x)
x = torch.ops.auto_deploy.torch_dist_all_reduce(x)
y = x + residual
normed = self.norm(y)
return normed, y

View File

@ -64,7 +64,7 @@ def _run_job(
return num_params
# now run the test
op_expected = getattr(torch.ops.dist, "all_gather")
op_expected = getattr(torch.ops.auto_deploy, "torch_dist_all_gather")
run_test(
model,
x,

View File

@ -26,8 +26,8 @@ class MLPAllReduce(nn.Module):
self.linear2 = cls(4 * in_features, out_features, bias=bias)
def forward(self, x):
y = F.relu(torch.ops.dist.all_reduce(self.linear1(x)))
return torch.ops.dist.all_reduce(self.linear2(y))
y = F.relu(torch.ops.auto_deploy.torch_dist_all_reduce(self.linear1(x)))
return torch.ops.auto_deploy.torch_dist_all_reduce(self.linear2(y))
def _run_job(
@ -58,7 +58,7 @@ def _run_job(
def check_transformed_graph(gm):
return any(is_op(n, op_expected) for n in gm.graph.nodes) and not any(
is_op(n, torch.ops.dist.all_reduce) for n in gm.graph.nodes
is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm.graph.nodes
)
# now run the test
@ -76,10 +76,10 @@ def _run_job(
@pytest.mark.parametrize(
"linear_cls, dist_op_expected",
(
(nn.Linear, "linear.fused_linear_all_reduce"),
(nn.Linear, "auto_deploy.trtllm_dist_fused_linear_all_reduce"),
pytest.param(
FP8Linear,
"quant.fused_fp8_linear_all_reduce",
"auto_deploy.torch_quant_fused_fp8_linear_all_reduce",
marks=pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support"),
),
),

View File

@ -33,7 +33,7 @@ def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None:
expected_expert = num_experts_per_rank * hidden_size * intermediate_size * 3
return n_gate + expected_expert
op_expected = torch.ops.dist.all_reduce
op_expected = torch.ops.auto_deploy.torch_dist_all_reduce
run_test(
model,

View File

@ -15,6 +15,50 @@ from tensorrt_llm._torch.auto_deploy.transformations.library import column_row_s
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
class GQA_Block(nn.Module):
def __init__(
self,
num_attention_heads: int,
hidden_size: int,
num_key_value_heads: int,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
self.head_dim = self.hidden_size // self.num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.is_gqa = num_key_value_heads < num_attention_heads
assert self.hidden_size == self.num_attention_heads * self.head_dim
# key, query, value, out projections
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.k_proj = nn.Linear(
self.hidden_size,
self.head_dim * self.num_key_value_heads,
bias=False,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.head_dim * self.num_key_value_heads,
bias=False,
)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s, _ = x.shape
q = self.q_proj(x).view(b, s, -1, self.head_dim)
k = self.k_proj(x).view(b, s, -1, self.head_dim)
v = self.v_proj(x).view(b, s, -1, self.head_dim)
y = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(q, k, v, is_causal=True)
y = y.contiguous().view(b, s, -1)
return self.o_proj(y)
class MLP(nn.Module):
def __init__(self, in_features, out_features, bias=False):
super().__init__()
@ -37,27 +81,80 @@ def _run_job(
) -> None:
# init model and input
batch_size = 4
num_features = 10
model = model_cls(num_features, num_features, bias=bias).to(device="cuda", dtype=torch.float16)
x = torch.randn(batch_size, num_features, device="cuda", dtype=torch.float16)
sequence_len = 8
num_features = 32
# GQA specific parameters
num_heads = 4
num_key_value_heads = 1
if model_cls == GQA_Block:
model = model_cls(
num_attention_heads=num_heads,
hidden_size=num_features,
num_key_value_heads=num_key_value_heads,
).to(device="cuda", dtype=torch.float16)
else:
model = model_cls(num_features, num_features, bias=bias).to(
device="cuda", dtype=torch.float16
)
x = torch.randn(batch_size, sequence_len, num_features, device="cuda", dtype=torch.float16)
if model_cls == GQA_Block:
head_dim = num_features // num_heads
min_local_size = head_dim
else:
min_local_size = 1
def _get_expected_num_params(num_p_og: int) -> int:
num_update = 0
if bias and dist_op_expected == "all_reduce":
if bias and dist_op_expected == "torch_dist_all_reduce":
num_p_og -= num_features
num_update = num_features * (rank == world_size - 1)
num_params = num_p_og // world_size + num_update
if min_local_size > 1:
# it means we are in the GQA. W_kv are partially replicated, we need to count
# the number of parameters manually.
W_q_local_size = num_features * num_features // world_size
W_o_local_size = W_q_local_size
W_k_local_size = num_features * head_dim * max(num_key_value_heads // world_size, 1)
W_v_local_size = W_k_local_size
num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size
else:
num_params = num_p_og // world_size + num_update
return num_params
def verify_local_weight_sizes(gm) -> bool:
"""Verify that all weight tensors have first dimension >= min_local_size after sharding."""
for name, param in gm.named_parameters():
# Only check parameters that have at least 1 dimension and are weight matrices
if param.dim() >= 1 and "weight" in name:
if param.shape[0] < min_local_size:
print(
f"Weight {name} has shape {param.shape}, dim {param.shape[0]} < min_local_size {min_local_size}"
)
return False
return True
# now run the test
op_expected = getattr(torch.ops.dist, dist_op_expected)
op_expected = getattr(torch.ops.auto_deploy, dist_op_expected)
transform_func = partial(column_row_shard, rank=rank, world_size=world_size)
def combined_graph_check(gm) -> bool:
# Check for expected distributed operations
has_expected_dist_ops = any(is_op(n, op_expected) for n in gm.graph.nodes) == (
world_size > 1
)
# Check weight size constraints
weight_sizes_valid = verify_local_weight_sizes(gm)
return has_expected_dist_ops and weight_sizes_valid
run_test(
model,
x,
transform=partial(column_row_shard, rank=rank, world_size=world_size),
check_transformed_graph=lambda gm: any(is_op(n, op_expected) for n in gm.graph.nodes)
== (world_size > 1),
transform=transform_func,
check_transformed_graph=combined_graph_check,
_get_expected_num_params=_get_expected_num_params,
)
@ -67,8 +164,9 @@ def _run_job(
@pytest.mark.parametrize(
"model_cls, dist_op_expected",
(
(MLP, "all_reduce"),
(nn.Linear, "all_gather"),
(MLP, "torch_dist_all_reduce"),
(nn.Linear, "torch_dist_all_gather"),
(GQA_Block, "torch_dist_all_reduce"),
),
)
def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, device_count: int):

View File

@ -49,7 +49,7 @@ def test_moe_op_run(dtype):
fused_w2_weight.data[expert_id].copy_(w2)
with torch.inference_mode():
output_torch_moe = torch.ops.moe.torch_moe(
output_torch_moe = torch.ops.auto_deploy.torch_moe(
x,
selected_experts,
final_scales,
@ -57,14 +57,14 @@ def test_moe_op_run(dtype):
w2_weight,
w3_weight,
)
output_torch_fused_moe = torch.ops.moe.torch_fused_moe(
output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused(
x,
selected_experts,
final_scales,
fused_w3_w1_stacked_weight,
fused_w2_weight,
)
output_trt_fused_moe = torch.ops.moe.trtllm_fused_moe(
output_trt_fused_moe = torch.ops.auto_deploy.trtllm_moe_fused(
x,
selected_experts,
final_scales,

View File

@ -21,7 +21,7 @@ def test_attention_op():
q, k, v = (x.contiguous() for x in torch.split(qkv, 1, dim=1))
output = torch.ops.attention.fused_mha_with_cache(
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_cache(
q, k, v, input_positions, k_cache, v_cache, None
)
ref = torch.nn.functional.scaled_dot_product_attention(
@ -66,7 +66,7 @@ def test_gqa_op(device, dtype, n_heads, group_size, seq_len):
v_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device)
# run custom op
output = torch.ops.attention.fused_mha_with_cache(
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_cache(
q, k, v, input_positions, k_cache, v_cache, None
)
@ -148,7 +148,7 @@ def test_flat_gqa_op(
v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
# run op
output = torch.ops.attention.flattened_mha_with_cache(
output = torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache(
# Q, K, V
q,
k,
@ -274,7 +274,7 @@ def test_flat_gqa_op_with_rope(
source = 1
if source == 1:
# call rope fusion kernels
output = torch.ops.attention.fused_flattened_mha_with_cache_rope_fusion(
output = torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion(
q,
k,
v,
@ -288,7 +288,7 @@ def test_flat_gqa_op_with_rope(
)
else:
# call stand-alone rope embedding kernel
output = torch.ops.attention.fused_flattened_mha_with_cache(
output = torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache(
q,
k,
v,
@ -466,7 +466,7 @@ def test_paged_gqa_op(
v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
# run op
output = torch.ops.attention.fused_mha_with_paged_cache(
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache(
q,
k,
v,

View File

@ -88,7 +88,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q,
k,
@ -213,7 +213,7 @@ def test_flashinfer_attention_op_decode(
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q,
k,
@ -329,7 +329,7 @@ def test_flashinfer_attention_context_and_generate(
),
BATCH_SIZE * PREFILL_SEQ_LEN,
)
flashinfer_output_1 = torch.ops.attention.flashinfer_mha_with_cache(
flashinfer_output_1 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q_1,
k_1,
@ -404,7 +404,7 @@ def test_flashinfer_attention_context_and_generate(
),
BATCH_SIZE * 1,
)
flashinfer_output_3 = torch.ops.attention.flashinfer_mha_with_cache(
flashinfer_output_3 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q_3,
k_3,
@ -513,7 +513,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q,
k,
@ -660,7 +660,7 @@ def test_flashinfer_attention_with_fp8_cache(
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q,
k,
@ -757,7 +757,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
),
BATCH_SIZE * SEQ_LEN,
)
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q,
k,
@ -840,7 +840,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
),
BATCH_SIZE * 1,
)
flashinfer_output_gen = torch.ops.attention.flashinfer_mha_with_cache(
flashinfer_output_gen = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
# Q, K, V
q_gen,
k_gen,

View File

@ -20,7 +20,7 @@ def test_fp8_linear(bias):
weight_scale = (torch.max(torch.abs(weight)) / 448).to("cuda")
weight_fp8 = (weight / weight_scale).to(torch.float8_e4m3fn)
output_fp8_gemm = torch.ops.quant.fp8_linear(
output_fp8_gemm = torch.ops.auto_deploy.torch_quant_fp8_linear(
input,
weight_fp8,
bias=bias,
@ -49,7 +49,7 @@ def test_fp4_linear():
weight, weight_scale_2, scaling_vector_size, False
)
output_fp4_gemm = torch.ops.quant.fp4_linear(
output_fp4_gemm = torch.ops.auto_deploy.torch_quant_fp4_linear(
input,
weight_fp4,
bias=None,

View File

@ -81,7 +81,7 @@ def test_flashinfer_custom_op_and_hf_impl(dtype, atol, rtol, head_dim):
# Custom op call
positions_flat = torch.arange(batch * seq_len, device=device)
custom_q, custom_k = torch.ops.rope.flashinfer(
custom_q, custom_k = torch.ops.auto_deploy.flashinfer_rope(
query, key, positions_flat, cos_sin_cache_expand, True
)
@ -135,7 +135,7 @@ def test_flashinfer_custom_op_and_complex_impl(dtype, atol, rtol, head_dim):
# q/k of llama4 rope is interleaved
positions_flat = torch.arange(batch * seq_len, device=device)
custom_q, custom_k = torch.ops.rope.flashinfer(
custom_q, custom_k = torch.ops.auto_deploy.flashinfer_rope(
query, key, positions_flat, cos_sin_cache_expand, False
)
@ -211,8 +211,8 @@ def test_triton_custom_op_and_hf_impl(layout, head_dim, dtype, atol, rtol):
q_hf = q_f32.to(dtype)
k_hf = k_f32.to(dtype)
q_out = torch.ops.rope.apply_rope_with_input_pos(q, cosin_cache, positions, layout)
k_out = torch.ops.rope.apply_rope_with_input_pos(k, cosin_cache, positions, layout)
q_out = torch.ops.auto_deploy.triton_rope_with_input_pos(q, cosin_cache, positions, layout)
k_out = torch.ops.auto_deploy.triton_rope_with_input_pos(k, cosin_cache, positions, layout)
torch.testing.assert_close(q_hf, q_out, atol=atol, rtol=rtol)
torch.testing.assert_close(k_hf, k_out, atol=atol, rtol=rtol)

View File

@ -34,7 +34,7 @@ def test_rope(d_head):
y_ref = torch_rope_reference(x, freqs_cis, input_position)
freqs_cis = freqs_cis.to("cuda")
x_reshaped = x.unflatten(-1, (N_ELEM // 2, 2)).transpose(-1, -2).flatten(-2).contiguous()
y = torch.ops.rope.apply_rope_with_input_pos(
y = torch.ops.auto_deploy.triton_rope_with_input_pos(
x_reshaped.to("cuda"), freqs_cis, input_position, "bsnd"
)
y_reshaped = y.unflatten(-1, (2, N_ELEM // 2)).transpose(-2, -1).flatten(-2).contiguous()
@ -64,7 +64,7 @@ def test_rope_flattened(d_head):
seq_start_indices = torch.zeros(len(SEQ_LENS), dtype=torch.int32, device="cuda")
seq_start_indices[1:] = torch.cumsum(seq_lens[:-1], 0)
y = torch.ops.rope.apply_rope_on_flattened_inputs(
y = torch.ops.auto_deploy.triton_rope_on_flattened_inputs(
x_reshaped.to("cuda"), freqs_cis, input_position, seq_lens, seq_start_indices
)
y_reshaped = y.unflatten(-1, (2, N_ELEM // 2)).transpose(-2, -1).flatten(-2).contiguous()

View File

@ -55,7 +55,9 @@ def mla_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale):
q_nope_proj = torch.einsum("bhsd,hdc->bhsc", q_nope, wkv_b_weight[:, :qk_nope_head_dim])
# MLA ref operation
x = torch.ops.deepseek.mla(q_nope_proj, q_pe, compressed_kv, k_pe, None, softmax_scale)
x = torch.ops.auto_deploy.torch_attention_deepseek_mla(
q_nope_proj, q_pe, compressed_kv, k_pe, None, softmax_scale
)
# Up project attention scores
x = torch.einsum("bshc,hdc->bshd", x, wkv_b_weight[:, -v_head_dim:])

View File

@ -379,8 +379,8 @@ class GroupedAttentionModel(torch.nn.Module):
# Manually apply repeat_kv to k and v
if self.num_kv_heads != self.num_heads:
k = torch.ops.attention.repeat_kv(k, self.n_rep)
v = torch.ops.attention.repeat_kv(v, self.n_rep)
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, self.n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, self.n_rep)
# Create attention mask if needed
attn_mask = None
@ -396,7 +396,7 @@ class GroupedAttentionModel(torch.nn.Module):
).masked_fill(mask, float("-inf"))
# Apply scaled dot product attention
attn_output = torch.ops.attention.scaled_dot_product_attention(
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
q,
k,
v,
@ -434,7 +434,9 @@ def test_match_repeat_kv(num_heads, num_kv_heads, model_cls):
expected_matches = 0 if num_heads == num_kv_heads else 2
def verify_matcher(gm):
repeat_kv_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.repeat_kv)]
repeat_kv_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_repeat_kv)
]
# Check that we have the expected number of replacements
if len(repeat_kv_nodes) != expected_matches:
@ -549,7 +551,7 @@ def test_match_eager_attention(has_mask, use_division, dropout, rtol, atol, mode
def verify_matcher(gm):
sdpa_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.attention.scaled_dot_product_attention)
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
]
# Check that we have the expected number of replacements
@ -655,8 +657,10 @@ def test_counter_example():
dynamic_shapes = model.get_dynamic_shapes()
def verify_no_matches(gm):
# No nodes should be replaced with torch.ops.attention.repeat_kv
repeat_kv_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.repeat_kv)]
# No nodes should be replaced with torch.ops.auto_deploy.torch_attention_repeat_kv
repeat_kv_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_repeat_kv)
]
return len(repeat_kv_nodes) == 0
# Ensure the pattern matcher doesn't match our counter-examples
@ -693,7 +697,9 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
def verify_matcher(gm):
grouped_sdpa_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)
n
for n in gm.graph.nodes
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
]
# Check that we have the expected number of replacements
@ -790,8 +796,8 @@ class CausalAttentionModel(torch.nn.Module):
# For grouped attention, repeat k and v
if self.use_grouped_sdpa and self.num_kv_heads != self.num_heads:
n_rep = self.num_heads // self.num_kv_heads
k = torch.ops.attention.repeat_kv(k, n_rep)
v = torch.ops.attention.repeat_kv(v, n_rep)
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
# Create attention mask based on mask_type
if self.mask_type == "triu":
@ -830,7 +836,7 @@ class CausalAttentionModel(torch.nn.Module):
# Choose the appropriate attention implementation
if self.use_grouped_sdpa:
attn_output = torch.ops.attention.grouped_sdpa(
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
q,
k,
v,
@ -840,7 +846,7 @@ class CausalAttentionModel(torch.nn.Module):
scale=1.0 / (self.head_dim**0.5),
)
else:
attn_output = torch.ops.attention.scaled_dot_product_attention(
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
q,
k,
v,
@ -886,12 +892,14 @@ def test_match_causal_attention(mask_type, use_grouped_sdpa):
def verify_matcher(gm):
# Find attention operations
if use_grouped_sdpa:
attn_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)]
else:
attn_nodes = [
n
for n in gm.graph.nodes
if is_op(n, torch.ops.attention.scaled_dot_product_attention)
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
]
else:
attn_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
]
if len(attn_nodes) != 1:
@ -990,8 +998,8 @@ class Llama3CausalAttentionModel(torch.nn.Module):
# For grouped attention, repeat k and v
if self.use_grouped_sdpa and self.num_kv_heads != self.num_heads:
n_rep = self.num_heads // self.num_kv_heads
k = torch.ops.attention.repeat_kv(k, n_rep)
v = torch.ops.attention.repeat_kv(v, n_rep)
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
# Create a llama-3.1 style causal mask
# 1. Create a full tensor with a very negative value
@ -1026,7 +1034,7 @@ class Llama3CausalAttentionModel(torch.nn.Module):
# Choose the appropriate attention implementation
if self.use_grouped_sdpa:
attn_output = torch.ops.attention.grouped_sdpa(
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
q,
k,
v,
@ -1036,7 +1044,7 @@ class Llama3CausalAttentionModel(torch.nn.Module):
scale=1.0 / (self.head_dim**0.5),
)
else:
attn_output = torch.ops.attention.scaled_dot_product_attention(
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
q,
k,
v,
@ -1078,12 +1086,14 @@ def test_match_llama3_causal_attention(use_grouped_sdpa):
def verify_matcher(gm):
# Find attention operations
if use_grouped_sdpa:
attn_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)]
else:
attn_nodes = [
n
for n in gm.graph.nodes
if is_op(n, torch.ops.attention.scaled_dot_product_attention)
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
]
else:
attn_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
]
if len(attn_nodes) != 1:
@ -1129,7 +1139,7 @@ class MockAttentionDescriptor:
"""A mock class that mimics the AttentionDescriptor interface for testing."""
layout: str = "bnsd"
source_attention_op: Callable = torch.ops.attention.scaled_dot_product_attention
source_attention_op: Callable = torch.ops.auto_deploy.torch_attention_sdpa
@classmethod
def get_attention_layout(cls) -> str:
@ -1199,7 +1209,7 @@ class AttentionLayoutModel(torch.nn.Module):
# Apply scaled dot product attention
if self.use_grouped_sdpa:
attn_output = torch.ops.attention.grouped_sdpa(
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
q,
k,
v,
@ -1209,7 +1219,7 @@ class AttentionLayoutModel(torch.nn.Module):
scale=1.0 / (self.head_dim**0.5),
)
else:
attn_output = torch.ops.attention.scaled_dot_product_attention(
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
q,
k,
v,
@ -1246,7 +1256,7 @@ class BsndAttentionModel(AttentionLayoutModel):
attn_mask = self._get_attn_mask(x) if self.has_mask else None
# Apply bsnd_grouped_sdpa directly
attn_output = torch.ops.attention.bsnd_grouped_sdpa.default(
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa.default(
q,
k,
v,
@ -1284,11 +1294,11 @@ def test_match_attention_layout(layout, model_config, has_mask):
MockAttentionDescriptor.layout = layout
if layout == "bnsd":
if model_config.get("use_grouped_sdpa"):
source_op = torch.ops.attention.grouped_sdpa
source_op = torch.ops.auto_deploy.torch_attention_grouped_sdpa
else:
source_op = torch.ops.attention.scaled_dot_product_attention
source_op = torch.ops.auto_deploy.torch_attention_sdpa
else:
source_op = torch.ops.attention.bsnd_grouped_sdpa
source_op = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
MockAttentionDescriptor.source_attention_op = source_op
# Create appropriate model based on model_config
@ -1319,18 +1329,24 @@ def test_match_attention_layout(layout, model_config, has_mask):
if model_config["type"] == "standard":
if model_config["use_grouped_sdpa"]:
original_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)
n
for n in gm.graph.nodes
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
]
else:
original_nodes = [
n
for n in gm.graph.nodes
if is_op(n, torch.ops.attention.scaled_dot_product_attention)
if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
]
else: # already_bsnd
original_nodes = []
bsnd_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.bsnd_grouped_sdpa)]
bsnd_nodes = [
n
for n in gm.graph.nodes
if is_op(n, torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa)
]
transpose_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.aten.transpose.int)]
# Different expectations based on model type and layout

View File

@ -31,7 +31,7 @@ class MockAttentionDescriptor:
@classmethod
def get_source_attention_op(cls) -> Callable:
return torch.ops.attention.bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
class HFWrapper(nn.Module):
@ -83,9 +83,11 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str)
)
def verify_matcher(gm: GraphModule):
"""Ensure that there is exactly one torch.ops.attention.bsnd_grouped_sdpa call in the graph."""
"""Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
call in the graph. Also check that there is no repeat_kv pattern left.
"""
nodes = gm.graph.find_nodes(
op="call_function", target=torch.ops.attention.bsnd_grouped_sdpa
op="call_function", target=torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
)
assert len(nodes) == 1, "Expected exactly one bsnd_grouped_sdpa call in the graph"
@ -100,7 +102,9 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str)
assert attn_node.args[6] == scale # scale
# TODO: check that there is no repeat_kv pattern left...
nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.attention.repeat_kv)
nodes = gm.graph.find_nodes(
op="call_function", target=torch.ops.auto_deploy.torch_attention_repeat_kv
)
assert len(nodes) == 0, "Found repeat_kv pattern in the graph"
return True

View File

@ -53,7 +53,9 @@ class GQAWithSdpa(GQA):
v = v.view(b, s, self.num_kv_heads, self.head_dim)
# Use grouped SDPA in bsnd layout
attn_output = torch.ops.attention.bsnd_grouped_sdpa(q, k, v, None, 0.0, True, None)
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(
q, k, v, None, 0.0, True, None
)
# SDPA output is already in [b, s, n, h_d] format
# Reshape to [b, s, n*h_d]

View File

@ -100,7 +100,7 @@ def test_moe_matching():
model,
x,
match_moe_pattern,
lambda gm: any(is_op(n, torch.ops.moe.torch_moe) for n in gm.graph.nodes),
lambda gm: any(is_op(n, torch.ops.auto_deploy.torch_moe) for n in gm.graph.nodes),
lambda num_p_og: num_p_og,
atol=1e-3,
rtol=1e-3,
@ -119,7 +119,9 @@ def test_moe_fusion():
x,
fuse_moe,
lambda gm: any(
is_op(n, {torch.ops.moe.torch_fused_moe, torch.ops.moe.trtllm_fused_moe})
is_op(
n, {torch.ops.auto_deploy.torch_moe_fused, torch.ops.auto_deploy.trtllm_moe_fused}
)
for n in gm.graph.nodes
),
lambda num_p_og: num_p_og,

View File

@ -14,7 +14,10 @@ from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
def check_quantized(gm):
op_expected = {torch.ops.quant.fp8_linear, torch.ops.quant.fp4_linear}
op_expected = {
torch.ops.auto_deploy.torch_quant_fp8_linear,
torch.ops.auto_deploy.torch_quant_fp4_linear,
}
return any(is_op(n, op_expected) for n in gm.graph.nodes)

View File

@ -91,7 +91,7 @@ class RoPEModel(torch.nn.Module):
if self.mode == "match":
q_out, k_out = apply_rotary_pos_emb_explicit(q, k, cos, sin, unsq_dim)
else: # optimize
q_out, k_out = torch.ops.rope.torch_apply_rope_with_explicit_cos_sin(
q_out, k_out = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin(
q, k, cos, sin, unsq_dim
)
@ -119,7 +119,7 @@ class RoPEModel(torch.nn.Module):
if self.mode == "match":
q_out, k_out = apply_rotary_pos_emb_complex(q, k, freqs, unsq_dim)
else:
q_out, k_out = torch.ops.rope.torch_apply_rope_with_complex_freqs(
q_out, k_out = torch.ops.auto_deploy.torch_rope_with_complex_freqs(
q, k, freqs, unsq_dim
)
@ -212,9 +212,9 @@ def test_rope_variants(
if transformation == "match":
fn = match_rope_pattern
check_op = (
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin
torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin
if variant == "explicit" or variant == "explicit_pm"
else torch.ops.rope.torch_apply_rope_with_complex_freqs
else torch.ops.auto_deploy.torch_rope_with_complex_freqs
)
def checker(gm):
@ -228,8 +228,8 @@ def test_rope_variants(
if is_op(
n,
{
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
torch.ops.rope.torch_apply_rope_with_complex_freqs,
torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin,
torch.ops.auto_deploy.torch_rope_with_complex_freqs,
},
):
q_arg, k_arg, *rest = n.args
@ -254,7 +254,7 @@ def test_rope_variants(
fn = optimize_rope
def checker(gm):
return any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes)
return any(is_op(n, torch.ops.auto_deploy.flashinfer_rope) for n in gm.graph.nodes)
if transformation == "match_layout":
_ = run_test(
@ -346,7 +346,7 @@ class DSModel(torch.nn.Module):
else:
cos = cos[pos_ids]
sin = sin[pos_ids]
q_out, k_out = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
q_out, k_out = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
q, k, cos, sin, unsq_dim
)
if self.layout == "BNSD":
@ -387,7 +387,7 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target
def checker(gm):
return any(
is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving)
is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving)
for n in gm.graph.nodes
)
@ -396,7 +396,7 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target
def checker(gm):
for n in gm.graph.nodes:
if is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving):
if is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving):
q_arg, k_arg, *rest = n.args
if not (
is_op(q_arg, torch.ops.aten.contiguous)