mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[AutoDeploy] Merge feat/ad_2025_06_13 feature branch (#5454)
Signed-off-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com> Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Co-authored-by: Grzegorz Kwasniewski <213329731+greg-kwasniewski1@users.noreply.github.com> Co-authored-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Co-authored-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
73ba4fc320
commit
5cffb7e0ec
42
tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Normal file
42
tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Normal file
@ -0,0 +1,42 @@
|
||||
## AutoDeploy Custom Operators
|
||||
|
||||
All AutoDeploy custom operators follow the following naming convention:
|
||||
|
||||
`torch.ops.auto_deploy.<kernel_backend>_<op_category>_<op_name>`
|
||||
|
||||
The table below lists the operators ordered by their backend.
|
||||
|
||||
### Available Custom Operators
|
||||
|
||||
| Operator Name | Description |
|
||||
|--------------|-------------|
|
||||
| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support |
|
||||
| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation |
|
||||
| `torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa` | Grouped SDPA (Scaled Dot Product Attention) with BSND format |
|
||||
| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) |
|
||||
| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation |
|
||||
| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation |
|
||||
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
|
||||
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
|
||||
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
|
||||
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation |
|
||||
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
|
||||
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
|
||||
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
|
||||
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
|
||||
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation |
|
||||
| `torch.ops.auto_deploy.torch_quant_fp4_linear` | FP4 quantized linear layer |
|
||||
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
|
||||
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
|
||||
| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine |
|
||||
| `torch.ops.auto_deploy.torch_rope_with_qk_interleaving` | RoPE with QK interleaving |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache` | Triton fused flattened MHA with cache |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion` | Triton fused flattened MHA with cache and RoPE fusion |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_cache` | Triton fused MHA with cache |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache` | Triton fused MHA with paged cache |
|
||||
| `torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache` | Triton flattened MHA with cache |
|
||||
| `torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache` | Triton fused flattened Multi-head Latent Attention with cache support |
|
||||
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
|
||||
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
|
||||
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT-LLM fused MoE implementation |
|
||||
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT-LLM fused linear layer followed by all-reduce operation |
|
||||
@ -4,11 +4,12 @@ from ._triton_attention_internal import *
|
||||
from .dist import *
|
||||
from .flashinfer_attention import *
|
||||
from .flashinfer_rope import *
|
||||
from .fused_moe import *
|
||||
from .linear import *
|
||||
from .mla import *
|
||||
from .quant import *
|
||||
from .rope import *
|
||||
from .torch_attention import *
|
||||
from .torch_moe import *
|
||||
from .torch_rope import *
|
||||
from .triton_attention import *
|
||||
from .triton_rope import *
|
||||
from .trtllm_moe import *
|
||||
|
||||
@ -168,7 +168,9 @@ def _paged_context_mha(
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::fused_mha_with_paged_cache", mutates_args=())
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_fused_mha_with_paged_cache", mutates_args=()
|
||||
)
|
||||
def fused_mha_with_paged_cache(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@ -210,10 +212,10 @@ def fused_mha_with_paged_cache(
|
||||
if freqs_cis is not None:
|
||||
if s == 1:
|
||||
rope_args = (freqs_cis, input_pos, "bsnd")
|
||||
fn_rope = torch.ops.rope.apply_rope_with_input_pos
|
||||
fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos
|
||||
else:
|
||||
rope_args = (freqs_cis, input_pos, seq_len, seq_start)
|
||||
fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
|
||||
fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs
|
||||
q = fn_rope(q, *rope_args)
|
||||
k = fn_rope(k, *rope_args)
|
||||
|
||||
@ -416,7 +418,9 @@ def _flattened_context_mha_rope_fusion(
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::fused_flattened_mha_with_cache_rope_fusion", mutates_args=())
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_fused_flattened_mha_with_cache_rope_fusion", mutates_args=()
|
||||
)
|
||||
def fused_flattened_mha_with_cache_rope_fusion(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@ -541,7 +545,7 @@ def _context_mha(
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::fused_mha_with_cache", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::triton_attention_fused_mha_with_cache", mutates_args=())
|
||||
def fused_mha_with_cache(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@ -563,8 +567,8 @@ def fused_mha_with_cache(
|
||||
|
||||
# rope embedding
|
||||
if freqs_cis is not None:
|
||||
q = torch.ops.rope.apply_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd")
|
||||
k = torch.ops.rope.apply_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd")
|
||||
q = torch.ops.auto_deploy.triton_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd")
|
||||
k = torch.ops.auto_deploy.triton_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd")
|
||||
|
||||
# attention (assumed layout is bsnd)
|
||||
y = torch.empty_like(q)
|
||||
@ -593,7 +597,9 @@ def fused_mha_fake(
|
||||
return torch.empty_like(q.contiguous())
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::fused_flattened_mha_with_cache", mutates_args=())
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_fused_flattened_mha_with_cache", mutates_args=()
|
||||
)
|
||||
def fused_flattened_mha_with_cache(
|
||||
# Q, K, V
|
||||
q: torch.Tensor,
|
||||
@ -638,10 +644,10 @@ def fused_flattened_mha_with_cache(
|
||||
if freqs_cis.numel() > 0:
|
||||
if s == 1:
|
||||
rope_args = (freqs_cis, input_pos, "bsnd")
|
||||
fn_rope = torch.ops.rope.apply_rope_with_input_pos
|
||||
fn_rope = torch.ops.auto_deploy.triton_rope_with_input_pos
|
||||
else:
|
||||
rope_args = (freqs_cis, input_pos, seq_len, seq_start)
|
||||
fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
|
||||
fn_rope = torch.ops.auto_deploy.triton_rope_on_flattened_inputs
|
||||
q = fn_rope(q, *rope_args)
|
||||
k = fn_rope(k, *rope_args)
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from ..distributed import common as dist
|
||||
from ..distributed import trtllm as trtllm_dist
|
||||
|
||||
|
||||
@torch.library.custom_op("dist::all_gather", mutates_args=(), device_types="cuda")
|
||||
@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda")
|
||||
def all_gather(
|
||||
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
|
||||
) -> torch.Tensor:
|
||||
@ -25,7 +25,7 @@ def all_gather_fake(tensor, dim=0):
|
||||
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)
|
||||
|
||||
|
||||
@torch.library.custom_op("dist::all_reduce", mutates_args=(), device_types="cuda")
|
||||
@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
|
||||
def all_reduce(t: torch.Tensor) -> torch.Tensor:
|
||||
"""All_reduce across the ranks. Reduction op is SUM.
|
||||
|
||||
|
||||
@ -153,7 +153,7 @@ class _FlashInferPlanner:
|
||||
_GlobalFlashInferPlanner = _FlashInferPlanner()
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::prepare_flashinfer_metadata", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=())
|
||||
def prepare_flashinfer_metadata(
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
@ -228,7 +228,7 @@ def prepare_flashinfer_metadata_fake(
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::flashinfer_mha_with_cache", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=())
|
||||
def flashinfer_mha_with_cache(
|
||||
# Q, K, V
|
||||
q: torch.Tensor,
|
||||
@ -355,15 +355,15 @@ class FlashInferAttention(AttentionDescriptor):
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
"""Get the source attention op that we target for replacement."""
|
||||
return torch.ops.attention.bsnd_grouped_sdpa
|
||||
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return torch.ops.attention.flashinfer_mha_with_cache
|
||||
return torch.ops.auto_deploy.flashinfer_attention_mha_with_cache
|
||||
|
||||
@classmethod
|
||||
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
|
||||
return torch.ops.attention.prepare_flashinfer_metadata, 6
|
||||
return torch.ops.auto_deploy.flashinfer_attention_prepare_metadata, 6
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -4,7 +4,7 @@ import flashinfer
|
||||
import torch
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::flashinfer", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::flashinfer_rope", mutates_args=())
|
||||
def apply_rope_with_input_pos_flashinfer(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
|
||||
@ -8,7 +8,7 @@ from ..distributed import common as dist
|
||||
from ..distributed import trtllm as trtllm_dist
|
||||
|
||||
|
||||
@torch.library.custom_op("linear::simple", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_linear_simple", mutates_args=())
|
||||
def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""A wrapper for the linear functional to control how it is exposed.
|
||||
|
||||
@ -30,7 +30,9 @@ def simple_fake(input, weight, bias):
|
||||
return torch.ops.aten.linear(input, weight, bias)
|
||||
|
||||
|
||||
@torch.library.custom_op("linear::fused_linear_all_reduce", mutates_args=(), device_types="cuda")
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
|
||||
)
|
||||
def fused_linear_all_reduce(
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
|
||||
@ -22,7 +22,9 @@ from .triton_attention import _flattened_context_mha, _generate_mha
|
||||
Constant = Union[int, float, str, None]
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::fused_flattened_mla_with_cache", mutates_args=())
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_fused_flattened_mla_with_cache", mutates_args=()
|
||||
)
|
||||
def fused_flattened_mla_with_cache(
|
||||
# Q, K, V
|
||||
q_nope: torch.Tensor,
|
||||
@ -94,7 +96,7 @@ def fused_flattened_mla_with_cache(
|
||||
q_slice = q_pe[start : start + length]
|
||||
k_slice = k_pe[start : start + length]
|
||||
|
||||
q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
|
||||
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
|
||||
q_slice,
|
||||
k_slice,
|
||||
cos,
|
||||
@ -169,7 +171,9 @@ def fused_flattened_mla_with_cache_fake(
|
||||
return torch.empty_like(kv[..., -v_head_dim:])
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::prepare_fused_mla_metadata", mutates_args=())
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_prepare_fused_mla_metadata", mutates_args=()
|
||||
)
|
||||
def prepare_fused_mla_metadata(
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
@ -221,15 +225,15 @@ class MultiHeadLatentAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
return torch.ops.deepseek.fused_mla
|
||||
return torch.ops.auto_deploy.torch_attention_deepseek_fused_mla
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return torch.ops.attention.fused_flattened_mla_with_cache
|
||||
return torch.ops.auto_deploy.triton_attention_fused_flattened_mla_with_cache
|
||||
|
||||
@classmethod
|
||||
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
|
||||
return torch.ops.attention.prepare_fused_mla_metadata, 4
|
||||
return torch.ops.auto_deploy.triton_attention_prepare_fused_mla_metadata, 4
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -16,7 +16,7 @@ TRTLLM_FP4_OP_AVAILABLE = True
|
||||
TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16
|
||||
|
||||
|
||||
@torch.library.custom_op("quant::quant_fn", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=())
|
||||
def quant_fn(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
scaled_x = x / scale
|
||||
rounded_x = torch.round(scaled_x)
|
||||
@ -37,7 +37,7 @@ class QuantModule(nn.Module):
|
||||
self.register_buffer("scale", torch.tensor(scale))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return torch.ops.quant.quant_fn(x, self.scale)
|
||||
return torch.ops.auto_deploy.torch_quant_fn(x, self.scale)
|
||||
|
||||
|
||||
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
|
||||
@ -50,7 +50,7 @@ def _to_fp8(x, scale):
|
||||
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
@torch.library.custom_op("quant::fp8_linear", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_quant_fp8_linear", mutates_args=())
|
||||
@torch.compile(dynamic=True)
|
||||
def fp8_linear(
|
||||
input: torch.Tensor,
|
||||
@ -105,7 +105,7 @@ def fp8_linear_fake(
|
||||
return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias)
|
||||
|
||||
|
||||
@torch.library.custom_op("quant::fused_fp8_linear_all_reduce", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_quant_fused_fp8_linear_all_reduce", mutates_args=())
|
||||
@torch.compile(dynamic=True)
|
||||
def fused_fp8_linear_all_reduce(
|
||||
input: torch.Tensor,
|
||||
@ -114,7 +114,9 @@ def fused_fp8_linear_all_reduce(
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
weight_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
out = torch.ops.quant.fp8_linear(input, weight_fp8, bias, input_scale, weight_scale)
|
||||
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
input, weight_fp8, bias, input_scale, weight_scale
|
||||
)
|
||||
if trtllm_dist.is_trtllm_op_available():
|
||||
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(out, op=dist.ReduceOp.SUM)
|
||||
@ -129,7 +131,9 @@ def fused_fp8_linear_all_reduce_fake(
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
weight_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.quant.fp8_linear(input, weight_fp8, bias, input_scale, weight_scale)
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
input, weight_fp8, bias, input_scale, weight_scale
|
||||
)
|
||||
|
||||
|
||||
class FP8Linear(nn.Linear):
|
||||
@ -146,12 +150,12 @@ class FP8Linear(nn.Linear):
|
||||
self.bias = nn.Parameter(self.bias.to(torch.half))
|
||||
|
||||
def forward(self, x):
|
||||
return torch.ops.quant.fp8_linear(
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
x, self.weight, self.bias, self.input_scale, self.weight_scale
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("quant::fp4_linear", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_quant_fp4_linear", mutates_args=())
|
||||
@torch.compile(dynamic=True)
|
||||
def fp4_linear(
|
||||
input: torch.Tensor,
|
||||
@ -218,4 +222,7 @@ def fp4_linear_fake(
|
||||
return torch.ops.aten.linear(input, weight_fp4.repeat(1, 2).to(input.dtype), bias)
|
||||
|
||||
|
||||
QUANT_OPS = [torch.ops.quant.fp8_linear, torch.ops.quant.fp4_linear]
|
||||
QUANT_OPS = [
|
||||
torch.ops.auto_deploy.torch_quant_fp8_linear,
|
||||
torch.ops.auto_deploy.torch_quant_fp4_linear,
|
||||
]
|
||||
|
||||
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::repeat_kv", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=())
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
@ -31,7 +31,7 @@ def repeat_kv_fake(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return torch.empty(replicated_shape, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::scaled_dot_product_attention", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_sdpa", mutates_args=())
|
||||
def scaled_dot_product_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -66,7 +66,7 @@ def scaled_dot_product_attention_fake(
|
||||
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::grouped_sdpa", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_grouped_sdpa", mutates_args=())
|
||||
def grouped_sdpa(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -104,7 +104,7 @@ def grouped_sdpa_fake(
|
||||
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::bsnd_grouped_sdpa", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
|
||||
def bsnd_grouped_sdpa(
|
||||
query: torch.Tensor, # layout: [b, n, s_q, d]
|
||||
key: torch.Tensor, # layout: [b, n, s_k, d]
|
||||
@ -162,7 +162,7 @@ def update_kv_cache(
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::fused_mla_ref", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_fused_mla_ref", mutates_args=())
|
||||
def fused_mla_ref(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
@ -215,7 +215,7 @@ def fused_mla_ref(
|
||||
q_slice = q_pe[start : start + length]
|
||||
k_slice = k_pe[start : start + length]
|
||||
|
||||
q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
|
||||
q_rot, k_rot = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
|
||||
q_slice,
|
||||
k_slice,
|
||||
cos,
|
||||
@ -315,7 +315,7 @@ def fused_mla_ref_fake(
|
||||
return torch.empty_like(kv[..., -v_head_dim:])
|
||||
|
||||
|
||||
@torch.library.custom_op("deepseek::fused_mla", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_deepseek_fused_mla", mutates_args=())
|
||||
def fused_mla(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
@ -340,7 +340,7 @@ def fused_mla(
|
||||
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
q_pe, k_pe = torch.ops.rope.torch_apply_rope_with_qk_interleaving(q_pe, k_pe, cos, sin)
|
||||
q_pe, k_pe = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(q_pe, k_pe, cos, sin)
|
||||
|
||||
query_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim)
|
||||
query_states[:, :, :, :qk_nope_head_dim] = q_nope
|
||||
@ -399,7 +399,7 @@ def fused_mla(
|
||||
return torch.empty_like(kv[..., -v_head_dim:])
|
||||
|
||||
|
||||
@torch.library.custom_op("deepseek::mla", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_deepseek_mla", mutates_args=())
|
||||
def mla(
|
||||
q_nope: torch.Tensor, # Down projected q_nope
|
||||
q_pe: torch.Tensor, # q_pe after applying rope
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import List
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...modules.fused_moe import MoE # noqa: F401
|
||||
|
||||
|
||||
@torch.library.custom_op("moe::torch_moe", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_moe", mutates_args=())
|
||||
def torch_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
@ -80,7 +78,7 @@ def torch_moe(
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@torch.library.custom_op("moe::torch_fused_moe", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_moe_fused", mutates_args=())
|
||||
def torch_fused_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
@ -90,7 +88,6 @@ def torch_fused_moe(
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A reference implementation of a fused MoE layer computation.
|
||||
|
||||
Parameters:
|
||||
x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
|
||||
S is the sequence length, and H is the hidden size.
|
||||
@ -102,7 +99,6 @@ def torch_fused_moe(
|
||||
containing the fused weights for w3 and w1 for each expert.
|
||||
w2_stacked_weight (torch.Tensor): A tensor of shape (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
|
||||
containing the weights for w2 for each expert.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape as the input x.
|
||||
"""
|
||||
@ -145,45 +141,3 @@ def torch_fused_moe(
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@torch.library.custom_op("moe::trtllm_fused_moe", mutates_args=())
|
||||
def trtllm_fused_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
|
||||
routing_weights = routing_weights.to(torch.float32)
|
||||
selected_experts = selected_experts.to(torch.int32)
|
||||
quant_scales = []
|
||||
|
||||
return torch.ops.trtllm.fused_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w3_w1_stacked_weight,
|
||||
w2_stacked_weight,
|
||||
x.dtype,
|
||||
quant_scales,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
enable_alltoall=False,
|
||||
)[0].view(x_shape)
|
||||
|
||||
|
||||
@trtllm_fused_moe.register_fake
|
||||
def trtllm_fused_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
@ -11,7 +11,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::torch_apply_rope_with_explicit_cos_sin", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_rope_with_explicit_cos_sin", mutates_args=())
|
||||
def torch_apply_rope_with_explicit_cos_sin(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -38,7 +38,7 @@ def torch_apply_rope_with_explicit_cos_sin_fake(
|
||||
return torch.empty_like(q), torch.empty_like(k)
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::torch_apply_rope_with_complex_freqs", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_rope_with_complex_freqs", mutates_args=())
|
||||
def torch_apply_rope_with_complex_freqs(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
@ -69,7 +69,7 @@ def torch_apply_rope_with_complex_freqs_fake(
|
||||
return torch.empty_like(xq), torch.empty_like(xk)
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::torch_apply_rope_with_qk_interleaving", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::torch_rope_with_qk_interleaving", mutates_args=())
|
||||
def torch_apply_rope_with_qk_interleaving(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@ -169,7 +169,7 @@ def _flattened_context_mha(
|
||||
)
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::flattened_mha_with_cache", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::triton_attention_flattened_mha_with_cache", mutates_args=())
|
||||
def flattened_mha_with_cache(
|
||||
# Q, K, V
|
||||
q: torch.Tensor,
|
||||
@ -259,7 +259,9 @@ def flattened_mha_fake(
|
||||
return q.new_empty(*q.shape[:-1], v.shape[-1]).contiguous()
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::prepare_fused_mha_metadata", mutates_args=())
|
||||
@torch.library.custom_op(
|
||||
"auto_deploy::triton_attention_prepare_fused_mha_metadata", mutates_args=()
|
||||
)
|
||||
def prepare_fused_mha_metadata(
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
@ -314,15 +316,15 @@ class TritonWithFlattenedInputs(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
return torch.ops.attention.bsnd_grouped_sdpa
|
||||
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
return torch.ops.attention.flattened_mha_with_cache
|
||||
return torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache
|
||||
|
||||
@classmethod
|
||||
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
|
||||
return torch.ops.attention.prepare_fused_mha_metadata, 4
|
||||
return torch.ops.auto_deploy.triton_attention_prepare_fused_mha_metadata, 4
|
||||
|
||||
@classmethod
|
||||
def get_cache_initializers(
|
||||
|
||||
@ -4,7 +4,7 @@ import triton
|
||||
from .triton_kernels.rope import rope_fwd_flattened_kernel, rope_fwd_kernel
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::apply_rope_with_input_pos", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::triton_rope_with_input_pos", mutates_args=())
|
||||
def apply_rope_with_input_pos(
|
||||
x: torch.Tensor, freqs_cis: torch.Tensor, input_pos: torch.Tensor, layout: str
|
||||
) -> torch.Tensor:
|
||||
@ -77,7 +77,7 @@ def apply_rope_with_input_pos_fake(x, freqs_cis, input_pos, layout):
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::apply_rope_on_flattened_inputs", mutates_args=())
|
||||
@torch.library.custom_op("auto_deploy::triton_rope_on_flattened_inputs", mutates_args=())
|
||||
def apply_rope_on_flattened_inputs(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
43
tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_moe.py
Normal file
43
tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_moe.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::trtllm_moe_fused", mutates_args=())
|
||||
def trtllm_fused_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
|
||||
routing_weights = routing_weights.to(torch.float32)
|
||||
selected_experts = selected_experts.to(torch.int32)
|
||||
quant_scales = []
|
||||
|
||||
return torch.ops.trtllm.fused_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w3_w1_stacked_weight,
|
||||
w2_stacked_weight,
|
||||
x.dtype,
|
||||
quant_scales,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
enable_alltoall=False,
|
||||
)[0].view(x_shape)
|
||||
|
||||
|
||||
@trtllm_fused_moe.register_fake
|
||||
def trtllm_fused_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
@ -53,7 +53,7 @@ def deepseek_v3_attention(
|
||||
# Use custom op to capture mla. This does not handle KV cache
|
||||
# as passing transformers Cache into a custom op is throwing an error.
|
||||
# Would not be an issue, cause we intend to replace mla op with our implementation further along the pipeline
|
||||
attn_output = torch.ops.deepseek.fused_mla(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_deepseek_fused_mla(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv,
|
||||
@ -131,7 +131,7 @@ def deepseek_v3_moe(self, hidden_states):
|
||||
"""DeepSeekV3MoE forward function rewritten in Mixtral style to enable torch export."""
|
||||
|
||||
selected_experts, routing_weights, *_ = self.gate(hidden_states)
|
||||
final_hidden_states = torch.ops.moe.torch_moe(
|
||||
final_hidden_states = torch.ops.auto_deploy.torch_moe(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
|
||||
@ -17,7 +17,7 @@ def _forward_moe(self: Qwen3MoeSparseMoeBlock, hidden_states: torch.Tensor):
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
final_hidden_states = torch.ops.moe.torch_moe(
|
||||
final_hidden_states = torch.ops.auto_deploy.torch_moe(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
|
||||
@ -13,7 +13,6 @@ from ....bindings.internal.batch_manager import CacheType
|
||||
from ....llmapi.llm_args import _AutoDeployLlmArgs
|
||||
from ....mapping import Mapping
|
||||
from ...distributed import MPIDist
|
||||
from ...pyexecutor._util import create_torch_sampler_args
|
||||
from ...pyexecutor.config import PyTorchConfig
|
||||
from ...pyexecutor.model_engine import ModelEngine
|
||||
from ...pyexecutor.py_executor import PyExecutor
|
||||
@ -266,9 +265,14 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
ad_config: _AutoDeployLlmArgs = executor_config.pytorch_backend_config
|
||||
|
||||
max_batch_size = ad_config.max_batch_size
|
||||
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
|
||||
max_seq_len = ad_config.max_seq_len
|
||||
attn_page_size = ad_config.attn_page_size
|
||||
max_num_tokens = ad_config.max_num_tokens
|
||||
max_draft_tokens = (
|
||||
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens
|
||||
)
|
||||
|
||||
ad_logger.info(f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}")
|
||||
|
||||
# initialize model engine
|
||||
@ -309,22 +313,30 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)
|
||||
|
||||
# search sampler with speculative decoding
|
||||
sampler_args = create_torch_sampler_args(
|
||||
executor_config, dist_mapping, mixed_sampler=False, max_seq_len=max_seq_len
|
||||
# TODO (lucaslie, fridah-nv): some models require mixed_sampler=True to have good outputs, see
|
||||
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
|
||||
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
|
||||
# correctly for models as needed.
|
||||
sampler_args = TorchSampler.Args(
|
||||
max_seq_len=max_seq_len,
|
||||
max_draft_tokens=max_draft_tokens,
|
||||
max_num_sequences=max_num_sequences,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
mixed_sampler=ad_config.mixed_sampler,
|
||||
)
|
||||
sampler = TorchSampler(sampler_args)
|
||||
|
||||
# creating the executor object
|
||||
py_executor = PyExecutor(
|
||||
resource_manager,
|
||||
scheduler,
|
||||
model_engine=engine,
|
||||
sampler=sampler,
|
||||
dist=mpi_dist,
|
||||
max_num_sequences=ad_config.max_batch_size * dist_mapping.pp_size,
|
||||
max_num_sequences=max_num_sequences,
|
||||
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
|
||||
max_input_len=ad_config.max_input_len,
|
||||
max_batch_size=ad_config.max_batch_size,
|
||||
max_draft_tokens=ad_config.speculative_config.max_draft_tokens
|
||||
if ad_config.speculative_config is not None
|
||||
else 0,
|
||||
max_batch_size=max_batch_size,
|
||||
max_draft_tokens=max_draft_tokens,
|
||||
)
|
||||
return py_executor
|
||||
|
||||
@ -215,11 +215,12 @@ class DemoEngine(ADEngine):
|
||||
def _sample(
|
||||
cls, logits: torch.Tensor, sampling_params: SamplingParams
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
probs = cls._logits_to_probs(
|
||||
logits, sampling_params.temperature, sampling_params.top_k
|
||||
) # [*logits.shape]
|
||||
# idx_next shape is [*logits.shape[:-1]]
|
||||
idx_next = cls._multinomial_sample_one_no_sync(probs)
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import top_k_sampling_batch
|
||||
|
||||
logits_shape = logits.shape
|
||||
logits = logits.view(-1, logits_shape[-1]) # top_k_sampling_batch expects 2D logits
|
||||
idx_next, probs = top_k_sampling_batch(logits, sampling_params.top_k)
|
||||
idx_next = idx_next.view(logits_shape[:-1])
|
||||
return idx_next, probs
|
||||
|
||||
def _decode_tokens(
|
||||
|
||||
@ -213,7 +213,7 @@ _torch_where_patch.where_original = torch.where
|
||||
def _torch_linear_patch(
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.linear.simple(input, weight, bias)
|
||||
return torch.ops.auto_deploy.torch_linear_simple(input, weight, bias)
|
||||
|
||||
|
||||
# TODO: remove once https://github.com/pytorch/pytorch/issues/142439 is resolved
|
||||
@ -336,7 +336,7 @@ def torch_export_to_gm(
|
||||
# there is no guarantee how it is represented and we need to make sure it is easily identifiable
|
||||
# in the graph.
|
||||
sdpa_original = F.scaled_dot_product_attention
|
||||
F.scaled_dot_product_attention = torch.ops.attention.scaled_dot_product_attention
|
||||
F.scaled_dot_product_attention = torch.ops.auto_deploy.torch_attention_sdpa
|
||||
|
||||
# We overwrite the linear functional as well. This basically avoids exporting the view ops
|
||||
# that are used to flatten/unflatten multiple batch dimensions of the input tensor.
|
||||
|
||||
@ -18,7 +18,7 @@ def match_repeat_kv(gm: GraphModule) -> GraphModule:
|
||||
The pattern is:
|
||||
unsqueeze -> expand -> reshape -> [optional] contiguous
|
||||
|
||||
This is replaced with torch.ops.attention.repeat_kv.
|
||||
This is replaced with torch.ops.auto_deploy.torch_attention_repeat_kv.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
@ -49,7 +49,7 @@ def match_eager_attention(gm: GraphModule) -> GraphModule:
|
||||
The pattern is:
|
||||
transpose -> matmul -> mul -> (optional) add -> softmax -> to -> dropout -> matmul
|
||||
|
||||
This is replaced with torch.ops.attention.scaled_dot_product_attention.
|
||||
This is replaced with torch.ops.auto_deploy.torch_attention_sdpa.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
@ -82,7 +82,7 @@ def match_grouped_attention(gm: GraphModule) -> GraphModule:
|
||||
repeat_kv(v, n_rep) ->
|
||||
sdpa(q, repeated_k, repeated_v)
|
||||
|
||||
This is replaced with torch.ops.attention.grouped_sdpa.
|
||||
This is replaced with torch.ops.auto_deploy.torch_attention_grouped_sdpa.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
@ -92,7 +92,7 @@ def match_grouped_attention(gm: GraphModule) -> GraphModule:
|
||||
# Iterate through nodes in the graph
|
||||
for node in list(graph.nodes):
|
||||
# Look for SDPA nodes that could be part of our pattern
|
||||
if is_op(node, torch.ops.attention.scaled_dot_product_attention):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_attention_sdpa):
|
||||
match_info = _match_grouped_attention_pattern(node)
|
||||
if match_info:
|
||||
ad_logger.debug(f"Found grouped attention pattern at {node}")
|
||||
@ -126,8 +126,8 @@ def match_causal_attn_mask(gm: GraphModule) -> GraphModule:
|
||||
for node in list(graph.nodes):
|
||||
# Look for SDPA nodes or grouped SDPA nodes
|
||||
if not (
|
||||
is_op(node, torch.ops.attention.scaled_dot_product_attention)
|
||||
or is_op(node, torch.ops.attention.grouped_sdpa)
|
||||
is_op(node, torch.ops.auto_deploy.torch_attention_sdpa)
|
||||
or is_op(node, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
):
|
||||
continue
|
||||
|
||||
@ -437,7 +437,7 @@ def _match_grouped_attention_pattern(sdpa_node: Node) -> Optional[Dict[str, Node
|
||||
Returns a dictionary with information about the match or None if no match.
|
||||
"""
|
||||
# Check that sdpa_node is an SDPA operation
|
||||
if not is_op(sdpa_node, torch.ops.attention.scaled_dot_product_attention):
|
||||
if not is_op(sdpa_node, torch.ops.auto_deploy.torch_attention_sdpa):
|
||||
return None
|
||||
|
||||
# SDPA should have query, key, value as its first three arguments
|
||||
@ -447,8 +447,8 @@ def _match_grouped_attention_pattern(sdpa_node: Node) -> Optional[Dict[str, Node
|
||||
query, key_repeated, value_repeated = sdpa_node.args[0:3]
|
||||
|
||||
# Key and value should come from repeat_kv operations
|
||||
if not is_op(key_repeated, torch.ops.attention.repeat_kv) or not is_op(
|
||||
value_repeated, torch.ops.attention.repeat_kv
|
||||
if not is_op(key_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv) or not is_op(
|
||||
value_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv
|
||||
):
|
||||
return None
|
||||
|
||||
@ -487,7 +487,7 @@ def _replace_with_repeat_kv(graph, match_info: Dict[str, Node]) -> None:
|
||||
|
||||
with graph.inserting_before(node_to_replace):
|
||||
repeat_kv_node = graph.call_function(
|
||||
torch.ops.attention.repeat_kv, args=(input_tensor, n_rep)
|
||||
torch.ops.auto_deploy.torch_attention_repeat_kv, args=(input_tensor, n_rep)
|
||||
)
|
||||
|
||||
# Preserve metadata from the original node
|
||||
@ -502,7 +502,7 @@ def _replace_with_sdpa(graph, match_info: Dict[str, Node]) -> None:
|
||||
Replace the matched eager attention pattern with scaled_dot_product_attention.
|
||||
"""
|
||||
# retrieve the default op for scaled_dot_product_attention
|
||||
sdpa_op = torch.ops.attention.scaled_dot_product_attention.default
|
||||
sdpa_op = torch.ops.auto_deploy.torch_attention_sdpa.default
|
||||
|
||||
# construct the args for the ops based on the match_info and the op's schema
|
||||
args = []
|
||||
@ -530,7 +530,7 @@ def _replace_with_sdpa(graph, match_info: Dict[str, Node]) -> None:
|
||||
|
||||
def _replace_with_grouped_sdpa(graph, match_info: Dict[str, Node]) -> None:
|
||||
"""
|
||||
Replace the matched grouped attention pattern with torch.ops.attention.grouped_sdpa.
|
||||
Replace the matched grouped attention pattern with torch.ops.auto_deploy.torch_attention_grouped_sdpa.
|
||||
"""
|
||||
sdpa_node = match_info["sdpa_node"]
|
||||
query = match_info["query"]
|
||||
@ -543,7 +543,7 @@ def _replace_with_grouped_sdpa(graph, match_info: Dict[str, Node]) -> None:
|
||||
|
||||
with graph.inserting_before(sdpa_node):
|
||||
grouped_sdpa_node = graph.call_function(
|
||||
torch.ops.attention.grouped_sdpa.default, args=args, kwargs=kwargs
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa.default, args=args, kwargs=kwargs
|
||||
)
|
||||
|
||||
# Preserve metadata from the original node
|
||||
@ -763,8 +763,8 @@ def match_attention_layout(gm: GraphModule, attention_op: Type[AttentionDescript
|
||||
|
||||
# List of SDPA operations to look for
|
||||
sdpa_ops = {
|
||||
torch.ops.attention.scaled_dot_product_attention,
|
||||
torch.ops.attention.grouped_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
|
||||
}
|
||||
|
||||
graph = gm.graph
|
||||
|
||||
@ -22,14 +22,14 @@ def fuse_collectives(gm: GraphModule) -> GraphModule:
|
||||
# lookup for fused ops
|
||||
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
|
||||
lookup = {
|
||||
torch.ops.linear.simple: torch.ops.linear.fused_linear_all_reduce,
|
||||
torch.ops.aten.linear: torch.ops.linear.fused_linear_all_reduce,
|
||||
torch.ops.quant.fp8_linear: torch.ops.quant.fused_fp8_linear_all_reduce,
|
||||
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
|
||||
}
|
||||
|
||||
# go through all nodes and find all_reduce nodes
|
||||
for node in gm.graph.nodes:
|
||||
if not is_op(node, torch.ops.dist.all_reduce):
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
continue
|
||||
|
||||
# check if args are as expected
|
||||
@ -162,7 +162,7 @@ def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> GraphModule:
|
||||
|
||||
# Traverse all nodes
|
||||
for node in gm.graph.nodes:
|
||||
if is_op(node, torch.ops.dist.all_reduce):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
trace_and_fuse(allreduce_node=node, graph=gm.graph)
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
|
||||
@ -38,7 +38,7 @@ def ep_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
|
||||
assert isinstance(gm, GraphModule), "Expecting GraphModule"
|
||||
num_moe_patterns = 0
|
||||
for node in list(gm.graph.nodes):
|
||||
if not is_op(node, torch.ops.moe.torch_moe):
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
continue
|
||||
_insert_sharded_moe(gm, node, rank, world_size)
|
||||
num_moe_patterns += 1
|
||||
@ -123,6 +123,8 @@ def _insert_sharded_moe(
|
||||
|
||||
# -- add an all_reduce node --
|
||||
with gm.graph.inserting_after(node):
|
||||
dist_node = gm.graph.call_function(torch.ops.dist.all_reduce, args=(node,))
|
||||
dist_node = gm.graph.call_function(
|
||||
torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,)
|
||||
)
|
||||
node.replace_all_uses_with(dist_node)
|
||||
dist_node.replace_input_with(dist_node, node)
|
||||
|
||||
@ -69,7 +69,7 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule:
|
||||
w3_list = expert_weights["w3"]
|
||||
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.moe.torch_moe,
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
@ -99,7 +99,7 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule:
|
||||
def fuse_moe(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Scan the FX graph and replace all calls to torch.ops.moe.torch_moe with
|
||||
torch.ops.moe.trtllm_fused_moe.
|
||||
torch.ops.auto_deploy.trtllm_moe_fused.
|
||||
"""
|
||||
ad_logger.debug("Before MoE fusion: " + str(gm))
|
||||
|
||||
@ -118,7 +118,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
|
||||
graph = gm.graph
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if not is_op(node, torch.ops.moe.torch_moe):
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
continue
|
||||
|
||||
ad_logger.debug(f"Found MoE op to fuse: {node} with args: {node.args}")
|
||||
@ -146,7 +146,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
|
||||
|
||||
with graph.inserting_before(node):
|
||||
new_node = graph.call_function(
|
||||
torch.ops.moe.trtllm_fused_moe,
|
||||
torch.ops.auto_deploy.trtllm_moe_fused,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""
|
||||
This transformation defines two main RoPE (Rotary Positional Embedding) pattern matchers used
|
||||
to identify and replace RoPE subgraphs with a custom op (`torch.ops.rope.flashinfer`).
|
||||
to identify and replace RoPE subgraphs with a custom op (`torch.ops.auto_deploy.flashinfer_rope`).
|
||||
|
||||
Supported RoPE variants:
|
||||
|
||||
@ -73,7 +73,7 @@ def _explicit_rope_pattern(q, k, cos, sin, unsqueeze_dim=1):
|
||||
|
||||
|
||||
def _explicit_rope_repl(q, k, cos, sin, unsqueeze_dim):
|
||||
return torch.ops.rope.torch_apply_rope_with_explicit_cos_sin.default(
|
||||
return torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin.default(
|
||||
q, k, cos, sin, unsqueeze_dim
|
||||
)
|
||||
|
||||
@ -91,7 +91,7 @@ def _interleaved_rope_pattern(q, k, cos, sin, unsqueeze_dim=1):
|
||||
|
||||
|
||||
def _interleaved_rope_repl(q, k, cos, sin, unsqueeze_dim):
|
||||
return torch.ops.rope.torch_apply_rope_with_qk_interleaving.default(
|
||||
return torch.ops.auto_deploy.torch_rope_with_qk_interleaving.default(
|
||||
q, k, cos, sin, unsqueeze_dim
|
||||
)
|
||||
|
||||
@ -109,7 +109,7 @@ def _complex_rope_pattern(xq, xk, freqs_cis, unsqueeze_dim=1):
|
||||
|
||||
|
||||
def _complex_rope_repl(q, k, freqs_cis, unsqueeze_dim):
|
||||
return torch.ops.rope.torch_apply_rope_with_complex_freqs.default(
|
||||
return torch.ops.auto_deploy.torch_rope_with_complex_freqs.default(
|
||||
q, k, freqs_cis, unsqueeze_dim
|
||||
)
|
||||
|
||||
@ -195,9 +195,9 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo
|
||||
|
||||
graph = gm.graph
|
||||
rope_ops = {
|
||||
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
|
||||
torch.ops.rope.torch_apply_rope_with_qk_interleaving,
|
||||
torch.ops.rope.torch_apply_rope_with_complex_freqs,
|
||||
torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin,
|
||||
torch.ops.auto_deploy.torch_rope_with_qk_interleaving,
|
||||
torch.ops.auto_deploy.torch_rope_with_complex_freqs,
|
||||
}
|
||||
|
||||
need_transpose = False
|
||||
@ -206,7 +206,7 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo
|
||||
if not is_op(node, rope_ops):
|
||||
continue
|
||||
|
||||
if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs):
|
||||
q_node, k_node, freqs_node, unsq = extract_op_args(
|
||||
node,
|
||||
"xq", # argument name in schema
|
||||
@ -257,7 +257,7 @@ def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphMo
|
||||
q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2)
|
||||
k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2)
|
||||
|
||||
if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs):
|
||||
new_args = (
|
||||
q_for_op_contig,
|
||||
k_for_op_contig,
|
||||
@ -309,9 +309,9 @@ def optimize_rope(gm: GraphModule) -> GraphModule:
|
||||
|
||||
num_rope_optimizations = 0
|
||||
for node in list(graph.nodes):
|
||||
if is_op(node, torch.ops.rope.torch_apply_rope_with_explicit_cos_sin):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin):
|
||||
_optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache)
|
||||
elif is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
|
||||
elif is_op(node, torch.ops.auto_deploy.torch_rope_with_complex_freqs):
|
||||
_optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache)
|
||||
else:
|
||||
continue
|
||||
@ -398,7 +398,7 @@ def _optimize_explicit(
|
||||
rope_position_ids_cache=pos_cache,
|
||||
)
|
||||
flash_node = graph.call_function(
|
||||
torch.ops.rope.flashinfer,
|
||||
torch.ops.auto_deploy.flashinfer_rope,
|
||||
args=(q_node, k_node, position_ids, fused_cos_sin_to, True),
|
||||
)
|
||||
|
||||
@ -478,7 +478,7 @@ def _optimize_complex(
|
||||
graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=pos_cache
|
||||
)
|
||||
flash_node = graph.call_function(
|
||||
torch.ops.rope.flashinfer,
|
||||
torch.ops.auto_deploy.flashinfer_rope,
|
||||
args=(q_node, k_node, position_ids, cos_sin_flash, False),
|
||||
)
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ Our sharding algorithm for tensor parallelism (TP) is based on the following ste
|
||||
happens automatically via the checkpoint loading hook added in step 2c.
|
||||
"""
|
||||
|
||||
import math
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
@ -71,7 +72,13 @@ def _load_hook_remove(
|
||||
|
||||
|
||||
def _insert_sharded_matmul(
|
||||
gm: GraphModule, node: Node, dim: int, rank: int, world_size: int, add_dist: bool = False
|
||||
gm: GraphModule,
|
||||
node: Node,
|
||||
dim: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
add_dist: bool = False,
|
||||
min_local_shape: int = 1,
|
||||
):
|
||||
"""Replaces the matmul node with a new matmul node that accepts sharded weights.
|
||||
|
||||
@ -83,8 +90,21 @@ def _insert_sharded_matmul(
|
||||
quantization_impl = QuantizationImpl.create(node)
|
||||
|
||||
def split_tensor(
|
||||
t: torch.Tensor, d: int = dim, r: int = rank, ws: int = world_size
|
||||
t: torch.Tensor,
|
||||
d: int = dim,
|
||||
r: int = rank,
|
||||
ws: int = world_size,
|
||||
min_d_shape: int = min_local_shape,
|
||||
) -> torch.Tensor:
|
||||
# The local tensor shape has to be divisible by min_d_shape
|
||||
max_split_size = t.shape[d] // min_d_shape
|
||||
if ws > max_split_size:
|
||||
num_groups = math.ceil(ws / max_split_size)
|
||||
ad_logger.debug(
|
||||
f"World size {ws} is greater than the max split size {max_split_size}. "
|
||||
+ f"Splitting tensor to {num_groups} chunks"
|
||||
)
|
||||
return torch.tensor_split(t, max_split_size, dim=d)[r // num_groups]
|
||||
return torch.tensor_split(t, ws, dim=d)[r]
|
||||
|
||||
num_users = num_users_of_weight_node(node)
|
||||
@ -168,8 +188,8 @@ def _insert_sharded_matmul(
|
||||
|
||||
# figure out the right dist op
|
||||
dist_lookup = {
|
||||
0: (torch.ops.dist.all_gather, -1),
|
||||
1: (torch.ops.dist.all_reduce,),
|
||||
0: (torch.ops.auto_deploy.torch_dist_all_gather, -1),
|
||||
1: (torch.ops.auto_deploy.torch_dist_all_reduce,),
|
||||
}
|
||||
fn_dist, *dist_args = dist_lookup[dim]
|
||||
|
||||
@ -191,7 +211,10 @@ def _simple_shard(
|
||||
|
||||
|
||||
def column_row_shard(
|
||||
gm: GraphModule, rank: int, world_size: int, simple_shard_only: bool = False
|
||||
gm: GraphModule,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
simple_shard_only: bool = False,
|
||||
) -> GraphModule:
|
||||
"""A transformation to apply sharding to the model following tensor parallelism.
|
||||
|
||||
@ -205,6 +228,9 @@ def column_row_shard(
|
||||
**all** nodes in the subgraph. The subgraph here is defined as the region between the first
|
||||
linear node to the last linear node of an identified sharding region.
|
||||
# 5. Shard the GEMM nodes or skip accordingly.
|
||||
|
||||
min_local_shape is the minimum size of the local tensor shard, to prevent TP parallelism
|
||||
splitting, e.g., the individual heads into smaller shards.
|
||||
"""
|
||||
ad_logger.debug("Before sharding graph: " + str(gm))
|
||||
|
||||
@ -232,9 +258,9 @@ def column_row_shard(
|
||||
|
||||
# acceptable attention nodes between sharded GEMMs
|
||||
shardable_attention_nodes = {
|
||||
torch.ops.attention.scaled_dot_product_attention,
|
||||
torch.ops.attention.grouped_sdpa,
|
||||
torch.ops.attention.bsnd_grouped_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa,
|
||||
}
|
||||
|
||||
# This is a heuristic. Basically, we assume those are okay to shard if we also encounter an
|
||||
@ -244,7 +270,7 @@ def column_row_shard(
|
||||
shardable_nodes_with_attention = {
|
||||
torch.ops.aten.view,
|
||||
torch.ops.aten.reshape,
|
||||
torch.ops.rope.flashinfer,
|
||||
torch.ops.auto_deploy.flashinfer_rope,
|
||||
operator.getitem,
|
||||
}
|
||||
|
||||
@ -327,9 +353,25 @@ def column_row_shard(
|
||||
|
||||
# If we can account for all sharded nodes, we can do a two-way shard
|
||||
# --> row_split (dim 0) + col_split (dim 1) + all_reduce
|
||||
|
||||
# check if we are sharding the attention block
|
||||
if attention_nodes:
|
||||
if len(attention_nodes) > 1:
|
||||
# Column-row shard boundary region detection is probably wrong - there should be
|
||||
# only one attention operation. Fall back to simple shard.
|
||||
ad_logger.debug(f"More than one attention node: {unaccounted_nodes}")
|
||||
_simple_shard(gm, nodes_linear, rank, world_size)
|
||||
continue
|
||||
# Extract head dimension. We cannot shard below the head_dim size.
|
||||
# Assume that head_dim is the last (innermost) dimension of the tensor
|
||||
min_local_shape = attention_nodes.pop().meta["val"].shape[-1]
|
||||
else:
|
||||
min_local_shape = 1
|
||||
for i, group in enumerate(nodes_linear.values()):
|
||||
for n in group:
|
||||
_insert_sharded_matmul(gm, n, i, rank, world_size, add_dist=i > 0)
|
||||
_insert_sharded_matmul(
|
||||
gm, n, i, rank, world_size, add_dist=i > 0, min_local_shape=min_local_shape
|
||||
)
|
||||
|
||||
# canonicalize and return
|
||||
if num_shards:
|
||||
@ -424,7 +466,7 @@ def dp_bmm_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
|
||||
base_size = bmm_batch_size // world_size
|
||||
remainder = bmm_batch_size % world_size
|
||||
|
||||
# NOTE: our torch.ops.dist.all_gather doesn't support uneven splits at the moment.
|
||||
# NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment.
|
||||
if remainder:
|
||||
ad_logger.warning(
|
||||
f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. "
|
||||
@ -451,7 +493,7 @@ def dp_bmm_shard(gm: GraphModule, rank: int, world_size: int) -> GraphModule:
|
||||
# Add all_gather node after BMM to collect results
|
||||
with gm.graph.inserting_after(node):
|
||||
gather_node = gm.graph.call_function(
|
||||
torch.ops.dist.all_gather,
|
||||
torch.ops.auto_deploy.torch_dist_all_gather,
|
||||
args=(node, 0), # Gather along batch dimension (0)
|
||||
)
|
||||
node.replace_all_uses_with(gather_node)
|
||||
|
||||
@ -68,11 +68,11 @@ PytorchExportedProgramAdapterImpl.add_outputs_metadata = add_outputs_metadata
|
||||
|
||||
# TODO(yudong): make custom_ops configurable
|
||||
CUSTOM_OPS = (
|
||||
torch.ops.dist.all_reduce.default,
|
||||
torch.ops.auto_deploy.torch_dist_all_reduce.default,
|
||||
torch.ops.aten.slice.Tensor,
|
||||
torch.ops.attention.fused_mha_with_cache.default,
|
||||
torch.ops.linear.fused_linear_all_reduce.default,
|
||||
torch.ops.linear.simple.default,
|
||||
torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default,
|
||||
torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce.default,
|
||||
torch.ops.auto_deploy.torch_linear_simple.default,
|
||||
torch.ops.aten.split_with_sizes.default,
|
||||
)
|
||||
|
||||
|
||||
@ -222,7 +222,7 @@ def is_linear_op(node: Node, include_quantization: bool = False) -> bool:
|
||||
"""
|
||||
lin_ops = {
|
||||
torch.ops.aten.linear,
|
||||
torch.ops.linear.simple,
|
||||
torch.ops.auto_deploy.torch_linear_simple,
|
||||
}
|
||||
|
||||
if include_quantization:
|
||||
@ -233,8 +233,8 @@ def is_linear_op(node: Node, include_quantization: bool = False) -> bool:
|
||||
def is_dist_op(node: Node) -> bool:
|
||||
"""Check if the node is a distributed op."""
|
||||
dist_ops = {
|
||||
torch.ops.dist.all_gather,
|
||||
torch.ops.dist.all_reduce,
|
||||
torch.ops.auto_deploy.torch_dist_all_gather,
|
||||
torch.ops.auto_deploy.torch_dist_all_reduce,
|
||||
}
|
||||
return is_op(node, dist_ops)
|
||||
|
||||
|
||||
@ -136,7 +136,7 @@ class QuantizationImpl:
|
||||
class FP8QuantizationImpl(QuantizationImpl):
|
||||
@staticmethod
|
||||
def target_op():
|
||||
return torch.ops.quant.fp8_linear
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_linear
|
||||
|
||||
@staticmethod
|
||||
def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor:
|
||||
@ -207,7 +207,7 @@ def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank,
|
||||
class FP4QuantizationImpl(QuantizationImpl):
|
||||
@staticmethod
|
||||
def target_op():
|
||||
return torch.ops.quant.fp4_linear
|
||||
return torch.ops.auto_deploy.torch_quant_fp4_linear
|
||||
|
||||
@staticmethod
|
||||
def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -184,7 +184,7 @@ class MoEOpModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: Tensor of shape (batch, hidden_size)
|
||||
Computes router logits via a gate, and then calls the MoE op via torch.moe.torch_moe.
|
||||
Computes router logits via a gate, and then calls the MoE op via torch.ops.auto_deploy.torch_moe.
|
||||
"""
|
||||
|
||||
router_logits = self.gate(x)
|
||||
@ -197,7 +197,7 @@ class MoEOpModel(nn.Module):
|
||||
w2_list = [expert.w2 for expert in self.experts]
|
||||
w3_list = [expert.w3 for expert in self.experts]
|
||||
|
||||
out = torch.ops.moe.torch_moe(
|
||||
out = torch.ops.auto_deploy.torch_moe(
|
||||
x, selected_experts, routing_weights, w1_list, w2_list, w3_list
|
||||
)
|
||||
return out
|
||||
|
||||
@ -9,14 +9,14 @@ from tensorrt_llm._torch.auto_deploy.distributed.common import spawn_multiproces
|
||||
|
||||
def _run_all_reduce_test(rank, world_size):
|
||||
x = torch.ones(10, 10).to("cuda")
|
||||
y = torch.ops.dist.all_reduce(x)
|
||||
y = torch.ops.auto_deploy.torch_dist_all_reduce(x)
|
||||
|
||||
assert torch.equal(x * world_size, y)
|
||||
|
||||
|
||||
def _run_all_gather_test(rank, world_size):
|
||||
x = torch.ones(10, 10).to("cuda")
|
||||
y = torch.ops.dist.all_gather(x)
|
||||
y = torch.ops.auto_deploy.torch_dist_all_gather(x)
|
||||
|
||||
assert torch.sum(y) == world_size * torch.sum(x)
|
||||
assert y.shape == (world_size * x.shape[0], *x.shape[1:])
|
||||
|
||||
@ -74,7 +74,7 @@ def _run_moe_ep_test(num_experts: int, topk: int, rank: int, world_size: int):
|
||||
|
||||
final_scales_local = final_scales * rank_mask
|
||||
|
||||
output_trt = torch.ops.moe.trtllm_fused_moe(
|
||||
output_trt = torch.ops.auto_deploy.trtllm_moe_fused(
|
||||
x,
|
||||
selected_experts_local,
|
||||
final_scales_local,
|
||||
|
||||
@ -36,7 +36,7 @@ class AllreduceResidualNorm(torch.nn.Module):
|
||||
self.norm = RMSNorm(hidden_size, 1e-5, dtype)
|
||||
|
||||
def forward(self, x, residual):
|
||||
x = torch.ops.dist.all_reduce(x)
|
||||
x = torch.ops.auto_deploy.torch_dist_all_reduce(x)
|
||||
y = x + residual
|
||||
normed = self.norm(y)
|
||||
return normed, y
|
||||
|
||||
@ -64,7 +64,7 @@ def _run_job(
|
||||
return num_params
|
||||
|
||||
# now run the test
|
||||
op_expected = getattr(torch.ops.dist, "all_gather")
|
||||
op_expected = getattr(torch.ops.auto_deploy, "torch_dist_all_gather")
|
||||
run_test(
|
||||
model,
|
||||
x,
|
||||
|
||||
@ -26,8 +26,8 @@ class MLPAllReduce(nn.Module):
|
||||
self.linear2 = cls(4 * in_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
y = F.relu(torch.ops.dist.all_reduce(self.linear1(x)))
|
||||
return torch.ops.dist.all_reduce(self.linear2(y))
|
||||
y = F.relu(torch.ops.auto_deploy.torch_dist_all_reduce(self.linear1(x)))
|
||||
return torch.ops.auto_deploy.torch_dist_all_reduce(self.linear2(y))
|
||||
|
||||
|
||||
def _run_job(
|
||||
@ -58,7 +58,7 @@ def _run_job(
|
||||
|
||||
def check_transformed_graph(gm):
|
||||
return any(is_op(n, op_expected) for n in gm.graph.nodes) and not any(
|
||||
is_op(n, torch.ops.dist.all_reduce) for n in gm.graph.nodes
|
||||
is_op(n, torch.ops.auto_deploy.torch_dist_all_reduce) for n in gm.graph.nodes
|
||||
)
|
||||
|
||||
# now run the test
|
||||
@ -76,10 +76,10 @@ def _run_job(
|
||||
@pytest.mark.parametrize(
|
||||
"linear_cls, dist_op_expected",
|
||||
(
|
||||
(nn.Linear, "linear.fused_linear_all_reduce"),
|
||||
(nn.Linear, "auto_deploy.trtllm_dist_fused_linear_all_reduce"),
|
||||
pytest.param(
|
||||
FP8Linear,
|
||||
"quant.fused_fp8_linear_all_reduce",
|
||||
"auto_deploy.torch_quant_fused_fp8_linear_all_reduce",
|
||||
marks=pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support"),
|
||||
),
|
||||
),
|
||||
|
||||
@ -33,7 +33,7 @@ def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None:
|
||||
expected_expert = num_experts_per_rank * hidden_size * intermediate_size * 3
|
||||
return n_gate + expected_expert
|
||||
|
||||
op_expected = torch.ops.dist.all_reduce
|
||||
op_expected = torch.ops.auto_deploy.torch_dist_all_reduce
|
||||
|
||||
run_test(
|
||||
model,
|
||||
|
||||
@ -15,6 +15,50 @@ from tensorrt_llm._torch.auto_deploy.transformations.library import column_row_s
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
class GQA_Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
hidden_size: int,
|
||||
num_key_value_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.is_gqa = num_key_value_heads < num_attention_heads
|
||||
assert self.hidden_size == self.num_attention_heads * self.head_dim
|
||||
|
||||
# key, query, value, out projections
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.head_dim * self.num_key_value_heads,
|
||||
bias=False,
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.head_dim * self.num_key_value_heads,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, s, _ = x.shape
|
||||
|
||||
q = self.q_proj(x).view(b, s, -1, self.head_dim)
|
||||
k = self.k_proj(x).view(b, s, -1, self.head_dim)
|
||||
v = self.v_proj(x).view(b, s, -1, self.head_dim)
|
||||
|
||||
y = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(q, k, v, is_causal=True)
|
||||
y = y.contiguous().view(b, s, -1)
|
||||
|
||||
return self.o_proj(y)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_features, out_features, bias=False):
|
||||
super().__init__()
|
||||
@ -37,27 +81,80 @@ def _run_job(
|
||||
) -> None:
|
||||
# init model and input
|
||||
batch_size = 4
|
||||
num_features = 10
|
||||
model = model_cls(num_features, num_features, bias=bias).to(device="cuda", dtype=torch.float16)
|
||||
x = torch.randn(batch_size, num_features, device="cuda", dtype=torch.float16)
|
||||
sequence_len = 8
|
||||
num_features = 32
|
||||
|
||||
# GQA specific parameters
|
||||
num_heads = 4
|
||||
num_key_value_heads = 1
|
||||
|
||||
if model_cls == GQA_Block:
|
||||
model = model_cls(
|
||||
num_attention_heads=num_heads,
|
||||
hidden_size=num_features,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
else:
|
||||
model = model_cls(num_features, num_features, bias=bias).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
)
|
||||
x = torch.randn(batch_size, sequence_len, num_features, device="cuda", dtype=torch.float16)
|
||||
|
||||
if model_cls == GQA_Block:
|
||||
head_dim = num_features // num_heads
|
||||
min_local_size = head_dim
|
||||
else:
|
||||
min_local_size = 1
|
||||
|
||||
def _get_expected_num_params(num_p_og: int) -> int:
|
||||
num_update = 0
|
||||
if bias and dist_op_expected == "all_reduce":
|
||||
if bias and dist_op_expected == "torch_dist_all_reduce":
|
||||
num_p_og -= num_features
|
||||
num_update = num_features * (rank == world_size - 1)
|
||||
|
||||
num_params = num_p_og // world_size + num_update
|
||||
if min_local_size > 1:
|
||||
# it means we are in the GQA. W_kv are partially replicated, we need to count
|
||||
# the number of parameters manually.
|
||||
W_q_local_size = num_features * num_features // world_size
|
||||
W_o_local_size = W_q_local_size
|
||||
W_k_local_size = num_features * head_dim * max(num_key_value_heads // world_size, 1)
|
||||
W_v_local_size = W_k_local_size
|
||||
num_params = W_q_local_size + W_k_local_size + W_v_local_size + W_o_local_size
|
||||
else:
|
||||
num_params = num_p_og // world_size + num_update
|
||||
return num_params
|
||||
|
||||
def verify_local_weight_sizes(gm) -> bool:
|
||||
"""Verify that all weight tensors have first dimension >= min_local_size after sharding."""
|
||||
for name, param in gm.named_parameters():
|
||||
# Only check parameters that have at least 1 dimension and are weight matrices
|
||||
if param.dim() >= 1 and "weight" in name:
|
||||
if param.shape[0] < min_local_size:
|
||||
print(
|
||||
f"Weight {name} has shape {param.shape}, dim {param.shape[0]} < min_local_size {min_local_size}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
# now run the test
|
||||
op_expected = getattr(torch.ops.dist, dist_op_expected)
|
||||
op_expected = getattr(torch.ops.auto_deploy, dist_op_expected)
|
||||
|
||||
transform_func = partial(column_row_shard, rank=rank, world_size=world_size)
|
||||
|
||||
def combined_graph_check(gm) -> bool:
|
||||
# Check for expected distributed operations
|
||||
has_expected_dist_ops = any(is_op(n, op_expected) for n in gm.graph.nodes) == (
|
||||
world_size > 1
|
||||
)
|
||||
# Check weight size constraints
|
||||
weight_sizes_valid = verify_local_weight_sizes(gm)
|
||||
return has_expected_dist_ops and weight_sizes_valid
|
||||
|
||||
run_test(
|
||||
model,
|
||||
x,
|
||||
transform=partial(column_row_shard, rank=rank, world_size=world_size),
|
||||
check_transformed_graph=lambda gm: any(is_op(n, op_expected) for n in gm.graph.nodes)
|
||||
== (world_size > 1),
|
||||
transform=transform_func,
|
||||
check_transformed_graph=combined_graph_check,
|
||||
_get_expected_num_params=_get_expected_num_params,
|
||||
)
|
||||
|
||||
@ -67,8 +164,9 @@ def _run_job(
|
||||
@pytest.mark.parametrize(
|
||||
"model_cls, dist_op_expected",
|
||||
(
|
||||
(MLP, "all_reduce"),
|
||||
(nn.Linear, "all_gather"),
|
||||
(MLP, "torch_dist_all_reduce"),
|
||||
(nn.Linear, "torch_dist_all_gather"),
|
||||
(GQA_Block, "torch_dist_all_reduce"),
|
||||
),
|
||||
)
|
||||
def test_sharding(model_cls: Type[nn.Module], dist_op_expected: str, bias: bool, device_count: int):
|
||||
|
||||
@ -49,7 +49,7 @@ def test_moe_op_run(dtype):
|
||||
fused_w2_weight.data[expert_id].copy_(w2)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_torch_moe = torch.ops.moe.torch_moe(
|
||||
output_torch_moe = torch.ops.auto_deploy.torch_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
@ -57,14 +57,14 @@ def test_moe_op_run(dtype):
|
||||
w2_weight,
|
||||
w3_weight,
|
||||
)
|
||||
output_torch_fused_moe = torch.ops.moe.torch_fused_moe(
|
||||
output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
fused_w3_w1_stacked_weight,
|
||||
fused_w2_weight,
|
||||
)
|
||||
output_trt_fused_moe = torch.ops.moe.trtllm_fused_moe(
|
||||
output_trt_fused_moe = torch.ops.auto_deploy.trtllm_moe_fused(
|
||||
x,
|
||||
selected_experts,
|
||||
final_scales,
|
||||
|
||||
@ -21,7 +21,7 @@ def test_attention_op():
|
||||
|
||||
q, k, v = (x.contiguous() for x in torch.split(qkv, 1, dim=1))
|
||||
|
||||
output = torch.ops.attention.fused_mha_with_cache(
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_cache(
|
||||
q, k, v, input_positions, k_cache, v_cache, None
|
||||
)
|
||||
ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
@ -66,7 +66,7 @@ def test_gqa_op(device, dtype, n_heads, group_size, seq_len):
|
||||
v_cache = torch.randn(BATCH_SIZE, CACHE_SEQ_LEN, n_kv_heads, D_HEAD, dtype=dtype, device=device)
|
||||
|
||||
# run custom op
|
||||
output = torch.ops.attention.fused_mha_with_cache(
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_cache(
|
||||
q, k, v, input_positions, k_cache, v_cache, None
|
||||
)
|
||||
|
||||
@ -148,7 +148,7 @@ def test_flat_gqa_op(
|
||||
v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
|
||||
|
||||
# run op
|
||||
output = torch.ops.attention.flattened_mha_with_cache(
|
||||
output = torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
k,
|
||||
@ -274,7 +274,7 @@ def test_flat_gqa_op_with_rope(
|
||||
source = 1
|
||||
if source == 1:
|
||||
# call rope fusion kernels
|
||||
output = torch.ops.attention.fused_flattened_mha_with_cache_rope_fusion(
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache_rope_fusion(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -288,7 +288,7 @@ def test_flat_gqa_op_with_rope(
|
||||
)
|
||||
else:
|
||||
# call stand-alone rope embedding kernel
|
||||
output = torch.ops.attention.fused_flattened_mha_with_cache(
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_flattened_mha_with_cache(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -466,7 +466,7 @@ def test_paged_gqa_op(
|
||||
v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs)
|
||||
|
||||
# run op
|
||||
output = torch.ops.attention.fused_mha_with_paged_cache(
|
||||
output = torch.ops.auto_deploy.triton_attention_fused_mha_with_paged_cache(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
||||
@ -88,7 +88,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
k,
|
||||
@ -213,7 +213,7 @@ def test_flashinfer_attention_op_decode(
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
k,
|
||||
@ -329,7 +329,7 @@ def test_flashinfer_attention_context_and_generate(
|
||||
),
|
||||
BATCH_SIZE * PREFILL_SEQ_LEN,
|
||||
)
|
||||
flashinfer_output_1 = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
flashinfer_output_1 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q_1,
|
||||
k_1,
|
||||
@ -404,7 +404,7 @@ def test_flashinfer_attention_context_and_generate(
|
||||
),
|
||||
BATCH_SIZE * 1,
|
||||
)
|
||||
flashinfer_output_3 = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
flashinfer_output_3 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q_3,
|
||||
k_3,
|
||||
@ -513,7 +513,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
k,
|
||||
@ -660,7 +660,7 @@ def test_flashinfer_attention_with_fp8_cache(
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
k,
|
||||
@ -757,7 +757,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
),
|
||||
BATCH_SIZE * SEQ_LEN,
|
||||
)
|
||||
flashinfer_output = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q,
|
||||
k,
|
||||
@ -840,7 +840,7 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
|
||||
),
|
||||
BATCH_SIZE * 1,
|
||||
)
|
||||
flashinfer_output_gen = torch.ops.attention.flashinfer_mha_with_cache(
|
||||
flashinfer_output_gen = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache(
|
||||
# Q, K, V
|
||||
q_gen,
|
||||
k_gen,
|
||||
|
||||
@ -20,7 +20,7 @@ def test_fp8_linear(bias):
|
||||
weight_scale = (torch.max(torch.abs(weight)) / 448).to("cuda")
|
||||
weight_fp8 = (weight / weight_scale).to(torch.float8_e4m3fn)
|
||||
|
||||
output_fp8_gemm = torch.ops.quant.fp8_linear(
|
||||
output_fp8_gemm = torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
input,
|
||||
weight_fp8,
|
||||
bias=bias,
|
||||
@ -49,7 +49,7 @@ def test_fp4_linear():
|
||||
weight, weight_scale_2, scaling_vector_size, False
|
||||
)
|
||||
|
||||
output_fp4_gemm = torch.ops.quant.fp4_linear(
|
||||
output_fp4_gemm = torch.ops.auto_deploy.torch_quant_fp4_linear(
|
||||
input,
|
||||
weight_fp4,
|
||||
bias=None,
|
||||
|
||||
@ -81,7 +81,7 @@ def test_flashinfer_custom_op_and_hf_impl(dtype, atol, rtol, head_dim):
|
||||
|
||||
# Custom op call
|
||||
positions_flat = torch.arange(batch * seq_len, device=device)
|
||||
custom_q, custom_k = torch.ops.rope.flashinfer(
|
||||
custom_q, custom_k = torch.ops.auto_deploy.flashinfer_rope(
|
||||
query, key, positions_flat, cos_sin_cache_expand, True
|
||||
)
|
||||
|
||||
@ -135,7 +135,7 @@ def test_flashinfer_custom_op_and_complex_impl(dtype, atol, rtol, head_dim):
|
||||
|
||||
# q/k of llama4 rope is interleaved
|
||||
positions_flat = torch.arange(batch * seq_len, device=device)
|
||||
custom_q, custom_k = torch.ops.rope.flashinfer(
|
||||
custom_q, custom_k = torch.ops.auto_deploy.flashinfer_rope(
|
||||
query, key, positions_flat, cos_sin_cache_expand, False
|
||||
)
|
||||
|
||||
@ -211,8 +211,8 @@ def test_triton_custom_op_and_hf_impl(layout, head_dim, dtype, atol, rtol):
|
||||
q_hf = q_f32.to(dtype)
|
||||
k_hf = k_f32.to(dtype)
|
||||
|
||||
q_out = torch.ops.rope.apply_rope_with_input_pos(q, cosin_cache, positions, layout)
|
||||
k_out = torch.ops.rope.apply_rope_with_input_pos(k, cosin_cache, positions, layout)
|
||||
q_out = torch.ops.auto_deploy.triton_rope_with_input_pos(q, cosin_cache, positions, layout)
|
||||
k_out = torch.ops.auto_deploy.triton_rope_with_input_pos(k, cosin_cache, positions, layout)
|
||||
|
||||
torch.testing.assert_close(q_hf, q_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_hf, k_out, atol=atol, rtol=rtol)
|
||||
|
||||
@ -34,7 +34,7 @@ def test_rope(d_head):
|
||||
y_ref = torch_rope_reference(x, freqs_cis, input_position)
|
||||
freqs_cis = freqs_cis.to("cuda")
|
||||
x_reshaped = x.unflatten(-1, (N_ELEM // 2, 2)).transpose(-1, -2).flatten(-2).contiguous()
|
||||
y = torch.ops.rope.apply_rope_with_input_pos(
|
||||
y = torch.ops.auto_deploy.triton_rope_with_input_pos(
|
||||
x_reshaped.to("cuda"), freqs_cis, input_position, "bsnd"
|
||||
)
|
||||
y_reshaped = y.unflatten(-1, (2, N_ELEM // 2)).transpose(-2, -1).flatten(-2).contiguous()
|
||||
@ -64,7 +64,7 @@ def test_rope_flattened(d_head):
|
||||
seq_start_indices = torch.zeros(len(SEQ_LENS), dtype=torch.int32, device="cuda")
|
||||
seq_start_indices[1:] = torch.cumsum(seq_lens[:-1], 0)
|
||||
|
||||
y = torch.ops.rope.apply_rope_on_flattened_inputs(
|
||||
y = torch.ops.auto_deploy.triton_rope_on_flattened_inputs(
|
||||
x_reshaped.to("cuda"), freqs_cis, input_position, seq_lens, seq_start_indices
|
||||
)
|
||||
y_reshaped = y.unflatten(-1, (2, N_ELEM // 2)).transpose(-2, -1).flatten(-2).contiguous()
|
||||
|
||||
@ -55,7 +55,9 @@ def mla_attn(q_nope, q_pe, compressed_kv, k_pe, wkv_b, softmax_scale):
|
||||
q_nope_proj = torch.einsum("bhsd,hdc->bhsc", q_nope, wkv_b_weight[:, :qk_nope_head_dim])
|
||||
|
||||
# MLA ref operation
|
||||
x = torch.ops.deepseek.mla(q_nope_proj, q_pe, compressed_kv, k_pe, None, softmax_scale)
|
||||
x = torch.ops.auto_deploy.torch_attention_deepseek_mla(
|
||||
q_nope_proj, q_pe, compressed_kv, k_pe, None, softmax_scale
|
||||
)
|
||||
|
||||
# Up project attention scores
|
||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b_weight[:, -v_head_dim:])
|
||||
|
||||
@ -379,8 +379,8 @@ class GroupedAttentionModel(torch.nn.Module):
|
||||
|
||||
# Manually apply repeat_kv to k and v
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
k = torch.ops.attention.repeat_kv(k, self.n_rep)
|
||||
v = torch.ops.attention.repeat_kv(v, self.n_rep)
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, self.n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, self.n_rep)
|
||||
|
||||
# Create attention mask if needed
|
||||
attn_mask = None
|
||||
@ -396,7 +396,7 @@ class GroupedAttentionModel(torch.nn.Module):
|
||||
).masked_fill(mask, float("-inf"))
|
||||
|
||||
# Apply scaled dot product attention
|
||||
attn_output = torch.ops.attention.scaled_dot_product_attention(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -434,7 +434,9 @@ def test_match_repeat_kv(num_heads, num_kv_heads, model_cls):
|
||||
expected_matches = 0 if num_heads == num_kv_heads else 2
|
||||
|
||||
def verify_matcher(gm):
|
||||
repeat_kv_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.repeat_kv)]
|
||||
repeat_kv_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_repeat_kv)
|
||||
]
|
||||
|
||||
# Check that we have the expected number of replacements
|
||||
if len(repeat_kv_nodes) != expected_matches:
|
||||
@ -549,7 +551,7 @@ def test_match_eager_attention(has_mask, use_division, dropout, rtol, atol, mode
|
||||
|
||||
def verify_matcher(gm):
|
||||
sdpa_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.attention.scaled_dot_product_attention)
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
|
||||
]
|
||||
|
||||
# Check that we have the expected number of replacements
|
||||
@ -655,8 +657,10 @@ def test_counter_example():
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
def verify_no_matches(gm):
|
||||
# No nodes should be replaced with torch.ops.attention.repeat_kv
|
||||
repeat_kv_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.repeat_kv)]
|
||||
# No nodes should be replaced with torch.ops.auto_deploy.torch_attention_repeat_kv
|
||||
repeat_kv_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_repeat_kv)
|
||||
]
|
||||
return len(repeat_kv_nodes) == 0
|
||||
|
||||
# Ensure the pattern matcher doesn't match our counter-examples
|
||||
@ -693,7 +697,9 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
|
||||
|
||||
def verify_matcher(gm):
|
||||
grouped_sdpa_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
]
|
||||
|
||||
# Check that we have the expected number of replacements
|
||||
@ -790,8 +796,8 @@ class CausalAttentionModel(torch.nn.Module):
|
||||
# For grouped attention, repeat k and v
|
||||
if self.use_grouped_sdpa and self.num_kv_heads != self.num_heads:
|
||||
n_rep = self.num_heads // self.num_kv_heads
|
||||
k = torch.ops.attention.repeat_kv(k, n_rep)
|
||||
v = torch.ops.attention.repeat_kv(v, n_rep)
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
|
||||
|
||||
# Create attention mask based on mask_type
|
||||
if self.mask_type == "triu":
|
||||
@ -830,7 +836,7 @@ class CausalAttentionModel(torch.nn.Module):
|
||||
|
||||
# Choose the appropriate attention implementation
|
||||
if self.use_grouped_sdpa:
|
||||
attn_output = torch.ops.attention.grouped_sdpa(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -840,7 +846,7 @@ class CausalAttentionModel(torch.nn.Module):
|
||||
scale=1.0 / (self.head_dim**0.5),
|
||||
)
|
||||
else:
|
||||
attn_output = torch.ops.attention.scaled_dot_product_attention(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -886,12 +892,14 @@ def test_match_causal_attention(mask_type, use_grouped_sdpa):
|
||||
def verify_matcher(gm):
|
||||
# Find attention operations
|
||||
if use_grouped_sdpa:
|
||||
attn_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)]
|
||||
else:
|
||||
attn_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.attention.scaled_dot_product_attention)
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
]
|
||||
else:
|
||||
attn_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
|
||||
]
|
||||
|
||||
if len(attn_nodes) != 1:
|
||||
@ -990,8 +998,8 @@ class Llama3CausalAttentionModel(torch.nn.Module):
|
||||
# For grouped attention, repeat k and v
|
||||
if self.use_grouped_sdpa and self.num_kv_heads != self.num_heads:
|
||||
n_rep = self.num_heads // self.num_kv_heads
|
||||
k = torch.ops.attention.repeat_kv(k, n_rep)
|
||||
v = torch.ops.attention.repeat_kv(v, n_rep)
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
|
||||
|
||||
# Create a llama-3.1 style causal mask
|
||||
# 1. Create a full tensor with a very negative value
|
||||
@ -1026,7 +1034,7 @@ class Llama3CausalAttentionModel(torch.nn.Module):
|
||||
|
||||
# Choose the appropriate attention implementation
|
||||
if self.use_grouped_sdpa:
|
||||
attn_output = torch.ops.attention.grouped_sdpa(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1036,7 +1044,7 @@ class Llama3CausalAttentionModel(torch.nn.Module):
|
||||
scale=1.0 / (self.head_dim**0.5),
|
||||
)
|
||||
else:
|
||||
attn_output = torch.ops.attention.scaled_dot_product_attention(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1078,12 +1086,14 @@ def test_match_llama3_causal_attention(use_grouped_sdpa):
|
||||
def verify_matcher(gm):
|
||||
# Find attention operations
|
||||
if use_grouped_sdpa:
|
||||
attn_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)]
|
||||
else:
|
||||
attn_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.attention.scaled_dot_product_attention)
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
]
|
||||
else:
|
||||
attn_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
|
||||
]
|
||||
|
||||
if len(attn_nodes) != 1:
|
||||
@ -1129,7 +1139,7 @@ class MockAttentionDescriptor:
|
||||
"""A mock class that mimics the AttentionDescriptor interface for testing."""
|
||||
|
||||
layout: str = "bnsd"
|
||||
source_attention_op: Callable = torch.ops.attention.scaled_dot_product_attention
|
||||
source_attention_op: Callable = torch.ops.auto_deploy.torch_attention_sdpa
|
||||
|
||||
@classmethod
|
||||
def get_attention_layout(cls) -> str:
|
||||
@ -1199,7 +1209,7 @@ class AttentionLayoutModel(torch.nn.Module):
|
||||
|
||||
# Apply scaled dot product attention
|
||||
if self.use_grouped_sdpa:
|
||||
attn_output = torch.ops.attention.grouped_sdpa(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1209,7 +1219,7 @@ class AttentionLayoutModel(torch.nn.Module):
|
||||
scale=1.0 / (self.head_dim**0.5),
|
||||
)
|
||||
else:
|
||||
attn_output = torch.ops.attention.scaled_dot_product_attention(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1246,7 +1256,7 @@ class BsndAttentionModel(AttentionLayoutModel):
|
||||
attn_mask = self._get_attn_mask(x) if self.has_mask else None
|
||||
|
||||
# Apply bsnd_grouped_sdpa directly
|
||||
attn_output = torch.ops.attention.bsnd_grouped_sdpa.default(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1284,11 +1294,11 @@ def test_match_attention_layout(layout, model_config, has_mask):
|
||||
MockAttentionDescriptor.layout = layout
|
||||
if layout == "bnsd":
|
||||
if model_config.get("use_grouped_sdpa"):
|
||||
source_op = torch.ops.attention.grouped_sdpa
|
||||
source_op = torch.ops.auto_deploy.torch_attention_grouped_sdpa
|
||||
else:
|
||||
source_op = torch.ops.attention.scaled_dot_product_attention
|
||||
source_op = torch.ops.auto_deploy.torch_attention_sdpa
|
||||
else:
|
||||
source_op = torch.ops.attention.bsnd_grouped_sdpa
|
||||
source_op = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
MockAttentionDescriptor.source_attention_op = source_op
|
||||
|
||||
# Create appropriate model based on model_config
|
||||
@ -1319,18 +1329,24 @@ def test_match_attention_layout(layout, model_config, has_mask):
|
||||
if model_config["type"] == "standard":
|
||||
if model_config["use_grouped_sdpa"]:
|
||||
original_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.attention.grouped_sdpa)
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
]
|
||||
else:
|
||||
original_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.attention.scaled_dot_product_attention)
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
|
||||
]
|
||||
else: # already_bsnd
|
||||
original_nodes = []
|
||||
|
||||
bsnd_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.attention.bsnd_grouped_sdpa)]
|
||||
bsnd_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa)
|
||||
]
|
||||
transpose_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.aten.transpose.int)]
|
||||
|
||||
# Different expectations based on model type and layout
|
||||
|
||||
@ -31,7 +31,7 @@ class MockAttentionDescriptor:
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> Callable:
|
||||
return torch.ops.attention.bsnd_grouped_sdpa
|
||||
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
|
||||
|
||||
class HFWrapper(nn.Module):
|
||||
@ -83,9 +83,11 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str)
|
||||
)
|
||||
|
||||
def verify_matcher(gm: GraphModule):
|
||||
"""Ensure that there is exactly one torch.ops.attention.bsnd_grouped_sdpa call in the graph."""
|
||||
"""Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
call in the graph. Also check that there is no repeat_kv pattern left.
|
||||
"""
|
||||
nodes = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.attention.bsnd_grouped_sdpa
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
)
|
||||
assert len(nodes) == 1, "Expected exactly one bsnd_grouped_sdpa call in the graph"
|
||||
|
||||
@ -100,7 +102,9 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str)
|
||||
assert attn_node.args[6] == scale # scale
|
||||
|
||||
# TODO: check that there is no repeat_kv pattern left...
|
||||
nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.attention.repeat_kv)
|
||||
nodes = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_attention_repeat_kv
|
||||
)
|
||||
assert len(nodes) == 0, "Found repeat_kv pattern in the graph"
|
||||
|
||||
return True
|
||||
|
||||
@ -53,7 +53,9 @@ class GQAWithSdpa(GQA):
|
||||
v = v.view(b, s, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# Use grouped SDPA in bsnd layout
|
||||
attn_output = torch.ops.attention.bsnd_grouped_sdpa(q, k, v, None, 0.0, True, None)
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(
|
||||
q, k, v, None, 0.0, True, None
|
||||
)
|
||||
|
||||
# SDPA output is already in [b, s, n, h_d] format
|
||||
# Reshape to [b, s, n*h_d]
|
||||
|
||||
@ -100,7 +100,7 @@ def test_moe_matching():
|
||||
model,
|
||||
x,
|
||||
match_moe_pattern,
|
||||
lambda gm: any(is_op(n, torch.ops.moe.torch_moe) for n in gm.graph.nodes),
|
||||
lambda gm: any(is_op(n, torch.ops.auto_deploy.torch_moe) for n in gm.graph.nodes),
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
@ -119,7 +119,9 @@ def test_moe_fusion():
|
||||
x,
|
||||
fuse_moe,
|
||||
lambda gm: any(
|
||||
is_op(n, {torch.ops.moe.torch_fused_moe, torch.ops.moe.trtllm_fused_moe})
|
||||
is_op(
|
||||
n, {torch.ops.auto_deploy.torch_moe_fused, torch.ops.auto_deploy.trtllm_moe_fused}
|
||||
)
|
||||
for n in gm.graph.nodes
|
||||
),
|
||||
lambda num_p_og: num_p_og,
|
||||
|
||||
@ -14,7 +14,10 @@ from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
def check_quantized(gm):
|
||||
op_expected = {torch.ops.quant.fp8_linear, torch.ops.quant.fp4_linear}
|
||||
op_expected = {
|
||||
torch.ops.auto_deploy.torch_quant_fp8_linear,
|
||||
torch.ops.auto_deploy.torch_quant_fp4_linear,
|
||||
}
|
||||
return any(is_op(n, op_expected) for n in gm.graph.nodes)
|
||||
|
||||
|
||||
|
||||
@ -91,7 +91,7 @@ class RoPEModel(torch.nn.Module):
|
||||
if self.mode == "match":
|
||||
q_out, k_out = apply_rotary_pos_emb_explicit(q, k, cos, sin, unsq_dim)
|
||||
else: # optimize
|
||||
q_out, k_out = torch.ops.rope.torch_apply_rope_with_explicit_cos_sin(
|
||||
q_out, k_out = torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin(
|
||||
q, k, cos, sin, unsq_dim
|
||||
)
|
||||
|
||||
@ -119,7 +119,7 @@ class RoPEModel(torch.nn.Module):
|
||||
if self.mode == "match":
|
||||
q_out, k_out = apply_rotary_pos_emb_complex(q, k, freqs, unsq_dim)
|
||||
else:
|
||||
q_out, k_out = torch.ops.rope.torch_apply_rope_with_complex_freqs(
|
||||
q_out, k_out = torch.ops.auto_deploy.torch_rope_with_complex_freqs(
|
||||
q, k, freqs, unsq_dim
|
||||
)
|
||||
|
||||
@ -212,9 +212,9 @@ def test_rope_variants(
|
||||
if transformation == "match":
|
||||
fn = match_rope_pattern
|
||||
check_op = (
|
||||
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin
|
||||
torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin
|
||||
if variant == "explicit" or variant == "explicit_pm"
|
||||
else torch.ops.rope.torch_apply_rope_with_complex_freqs
|
||||
else torch.ops.auto_deploy.torch_rope_with_complex_freqs
|
||||
)
|
||||
|
||||
def checker(gm):
|
||||
@ -228,8 +228,8 @@ def test_rope_variants(
|
||||
if is_op(
|
||||
n,
|
||||
{
|
||||
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
|
||||
torch.ops.rope.torch_apply_rope_with_complex_freqs,
|
||||
torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin,
|
||||
torch.ops.auto_deploy.torch_rope_with_complex_freqs,
|
||||
},
|
||||
):
|
||||
q_arg, k_arg, *rest = n.args
|
||||
@ -254,7 +254,7 @@ def test_rope_variants(
|
||||
fn = optimize_rope
|
||||
|
||||
def checker(gm):
|
||||
return any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes)
|
||||
return any(is_op(n, torch.ops.auto_deploy.flashinfer_rope) for n in gm.graph.nodes)
|
||||
|
||||
if transformation == "match_layout":
|
||||
_ = run_test(
|
||||
@ -346,7 +346,7 @@ class DSModel(torch.nn.Module):
|
||||
else:
|
||||
cos = cos[pos_ids]
|
||||
sin = sin[pos_ids]
|
||||
q_out, k_out = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
|
||||
q_out, k_out = torch.ops.auto_deploy.torch_rope_with_qk_interleaving(
|
||||
q, k, cos, sin, unsq_dim
|
||||
)
|
||||
if self.layout == "BNSD":
|
||||
@ -387,7 +387,7 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target
|
||||
|
||||
def checker(gm):
|
||||
return any(
|
||||
is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving)
|
||||
is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving)
|
||||
for n in gm.graph.nodes
|
||||
)
|
||||
|
||||
@ -396,7 +396,7 @@ def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target
|
||||
|
||||
def checker(gm):
|
||||
for n in gm.graph.nodes:
|
||||
if is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving):
|
||||
if is_op(n, torch.ops.auto_deploy.torch_rope_with_qk_interleaving):
|
||||
q_arg, k_arg, *rest = n.args
|
||||
if not (
|
||||
is_op(q_arg, torch.ops.aten.contiguous)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user