[None][autodeploy] small refactors on attention matching (#8079)

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Frida Hou 2025-10-03 22:00:27 -07:00 committed by GitHub
parent 88ea2c4ee9
commit 744246d316
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 462 additions and 418 deletions

View File

@ -12,10 +12,9 @@ The table below lists the operators ordered by their backend.
|--------------|-------------|
| `torch.ops.auto_deploy.flashinfer_attention_mha_with_cache` | FlashInfer attention with KV cache support |
| `torch.ops.auto_deploy.flashinfer_rope` | FlashInfer RoPE (Rotary Position Embedding) implementation |
| `torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa` | Grouped SDPA (Scaled Dot Product Attention) with BSND format |
| `torch.ops.auto_deploy.torch_attention_deepseek_fused_mla` | DeepSeek fused MLA (Multi-head Linear Attention) |
| `torch.ops.auto_deploy.torch_attention_deepseek_mla` | DeepSeek MLA implementation |
| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation |
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |

View File

@ -355,7 +355,7 @@ class FlashInferAttention(AttentionDescriptor):
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
"""Get the source attention op that we target for replacement."""
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
@ -399,6 +399,21 @@ class FlashInferAttention(AttentionDescriptor):
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
# Sanity check: layout == "bsnd"
# Prefer kwargs; fall back to the final positional arg if it's a string.
layout = source_attn_node.kwargs.get("layout", None)
if (
layout is None
and len(source_attn_node.args) > 0
and isinstance(source_attn_node.args[-1], str)
):
layout = source_attn_node.args[-1]
if layout != "bsnd":
raise RuntimeError(
f"Expected torch_attention layout='bsnd' but got {layout!r} "
f"for node: {source_attn_node.format_node()}"
)
# Double check other arguments
attn_mask, dropout_p, is_causal = extract_op_args(
source_attn_node, "attn_mask", "dropout_p", "is_causal"

View File

@ -91,8 +91,9 @@ def scaled_dot_product_attention_fake(
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
@torch.library.custom_op("auto_deploy::torch_attention_grouped_sdpa", mutates_args=())
def grouped_sdpa(
# Unified attention op
@torch.library.custom_op("auto_deploy::torch_attention", mutates_args=())
def torch_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@ -103,8 +104,25 @@ def grouped_sdpa(
sinks: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
logit_cap: Optional[float] = None,
layout: str = "bnsd", # "bnsd" or "bsnd"
) -> torch.Tensor:
"""SDPA attention that can handle GQA. Expects bnsd format inputs."""
"""
SDPA attention (with optional GQA) that supports two memory layouts via `layout`:
- "bnsd": [batch, num_heads, seq_len, head_dim]
- "bsnd": [batch, seq_len, num_heads, head_dim]
The `attn_mask` is always interpreted as [b, n, s_q, s_k].
Returns a tensor in the SAME layout as inputs specified by `layout`.
"""
if layout not in ("bnsd", "bsnd"):
raise ValueError(f"layout must be 'bnsd' or 'bsnd', got {layout!r}")
if layout == "bsnd":
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
@ -188,72 +206,26 @@ def grouped_sdpa(
if dropout_p > 0.0:
attn_out = F.dropout(attn_out, p=dropout_p, training=False)
# Return in bnsd format (same as input format)
return attn_out
if layout == "bsnd":
return attn_out.transpose(1, 2).contiguous()
else:
return attn_out.contiguous()
@grouped_sdpa.register_fake
def grouped_sdpa_fake(
@torch_attention.register_fake
def torch_attention_fake(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
sinks=None,
sliding_window=None,
logit_cap=None,
):
"""Fake implementation of grouped SDPA."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
def bsnd_grouped_sdpa(
query: torch.Tensor, # layout: [b, s_q, n, d]
key: torch.Tensor, # layout: [b, s_k, n, d]
value: torch.Tensor, # layout: [b, s_k, n, d]
attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k]
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
sinks: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
logit_cap: Optional[float] = None,
) -> torch.Tensor:
"""Attention that assumes the input layout is bsnd.
Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
original sdpa op!
"""
# Transpose inputs to bnsd format for grouped_sdpa
query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d]
key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
# Call grouped_sdpa with bnsd inputs
out = grouped_sdpa(
query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap
)
# Transpose back to bsnd format
return out.transpose(1, 2).contiguous()
@bsnd_grouped_sdpa.register_fake
def bsnd_grouped_sdpa_fake(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
sinks=None,
sliding_window=None,
logit_cap=None,
layout: str = "bnsd",
):
"""Fake implementation of bnsd grouped SDPA."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()

View File

@ -409,7 +409,7 @@ class TorchBackendAttention(AttentionDescriptor):
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
@ -460,6 +460,21 @@ class TorchBackendAttention(AttentionDescriptor):
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
# Sanity check: layout == "bsnd"
# Prefer kwargs; fall back to the final positional arg if it's a string.
layout = source_attn_node.kwargs.get("layout", None)
if (
layout is None
and len(source_attn_node.args) > 0
and isinstance(source_attn_node.args[-1], str)
):
layout = source_attn_node.args[-1]
if layout != "bsnd":
raise RuntimeError(
f"Expected torch_attention layout='bsnd' but got {layout!r} "
f"for node: {source_attn_node.format_node()}"
)
# Check other arguments
attn_mask, dropout_p, is_causal = extract_op_args(
source_attn_node, "attn_mask", "dropout_p", "is_causal"

View File

@ -339,7 +339,7 @@ class TritonAttention(AttentionDescriptor):
@classmethod
def get_source_attention_op(cls) -> OpOverloadPacket:
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention
@classmethod
def get_cached_attention_op(cls) -> MHACallable:
@ -390,6 +390,21 @@ class TritonAttention(AttentionDescriptor):
@classmethod
def get_constants(cls, source_attn_node: Node) -> List[Constant]:
# Sanity check: layout == "bsnd"
# Prefer kwargs; fall back to the final positional arg if it's a string.
layout = source_attn_node.kwargs.get("layout", None)
if (
layout is None
and len(source_attn_node.args) > 0
and isinstance(source_attn_node.args[-1], str)
):
layout = source_attn_node.args[-1]
if layout != "bsnd":
raise RuntimeError(
f"Expected torch_attention layout='bsnd' but got {layout!r} "
f"for node: {source_attn_node.format_node()}"
)
# retrieve head_dim from k_fake
attn_mask, dropout_p, is_causal = extract_op_args(
source_attn_node, "attn_mask", "dropout_p", "is_causal"

View File

@ -59,7 +59,7 @@ def gpt_oss_attention(
sinks = self.sinks
# Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim)
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(
attn_output = torch.ops.auto_deploy.torch_attention(
query_states,
key_states,
value_states,
@ -69,6 +69,7 @@ def gpt_oss_attention(
scale=self.scaling,
sinks=sinks,
sliding_window=sliding_window,
layout="bsnd",
)
# Reshape back to original input shape

View File

@ -1,5 +1,7 @@
"""Pattern matching for detecting repeat_kv, eager, grouped attention patterns from Huggingface models."""
from inspect import Parameter, Signature
from itertools import product
from typing import Any, Callable, Dict, List, Tuple, Type
import torch
@ -11,7 +13,6 @@ from ...custom_ops.attention_interface import AttentionDescriptor
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ...utils.node_utils import is_op
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ..interface import (
BaseTransform,
@ -270,170 +271,6 @@ def _get_sfdp_patterns() -> List[Dict[str, Any]]:
return patterns
def _grouped_attn_pattern_1(q, k, v, n_rep, attn_mask, dropout_p, scale):
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)
def _grouped_attn_replacement_1(q, k, v, n_rep, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)
# Only expose torch_attention_grouped_sdpa after the transformation
def _grouped_attn_pattern_2(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)
def _grouped_attn_replacement_2(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)
def _grouped_attn_pattern_3(q, k, v, n_rep, attn_mask, dropout_p, scale):
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)
def _grouped_attn_replacement_3(q, k, v, n_rep, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)
# Only expose torch_attention_grouped_sdpa after the transformation
def _grouped_attn_pattern_4(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)
def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)
def _grouped_attn_pattern_5(q, k, v, n_rep, attn_mask):
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
return torch.ops.auto_deploy.torch_attention_sdpa.default(q, k, v, attn_mask)
def _grouped_attn_replacement_5(q, k, v, n_rep, attn_mask):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(q, k, v, attn_mask)
def _grouped_attn_pattern_6(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=False,
scale=scale,
enable_gqa=True,
)
def _grouped_attn_replacement_6(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)
def _grouped_attn_pattern_7(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=True,
scale=scale,
enable_gqa=True,
)
def _grouped_attn_replacement_7(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)
def _grouped_attn_pattern_8(q, k, v, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q,
k,
v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=False,
scale=scale,
enable_gqa=True,
)
def _grouped_attn_replacement_8(q, k, v, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=False, scale=scale
)
def _grouped_attn_pattern_9(q, k, v, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q,
k,
v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True,
scale=scale,
enable_gqa=True,
)
def _grouped_attn_replacement_9(q, k, v, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=None, dropout_p=dropout_p, is_causal=True, scale=scale
)
def _grouped_attn_pattern_10(q, k, v, n_rep, dropout_p):
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q,
k,
v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True,
)
def _grouped_attn_replacement_10(q, k, v, n_rep, dropout_p):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q,
k,
v,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True,
)
@TransformRegistry.register("match_repeat_kv")
class MatchRepeatKV(BaseTransform):
"""
@ -508,11 +345,194 @@ class MatchEagerAttention(BaseTransform):
return gm, info
def _attach_signature(fn: Callable, argnames: List[str]) -> Callable:
# Make FX "see" q,k,v[,attn_mask][,dropout_p][,scale] even though fn(*args) internally
params = [Parameter(n, kind=Parameter.POSITIONAL_OR_KEYWORD) for n in argnames]
fn.__signature__ = Signature(parameters=params)
return fn
def _call_sdpa(
q, k, v, *, is_causal: bool, enable_gqa: bool, attn_mask=None, dropout_p=None, scale=None
):
kwargs = {"is_causal": is_causal}
if attn_mask is not None:
kwargs["attn_mask"] = attn_mask
if dropout_p is not None:
kwargs["dropout_p"] = dropout_p
if scale is not None:
kwargs["scale"] = scale
if enable_gqa:
kwargs["enable_gqa"] = True
return torch.ops.auto_deploy.torch_attention_sdpa.default(q, k, v, **kwargs)
def _call_attn(q, k, v, *, is_causal: bool, attn_mask=None, dropout_p=None, scale=None):
kwargs = {"is_causal": is_causal}
if attn_mask is not None:
kwargs["attn_mask"] = attn_mask
if dropout_p is not None:
kwargs["dropout_p"] = dropout_p
if scale is not None:
kwargs["scale"] = scale
return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs)
def make_grouped_attn_pair(
*,
repeat_kv: bool,
is_causal: bool,
has_scale: bool,
enable_gqa: bool,
has_attn_mask: bool,
has_dropout: bool,
) -> Tuple[Callable, Callable, List[str]]:
"""
Returns (pattern_fn, replacement_fn, argnames) with exact positional parity.
Arg order rules:
Base: (q, k, v)
+repeat_kv -> insert n_rep after (q, k, v)
+attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v)
+dropout -> include dropout_p after attn_mask or after n_rep/base if no attn_mask
+scale -> include scale last
"""
argnames: List[str] = ["q", "k", "v"]
if repeat_kv:
argnames.append("n_rep")
if has_attn_mask:
argnames.append("attn_mask")
if has_dropout:
argnames.append("dropout_p")
if has_scale:
argnames.append("scale")
def pattern_fn(*args):
if len(args) != len(argnames):
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
m = dict(zip(argnames, args))
q = m["q"]
k = m["k"]
v = m["v"]
if repeat_kv:
n_rep = m["n_rep"]
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
return _call_sdpa(
q,
k,
v,
is_causal=is_causal,
enable_gqa=enable_gqa,
attn_mask=m.get("attn_mask"),
dropout_p=m.get("dropout_p"),
scale=m.get("scale"),
)
# Replacement: torch_attention.default mirroring the positional signature exactly.
# We do NOT pass enable_gqa here (its SDPA-only). We accept n_rep to mirror signature,
# but we dont need to use it in the replacement graph.
def replacement_fn(*args):
if len(args) != len(argnames):
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
m = dict(zip(argnames, args))
return _call_attn(
m["q"],
m["k"],
m["v"],
is_causal=is_causal,
attn_mask=m.get("attn_mask"),
dropout_p=m.get("dropout_p"),
scale=m.get("scale"),
)
# Pattern matcher needs to see explicit arg names
_attach_signature(pattern_fn, argnames)
_attach_signature(replacement_fn, argnames)
return pattern_fn, replacement_fn, argnames
def generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern: Callable):
"""
Auto-generate all grouped attention patterns across these axes:
1) repeat_kv: [False, True]
2) is_causal: [False, True]
3) has_scale: [False, True]
4) enable_gqa: [False, True] (only a kwarg to SDPA side)
5) has_attn_mask: [False, True]
6) has_dropout: [False, True]
For each valid combo, we:
- build pattern/replacement functions with exact-arg parity
- build dummy args matching the signature (with CUDA fp16 tensors etc.)
- build scalar_workaround dict for any scalars/n_rep present
- call register_ad_pattern(...)
"""
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
attn_mask_tensor = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16)
dropout_val = 0.12345
scale_val = 0.56789
n_rep_val = 7
total = 0
axes = ((False, True),) * 6
for repeat_kv, is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes):
pat_fn, rep_fn, argnames = make_grouped_attn_pair(
repeat_kv=repeat_kv,
is_causal=is_causal,
has_scale=has_scale,
enable_gqa=enable_gqa,
has_attn_mask=has_attn_mask,
has_dropout=has_dropout,
)
# Build dummy args in the same positional order
value_map = {
"q": q,
"k": k1,
"v": v1,
"n_rep": n_rep_val,
"attn_mask": attn_mask_tensor,
"dropout_p": dropout_val,
"scale": scale_val,
}
dummy_args: List[object] = []
for name in argnames:
try:
dummy_args.append(value_map[name])
except KeyError:
raise RuntimeError(f"Unexpected arg name: {name}")
scalar_names = {"n_rep", "dropout_p", "scale"}
scalar_workaround: Dict[str, object] = {
n: value_map[n] for n in argnames if n in scalar_names
}
if not scalar_workaround:
scalar_workaround = None
register_ad_pattern(
search_fn=pat_fn,
replace_fn=rep_fn,
patterns=patterns,
dummy_args=dummy_args,
scalar_workaround=scalar_workaround,
)
total += 1
return total
@TransformRegistry.register("match_grouped_attention")
class MatchGroupedAttention(BaseTransform):
"""
Match and replace the grouped attention pattern with
torch.ops.auto_deploy.torch_attention_grouped_sdpa.
torch.ops.auto_deploy.torch_attention.
"""
def _apply(
@ -523,99 +543,10 @@ class MatchGroupedAttention(BaseTransform):
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
def register_grouped_attention(patterns: ADPatternMatcherPass):
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
attn_mask = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16)
dropout = 0.12345
scale = 0.56789
n_rep = 7
dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale]
dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale]
dummy_args_3 = [q, k1, v1, n_rep, attn_mask]
dummy_args_4 = [q, k1, v1, dropout, scale]
dummy_args_5 = [q, k1, v1, n_rep, dropout]
register_ad_pattern(
search_fn=_grouped_attn_pattern_1,
replace_fn=_grouped_attn_replacement_1,
patterns=patterns,
dummy_args=dummy_args_1,
scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_2,
replace_fn=_grouped_attn_replacement_2,
patterns=patterns,
dummy_args=dummy_args_2,
scalar_workaround={
"scale": scale,
"dropout_p": dropout,
},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_3,
replace_fn=_grouped_attn_replacement_3,
patterns=patterns,
dummy_args=dummy_args_1,
scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_4,
replace_fn=_grouped_attn_replacement_4,
patterns=patterns,
dummy_args=dummy_args_2,
scalar_workaround={
"scale": scale,
"dropout_p": dropout,
},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_5,
replace_fn=_grouped_attn_replacement_5,
patterns=patterns,
dummy_args=dummy_args_3,
scalar_workaround={"n_rep": n_rep},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_6,
replace_fn=_grouped_attn_replacement_6,
patterns=patterns,
dummy_args=dummy_args_2,
scalar_workaround={"scale": scale, "dropout_p": dropout},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_7,
replace_fn=_grouped_attn_replacement_7,
patterns=patterns,
dummy_args=dummy_args_2,
scalar_workaround={"scale": scale, "dropout_p": dropout},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_8,
replace_fn=_grouped_attn_replacement_8,
patterns=patterns,
dummy_args=dummy_args_4,
scalar_workaround={"scale": scale, "dropout_p": dropout},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_9,
replace_fn=_grouped_attn_replacement_9,
patterns=patterns,
dummy_args=dummy_args_4,
scalar_workaround={"scale": scale, "dropout_p": dropout},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_10,
replace_fn=_grouped_attn_replacement_10,
patterns=patterns,
dummy_args=dummy_args_5,
scalar_workaround={"dropout_p": dropout, "n_rep": n_rep},
)
return generate_and_register_grouped_attn_patterns(patterns, register_ad_pattern)
num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention)
if num_grouped_patterns == 0:
ad_logger.warning(
"Fail to find any Group Attention Pattern, output or performance may be incorrect"
@ -627,10 +558,146 @@ class MatchGroupedAttention(BaseTransform):
is_clean=False,
has_valid_shapes=False,
)
return gm, info
def _call_torch_attention(
q, k, v, *, is_causal, layout, attn_mask=None, dropout_p=None, scale=None
):
kwargs = {"is_causal": is_causal, "layout": layout}
if attn_mask is not None:
kwargs["attn_mask"] = attn_mask
if dropout_p is not None:
kwargs["dropout_p"] = dropout_p
if scale is not None:
kwargs["scale"] = scale
return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs)
def make_attn_bnsd_pair(
*,
has_attn_mask: bool,
has_dropout: bool,
is_causal: bool,
has_scale: bool,
) -> Tuple[Callable, Callable, List[str], str, str]:
argnames: List[str] = ["q", "k", "v"]
if has_attn_mask:
argnames.append("attn_mask")
if has_dropout:
argnames.append("dropout_p")
if has_scale:
argnames.append("scale")
def pattern_fn(*args):
if len(args) != len(argnames):
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
m = dict(zip(argnames, args))
return _call_torch_attention(
m["q"],
m["k"],
m["v"],
is_causal=is_causal,
layout="bnsd",
attn_mask=m.get("attn_mask"),
dropout_p=m.get("dropout_p"),
scale=m.get("scale"),
)
def replacement_fn(*args):
if len(args) != len(argnames):
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
m = dict(zip(argnames, args))
q_b = torch.ops.aten.transpose.int(m["q"], 1, 2)
k_b = torch.ops.aten.transpose.int(m["k"], 1, 2)
v_b = torch.ops.aten.transpose.int(m["v"], 1, 2)
out_b = _call_torch_attention(
q_b,
k_b,
v_b,
is_causal=is_causal,
layout="bsnd",
attn_mask=m.get("attn_mask"),
dropout_p=m.get("dropout_p"),
scale=m.get("scale"),
)
return torch.ops.aten.transpose.int(out_b, 1, 2)
# Pattern matcher needs to see explicit arg names
_attach_signature(pattern_fn, argnames)
_attach_signature(replacement_fn, argnames)
return pattern_fn, replacement_fn, argnames
def generate_and_register_attn_layout_patterns(patterns, register_ad_pattern: Callable):
"""
Enumerate all combinations across:
- has_attn_mask in {False, True}
- has_dropout in {False, True}
- is_causal in {False, True}
- has_scale in {False, True}
Register each pattern/replacement with appropriate dummy args and scalar workarounds.
"""
# Dummy tensors in BNSD
bs, n_heads, s_q, head_dim = 8, 8, 16, 64
q = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(bs, n_heads, s_q, head_dim, device="cuda", dtype=torch.float16)
attn_mask = torch.randn(bs, n_heads, 1, s_q, device="cuda", dtype=torch.float16)
dropout_p = 0.12345
scale_val = 0.56789
total = 0
axes = ((False, True),) * 4
for has_attn_mask, has_dropout, is_causal, has_scale in product(*axes):
pat_fn, rep_fn, argnames = make_attn_bnsd_pair(
has_attn_mask=has_attn_mask,
has_dropout=has_dropout,
is_causal=is_causal,
has_scale=has_scale,
)
# Build dummy args following positional signature
value_map = {
"q": q,
"k": k,
"v": v,
"attn_mask": attn_mask,
"dropout_p": dropout_p,
"scale": scale_val,
}
dummy_args: List[object] = []
for name in argnames:
try:
dummy_args.append(value_map[name])
except KeyError:
raise RuntimeError(f"Unexpected arg name: {name}")
# Scalar workaround for present scalars only
scalar_names = {"dropout_p", "scale"}
scalar_workaround: Dict[str, object] = {
n: value_map[n] for n in argnames if n in scalar_names
}
if not scalar_workaround:
scalar_workaround = None
register_ad_pattern(
search_fn=pat_fn,
replace_fn=rep_fn,
patterns=patterns,
dummy_args=dummy_args,
scalar_workaround=scalar_workaround,
)
total += 1
return total
def register_match_attn_layout(patterns: ADPatternMatcherPass):
return generate_and_register_attn_layout_patterns(patterns, register_ad_pattern)
class MatchAttentionLayoutConfig(TransformConfig):
"""Configuration for the insert cached attention transform."""
@ -640,13 +707,8 @@ class MatchAttentionLayoutConfig(TransformConfig):
@TransformRegistry.register("match_attention_layout")
class MatchAttentionLayout(BaseTransform):
"""
Match and transform attention operations to match the layout expected by the attention backend.
If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which
is the default for SDPA operations, we don't need to transform anything.
If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert
appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa.
Convert unified torch_attention calls from layout='bnsd' (explicit, positional or default)
into layout='bsnd' + correct Q/K/V transposes, and transpose the output back to bnsd.
"""
config: MatchAttentionLayoutConfig
@ -662,82 +724,26 @@ class MatchAttentionLayout(BaseTransform):
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
# Get attention layout from attention_op
attention_layout = self.config.attention_op.get_attention_layout()
# List of SDPA operations to look for
sdpa_ops = {
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
}
if attention_layout not in ("bnsd", "bsnd"):
raise ValueError(f"Unsupported attention layout: {attention_layout}")
graph = gm.graph
num_bsnd_patterns = 0
# If backend expects bnsd, nothing to do.
if attention_layout == "bnsd":
return gm, TransformInfo(
skipped=False, num_matches=0, is_clean=False, has_valid_shapes=False
)
# Look for SDPA operations
for sdpa_node in list(graph.nodes):
if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops):
continue
ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}")
# Extract q, k, v inputs
q, k, v = sdpa_node.args[:3]
# Check if we need to transpose the inputs
if attention_layout == "bsnd":
# Add transposes before the node (from bnsd to bsnd)
with graph.inserting_before(sdpa_node):
q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2))
k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2))
v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2))
# Preserve fake tensor in meta["val"] for the transposed inputs
q_updated.meta["val"] = q.meta["val"].transpose(1, 2)
k_updated.meta["val"] = k.meta["val"].transpose(1, 2)
v_updated.meta["val"] = v.meta["val"].transpose(1, 2)
elif attention_layout == "bnsd":
# we don't need to do anything...
q_updated = q
k_updated = k
v_updated = v
else:
raise ValueError(f"Unsupported attention layout: {attention_layout}")
# Create bsnd_grouped_sdpa node with the same args as the original node
# but using the transposed inputs
with graph.inserting_before(sdpa_node):
source_sdpa_node = graph.call_function(
self.config.attention_op.get_source_attention_op(),
args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:],
kwargs=sdpa_node.kwargs,
)
# Check if need to update the output node to match the layout
if attention_layout == "bsnd":
# Add transpose for the output (from bsnd back to bnsd)
with graph.inserting_after(source_sdpa_node):
output_updated = graph.call_function(
torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2)
)
# Preserve fake tensor in meta["val"] for the transposed inputs
source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous()
output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2)
elif attention_layout == "bnsd":
output_updated = source_sdpa_node
else:
raise ValueError(f"Unsupported attention layout: {attention_layout}")
# Replace the old node with the transposed output
sdpa_node.replace_all_uses_with(output_updated)
num_bsnd_patterns += 1
num_matches = _apply_pattern(
gm, "MatchAttentionLayout(bnsd→bsnd)", register_match_attn_layout
)
# If we changed any attention calls, the shapes may change around the transposes; flag for shape prop.
info = TransformInfo(
skipped=False,
num_matches=num_bsnd_patterns,
num_matches=num_matches,
is_clean=False,
has_valid_shapes=False,
)
return gm, info

View File

@ -60,8 +60,9 @@ def fake_profiler_mha(
v_fake = graph.placeholder(name="v_fake")
v_fake.meta["val"] = torch.empty_like(value.transpose(1, 2), device="meta", dtype=value.dtype)
node_kwargs["layout"] = "bsnd"
module._node_ref = graph.call_function(
torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa,
torch.ops.auto_deploy.torch_attention,
args=(q_fake, k_fake, v_fake),
kwargs=node_kwargs,
)

View File

@ -479,8 +479,7 @@ def detect_column_row_shard(
# acceptable attention nodes between sharded GEMMs
shardable_attention_nodes = {
torch.ops.auto_deploy.torch_attention_sdpa,
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa,
torch.ops.auto_deploy.torch_attention,
}
# This is a heuristic. Basically, we assume those are okay to shard if we also encounter an

View File

@ -90,7 +90,7 @@ class GQA_Block(nn.Module):
k = self.k_proj(x).view(b, s, -1, self.head_dim)
v = self.v_proj(x).view(b, s, -1, self.head_dim)
y = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(q, k, v, is_causal=True)
y = torch.ops.auto_deploy.torch_attention(q, k, v, is_causal=True, layout="bsnd")
y = y.contiguous().view(b, s, -1)
return self.o_proj(y)

View File

@ -741,14 +741,12 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
dynamic_shapes = model.get_dynamic_shapes()
# We should find 1 instance of torch_attention_grouped_sdpa
# We should find 1 instance of torch_attention
expected_matches = 1
def verify_matcher(gm):
grouped_sdpa_nodes = [
n
for n in gm.graph.nodes
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention)
]
# Check that we have the expected number of replacements
@ -879,7 +877,7 @@ class CausalAttentionModel(torch.nn.Module):
# Choose the appropriate attention implementation
if self.use_grouped_sdpa:
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
attn_output = torch.ops.auto_deploy.torch_attention(
q,
k,
v,
@ -985,7 +983,7 @@ class Llama3CausalAttentionModel(torch.nn.Module):
# Choose the appropriate attention implementation
if self.use_grouped_sdpa:
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
attn_output = torch.ops.auto_deploy.torch_attention(
q,
k,
v,
@ -1089,7 +1087,7 @@ class AttentionLayoutModel(torch.nn.Module):
# Apply scaled dot product attention
if self.use_grouped_sdpa:
attn_output = torch.ops.auto_deploy.torch_attention_grouped_sdpa(
attn_output = torch.ops.auto_deploy.torch_attention(
q,
k,
v,
@ -1136,7 +1134,7 @@ class BsndAttentionModel(AttentionLayoutModel):
attn_mask = self._get_attn_mask(x) if self.has_mask else None
# Apply bsnd_grouped_sdpa directly
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa.default(
attn_output = torch.ops.auto_deploy.torch_attention.default(
q,
k,
v,
@ -1144,6 +1142,7 @@ class BsndAttentionModel(AttentionLayoutModel):
dropout_p=0.0,
is_causal=True,
scale=1.0 / (self.head_dim**0.5),
layout="bsnd",
)
# Reshape output for the linear projection (no transpose needed)
@ -1173,11 +1172,11 @@ def test_match_attention_layout(layout, model_config, has_mask):
MockAttentionDescriptor.layout = layout
if layout == "bnsd":
if model_config.get("use_grouped_sdpa"):
source_op = torch.ops.auto_deploy.torch_attention_grouped_sdpa
source_op = torch.ops.auto_deploy.torch_attention
else:
source_op = torch.ops.auto_deploy.torch_attention_sdpa
else:
source_op = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
source_op = torch.ops.auto_deploy.torch_attention
MockAttentionDescriptor.source_attention_op = source_op
# Create appropriate model based on model_config
@ -1210,7 +1209,8 @@ def test_match_attention_layout(layout, model_config, has_mask):
original_nodes = [
n
for n in gm.graph.nodes
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
if is_op(n, torch.ops.auto_deploy.torch_attention)
and not (isinstance(n.args[-1], str) and n.args[-1] == "bsnd")
]
else:
original_nodes = [
@ -1224,7 +1224,11 @@ def test_match_attention_layout(layout, model_config, has_mask):
bsnd_nodes = [
n
for n in gm.graph.nodes
if is_op(n, torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa)
if (
is_op(n, torch.ops.auto_deploy.torch_attention)
and isinstance(n.args[-1], str)
and n.args[-1] == "bsnd"
)
]
transpose_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.aten.transpose.int)]

View File

@ -16,6 +16,7 @@ from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import Atten
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
torch.manual_seed(0)
@ -31,7 +32,7 @@ class MockAttentionDescriptor(AttentionDescriptor):
@classmethod
def get_source_attention_op(cls) -> Callable:
return torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
return torch.ops.auto_deploy.torch_attention
class HFWrapper(nn.Module):
@ -82,12 +83,18 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str)
pytest.skip("https://nvbugspro.nvidia.com/bug/5170222")
def verify_matcher(gm: GraphModule):
"""Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
"""Ensure that there is exactly one torch.ops.auto_deploy.torch_attention (layout="bsnd")
call in the graph. Also check that there is no repeat_kv pattern left.
"""
nodes = gm.graph.find_nodes(
op="call_function", target=torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
)
nodes = [
n
for n in gm.graph.nodes
if (
is_op(n, torch.ops.auto_deploy.torch_attention)
and isinstance(n.args[-1], str)
and n.args[-1] == "bsnd"
)
]
assert len(nodes) == 1, "Expected exactly one bsnd_grouped_sdpa call in the graph"
# TODO: check non-qkv args of node

View File

@ -74,8 +74,18 @@ class GQAWithSdpa(GQA):
v = v.view(b, s, self.num_kv_heads, self.head_dim)
# Use grouped SDPA in bsnd layout
attn_output = torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa(
q, k, v, None, 0.0, True, None
attn_output = torch.ops.auto_deploy.torch_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
scale=None,
sinks=None,
sliding_window=None,
logit_cap=None,
layout="bsnd",
)
# SDPA output is already in [b, s, n, h_d] format