mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][autodeploy] small refactors on attention matching (#8079)
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
88ea2c4ee9
commit
744246d316
@ -12,10 +12,9 @@ The table below lists the operators ordered by their backend.
|
||||
|--------------|-------------|
|
||||
| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support |
|
||||
| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation |
|
||||
| `torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa` | Grouped SDPA (Scaled Dot Product Attention) with BSND format |
|
||||
| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) |
|
||||
| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation |
|
||||
| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation |
|
||||
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
|
||||
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
|
||||
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
|
||||
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
|
||||
|
||||
@ -355,7 +355,7 @@ class FlashInferAttention(AttentionDescriptor):
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
"""Get the source attention op that we target for replacement."""
|
||||
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
return torch.ops.auto_deploy.torch_attention
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
@ -399,6 +399,21 @@ class FlashInferAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
# Sanity check: layout == "bsnd"
|
||||
# Prefer kwargs; fall back to the final positional arg if it's a string.
|
||||
layout = source_attn_node.kwargs.get("layout", None)
|
||||
if (
|
||||
layout is None
|
||||
and len(source_attn_node.args) > 0
|
||||
and isinstance(source_attn_node.args[-1], str)
|
||||
):
|
||||
layout = source_attn_node.args[-1]
|
||||
if layout != "bsnd":
|
||||
raise RuntimeError(
|
||||
f"Expected torch_attention layout='bsnd' but got {layout!r} "
|
||||
f"for node: {source_attn_node.format_node()}"
|
||||
)
|
||||
|
||||
# Double check other arguments
|
||||
attn_mask, dropout_p, is_causal = extract_op_args(
|
||||
source_attn_node, "attn_mask", "dropout_p", "is_causal"
|
||||
|
||||
@ -91,8 +91,9 @@ def scaled_dot_product_attention_fake(
|
||||
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_grouped_sdpa", mutates_args=())
|
||||
def grouped_sdpa(
|
||||
# Unified attention op
|
||||
@torch.library.custom_op("auto_deploy::torch_attention", mutates_args=())
|
||||
def torch_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@ -103,8 +104,25 @@ def grouped_sdpa(
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
logit_cap: Optional[float] = None,
|
||||
layout: str = "bnsd", # "bnsd" or "bsnd"
|
||||
) -> torch.Tensor:
|
||||
"""SDPA attention that can handle GQA. Expects bnsd format inputs."""
|
||||
"""
|
||||
SDPA attention (with optional GQA) that supports two memory layouts via `layout`:
|
||||
- "bnsd": [batch, num_heads, seq_len, head_dim]
|
||||
- "bsnd": [batch, seq_len, num_heads, head_dim]
|
||||
|
||||
The `attn_mask` is always interpreted as [b, n, s_q, s_k].
|
||||
|
||||
Returns a tensor in the SAME layout as inputs specified by `layout`.
|
||||
"""
|
||||
if layout not in ("bnsd", "bsnd"):
|
||||
raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}")
|
||||
|
||||
if layout == "bsnd":
|
||||
query = query.transpose(1, 2).contiguous()
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
|
||||
b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
|
||||
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
|
||||
|
||||
@ -188,72 +206,26 @@ def grouped_sdpa(
|
||||
if dropout_p > 0.0:
|
||||
attn_out = F.dropout(attn_out, p=dropout_p, training=False)
|
||||
|
||||
# Return in bnsd format (same as input format)
|
||||
return attn_out
|
||||
if layout == "bsnd":
|
||||
return attn_out.transpose(1, 2).contiguous()
|
||||
else:
|
||||
return attn_out.contiguous()
|
||||
|
||||
|
||||
@grouped_sdpa.register_fake
|
||||
def grouped_sdpa_fake(
|
||||
@torch_attention.register_fake
|
||||
def torch_attention_fake(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
sinks=None,
|
||||
sliding_window=None,
|
||||
logit_cap=None,
|
||||
):
|
||||
"""Fake implementation of grouped SDPA."""
|
||||
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
|
||||
def bsnd_grouped_sdpa(
|
||||
query: torch.Tensor, # layout: [b, s_q, n, d]
|
||||
key: torch.Tensor, # layout: [b, s_k, n, d]
|
||||
value: torch.Tensor, # layout: [b, s_k, n, d]
|
||||
attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k]
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
logit_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Attention that assumes the input layout is bsnd.
|
||||
|
||||
Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
|
||||
original sdpa op!
|
||||
"""
|
||||
# Transpose inputs to bnsd format for grouped_sdpa
|
||||
query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d]
|
||||
key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
|
||||
value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
|
||||
|
||||
# Call grouped_sdpa with bnsd inputs
|
||||
out = grouped_sdpa(
|
||||
query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap
|
||||
)
|
||||
# Transpose back to bsnd format
|
||||
return out.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
@bsnd_grouped_sdpa.register_fake
|
||||
def bsnd_grouped_sdpa_fake(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
sinks=None,
|
||||
sliding_window=None,
|
||||
logit_cap=None,
|
||||
layout: str = "bnsd",
|
||||
):
|
||||
"""Fake implementation of bnsd grouped SDPA."""
|
||||
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
|
||||
|
||||
|
||||
|
||||
@ -409,7 +409,7 @@ class TorchBackendAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
return torch.ops.auto_deploy.torch_attention
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
@ -460,6 +460,21 @@ class TorchBackendAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
# Sanity check: layout == "bsnd"
|
||||
# Prefer kwargs; fall back to the final positional arg if it's a string.
|
||||
layout = source_attn_node.kwargs.get("layout", None)
|
||||
if (
|
||||
layout is None
|
||||
and len(source_attn_node.args) > 0
|
||||
and isinstance(source_attn_node.args[-1], str)
|
||||
):
|
||||
layout = source_attn_node.args[-1]
|
||||
if layout != "bsnd":
|
||||
raise RuntimeError(
|
||||
f"Expected torch_attention layout='bsnd' but got {layout!r} "
|
||||
f"for node: {source_attn_node.format_node()}"
|
||||
)
|
||||
|
||||
# Check other arguments
|
||||
attn_mask, dropout_p, is_causal = extract_op_args(
|
||||
source_attn_node, "attn_mask", "dropout_p", "is_causal"
|
||||
|
||||
@ -339,7 +339,7 @@ class TritonAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> OpOverloadPacket:
|
||||
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
return torch.ops.auto_deploy.torch_attention
|
||||
|
||||
@classmethod
|
||||
def get_cached_attention_op(cls) -> MHACallable:
|
||||
@ -390,6 +390,21 @@ class TritonAttention(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
|
||||
# Sanity check: layout == "bsnd"
|
||||
# Prefer kwargs; fall back to the final positional arg if it's a string.
|
||||
layout = source_attn_node.kwargs.get("layout", None)
|
||||
if (
|
||||
layout is None
|
||||
and len(source_attn_node.args) > 0
|
||||
and isinstance(source_attn_node.args[-1], str)
|
||||
):
|
||||
layout = source_attn_node.args[-1]
|
||||
if layout != "bsnd":
|
||||
raise RuntimeError(
|
||||
f"Expected torch_attention layout='bsnd' but got {layout!r} "
|
||||
f"for node: {source_attn_node.format_node()}"
|
||||
)
|
||||
|
||||
# retrieve head_dim from k_fake
|
||||
attn_mask, dropout_p, is_causal = extract_op_args(
|
||||
source_attn_node, "attn_mask", "dropout_p", "is_causal"
|
||||
|
||||
@ -59,7 +59,7 @@ def gpt_oss_attention(
|
||||
sinks = self.sinks
|
||||
|
||||
# Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim)
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@ -69,6 +69,7 @@ def gpt_oss_attention(
|
||||
scale=self.scaling,
|
||||
sinks=sinks,
|
||||
sliding_window=sliding_window,
|
||||
layout="bsnd",
|
||||
)
|
||||
|
||||
# Reshape back to original input shape
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Pattern matching for detecting repeat_kv, eager, grouped attention patterns from Huggingface models."""
|
||||
|
||||
from inspect import Parameter, Signature
|
||||
from itertools import product
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch
|
||||
@ -11,7 +13,6 @@ from ...custom_ops.attention_interface import AttentionDescriptor
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import is_op
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
@ -270,170 +271,6 @@ def _get_sfdp_patterns() -> List[Dict[str, Any]]:
|
||||
return patterns
|
||||
|
||||
|
||||
def _grouped_attn_pattern_1(q, k, v, n_rep, attn_mask, dropout_p, scale):
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_1(q, k, v, n_rep, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
# Only expose torch_attention_grouped_sdpa after the transformation
|
||||
def _grouped_attn_pattern_2(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_2(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_pattern_3(q, k, v, n_rep, attn_mask, dropout_p, scale):
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_3(q, k, v, n_rep, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
# Only expose torch_attention_grouped_sdpa after the transformation
|
||||
def _grouped_attn_pattern_4(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_pattern_5(q, k, v, n_rep, attn_mask):
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(q, k, v, attn_mask)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_5(q, k, v, n_rep, attn_mask):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(q, k, v, attn_mask)
|
||||
|
||||
|
||||
def _grouped_attn_pattern_6(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=False,
|
||||
scale=scale,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_6(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_pattern_7(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True,
|
||||
scale=scale,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_7(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_pattern_8(q, k, v, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=False,
|
||||
scale=scale,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_8(q, k, v, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_pattern_9(q, k, v, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True,
|
||||
scale=scale,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_9(q, k, v, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_pattern_10(q, k, v, n_rep, dropout_p):
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_10(q, k, v, n_rep, dropout_p):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("match_repeat_kv")
|
||||
class MatchRepeatKV(BaseTransform):
|
||||
"""
|
||||
@ -508,11 +345,194 @@ class MatchEagerAttention(BaseTransform):
|
||||
return gm, info
|
||||
|
||||
|
||||
def _attach_signature(fn: Callable, argnames: List[str]) -> Callable:
|
||||
# Make FX "see" q,k,v[,attn_mask][,dropout_p][,scale] even though fn(*args) internally
|
||||
params = [Parameter(n, kind=Parameter.POSITIONAL_OR_KEYWORD) for n in argnames]
|
||||
fn.__signature__ = Signature(parameters=params)
|
||||
return fn
|
||||
|
||||
|
||||
def _call_sdpa(
|
||||
q, k, v, *, is_causal: bool, enable_gqa: bool, attn_mask=None, dropout_p=None, scale=None
|
||||
):
|
||||
kwargs = {"is_causal": is_causal}
|
||||
if attn_mask is not None:
|
||||
kwargs["attn_mask"] = attn_mask
|
||||
if dropout_p is not None:
|
||||
kwargs["dropout_p"] = dropout_p
|
||||
if scale is not None:
|
||||
kwargs["scale"] = scale
|
||||
if enable_gqa:
|
||||
kwargs["enable_gqa"] = True
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(q, k, v, **kwargs)
|
||||
|
||||
|
||||
def _call_attn(q, k, v, *, is_causal: bool, attn_mask=None, dropout_p=None, scale=None):
|
||||
kwargs = {"is_causal": is_causal}
|
||||
if attn_mask is not None:
|
||||
kwargs["attn_mask"] = attn_mask
|
||||
if dropout_p is not None:
|
||||
kwargs["dropout_p"] = dropout_p
|
||||
if scale is not None:
|
||||
kwargs["scale"] = scale
|
||||
return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs)
|
||||
|
||||
|
||||
def make_grouped_attn_pair(
|
||||
*,
|
||||
repeat_kv: bool,
|
||||
is_causal: bool,
|
||||
has_scale: bool,
|
||||
enable_gqa: bool,
|
||||
has_attn_mask: bool,
|
||||
has_dropout: bool,
|
||||
) -> Tuple[Callable, Callable, List[str]]:
|
||||
"""
|
||||
Returns (pattern_fn, replacement_fn, argnames) with exact positional parity.
|
||||
|
||||
Arg order rules:
|
||||
Base: (q, k, v)
|
||||
+repeat_kv -> insert n_rep after (q, k, v)
|
||||
+attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v)
|
||||
+dropout -> include dropout_p after attn_mask or after n_rep/base if no attn_mask
|
||||
+scale -> include scale last
|
||||
"""
|
||||
argnames: List[str] = ["q", "k", "v"]
|
||||
if repeat_kv:
|
||||
argnames.append("n_rep")
|
||||
if has_attn_mask:
|
||||
argnames.append("attn_mask")
|
||||
if has_dropout:
|
||||
argnames.append("dropout_p")
|
||||
if has_scale:
|
||||
argnames.append("scale")
|
||||
|
||||
def pattern_fn(*args):
|
||||
if len(args) != len(argnames):
|
||||
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
|
||||
m = dict(zip(argnames, args))
|
||||
|
||||
q = m["q"]
|
||||
k = m["k"]
|
||||
v = m["v"]
|
||||
|
||||
if repeat_kv:
|
||||
n_rep = m["n_rep"]
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
|
||||
|
||||
return _call_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
is_causal=is_causal,
|
||||
enable_gqa=enable_gqa,
|
||||
attn_mask=m.get("attn_mask"),
|
||||
dropout_p=m.get("dropout_p"),
|
||||
scale=m.get("scale"),
|
||||
)
|
||||
|
||||
# Replacement: torch_attention.default mirroring the positional signature exactly.
|
||||
# We do NOT pass enable_gqa here (it’s SDPA-only). We accept n_rep to mirror signature,
|
||||
# but we don’t need to use it in the replacement graph.
|
||||
def replacement_fn(*args):
|
||||
if len(args) != len(argnames):
|
||||
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
|
||||
m = dict(zip(argnames, args))
|
||||
return _call_attn(
|
||||
m["q"],
|
||||
m["k"],
|
||||
m["v"],
|
||||
is_causal=is_causal,
|
||||
attn_mask=m.get("attn_mask"),
|
||||
dropout_p=m.get("dropout_p"),
|
||||
scale=m.get("scale"),
|
||||
)
|
||||
|
||||
# Pattern matcher needs to see explicit arg names
|
||||
_attach_signature(pattern_fn, argnames)
|
||||
_attach_signature(replacement_fn, argnames)
|
||||
|
||||
return pattern_fn, replacement_fn, argnames
|
||||
|
||||
|
||||
def generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern: Callable):
|
||||
"""
|
||||
Auto-generate all grouped attention patterns across these axes:
|
||||
1) repeat_kv: [False, True]
|
||||
2) is_causal: [False, True]
|
||||
3) has_scale: [False, True]
|
||||
4) enable_gqa: [False, True] (only a kwarg to SDPA side)
|
||||
5) has_attn_mask: [False, True]
|
||||
6) has_dropout: [False, True]
|
||||
|
||||
For each valid combo, we:
|
||||
- build pattern/replacement functions with exact-arg parity
|
||||
- build dummy args matching the signature (with CUDA fp16 tensors etc.)
|
||||
- build scalar_workaround dict for any scalars/n_rep present
|
||||
- call register_ad_pattern(...)
|
||||
"""
|
||||
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
|
||||
k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
|
||||
v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
|
||||
attn_mask_tensor = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16)
|
||||
|
||||
dropout_val = 0.12345
|
||||
scale_val = 0.56789
|
||||
n_rep_val = 7
|
||||
|
||||
total = 0
|
||||
axes = ((False, True),) * 6
|
||||
for repeat_kv, is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes):
|
||||
pat_fn, rep_fn, argnames = make_grouped_attn_pair(
|
||||
repeat_kv=repeat_kv,
|
||||
is_causal=is_causal,
|
||||
has_scale=has_scale,
|
||||
enable_gqa=enable_gqa,
|
||||
has_attn_mask=has_attn_mask,
|
||||
has_dropout=has_dropout,
|
||||
)
|
||||
|
||||
# Build dummy args in the same positional order
|
||||
value_map = {
|
||||
"q": q,
|
||||
"k": k1,
|
||||
"v": v1,
|
||||
"n_rep": n_rep_val,
|
||||
"attn_mask": attn_mask_tensor,
|
||||
"dropout_p": dropout_val,
|
||||
"scale": scale_val,
|
||||
}
|
||||
dummy_args: List[object] = []
|
||||
for name in argnames:
|
||||
try:
|
||||
dummy_args.append(value_map[name])
|
||||
except KeyError:
|
||||
raise RuntimeError(f"Unexpected arg name: {name}")
|
||||
|
||||
scalar_names = {"n_rep", "dropout_p", "scale"}
|
||||
scalar_workaround: Dict[str, object] = {
|
||||
n: value_map[n] for n in argnames if n in scalar_names
|
||||
}
|
||||
if not scalar_workaround:
|
||||
scalar_workaround = None
|
||||
|
||||
register_ad_pattern(
|
||||
search_fn=pat_fn,
|
||||
replace_fn=rep_fn,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args,
|
||||
scalar_workaround=scalar_workaround,
|
||||
)
|
||||
total += 1
|
||||
return total
|
||||
|
||||
|
||||
@TransformRegistry.register("match_grouped_attention")
|
||||
class MatchGroupedAttention(BaseTransform):
|
||||
"""
|
||||
Match and replace the grouped attention pattern with
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa.
|
||||
torch.ops.auto_deploy.torch_attention.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
@ -523,99 +543,10 @@ class MatchGroupedAttention(BaseTransform):
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
def register_grouped_attention(patterns: ADPatternMatcherPass):
|
||||
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
|
||||
k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
|
||||
v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
|
||||
attn_mask = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16)
|
||||
dropout = 0.12345
|
||||
scale = 0.56789
|
||||
n_rep = 7
|
||||
|
||||
dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale]
|
||||
dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale]
|
||||
dummy_args_3 = [q, k1, v1, n_rep, attn_mask]
|
||||
dummy_args_4 = [q, k1, v1, dropout, scale]
|
||||
dummy_args_5 = [q, k1, v1, n_rep, dropout]
|
||||
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_1,
|
||||
replace_fn=_grouped_attn_replacement_1,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_1,
|
||||
scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_2,
|
||||
replace_fn=_grouped_attn_replacement_2,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_2,
|
||||
scalar_workaround={
|
||||
"scale": scale,
|
||||
"dropout_p": dropout,
|
||||
},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_3,
|
||||
replace_fn=_grouped_attn_replacement_3,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_1,
|
||||
scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_4,
|
||||
replace_fn=_grouped_attn_replacement_4,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_2,
|
||||
scalar_workaround={
|
||||
"scale": scale,
|
||||
"dropout_p": dropout,
|
||||
},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_5,
|
||||
replace_fn=_grouped_attn_replacement_5,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_3,
|
||||
scalar_workaround={"n_rep": n_rep},
|
||||
)
|
||||
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_6,
|
||||
replace_fn=_grouped_attn_replacement_6,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_2,
|
||||
scalar_workaround={"scale": scale, "dropout_p": dropout},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_7,
|
||||
replace_fn=_grouped_attn_replacement_7,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_2,
|
||||
scalar_workaround={"scale": scale, "dropout_p": dropout},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_8,
|
||||
replace_fn=_grouped_attn_replacement_8,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_4,
|
||||
scalar_workaround={"scale": scale, "dropout_p": dropout},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_9,
|
||||
replace_fn=_grouped_attn_replacement_9,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_4,
|
||||
scalar_workaround={"scale": scale, "dropout_p": dropout},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_10,
|
||||
replace_fn=_grouped_attn_replacement_10,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_5,
|
||||
scalar_workaround={"dropout_p": dropout, "n_rep": n_rep},
|
||||
)
|
||||
return generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern)
|
||||
|
||||
num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention)
|
||||
|
||||
if num_grouped_patterns == 0:
|
||||
ad_logger.warning(
|
||||
"Fail to find any Group Attention Pattern, output or performance may be incorrect"
|
||||
@ -627,10 +558,146 @@ class MatchGroupedAttention(BaseTransform):
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
def _call_torch_attention(
|
||||
q, k, v, *, is_causal, layout, attn_mask=None, dropout_p=None, scale=None
|
||||
):
|
||||
kwargs = {"is_causal": is_causal, "layout": layout}
|
||||
if attn_mask is not None:
|
||||
kwargs["attn_mask"] = attn_mask
|
||||
if dropout_p is not None:
|
||||
kwargs["dropout_p"] = dropout_p
|
||||
if scale is not None:
|
||||
kwargs["scale"] = scale
|
||||
return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs)
|
||||
|
||||
|
||||
def make_attn_bnsd_pair(
|
||||
*,
|
||||
has_attn_mask: bool,
|
||||
has_dropout: bool,
|
||||
is_causal: bool,
|
||||
has_scale: bool,
|
||||
) -> Tuple[Callable, Callable, List[str], str, str]:
|
||||
argnames: List[str] = ["q", "k", "v"]
|
||||
if has_attn_mask:
|
||||
argnames.append("attn_mask")
|
||||
if has_dropout:
|
||||
argnames.append("dropout_p")
|
||||
if has_scale:
|
||||
argnames.append("scale")
|
||||
|
||||
def pattern_fn(*args):
|
||||
if len(args) != len(argnames):
|
||||
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
|
||||
m = dict(zip(argnames, args))
|
||||
return _call_torch_attention(
|
||||
m["q"],
|
||||
m["k"],
|
||||
m["v"],
|
||||
is_causal=is_causal,
|
||||
layout="bnsd",
|
||||
attn_mask=m.get("attn_mask"),
|
||||
dropout_p=m.get("dropout_p"),
|
||||
scale=m.get("scale"),
|
||||
)
|
||||
|
||||
def replacement_fn(*args):
|
||||
if len(args) != len(argnames):
|
||||
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
|
||||
m = dict(zip(argnames, args))
|
||||
q_b = torch.ops.aten.transpose.int(m["q"], 1, 2)
|
||||
k_b = torch.ops.aten.transpose.int(m["k"], 1, 2)
|
||||
v_b = torch.ops.aten.transpose.int(m["v"], 1, 2)
|
||||
out_b = _call_torch_attention(
|
||||
q_b,
|
||||
k_b,
|
||||
v_b,
|
||||
is_causal=is_causal,
|
||||
layout="bsnd",
|
||||
attn_mask=m.get("attn_mask"),
|
||||
dropout_p=m.get("dropout_p"),
|
||||
scale=m.get("scale"),
|
||||
)
|
||||
return torch.ops.aten.transpose.int(out_b, 1, 2)
|
||||
|
||||
# Pattern matcher needs to see explicit arg names
|
||||
_attach_signature(pattern_fn, argnames)
|
||||
_attach_signature(replacement_fn, argnames)
|
||||
|
||||
return pattern_fn, replacement_fn, argnames
|
||||
|
||||
|
||||
def generate_and_register_attn_layout_patterns(patterns, register_ad_pattern: Callable):
|
||||
"""
|
||||
Enumerate all combinations across:
|
||||
- has_attn_mask in {False, True}
|
||||
- has_dropout in {False, True}
|
||||
- is_causal in {False, True}
|
||||
- has_scale in {False, True}
|
||||
Register each pattern/replacement with appropriate dummy args and scalar workarounds.
|
||||
"""
|
||||
# Dummy tensors in BNSD
|
||||
bs, n_heads, s_q, head_dim = 8, 8, 16, 64
|
||||
q = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16)
|
||||
k = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16)
|
||||
v = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16)
|
||||
attn_mask = torch.randn(bs, n_heads, 1, s_q, device="cuda", dtype=torch.float16)
|
||||
|
||||
dropout_p = 0.12345
|
||||
scale_val = 0.56789
|
||||
|
||||
total = 0
|
||||
axes = ((False, True),) * 4
|
||||
for has_attn_mask, has_dropout, is_causal, has_scale in product(*axes):
|
||||
pat_fn, rep_fn, argnames = make_attn_bnsd_pair(
|
||||
has_attn_mask=has_attn_mask,
|
||||
has_dropout=has_dropout,
|
||||
is_causal=is_causal,
|
||||
has_scale=has_scale,
|
||||
)
|
||||
|
||||
# Build dummy args following positional signature
|
||||
value_map = {
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"attn_mask": attn_mask,
|
||||
"dropout_p": dropout_p,
|
||||
"scale": scale_val,
|
||||
}
|
||||
dummy_args: List[object] = []
|
||||
for name in argnames:
|
||||
try:
|
||||
dummy_args.append(value_map[name])
|
||||
except KeyError:
|
||||
raise RuntimeError(f"Unexpected arg name: {name}")
|
||||
|
||||
# Scalar workaround for present scalars only
|
||||
scalar_names = {"dropout_p", "scale"}
|
||||
scalar_workaround: Dict[str, object] = {
|
||||
n: value_map[n] for n in argnames if n in scalar_names
|
||||
}
|
||||
if not scalar_workaround:
|
||||
scalar_workaround = None
|
||||
|
||||
register_ad_pattern(
|
||||
search_fn=pat_fn,
|
||||
replace_fn=rep_fn,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args,
|
||||
scalar_workaround=scalar_workaround,
|
||||
)
|
||||
total += 1
|
||||
return total
|
||||
|
||||
|
||||
def register_match_attn_layout(patterns: ADPatternMatcherPass):
|
||||
return generate_and_register_attn_layout_patterns(patterns, register_ad_pattern)
|
||||
|
||||
|
||||
class MatchAttentionLayoutConfig(TransformConfig):
|
||||
"""Configuration for the insert cached attention transform."""
|
||||
|
||||
@ -640,13 +707,8 @@ class MatchAttentionLayoutConfig(TransformConfig):
|
||||
@TransformRegistry.register("match_attention_layout")
|
||||
class MatchAttentionLayout(BaseTransform):
|
||||
"""
|
||||
Match and transform attention operations to match the layout expected by the attention backend.
|
||||
|
||||
If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which
|
||||
is the default for SDPA operations, we don't need to transform anything.
|
||||
|
||||
If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert
|
||||
appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa.
|
||||
Convert unified torch_attention calls from layout='bnsd' (explicit, positional or default)
|
||||
into layout='bsnd' + correct Q/K/V transposes, and transpose the output back to bnsd.
|
||||
"""
|
||||
|
||||
config: MatchAttentionLayoutConfig
|
||||
@ -662,82 +724,26 @@ class MatchAttentionLayout(BaseTransform):
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# Get attention layout from attention_op
|
||||
attention_layout = self.config.attention_op.get_attention_layout()
|
||||
|
||||
# List of SDPA operations to look for
|
||||
sdpa_ops = {
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
|
||||
}
|
||||
if attention_layout not in ("bnsd", "bsnd"):
|
||||
raise ValueError(f"Unsupported attention layout: {attention_layout}")
|
||||
|
||||
graph = gm.graph
|
||||
num_bsnd_patterns = 0
|
||||
# If backend expects bnsd, nothing to do.
|
||||
if attention_layout == "bnsd":
|
||||
return gm, TransformInfo(
|
||||
skipped=False, num_matches=0, is_clean=False, has_valid_shapes=False
|
||||
)
|
||||
|
||||
# Look for SDPA operations
|
||||
for sdpa_node in list(graph.nodes):
|
||||
if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops):
|
||||
continue
|
||||
|
||||
ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}")
|
||||
|
||||
# Extract q, k, v inputs
|
||||
q, k, v = sdpa_node.args[:3]
|
||||
|
||||
# Check if we need to transpose the inputs
|
||||
if attention_layout == "bsnd":
|
||||
# Add transposes before the node (from bnsd to bsnd)
|
||||
with graph.inserting_before(sdpa_node):
|
||||
q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2))
|
||||
k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2))
|
||||
v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2))
|
||||
|
||||
# Preserve fake tensor in meta["val"] for the transposed inputs
|
||||
q_updated.meta["val"] = q.meta["val"].transpose(1, 2)
|
||||
k_updated.meta["val"] = k.meta["val"].transpose(1, 2)
|
||||
v_updated.meta["val"] = v.meta["val"].transpose(1, 2)
|
||||
elif attention_layout == "bnsd":
|
||||
# we don't need to do anything...
|
||||
q_updated = q
|
||||
k_updated = k
|
||||
v_updated = v
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention layout: {attention_layout}")
|
||||
|
||||
# Create bsnd_grouped_sdpa node with the same args as the original node
|
||||
# but using the transposed inputs
|
||||
with graph.inserting_before(sdpa_node):
|
||||
source_sdpa_node = graph.call_function(
|
||||
self.config.attention_op.get_source_attention_op(),
|
||||
args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:],
|
||||
kwargs=sdpa_node.kwargs,
|
||||
)
|
||||
|
||||
# Check if need to update the output node to match the layout
|
||||
if attention_layout == "bsnd":
|
||||
# Add transpose for the output (from bsnd back to bnsd)
|
||||
with graph.inserting_after(source_sdpa_node):
|
||||
output_updated = graph.call_function(
|
||||
torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2)
|
||||
)
|
||||
|
||||
# Preserve fake tensor in meta["val"] for the transposed inputs
|
||||
source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous()
|
||||
output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2)
|
||||
elif attention_layout == "bnsd":
|
||||
output_updated = source_sdpa_node
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention layout: {attention_layout}")
|
||||
|
||||
# Replace the old node with the transposed output
|
||||
sdpa_node.replace_all_uses_with(output_updated)
|
||||
|
||||
num_bsnd_patterns += 1
|
||||
num_matches = _apply_pattern(
|
||||
gm, "MatchAttentionLayout(bnsd→bsnd)", register_match_attn_layout
|
||||
)
|
||||
|
||||
# If we changed any attention calls, the shapes may change around the transposes; flag for shape prop.
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_bsnd_patterns,
|
||||
num_matches=num_matches,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
@ -60,8 +60,9 @@ def fake_profiler_mha(
|
||||
v_fake = graph.placeholder(name="v_fake")
|
||||
v_fake.meta["val"] = torch.empty_like(value.transpose(1, 2), device="meta", dtype=value.dtype)
|
||||
|
||||
node_kwargs["layout"] = "bsnd"
|
||||
module._node_ref = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention,
|
||||
args=(q_fake, k_fake, v_fake),
|
||||
kwargs=node_kwargs,
|
||||
)
|
||||
|
||||
@ -479,8 +479,7 @@ def detect_column_row_shard(
|
||||
# acceptable attention nodes between sharded GEMMs
|
||||
shardable_attention_nodes = {
|
||||
torch.ops.auto_deploy.torch_attention_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention,
|
||||
}
|
||||
|
||||
# This is a heuristic. Basically, we assume those are okay to shard if we also encounter an
|
||||
|
||||
@ -90,7 +90,7 @@ class GQA_Block(nn.Module):
|
||||
k = self.k_proj(x).view(b, s, -1, self.head_dim)
|
||||
v = self.v_proj(x).view(b, s, -1, self.head_dim)
|
||||
|
||||
y = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(q, k, v, is_causal=True)
|
||||
y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd")
|
||||
y = y.contiguous().view(b, s, -1)
|
||||
|
||||
return self.o_proj(y)
|
||||
|
||||
@ -741,14 +741,12 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
# We should find 1 instance of torch_attention_grouped_sdpa
|
||||
# We should find 1 instance of torch_attention
|
||||
expected_matches = 1
|
||||
|
||||
def verify_matcher(gm):
|
||||
grouped_sdpa_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention)
|
||||
]
|
||||
|
||||
# Check that we have the expected number of replacements
|
||||
@ -879,7 +877,7 @@ class CausalAttentionModel(torch.nn.Module):
|
||||
|
||||
# Choose the appropriate attention implementation
|
||||
if self.use_grouped_sdpa:
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -985,7 +983,7 @@ class Llama3CausalAttentionModel(torch.nn.Module):
|
||||
|
||||
# Choose the appropriate attention implementation
|
||||
if self.use_grouped_sdpa:
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1089,7 +1087,7 @@ class AttentionLayoutModel(torch.nn.Module):
|
||||
|
||||
# Apply scaled dot product attention
|
||||
if self.use_grouped_sdpa:
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1136,7 +1134,7 @@ class BsndAttentionModel(AttentionLayoutModel):
|
||||
attn_mask = self._get_attn_mask(x) if self.has_mask else None
|
||||
|
||||
# Apply bsnd_grouped_sdpa directly
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa.default(
|
||||
attn_output = torch.ops.auto_deploy.torch_attention.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1144,6 +1142,7 @@ class BsndAttentionModel(AttentionLayoutModel):
|
||||
dropout_p=0.0,
|
||||
is_causal=True,
|
||||
scale=1.0 / (self.head_dim**0.5),
|
||||
layout="bsnd",
|
||||
)
|
||||
|
||||
# Reshape output for the linear projection (no transpose needed)
|
||||
@ -1173,11 +1172,11 @@ def test_match_attention_layout(layout, model_config, has_mask):
|
||||
MockAttentionDescriptor.layout = layout
|
||||
if layout == "bnsd":
|
||||
if model_config.get("use_grouped_sdpa"):
|
||||
source_op = torch.ops.auto_deploy.torch_attention_grouped_sdpa
|
||||
source_op = torch.ops.auto_deploy.torch_attention
|
||||
else:
|
||||
source_op = torch.ops.auto_deploy.torch_attention_sdpa
|
||||
else:
|
||||
source_op = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
source_op = torch.ops.auto_deploy.torch_attention
|
||||
MockAttentionDescriptor.source_attention_op = source_op
|
||||
|
||||
# Create appropriate model based on model_config
|
||||
@ -1210,7 +1209,8 @@ def test_match_attention_layout(layout, model_config, has_mask):
|
||||
original_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention)
|
||||
and not (isinstance(n.args[-1], str) and n.args[-1] == "bsnd")
|
||||
]
|
||||
else:
|
||||
original_nodes = [
|
||||
@ -1224,7 +1224,11 @@ def test_match_attention_layout(layout, model_config, has_mask):
|
||||
bsnd_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa)
|
||||
if (
|
||||
is_op(n, torch.ops.auto_deploy.torch_attention)
|
||||
and isinstance(n.args[-1], str)
|
||||
and n.args[-1] == "bsnd"
|
||||
)
|
||||
]
|
||||
transpose_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.aten.transpose.int)]
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import Atten
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@ -31,7 +32,7 @@ class MockAttentionDescriptor(AttentionDescriptor):
|
||||
|
||||
@classmethod
|
||||
def get_source_attention_op(cls) -> Callable:
|
||||
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
return torch.ops.auto_deploy.torch_attention
|
||||
|
||||
|
||||
class HFWrapper(nn.Module):
|
||||
@ -82,12 +83,18 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str)
|
||||
pytest.skip("https://nvbugspro.nvidia.com/bug/5170222")
|
||||
|
||||
def verify_matcher(gm: GraphModule):
|
||||
"""Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
"""Ensure that there is exactly one torch.ops.auto_deploy.torch_attention (layout="bsnd")
|
||||
call in the graph. Also check that there is no repeat_kv pattern left.
|
||||
"""
|
||||
nodes = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
)
|
||||
nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if (
|
||||
is_op(n, torch.ops.auto_deploy.torch_attention)
|
||||
and isinstance(n.args[-1], str)
|
||||
and n.args[-1] == "bsnd"
|
||||
)
|
||||
]
|
||||
assert len(nodes) == 1, "Expected exactly one bsnd_grouped_sdpa call in the graph"
|
||||
|
||||
# TODO: check non-qkv args of node
|
||||
|
||||
@ -74,8 +74,18 @@ class GQAWithSdpa(GQA):
|
||||
v = v.view(b, s, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# Use grouped SDPA in bsnd layout
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(
|
||||
q, k, v, None, 0.0, True, None
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=True,
|
||||
scale=None,
|
||||
sinks=None,
|
||||
sliding_window=None,
|
||||
logit_cap=None,
|
||||
layout="bsnd",
|
||||
)
|
||||
|
||||
# SDPA output is already in [b, s, n, h_d] format
|
||||
|
||||
Loading…
Reference in New Issue
Block a user