[AutoDeploy] merge feat/ad-2025-07-22 (#6520)

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
Signed-off-by: Gal Agam <ghubaraagam@cw-dfw-cs-001-login-01.cm.cluster>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: haoguo <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Co-authored-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
Co-authored-by: Gal Agam <ghubaraagam@cw-dfw-h100-004-328-012.cm.cluster>
Co-authored-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Co-authored-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Co-authored-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
This commit is contained in:
Lucas Liebenwein 2025-08-01 11:51:08 -04:00 committed by GitHub
parent 16febefee0
commit 5247df6ae2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1281 additions and 1304 deletions

View File

@ -19,3 +19,15 @@ transforms:
stage: post_export
cleanup_input_constraints:
stage: post_export
quantize:
stage: pattern_matcher
quantize_moe:
stage: pattern_matcher
match_repeat_kv:
stage: pattern_matcher
match_eager_attention:
stage: pattern_matcher
match_grouped_attention:
stage: pattern_matcher
match_attention_layout:
stage: pattern_matcher

View File

@ -7,7 +7,28 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
# TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.
def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor:
"""Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
if logit_cap is not None and logit_cap > 0.0:
return logit_cap * torch.tanh(attn_scores / logit_cap)
return attn_scores
def _convert_boolean_mask_to_float(attn_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Convert boolean attention mask to floating point mask.
Args:
attn_mask: Boolean tensor where True allows attention, False blocks it
dtype: Target dtype for the output mask
Returns:
Floating point mask where True -> 1.0, False -> -inf
"""
if attn_mask.dtype == torch.bool:
float_mask = torch.zeros_like(attn_mask, dtype=dtype)
float_mask = float_mask.masked_fill(attn_mask, 1.0) # True -> 1.0
float_mask = float_mask.masked_fill(~attn_mask, float("-inf")) # False -> -inf
return float_mask
return attn_mask
@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=())
@ -77,19 +98,96 @@ def grouped_sdpa(
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
sinks: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
logit_cap: Optional[float] = None,
) -> torch.Tensor:
"""SDPA attention that can handle GQA."""
"""SDPA attention that can handle GQA. Expects bnsd format inputs."""
b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
return F.scaled_dot_product_attention(
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=True,
)
# Inputs are already in bnsd format, no need to transpose
query_t = query # [b, n_heads, s_q, head_dim]
key_t = key # [b, n_kv_heads, s_k, head_dim]
value_t = value # [b, n_kv_heads, s_k, v_head_dim]
# Handle GQA by repeating KV if needed
if n_heads != n_kv_heads:
n_rep = n_heads // n_kv_heads
key_t = repeat_kv(key_t, n_rep)
value_t = repeat_kv(value_t, n_rep)
# Set scale
if scale is None:
scale = 1.0 / math.sqrt(head_dim)
# Compute attention scores: Q @ K^T
attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k]
# Apply attention mask if provided
if attn_mask is not None:
# Convert boolean mask to float if needed
attn_mask = _convert_boolean_mask_to_float(attn_mask, attn_scores.dtype)
attn_scores = attn_scores + attn_mask
# Apply causal mask if specified and only during the context phase
if is_causal and s_q == s_k: # Only apply causal mask during context processing
causal_mask = torch.triu(
torch.ones(s_q, s_k, device=query.device, dtype=torch.bool),
diagonal=1, # Use diagonal=1 for standard causal masking
)
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
# Apply sliding window mask if specified
if sliding_window is not None and sliding_window > 0:
# Handle position calculation for both context and generation phases
if s_q == s_k:
# Context phase: standard position calculation
query_positions = torch.arange(s_q, device=query.device)
key_positions = torch.arange(s_k, device=query.device)
else:
# Generation phase: query is at position s_k (after the cache)
query_positions = torch.arange(s_k, s_k + s_q, device=query.device) # [s_k] for s_q=1
key_positions = torch.arange(s_k, device=query.device) # [0,1,2,...,s_k-1]
# Create position difference matrix: query_pos - key_pos
pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k]
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k]
attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
# Apply logit softcapping if enabled
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
# Apply sinks if provided
if sinks is not None:
# Concatenate sinks to attention scores following the reference implementation
# sinks should have n_heads elements, each head gets its own sink value
# Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
sinks_expanded = sinks.reshape(1, -1, 1, 1).expand(
b, n_heads, s_q, 1
) # [b, n_heads, s_q, 1]
# Concatenate along the key dimension (last dimension)
logits_max = torch.max(attn_scores, dim=-1, keepdim=True).values
sinks = torch.exp(sinks_expanded - logits_max)
unnormalized_scores = torch.exp(attn_scores - logits_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
scores = unnormalized_scores / normalizer
# Use only the non-sink portion for computing output
# We added exactly 1 column, so remove exactly 1 column
attn_out = torch.matmul(scores, value_t) # [b, n_heads, s_q, v_head_dim]
else:
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype)
attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim]
# Apply dropout if specified
if dropout_p > 0.0:
attn_out = F.dropout(attn_out, p=dropout_p, training=False)
# Return in bnsd format (same as input format)
return attn_out
@grouped_sdpa.register_fake
@ -101,6 +199,9 @@ def grouped_sdpa_fake(
dropout_p=0.0,
is_causal=False,
scale=None,
sinks=None,
sliding_window=None,
logit_cap=None,
):
"""Fake implementation of grouped SDPA."""
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
@ -108,9 +209,9 @@ def grouped_sdpa_fake(
@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
def bsnd_grouped_sdpa(
query: torch.Tensor, # layout: [b, n, s_q, d]
key: torch.Tensor, # layout: [b, n, s_k, d]
value: torch.Tensor, # layout: [b, n, s_k, d]
query: torch.Tensor, # layout: [b, s_q, n, d]
key: torch.Tensor, # layout: [b, s_k, n, d]
value: torch.Tensor, # layout: [b, s_k, n, d]
attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k]
dropout_p: float = 0.0,
is_causal: bool = False,
@ -124,14 +225,16 @@ def bsnd_grouped_sdpa(
Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
original sdpa op!
"""
# let's transpose to bnsd so we can use the grouped sdpa
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
# Transpose inputs to bnsd format for grouped_sdpa
query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d]
key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
# let's transpose back to bnsd
# Call grouped_sdpa with bnsd inputs
out = grouped_sdpa(
query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap
)
# Transpose back to bsnd format
return out.transpose(1, 2).contiguous()

View File

@ -103,7 +103,7 @@ def _torch_generate_mha(
# Apply sinks if provided (following the model file pattern)
if sinks is not None:
# Concatenate sinks to attention scores
sinks = sinks.reshape(-1, 1, 1).expand(-1, attn_scores.shape[-2], -1)
sinks = sinks.reshape(-1, 1, 1)
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
# Use only the non-sink portion for computing output (ignore sinks)
@ -202,9 +202,7 @@ def _torch_context_mha(
) # [seq_len_i, kv_seq_len]
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
sliding_window_mask = (pos_diff < 0) | (
pos_diff >= sliding_window_size
) # [seq_len_i, kv_seq_len]
sliding_window_mask = pos_diff >= sliding_window_size
# Combine causal and sliding window masks
combined_mask = causal_mask | sliding_window_mask
@ -219,14 +217,14 @@ def _torch_context_mha(
# Apply sinks if provided (following the model file pattern)
if sinks is not None:
# Concatenate sinks to attention scores
sinks = sinks.reshape(1, -1, 1, 1).expand(
attn_scores.shape[0], -1, attn_scores.shape[-2], -1
new_sinks = sinks.reshape(1, -1, 1, 1).expand(
attn_scores.shape[0], -1, attn_scores.shape[2], 1
)
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
attn_weights = torch.cat([attn_scores, new_sinks], dim=-1)
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
# Use only the non-sink portion for computing output (ignore sinks)
attn_out = torch.matmul(
attn_weights[..., : -sinks.size(-1)], v_seq_t
attn_weights[..., : -new_sinks.size(-1)], v_seq_t
) # [1, n_heads, seq_len_i, v_head_dim]
else:
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)

View File

@ -17,7 +17,8 @@ try:
rank, world_size = get_rank_world_size()
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO)
# Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.NCCL)
return torch_op(tensor, all_reduce_params=all_reduce_params)
@torch.library.custom_op(

View File

@ -76,6 +76,12 @@ class AutoModelForCausalLMFactory(ModelFactory):
"max_position_embeddings": 1024,
}
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
"""Get the max position embeddings config for the model."""
return {
"max_position_embeddings": self.max_seq_len,
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -83,7 +89,11 @@ class AutoModelForCausalLMFactory(ModelFactory):
# Ingest defaults for tokenizer and model kwargs
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
self.model_kwargs = deep_merge_dicts(self._model_defaults, self.model_kwargs)
self.model_kwargs = deep_merge_dicts(
self._model_defaults,
self.model_kwargs,
self._get_max_position_embeddings_config(),
)
# special handling for torch_dtype in model_kwargs since HF does not correctly update
# torch_dtype string to an actual torch.dtype object (only with default)
@ -295,7 +305,7 @@ class AutoModelForCausalLMFactory(ModelFactory):
# at this point it should be a directory (either the original one or the download dir)
assert os.path.isdir(fetched_dir), f"Checkpoint path {fetched_dir} is not a directory."
self._load_quantization_config()
self._load_quantization_config(fetched_dir)
return fetched_dir
@ -313,13 +323,13 @@ class AutoModelForCausalLMFactory(ModelFactory):
# model-transformed weights,leading to unexpected key mismatches or format issues.
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
def _load_quantization_config(self):
def _load_quantization_config(self, fetched_dir: str):
"""Load the quantization config from the model directory if not done already."""
if self._quant_config is not None:
return
assert self.model
hf_quant_config_file = os.path.join(self.model, "hf_quant_config.json")
hf_quant_config_file = os.path.join(fetched_dir, "hf_quant_config.json")
if os.path.exists(hf_quant_config_file):
with open(hf_quant_config_file, "r") as file:
quantization_config = json.load(file)
@ -344,6 +354,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
},
}
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
"""Get the max position embeddings config for the model."""
return {
"max_position_embeddings": self.max_seq_len,
"text_config": {
"max_position_embeddings": self.max_seq_len,
},
}
@property
def automodel_from_config(self):
return AutoModelForImageTextToText.from_config

View File

@ -227,18 +227,26 @@ class BaseTransform(ABC):
# run or skip the transform
if self.config.enabled:
# run graph pre-cleanup
self._run_pre_cleanup(gm, info_last)
is_clean_pre, has_valid_shapes_pre = self._run_pre_cleanup(gm, info_last)
# run the transform in a error-handling wrapper
try:
gm, info = self._apply(gm, cm, factory)
except Exception as e:
error_msg = f"Transform {t_name} failed"
if self.config.skip_on_error:
# run the transform in a error-handling wrapper if desired
if self.config.skip_on_error:
try:
gm, info = self._apply(gm, cm, factory)
except Exception as e:
error_msg = f"Transform {t_name} failed"
ad_logger.warning(f"{error_msg}: {e}")
info = TransformInfo(skipped=True, num_matches=0)
else:
raise TransformError(error_msg) from e
else:
# handle this here normally to improve debugging and error message
gm, info = self._apply(gm, cm, factory)
# we cannot say it's clean if the previous wasn't clean even if this one is
# create new info object with updated cleanup status
info_dict = info.model_dump()
info_dict["is_clean"] &= is_clean_pre
info_dict["has_valid_shapes"] &= has_valid_shapes_pre
info = TransformInfo(**info_dict)
# run graph post-cleanup
info = self._run_post_cleanup(gm, info)
@ -279,20 +287,36 @@ class BaseTransform(ABC):
gm.meta[self._autodeploy_meta_key] = autodeploy_meta
@final
def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> None:
def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> Tuple[bool, bool]:
"""Run graph cleanup before the transform.
Args:
gm: The graph module to run cleanup on.
info: The last transform info.
Returns:
A tuple of (is_clean, has_valid_shapes) indicating the cleanup status after the
pre-cleanup.
This is used to ensure the transform is applied to a clean graph as needed by the transform.
"""
if not self.config.requires_clean_graph:
return
return info.is_clean, info.has_valid_shapes
is_clean = info.is_clean
has_valid_shapes = is_clean and info.has_valid_shapes
# check if run cleanup depending on the config and info
if self.config.requires_shape_prop and not (info.is_clean and info.has_valid_shapes):
if self.config.requires_shape_prop and not has_valid_shapes:
with lift_to_meta(gm):
canonicalize_graph(gm, shape_prop=True)
elif self.config.requires_clean_graph and not info.is_clean:
is_clean = True
has_valid_shapes = True
elif self.config.requires_clean_graph and not is_clean:
canonicalize_graph(gm)
is_clean = True
return is_clean, has_valid_shapes
@final
def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformInfo:

View File

@ -0,0 +1,562 @@
"""Pattern matching for detecting repeat_kv, eager, grouped attention patterns from Huggingface models."""
from typing import Any, Callable, Dict, List, Tuple, Type
import torch
import torch.nn.functional as F
from pydantic import Field
from torch.fx import GraphModule
from ...custom_ops.attention_interface import AttentionDescriptor
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ...utils.node_utils import is_op
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ..interface import BaseTransform, TransformConfig, TransformInfo, TransformRegistry
def _apply_pattern(
gm: GraphModule,
pattern_name: str,
register_fn: Callable[[ADPatternMatcherPass], None],
) -> int:
"""Utility to register and apply a pattern."""
patterns = ADPatternMatcherPass()
register_fn(patterns)
num_matches = patterns.apply(gm.graph)
return num_matches
def _repeat_kv_pattern(hidden_states, n_rep) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = torch.unsqueeze(hidden_states, 2)
hidden_states = hidden_states.expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _repeat_kv_repl(hidden_states, n_rep) -> torch.Tensor:
return torch.ops.auto_deploy.torch_attention_repeat_kv(hidden_states, n_rep)
# with causal_mask, no division
def _sfdp_pattern_1(query, key, value, attention_mask, scaling, dropout):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _sfdp_replacement_1(query, key, value, attention_mask, scaling, dropout):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
query,
key,
value,
attn_mask=None,
dropout_p=dropout,
is_causal=True,
scale=scaling,
)
# no causal_mask, no division
def _sfdp_pattern_2(query, key, value, scaling, dropout):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _sfdp_replacement_2(query, key, value, scaling, dropout):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
query,
key,
value,
attn_mask=None,
dropout_p=dropout,
is_causal=False,
scale=scaling,
)
# with causal_mask, with division
def _sfdp_pattern_3(query, key, value, attention_mask, scaling, dropout):
attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _sfdp_replacement_3(query, key, value, attention_mask, scaling, dropout):
scaling = 1.0 / scaling
return torch.ops.auto_deploy.torch_attention_sdpa.default(
query,
key,
value,
attn_mask=None,
dropout_p=dropout,
is_causal=True,
scale=scaling,
)
# no causal_mask, with division
def _sfdp_pattern_4(query, key, value, scaling, dropout):
attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _sfdp_replacement_4(query, key, value, scaling, dropout):
scaling = 1.0 / scaling
return torch.ops.auto_deploy.torch_attention_sdpa.default(
query,
key,
value,
attn_mask=None,
dropout_p=dropout,
is_causal=False,
scale=scaling,
)
# no causal_mask, with division, explicit casting model
def _sfdp_pattern_5(query, key, value, scaling, dropout):
attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling
attn_weights = attn_weights.to(torch.float32)
attn_weights = F.softmax(attn_weights, dim=-1).to(query.dtype)
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _sfdp_replacement_5(query, key, value, scaling, dropout):
scaling = 1.0 / scaling
return torch.ops.auto_deploy.torch_attention_sdpa.default(
query,
key,
value,
attn_mask=None,
dropout_p=dropout,
is_causal=False,
scale=scaling,
)
# with causal_mask, with division, explicit casting model
def _sfdp_pattern_6(query, key, value, attention_mask, scaling, dropout):
attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling
attn_weights = attn_weights + attention_mask
attn_weights = attn_weights.to(torch.float32)
attn_weights = F.softmax(attn_weights, dim=-1).to(query.dtype)
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _sfdp_replacement_6(query, key, value, attention_mask, scaling, dropout):
scaling = 1.0 / scaling
return torch.ops.auto_deploy.torch_attention_sdpa.default(
query,
key,
value,
attn_mask=None,
dropout_p=dropout,
is_causal=True,
scale=scaling,
)
# Only pass in causal attention mask in downstream standardized pipeline
def _sfdp_pattern_7(query, key, value, attention_mask, scaling, dropout):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout,
is_causal=False,
scale=scaling,
)
def _sfdp_replacement_7(query, key, value, attention_mask, scaling, dropout):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
query,
key,
value,
attn_mask=None,
dropout_p=dropout,
is_causal=True if attention_mask is not None else False,
scale=scaling,
)
# with causal_mask, no division, does not cast to fp32 for softmax
def _sfdp_pattern_8(query, key, value, attention_mask, scaling, dropout):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _sfdp_replacement_8(query, key, value, attention_mask, scaling, dropout):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
query,
key,
value,
attn_mask=None,
dropout_p=dropout,
is_causal=True,
scale=scaling,
)
def _get_sfdp_patterns() -> List[Dict[str, Any]]:
bs, seq_len, n_heads, hidden_size = 8, 16, 8, 512
head_dim = hidden_size // n_heads
def common_tensor():
return torch.randn(bs, n_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
def causal_mask():
return torch.randn(bs, 1, 1, seq_len, device="cuda", dtype=torch.bfloat16)
configs = [
(_sfdp_pattern_1, _sfdp_replacement_1, True, 0.1234743, 0.85849734),
(_sfdp_pattern_2, _sfdp_replacement_2, False, 0.234743, 0.5849734),
(_sfdp_pattern_3, _sfdp_replacement_3, True, 0.34743, 0.849734),
(_sfdp_pattern_4, _sfdp_replacement_4, False, 0.74321, 0.9734),
(_sfdp_pattern_5, _sfdp_replacement_5, False, 0.874321, 0.89734),
(_sfdp_pattern_6, _sfdp_replacement_6, True, 0.634743, 0.6849734),
(_sfdp_pattern_7, _sfdp_replacement_7, True, 0.34743, 0.849734),
(_sfdp_pattern_8, _sfdp_replacement_8, True, 0.2234743, 0.95849734),
]
patterns = []
for search_fn, replace_fn, has_mask, scale, dropout in configs:
dummy_args = [common_tensor(), common_tensor(), common_tensor()]
if has_mask:
dummy_args.append(causal_mask())
dummy_args.extend([scale, dropout])
patterns.append(
{
"search_fn": search_fn,
"replace_fn": replace_fn,
"dummy_args": dummy_args,
"scalar_workaround": {"scaling": scale, "dropout": dropout},
"op_ignore_types": {torch.ops.aten.to.dtype: (torch.dtype,)},
}
)
return patterns
def _grouped_attn_pattern_1(q, k, v, n_rep, attn_mask, dropout_p, scale):
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)
def _grouped_attn_replacement_1(q, k, v, n_rep, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)
# Only expose torch_attention_grouped_sdpa after the transformation
def _grouped_attn_pattern_2(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)
def _grouped_attn_replacement_2(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
)
def _grouped_attn_pattern_3(q, k, v, n_rep, attn_mask, dropout_p, scale):
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)
def _grouped_attn_replacement_3(q, k, v, n_rep, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)
# Only expose torch_attention_grouped_sdpa after the transformation
def _grouped_attn_pattern_4(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)
def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale):
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
)
@TransformRegistry.register("match_repeat_kv")
class MatchRepeatKV(BaseTransform):
"""
Match and replace the repeat_kv pattern with torch.ops.auto_deploy.torch_attention_repeat_kv.
"""
def _apply(
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
) -> Tuple[GraphModule, TransformInfo]:
def register_repeat_kv(patterns: ADPatternMatcherPass):
dummy_args = [
torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16),
7,
]
register_ad_pattern(
search_fn=_repeat_kv_pattern,
replace_fn=_repeat_kv_repl,
patterns=patterns,
dummy_args=dummy_args,
op_ignore_types={
torch.ops.aten.reshape.default: (int,),
torch.ops.aten.expand.default: (int,),
},
scalar_workaround={"n_rep": dummy_args[1]},
)
num_kv_patterns = _apply_pattern(gm, "Repeat KV", register_repeat_kv)
if num_kv_patterns > 0:
self.config.run_shape_prop = True
info = TransformInfo(
skipped=False,
num_matches=num_kv_patterns,
is_clean=False,
has_valid_shapes=False,
)
return gm, info
@TransformRegistry.register("match_eager_attention")
class MatchEagerAttention(BaseTransform):
"""
Match and replace the eager attention pattern with torch.ops.auto_deploy.torch_attention_sdpa.
"""
def _apply(
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
) -> Tuple[GraphModule, TransformInfo]:
def register_eager_attention(patterns: ADPatternMatcherPass):
for pattern_config in _get_sfdp_patterns():
register_ad_pattern(**pattern_config, patterns=patterns)
num_eager_patterns = _apply_pattern(gm, "Eager Attention", register_eager_attention)
info = TransformInfo(
skipped=False,
num_matches=num_eager_patterns,
is_clean=False,
has_valid_shapes=False,
)
return gm, info
@TransformRegistry.register("match_grouped_attention")
class MatchGroupedAttention(BaseTransform):
"""
Match and replace the grouped attention pattern with
torch.ops.auto_deploy.torch_attention_grouped_sdpa.
"""
def _apply(
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
) -> Tuple[GraphModule, TransformInfo]:
def register_grouped_attention(patterns: ADPatternMatcherPass):
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
attn_mask = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16)
dropout = 0.12345
scale = 0.56789
n_rep = 7
dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale]
dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale]
register_ad_pattern(
search_fn=_grouped_attn_pattern_1,
replace_fn=_grouped_attn_replacement_1,
patterns=patterns,
dummy_args=dummy_args_1,
scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_2,
replace_fn=_grouped_attn_replacement_2,
patterns=patterns,
dummy_args=dummy_args_2,
scalar_workaround={
"scale": scale,
"dropout_p": dropout,
},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_3,
replace_fn=_grouped_attn_replacement_3,
patterns=patterns,
dummy_args=dummy_args_1,
scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep},
)
register_ad_pattern(
search_fn=_grouped_attn_pattern_4,
replace_fn=_grouped_attn_replacement_4,
patterns=patterns,
dummy_args=dummy_args_2,
scalar_workaround={
"scale": scale,
"dropout_p": dropout,
},
)
num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention)
info = TransformInfo(
skipped=False,
num_matches=num_grouped_patterns,
is_clean=False,
has_valid_shapes=False,
)
return gm, info
class MatchAttentionLayoutConfig(TransformConfig):
"""Configuration for the insert cached attention transform."""
attention_op: Type[AttentionDescriptor] = Field(description="The attention descriptor to use.")
@TransformRegistry.register("match_attention_layout")
class MatchAttentionLayout(BaseTransform):
"""
Match and transform attention operations to match the layout expected by the attention backend.
If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which
is the default for SDPA operations, we don't need to transform anything.
If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert
appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa.
"""
config: MatchAttentionLayoutConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return MatchAttentionLayoutConfig
def _apply(
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
) -> Tuple[GraphModule, TransformInfo]:
# Get attention layout from attention_op
attention_layout = self.config.attention_op.get_attention_layout()
# List of SDPA operations to look for
sdpa_ops = {
torch.ops.auto_deploy.torch_attention_sdpa,
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
}
graph = gm.graph
num_bsnd_patterns = 0
# Look for SDPA operations
for sdpa_node in list(graph.nodes):
if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops):
continue
ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}")
# Extract q, k, v inputs
q, k, v = sdpa_node.args[:3]
# Check if we need to transpose the inputs
if attention_layout == "bsnd":
# Add transposes before the node (from bnsd to bsnd)
with graph.inserting_before(sdpa_node):
q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2))
k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2))
v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2))
# Preserve fake tensor in meta["val"] for the transposed inputs
q_updated.meta["val"] = q.meta["val"].transpose(1, 2)
k_updated.meta["val"] = k.meta["val"].transpose(1, 2)
v_updated.meta["val"] = v.meta["val"].transpose(1, 2)
elif attention_layout == "bnsd":
# we don't need to do anything...
q_updated = q
k_updated = k
v_updated = v
else:
raise ValueError(f"Unsupported attention layout: {attention_layout}")
# Create bsnd_grouped_sdpa node with the same args as the original node
# but using the transposed inputs
with graph.inserting_before(sdpa_node):
source_sdpa_node = graph.call_function(
self.config.attention_op.get_source_attention_op(),
args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:],
kwargs=sdpa_node.kwargs,
)
# Check if need to update the output node to match the layout
if attention_layout == "bsnd":
# Add transpose for the output (from bsnd back to bnsd)
with graph.inserting_after(source_sdpa_node):
output_updated = graph.call_function(
torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2)
)
# Preserve fake tensor in meta["val"] for the transposed inputs
source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous()
output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2)
elif attention_layout == "bnsd":
output_updated = source_sdpa_node
else:
raise ValueError(f"Unsupported attention layout: {attention_layout}")
# Replace the old node with the transposed output
sdpa_node.replace_all_uses_with(output_updated)
num_bsnd_patterns += 1
info = TransformInfo(
skipped=False,
num_matches=num_bsnd_patterns,
is_clean=False,
has_valid_shapes=False,
)
return gm, info

View File

@ -1,11 +1,12 @@
from collections import defaultdict
from functools import partial
from typing import Any, Dict
from typing import Dict, Tuple
import torch.nn as nn
from torch.fx import GraphModule, Node
from ...utils.logger import ad_logger
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import (
extract_param_names_from_lin_node,
get_quantization_params_from_linear_node,
@ -20,7 +21,7 @@ from ...utils.quantization_utils import (
remove_output_quantizers,
should_skip_quantization,
)
from .._graph import canonicalize_graph
from ..interface import BaseTransform, TransformInfo, TransformRegistry
def _insert_quantized_linear(
@ -138,12 +139,8 @@ def _insert_quantized_bmm(
scale_target_module = gm # Register in root module
scale_name_prefix = ""
ad_logger.info(f"Quantized BMM with dynamic weight tensor for node {node}")
else:
# If we can't determine the shape, skip quantization
ad_logger.warning(
f"BMM weight is dynamic tensor without shape metadata, skipping quantization for node {node}"
)
return
# Common logic for both parameter and dynamic tensor cases
@ -169,53 +166,70 @@ def _insert_quantized_bmm(
node.args = (*node.args, *scale_values)
def quantize(gm: GraphModule, quant_config: Dict[str, Any]) -> None:
"""Quantize the GraphModule and replace linear with quantized linear."""
# extract info from quant_config
is_quant_graph = is_quantized_graph(gm)
quant_algo = quant_config.get("quant_algo")
excluded_patterns = quant_config.get("exclude_modules", [])
@TransformRegistry.register("quantize")
class Quantization(BaseTransform):
"""Quantize the GraphModule and replace linear/BMM with quantized linear/BMM."""
# no quantization to do
if not (is_quant_graph or quant_config):
ad_logger.info("No quantization to do.")
return
# tracking quantized operations in the graph
quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
for n in gm.graph.nodes:
if should_skip_quantization(n, excluded_patterns):
continue
# Process linear operations
if is_linear_op(n, include_quantization=False):
# get per-layer quantization format from the node
quant_algo_n: str = (
get_quantization_from_linear_node(n) if is_quant_graph else quant_algo
def _apply(
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
) -> Tuple[GraphModule, TransformInfo]:
# extract info from quant_config
quant_config = factory.get_quant_config()
if not quant_config:
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
if not quant_algo_n:
is_quant_graph = is_quantized_graph(gm)
quant_algo = quant_config.get("quant_algo")
excluded_patterns = quant_config.get("exclude_modules", [])
if not quant_algo:
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
# tracking quantized operations in the graph
quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
for n in gm.graph.nodes:
if should_skip_quantization(n, excluded_patterns):
continue
# insert quantized linear node
_insert_quantized_linear(gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph)
quantized_nodes[quant_algo_n]["linear"] += 1
# Process linear operations
if is_linear_op(n, include_quantization=False):
# get per-layer quantization format from the node
quant_algo_n: str = (
get_quantization_from_linear_node(n) if is_quant_graph else quant_algo
)
if not quant_algo_n:
continue
# Process BMM operations
elif is_bmm_op(n):
if not quant_algo:
continue
# insert quantized linear node
_insert_quantized_linear(
gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph
)
quantized_nodes[quant_algo_n]["linear"] += 1
# insert quantized bmm node
_insert_quantized_bmm(
gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph
)
quantized_nodes[quant_algo]["bmm"] += 1
# Process BMM operations
elif is_bmm_op(n):
if not quant_algo:
continue
if is_quant_graph:
remove_output_quantizers(gm)
# insert quantized bmm node
_insert_quantized_bmm(
gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph
)
quantized_nodes[quant_algo]["bmm"] += 1
canonicalize_graph(gm)
for quant_algo in quantized_nodes:
for op_type, count in quantized_nodes[quant_algo].items():
ad_logger.info(f"Found {count} {quant_algo} quantized {op_type} nodes.")
ad_logger.debug("After quantization: " + str(gm))
if is_quant_graph:
remove_output_quantizers(gm)
num_matches = 0
for quant_algo in quantized_nodes:
for op_type, count in quantized_nodes[quant_algo].items():
num_matches += count
info = TransformInfo(
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
)
return gm, info

View File

@ -1,14 +1,15 @@
from functools import partial
from typing import Any, Callable, Dict, List, Tuple
from typing import Callable, List, Tuple
import torch
import torch.nn as nn
from torch.fx import GraphModule, Node
from ...utils.logger import ad_logger
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import is_op
from ...utils.quantization_utils import QuantizationImpl, should_skip_quantization
from .._graph import canonicalize_graph
from ..interface import BaseTransform, TransformInfo, TransformRegistry
quantized_moe_op_map = {
"FP8": torch.ops.auto_deploy.torch_quant_fp8_moe,
@ -92,47 +93,10 @@ def _quantize_moe_node(
quantized_op,
args=tuple(args),
)
ad_logger.debug(f"Updating {node.name} args to {new_node.args}")
node.replace_all_uses_with(new_node)
gm.graph.erase_node(node)
def quantize_moe(gm: GraphModule, quant_config: Dict[str, Any]) -> None:
"""
Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the
quantized version using the quant_algo from quant_config.
"""
quant_algo = quant_config.get("quant_algo")
if not quant_algo:
ad_logger.info("No quantization to do.")
return gm
excluded_patterns = quant_config.get("exclude_modules", [])
quant_impl = QuantizationImpl.create(quant_algo)
quantized_op = quantized_moe_op_map[quant_algo]
count = 0
for node in list(gm.graph.nodes):
if is_op(node, torch.ops.auto_deploy.torch_moe):
# Check that all expert weights should be quantized
w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node)
if any(
should_skip_quantization(n, excluded_patterns)
for n in w1_names + w2_names + w3_names
):
continue
_quantize_moe_node(gm, node, quant_impl, quantized_op)
count += 1
if count == 0:
return gm
gm = canonicalize_graph(gm)
ad_logger.info(f"Found {count} {quant_algo} quantized {quantized_op} nodes.")
return
# TODO(Fridah-nv): robust handling similar to `extract_param_names_from_lin_node` or expand it
def _extract_moe_weight_param_lists(moe_node: Node) -> Tuple[List[str], List[str], List[str]]:
"""
@ -165,3 +129,51 @@ def _extract_moe_weight_param_lists(moe_node: Node) -> Tuple[List[str], List[str
w3_names = _unwrap_list(w3_list)
return w1_names, w2_names, w3_names
@TransformRegistry.register("quantize_moe")
class QuantizeMOE(BaseTransform):
"""
Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the
quantized version using the quant_algo from quant_config.
"""
def _apply(
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
) -> Tuple[GraphModule, TransformInfo]:
quant_config = factory.get_quant_config()
quant_algo = quant_config.get("quant_algo") if quant_config else None
if not quant_config or not quant_algo:
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
excluded_patterns = quant_config.get("exclude_modules", [])
quant_impl = QuantizationImpl.create(quant_algo)
quantized_op = quantized_moe_op_map[quant_algo]
count = 0
for node in list(gm.graph.nodes):
if is_op(node, torch.ops.auto_deploy.torch_moe):
# Check that all expert weights should be quantized
w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node)
if any(
should_skip_quantization(n, excluded_patterns)
for n in w1_names + w2_names + w3_names
):
continue
_quantize_moe_node(gm, node, quant_impl, quantized_op)
count += 1
if count == 0:
return gm, TransformInfo(
skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True
)
info = TransformInfo(
skipped=False, num_matches=count, is_clean=False, has_valid_shapes=False
)
return gm, info

View File

@ -96,23 +96,24 @@ def named_graphmodules(gm: fx.GraphModule) -> Iterator[Tuple[str, fx.GraphModule
yield name, m
def _move_single_gm_to_device(
gm: GraphModule, device: torch.device, recompile_graph: bool = False
) -> None:
def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None:
"""Move one GraphModule and its nodes to the specified device in-place.
Partially inspired by https://github.com/pytorch/pytorch/blob/05cb98f91d49df9eadfcb3fc29bbd1b621d88860/torch/export/passes/__init__.py#L11
"""
# move state dict
gm.to(device)
recompile_graph = False
for node in gm.graph.nodes:
# move all the nodes kwargs with burnt-in device
if "device" in node.kwargs:
recompile_graph = True
kwargs = node.kwargs.copy()
kwargs["device"] = device
node.kwargs = kwargs
if is_op(node, torch.ops.aten.to.device):
recompile_graph = True
args = list(node.args)
args[1] = device
node.args = tuple(args)
@ -135,7 +136,7 @@ def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> fx.GraphModule
for _, subgm in reversed(list(named_graphmodules(gm))):
# recompile graph to update self generated codes in subgraph
_move_single_gm_to_device(subgm, device, subgm is not gm)
_move_single_gm_to_device(subgm, device)
def _is_impure_node(node: Node) -> bool:

View File

@ -1,13 +1,10 @@
"""A library of transformation passes."""
from .attention import *
from .collectives import *
from .eliminate_redundant_transposes import *
from .fused_moe import *
from .fusion import *
from .kvcache import *
from .quantization import *
from .quantize_moe import *
from .rms_norm import *
from .rope import *
from .sharding import *

View File

@ -1,833 +0,0 @@
"""Pattern matching for detecting repeat_kv pattern from Huggingface models."""
from typing import Dict, Optional, Type
import torch
from torch.fx import GraphModule, Node
from ...custom_ops.attention_interface import AttentionDescriptor
from ...utils.logger import ad_logger
from ...utils.node_utils import is_op
from .._graph import canonicalize_graph
def match_repeat_kv(gm: GraphModule) -> None:
"""
Match and replace the repeat_kv pattern in fx graphs.
The pattern is:
unsqueeze -> expand -> reshape -> [optional] contiguous
This is replaced with torch.ops.auto_deploy.torch_attention_repeat_kv.
"""
graph = gm.graph
num_kv_patterns = 0
# Iterate through nodes in the graph
for node in list(graph.nodes):
# Look for reshape nodes that could be the end of our pattern
if is_op(node, torch.ops.aten.reshape):
match_info = _match_repeat_kv_pattern(node)
if match_info:
ad_logger.debug(f"Found repeat_kv pattern at {node}")
_replace_with_repeat_kv(graph, match_info)
num_kv_patterns += 1
# Clean up the graph if we made any replacements
if num_kv_patterns:
canonicalize_graph(gm)
ad_logger.info(f"Found {num_kv_patterns} repeat_kv patterns")
def match_eager_attention(gm: GraphModule) -> None:
"""
Match and replace the eager attention pattern in fx graphs.
The pattern is:
transpose -> matmul -> mul -> (optional) add -> softmax -> to -> dropout -> matmul
This is replaced with torch.ops.auto_deploy.torch_attention_sdpa.
"""
graph = gm.graph
# Track replacements to avoid processing nodes multiple times
num_eager_patterns = 0
# Iterate through nodes in the graph
for node in list(graph.nodes):
# Look for the final matmul nodes that could be part of our pattern
if is_op(node, torch.ops.aten.matmul):
match_info = _match_eager_attention_pattern(node)
if match_info:
ad_logger.debug(f"Found eager attention pattern at {node}")
_replace_with_sdpa(graph, match_info)
num_eager_patterns += 1
# Clean up the graph if we made any replacements
if num_eager_patterns:
canonicalize_graph(gm)
ad_logger.info(f"Found {num_eager_patterns} eager attention patterns")
def match_grouped_attention(gm: GraphModule) -> None:
"""
Match and replace the grouped attention pattern in fx graphs.
The pattern is:
repeat_kv(k, n_rep) ->
repeat_kv(v, n_rep) ->
sdpa(q, repeated_k, repeated_v)
This is replaced with torch.ops.auto_deploy.torch_attention_grouped_sdpa.
"""
graph = gm.graph
# Track replacements to avoid processing nodes multiple times
num_grouped_patterns = 0
# Iterate through nodes in the graph
for node in list(graph.nodes):
# Look for SDPA nodes that could be part of our pattern
if is_op(node, torch.ops.auto_deploy.torch_attention_sdpa):
match_info = _match_grouped_attention_pattern(node)
if match_info:
ad_logger.debug(f"Found grouped attention pattern at {node}")
_replace_with_grouped_sdpa(graph, match_info)
num_grouped_patterns += 1
# Clean up the graph if we made any replacements
if num_grouped_patterns:
canonicalize_graph(gm)
ad_logger.info(f"Found {num_grouped_patterns} grouped attention patterns")
def match_causal_attn_mask(gm: GraphModule) -> None:
"""
Match attention operations with causal attention masks and optimize them.
For operations that use explicit causal masks, this replaces:
- sdpa(q, k, v, causal_mask, dropout_p, False, scale)
with:
- sdpa(q, k, v, None, dropout_p, True, scale)
This optimization enables more efficient implementations on supported backends.
"""
graph = gm.graph
# Track replacements to avoid processing nodes multiple times
num_causal_patterns = 0
# Iterate through nodes in the graph
for node in list(graph.nodes):
# Look for SDPA nodes or grouped SDPA nodes
if not (
is_op(node, torch.ops.auto_deploy.torch_attention_sdpa)
or is_op(node, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
):
continue
# Get the attention mask argument (4th argument)
if len(node.args) < 4 or node.args[3] is None:
continue
attn_mask = node.args[3]
# Check if this mask is a causal mask
if not _is_causal_mask(attn_mask):
ad_logger.debug(f"Found non-causal attention mask at {node=}!")
continue
ad_logger.debug(f"Found causal attention mask at {node}")
# construct the new args list with args provided to the node and the default values otherwise
new_args = []
for idx, arg in enumerate(node.target._schema.arguments):
# In case arg is provided to the node, use it
if idx < len(node.args):
new_args.append(node.args[idx])
# In case arg is not provided to the node, use the default value
elif arg.has_default_value:
new_args.append(arg.default_value)
else:
raise ValueError(f"Missing required argument: {arg.name}")
# Create new arguments with None mask and is_causal=True
new_args[3] = None # Set mask to None
new_args[5] = True # Set is_causal to True
# Create new node with updated arguments
with graph.inserting_before(node):
new_node = graph.call_function(node.target, args=tuple(new_args), kwargs=node.kwargs)
# Preserve metadata
new_node.meta = node.meta.copy()
# Replace the old node with the new one
node.replace_all_uses_with(new_node)
num_causal_patterns += 1
# Clean up the graph if we made any replacements
if num_causal_patterns:
canonicalize_graph(gm)
ad_logger.info(f"Found {num_causal_patterns} causal mask attention patterns")
def _match_repeat_kv_pattern(reshape_node: Node) -> Optional[Dict[str, Node]]:
"""
Match the repeat_kv pattern starting from a reshape node.
The pattern is:
unsqueeze -> expand -> reshape -> [optional] contiguous
Returns a dictionary with information about the match or None if no match.
"""
# Check that reshape_node is a reshape operation
if not is_op(reshape_node, torch.ops.aten.reshape):
return None
# The reshape should have expand as its first argument
if len(reshape_node.args) < 1:
return None
expand_node = reshape_node.args[0]
if not is_op(expand_node, torch.ops.aten.expand):
return None
# The expand should have unsqueeze as its first argument
if len(expand_node.args) < 1:
return None
unsqueeze_node = expand_node.args[0]
if not is_op(unsqueeze_node, torch.ops.aten.unsqueeze):
return None
# The unsqueeze should be inserting a dimension at position 2
if len(unsqueeze_node.args) < 2 or unsqueeze_node.args[1] != 2:
return None
# Get the input tensor to unsqueeze
if len(unsqueeze_node.args) < 1:
return None
input_tensor = unsqueeze_node.args[0]
# Check input dimensions - should be 4D (batch, num_key_value_heads, seq_len, head_dim)
input_val = input_tensor.meta.get("val", None)
if input_val is None or len(input_val.shape) != 4:
return None
# Extract batch size, num_kv_heads, seq_len, and head_dim from the input tensor shape
batch_size, num_kv_heads, seq_len, head_dim = input_val.shape
# Check reshape args
if len(reshape_node.args) < 2 or not isinstance(reshape_node.args[1], list):
return None
reshape_args = reshape_node.args[1]
if len(reshape_args) != 4:
return None
# Check expand args
if len(expand_node.args) < 2 or not isinstance(expand_node.args[1], list):
return None
expand_args = expand_node.args[1]
if len(expand_args) != 5:
return None
# Determine n_rep by comparing the output and input head dimensions
# In the expand args, we should have [batch, num_kv_heads, n_rep, seq_len, head_dim]
# In the reshape args, we should have [batch, num_heads, seq_len, head_dim]
# where num_heads = num_kv_heads * n_rep
_, _, n_rep, _, _ = expand_args
_, reshape_num_heads, _, _ = reshape_args
# Check that n_rep is an integer
if not isinstance(n_rep, int):
return None
# Check that num_heads = num_kv_heads * n_rep
# This may be a symbolic expression, so we need to compare with caution
reshape_out_val = reshape_node.meta.get("val", None)
if reshape_out_val is None or len(reshape_out_val.shape) != 4:
return None
# Ensure output shape is correct
out_batch, out_heads, out_seq, out_dim = reshape_out_val.shape
# Check that input batch and seq dimensions match output
if out_batch != batch_size or out_seq != seq_len or out_dim != head_dim:
return None
# Check if reshape is followed by a contiguous node
contiguous_node = None
users = list(reshape_node.users)
# Only consider contiguous if reshape has exactly one user
if len(users) == 1 and is_op(users[0], torch.ops.aten.contiguous):
contiguous_node = users[0]
result = {
"input_tensor": input_tensor,
"unsqueeze_node": unsqueeze_node,
"expand_node": expand_node,
"reshape_node": reshape_node,
"n_rep": n_rep,
}
if contiguous_node:
result["contiguous_node"] = contiguous_node
return result
def _match_eager_attention_pattern(final_matmul_node: Node) -> Optional[Dict[str, Node]]:
"""
Match the eager attention pattern starting from the final matmul node.
The pattern is:
transpose -> matmul -> mul/div -> (optional) add -> (optional) to -> softmax -> (optional) to -> dropout -> matmul
Returns a dictionary with information about the match or None if no match.
"""
# Check that final_matmul_node is a matmul operation
if not is_op(final_matmul_node, torch.ops.aten.matmul):
return None
# Check we have two arguments
if len(final_matmul_node.args) < 2:
return None
# The first arg of final matmul should be dropout
dropout_node = final_matmul_node.args[0]
if not is_op(dropout_node, torch.ops.aten.dropout):
return None
# The second arg of final matmul is the value tensor (possibly repeated/transformed)
value = final_matmul_node.args[1]
# The dropout should have a to_dtype node (or directly softmax) as input
if len(dropout_node.args) < 1:
return None
# Allow optional to_dtype node after softmax
to_dtype_after_softmax = dropout_node.args[0]
if is_op(to_dtype_after_softmax, torch.ops.aten.to):
if len(to_dtype_after_softmax.args) < 1:
return None
softmax_node = to_dtype_after_softmax.args[0]
else:
softmax_node = to_dtype_after_softmax
# Now we should have a softmax node
if not is_op(softmax_node, torch.ops.aten.softmax):
return None
# The softmax should have dim=-1 (may be specified in different ways)
if len(softmax_node.args) < 2 or (
isinstance(softmax_node.args[1], int) and softmax_node.args[1] != -1
):
# Check kwargs if not in args
if softmax_node.kwargs.get("dim", -1) != -1:
return None
# The softmax node's input can be:
# - direct from add/mul/div
# - or through a to_dtype node (like to_35 in the example)
if len(softmax_node.args) < 1:
return None
# Handle optional to_dtype node before softmax
prev_node = softmax_node.args[0]
if is_op(prev_node, torch.ops.aten.to):
if len(prev_node.args) < 1:
return None
prev_node = prev_node.args[0]
# Check for attention mask pattern (add node)
if is_op(prev_node, torch.ops.aten.add):
add_node = prev_node
attn_mask = add_node.args[1] # Second arg is the mask
# The add should have a mul or div node as its first argument
if len(add_node.args) < 1:
return None
scaling_node = add_node.args[0]
if not (is_op(scaling_node, torch.ops.aten.mul) or is_op(scaling_node, torch.ops.aten.div)):
return None
elif is_op(prev_node, torch.ops.aten.mul) or is_op(prev_node, torch.ops.aten.div):
# No mask case - the softmax input is directly the mul or div node
scaling_node = prev_node
attn_mask = None
else:
return None
# Check the scaling operation and extract the scaling factor
is_division = is_op(scaling_node, torch.ops.aten.div)
# The mul/div node should have a matmul node as input
if len(scaling_node.args) < 2:
return None
# Extract the scaling factor, adjusting for division vs multiplication
scale = scaling_node.args[1]
# Allow for constant or tensor scale
if not isinstance(scale, (float, int, Node)):
return None
# For division, we need to invert the scaling factor if it's a constant
if is_division and isinstance(scale, (float, int)):
scale = 1.0 / scale
first_matmul_node = scaling_node.args[0]
if not is_op(first_matmul_node, torch.ops.aten.matmul):
return None
# The first matmul should have the query and key transpose as inputs
if len(first_matmul_node.args) < 2:
return None
query = first_matmul_node.args[0]
transpose_key = first_matmul_node.args[1]
# Check for transpose, could be any dimensions
if not is_op(transpose_key, torch.ops.aten.transpose):
return None
# The transpose should have the key as input
if len(transpose_key.args) < 1:
return None
key = transpose_key.args[0]
# Create the match info dictionary
match_info = {
"query": query,
"key": key,
"value": value,
"scale": scale,
"dropout_p": dropout_node.args[1] if len(dropout_node.args) > 1 else 0.0,
"final_matmul": final_matmul_node,
}
# Add the attention mask if it exists
if attn_mask is not None:
match_info["attn_mask"] = attn_mask
return match_info
def _match_grouped_attention_pattern(sdpa_node: Node) -> Optional[Dict[str, Node]]:
"""
Match the grouped attention pattern starting from an SDPA node.
The pattern is:
repeat_kv(k, n_rep) ->
repeat_kv(v, n_rep) ->
sdpa(q, repeated_k, repeated_v)
Returns a dictionary with information about the match or None if no match.
"""
# Check that sdpa_node is an SDPA operation
if not is_op(sdpa_node, torch.ops.auto_deploy.torch_attention_sdpa):
return None
# SDPA should have query, key, value as its first three arguments
if len(sdpa_node.args) < 3:
return None
query, key_repeated, value_repeated = sdpa_node.args[0:3]
# Key and value should come from repeat_kv operations
if not is_op(key_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv) or not is_op(
value_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv
):
return None
# Extract the original key, value, and n_rep
orig_key = key_repeated.args[0]
orig_value = value_repeated.args[0]
key_n_rep = key_repeated.args[1]
value_n_rep = value_repeated.args[1]
# Both repeat_kv operations should have the same n_rep
if key_n_rep != value_n_rep:
return None
# Return the match information
return {
"query": query,
"key": orig_key,
"value": orig_value,
"key_repeated": key_repeated,
"value_repeated": value_repeated,
"n_rep": key_n_rep,
"sdpa_node": sdpa_node,
}
def _replace_with_repeat_kv(graph, match_info: Dict[str, Node]) -> None:
"""
Replace the matched repeat_kv pattern with the custom op.
"""
input_tensor = match_info["input_tensor"]
reshape_node = match_info["reshape_node"]
n_rep = match_info["n_rep"]
# Determine the node to replace (either reshape or contiguous if present)
node_to_replace = match_info.get("contiguous_node", reshape_node)
with graph.inserting_before(node_to_replace):
repeat_kv_node = graph.call_function(
torch.ops.auto_deploy.torch_attention_repeat_kv, args=(input_tensor, n_rep)
)
# Preserve metadata from the original node
repeat_kv_node.meta = node_to_replace.meta.copy()
# Replace all uses of the node with the repeat_kv node
node_to_replace.replace_all_uses_with(repeat_kv_node)
def _replace_with_sdpa(graph, match_info: Dict[str, Node]) -> None:
"""
Replace the matched eager attention pattern with scaled_dot_product_attention.
"""
# retrieve the default op for scaled_dot_product_attention
sdpa_op = torch.ops.auto_deploy.torch_attention_sdpa.default
# construct the args for the ops based on the match_info and the op's schema
args = []
for arg in sdpa_op._schema.arguments:
if arg.name in match_info:
args.append(match_info[arg.name])
elif arg.has_default_value:
args.append(arg.default_value)
else:
raise ValueError(f"Missing required argument: {arg.name}")
args = tuple(args)
# retrieve the final matmul node to know where to insert the sdpa node
final_matmul = match_info["final_matmul"]
with graph.inserting_before(final_matmul):
sdpa_node = graph.call_function(sdpa_op, args=args)
# Preserve metadata from the original node
sdpa_node.meta = final_matmul.meta.copy()
# Replace all uses of the final matmul node with the sdpa node
final_matmul.replace_all_uses_with(sdpa_node)
def _replace_with_grouped_sdpa(graph, match_info: Dict[str, Node]) -> None:
"""
Replace the matched grouped attention pattern with torch.ops.auto_deploy.torch_attention_grouped_sdpa.
"""
sdpa_node = match_info["sdpa_node"]
query = match_info["query"]
key = match_info["key"]
value = match_info["value"]
# Construct the new args and kwargs
args = (query, key, value) + sdpa_node.args[3:]
kwargs = sdpa_node.kwargs.copy()
with graph.inserting_before(sdpa_node):
grouped_sdpa_node = graph.call_function(
torch.ops.auto_deploy.torch_attention_grouped_sdpa.default, args=args, kwargs=kwargs
)
# Preserve metadata from the original node
grouped_sdpa_node.meta = sdpa_node.meta.copy()
# Replace all uses of the SDPA node with the grouped_sdpa node
sdpa_node.replace_all_uses_with(grouped_sdpa_node)
def _is_causal_mask(mask_node: Node) -> bool:
"""
Determine if a node represents a causal attention mask.
Causal masks typically involve:
1. Creating a matrix with very negative values (e.g., -inf or close to it)
2. Using triu with offset 1 to create an upper triangular matrix
3. Usually involves comparison operations (gt, lt) with position indices
Returns True if the node appears to be a causal mask pattern.
"""
# Direct pattern from the test case: masked_fill with triu(ones,1) and -inf
if is_op(mask_node, torch.ops.aten.masked_fill):
mask_args = mask_node.args
if len(mask_args) >= 2:
_ = mask_args[0] # zero tensor
mask_tensor = mask_args[1]
fill_value = mask_args[2] if len(mask_args) > 2 else mask_node.kwargs.get("value", None)
# Check if fill value is very negative (e.g., -inf)
if fill_value is not None and (
fill_value == float("-inf")
or (isinstance(fill_value, (int, float)) and fill_value < -1e4)
):
# Try to trace back to find a triu pattern
if _has_triu_ancestor(mask_tensor, offset=1):
return True
# Pattern from negative_fill test case: masked_fill with ~triu(ones,1) and 0.0
# The negative_fill pattern has a pre-filled tensor with very negative values
# and zeros in the lower triangle
if is_op(mask_node, torch.ops.aten.masked_fill):
mask_args = mask_node.args
if len(mask_args) >= 2:
negative_tensor = mask_args[0]
mask_tensor = mask_args[1]
fill_value = mask_args[2] if len(mask_args) > 2 else mask_node.kwargs.get("value", None)
# Check if fill value is zero and the tensor is pre-filled with negative values
if fill_value == 0.0 or fill_value == 0:
# Check for the full tensor with negative values
if is_op(negative_tensor, torch.ops.aten.full):
fill_args = negative_tensor.args
if (
len(fill_args) > 1
and isinstance(fill_args[1], (int, float))
and fill_args[1] < -1e4
):
# This is likely a negative-filled tensor
# Now check if the mask is a bitwise_not of triu
if is_op(mask_tensor, torch.ops.aten.bitwise_not):
if len(mask_tensor.args) > 0 and _has_triu_ancestor(
mask_tensor.args[0], offset=1
):
return True
# Pattern for llama-3.1 style causal mask: slice of expand(unsqueeze(unsqueeze(mul_(triu, gt))))
if is_op(mask_node, torch.ops.aten.slice):
# Follow the chain backward to the source of the slice
if len(mask_node.args) == 0:
return False
slice_source = mask_node.args[0]
# Check for typical expand pattern
if not (slice_source and is_op(slice_source, torch.ops.aten.expand)):
return False
# Continue tracing back through the pattern
if len(slice_source.args) == 0:
return False
expand_source = slice_source.args[0]
# Check for first unsqueeze operation
if not (expand_source and is_op(expand_source, torch.ops.aten.unsqueeze)):
return False
# Look for the source of first unsqueeze
if len(expand_source.args) == 0:
return False
first_unsqueeze_source = expand_source.args[0]
# Check for second unsqueeze operation
if not (first_unsqueeze_source and is_op(first_unsqueeze_source, torch.ops.aten.unsqueeze)):
return False
# Look for the source of the second unsqueeze
if len(first_unsqueeze_source.args) == 0:
return False
second_unsqueeze_source = first_unsqueeze_source.args[0]
# Check for mul_ operation
if is_op(second_unsqueeze_source, torch.ops.aten.mul_):
# Check if one of the mul_ arguments is a triu operation
has_triu = False
for arg in second_unsqueeze_source.args:
if is_op(arg, torch.ops.aten.triu):
if len(arg.args) > 1 and arg.args[1] == 1:
has_triu = True
break
if has_triu:
# Check if one of the mul_ arguments involves a full tensor with negative values
for arg in second_unsqueeze_source.args:
if is_op(arg, torch.ops.aten.full):
if (
len(arg.args) > 1
and isinstance(arg.args[1], (int, float))
and arg.args[1] < -1e4
):
return True
return has_triu
# Original implementation for backward compatibility
if is_op(mask_node, torch.ops.aten.slice):
# Follow the chain backward to the source of the slice
if len(mask_node.args) == 0:
return False
slice_source = mask_node.args[0]
# Check for typical expand pattern
if not (slice_source and is_op(slice_source, torch.ops.aten.expand)):
return False
# Continue tracing back through the pattern
if len(slice_source.args) == 0:
return False
expand_source = slice_source.args[0]
# Check for unsqueeze operations
if not (expand_source and is_op(expand_source, torch.ops.aten.unsqueeze)):
return False
# Look for the source of the unsqueeze
if len(expand_source.args) == 0:
return False
unsqueeze_source = expand_source.args[0]
if not unsqueeze_source:
return False
# Check for triu pattern which is common in causal masks
if is_op(unsqueeze_source, torch.ops.aten.mul_):
for arg in unsqueeze_source.args:
if not is_op(arg, torch.ops.aten.triu):
continue
if len(arg.args) <= 1:
continue
triu_offset = arg.args[1]
# Causal masks typically use triu with offset 1
if triu_offset == 1:
return True
return False
# Check if we have a full tensor filled with a very negative number
if not is_op(unsqueeze_source, torch.ops.aten.full):
return False
if len(unsqueeze_source.args) <= 1:
return False
fill_value = unsqueeze_source.args[1]
# Check if the fill value is very negative (likely -inf or close)
if isinstance(fill_value, float) and fill_value < -1e10:
return True
# If we can't definitively identify it as causal, return False
return False
def _has_triu_ancestor(node: Node, offset: int = 1, depth: int = 0, max_depth: int = 5) -> bool:
"""Helper function to find a triu operation in the ancestry of a node."""
if depth > max_depth: # Prevent infinite recursion
return False
if is_op(node, torch.ops.aten.triu):
if len(node.args) > 1 and node.args[1] == offset:
return True
# Check if any of the arguments has a triu ancestor
for arg in node.args:
if isinstance(arg, Node) and _has_triu_ancestor(arg, offset, depth + 1, max_depth):
return True
# Check if any of the kwargs has a triu ancestor
for value in node.kwargs.values():
if isinstance(value, Node) and _has_triu_ancestor(value, offset, depth + 1, max_depth):
return True
return False
def match_attention_layout(gm: GraphModule, attention_op: Type[AttentionDescriptor]) -> None:
"""
Match and transform attention operations to match the layout expected by the attention backend.
If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which
is the default for SDPA operations, we don't need to transform anything.
If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert
appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa.
"""
# Get attention layout from attention_op
attention_layout = attention_op.get_attention_layout()
# List of SDPA operations to look for
sdpa_ops = {
torch.ops.auto_deploy.torch_attention_sdpa,
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
}
graph = gm.graph
num_bsnd_patterns = 0
# Look for SDPA operations
for sdpa_node in list(graph.nodes):
if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops):
continue
ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}")
# Extract q, k, v inputs
q, k, v = sdpa_node.args[:3]
# Check if we need to transpose the inputs
if attention_layout == "bsnd":
# Add transposes before the node (from bnsd to bsnd)
with graph.inserting_before(sdpa_node):
q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2))
k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2))
v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2))
# Preserve fake tensor in meta["val"] for the transposed inputs
q_updated.meta["val"] = q.meta["val"].transpose(1, 2)
k_updated.meta["val"] = k.meta["val"].transpose(1, 2)
v_updated.meta["val"] = v.meta["val"].transpose(1, 2)
elif attention_layout == "bnsd":
# we don't need to do anything...
q_updated = q
k_updated = k
v_updated = v
else:
raise ValueError(f"Unsupported attention layout: {attention_layout}")
# Create bsnd_grouped_sdpa node with the same args as the original node
# but using the transposed inputs
with graph.inserting_before(sdpa_node):
source_sdpa_node = graph.call_function(
attention_op.get_source_attention_op(),
args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:],
kwargs=sdpa_node.kwargs,
)
# Check if need to update the output node to match the layout
if attention_layout == "bsnd":
# Add transpose for the output (from bsnd back to bnsd)
with graph.inserting_after(source_sdpa_node):
output_updated = graph.call_function(
torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2)
)
# Preserve fake tensor in meta["val"] for the transposed inputs
source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous()
output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2)
elif attention_layout == "bnsd":
output_updated = source_sdpa_node
else:
raise ValueError(f"Unsupported attention layout: {attention_layout}")
# Replace the old node with the transposed output
sdpa_node.replace_all_uses_with(output_updated)
num_bsnd_patterns += 1
# Clean up the graph if we made any replacements
if num_bsnd_patterns:
canonicalize_graph(gm)
ad_logger.debug(f"Transformed graph for bsnd layout: {gm}")
ad_logger.info(f"Found and matched {num_bsnd_patterns} attention layouts")

View File

@ -174,7 +174,7 @@ def resize_kv_cache(
memory_for_forward_pass = free_mem_pre - free_mem_post
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
new_cache_size = free_mem_post * free_mem_ratio + current_cache_size
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
# Need to sync all the GPUs

View File

@ -141,6 +141,12 @@ def match_rope_pattern(gm: GraphModule) -> int:
torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16),
torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float16),
]
# float32 input can change the graph when there's .float() in pattern
dummy_complex_2 = [
torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32),
torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32),
torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float32),
]
register_ad_pattern(
search_fn=_explicit_rope_pattern,
replace_fn=_explicit_rope_repl,
@ -172,6 +178,16 @@ def match_rope_pattern(gm: GraphModule) -> int:
},
scalar_workaround={"unsqueeze_dim": 1},
)
register_ad_pattern(
search_fn=_complex_rope_pattern,
replace_fn=_complex_rope_repl,
patterns=patterns,
dummy_args=dummy_complex_2,
op_ignore_types={
torch.ops.aten.reshape.default: (int,),
},
scalar_workaround={"unsqueeze_dim": 1},
)
num_matches = patterns.apply(graph)
canonicalize_graph(gm)

View File

@ -24,17 +24,10 @@ from .library import (
fuse_collectives,
fuse_rmsnorm,
insert_cached_attention,
match_attention_layout,
match_causal_attn_mask,
match_eager_attention,
match_grouped_attention,
match_moe_pattern,
match_repeat_kv,
match_rope_layout,
match_rope_pattern,
optimize_rope,
quantize,
quantize_moe,
resize_kv_cache,
sharding_transform_executor,
update_in_out_nodes,
@ -63,6 +56,12 @@ class InferenceOptimizer:
############################################################################################
# RUN MODULAR INFERENCE OPTIMIZER FOR ALREADY-MIGRATED TRANSFORMS
############################################################################################
# TODO (hg): default values that are not representable in YAML.
if "match_attention_layout" in self.ad_config.transforms:
self.ad_config.transforms[
"match_attention_layout"
].attention_op = AttentionRegistry.get(self.ad_config.attn_backend)
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
egm = new_optimizer(cm)
@ -71,28 +70,10 @@ class InferenceOptimizer:
############################################################################################
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
############################################################################################
# quantization
quantize(egm, self.factory.get_quant_config())
quantize_moe(egm, self.factory.get_quant_config())
# Match MoE pattern
match_moe_pattern(egm)
# Match repeat_kv pattern
match_repeat_kv(egm)
# Match eager attention pattern
match_eager_attention(egm)
# Match grouped attention pattern
match_grouped_attention(egm)
# Match and optimize causal attention masks
match_causal_attn_mask(egm)
# Match attention layout expected by our backend
match_attention_layout(egm, AttentionRegistry.get(self.ad_config.attn_backend))
# Match rope
match_rope_pattern(egm)

View File

@ -153,6 +153,8 @@ def register_ad_pattern(
5. register_replacement can auto-generate `search_fn_pattern` if you input `example_inputs`,
but that approach will fail when symbolic shapes are involved. Here
we explicitly trace & convert via `fx_to_pattern`.
6. The PatternMatcherPass would check num_users of the nodes, meaning that the pattern is required
to be functionally isolated, no intermediate nodes are shared with the rest of the graph.
"""
argnames = list(inspect.signature(search_fn).parameters.keys())

View File

@ -8,16 +8,49 @@ from _torch_test_utils import all_close, reset_parameters
from torch.export import export
from torch.fx import GraphModule
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory
from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import ShardingTransformInfo
class FakeFactory:
def __init__(self, model: nn.Module):
self.model = model
class FakeFactory(ModelFactory):
"""Dummy factory to pass cache_config for testing."""
def build_model(self, device: str) -> nn.Module:
return self.model.to(device=device)
def __init__(self, model=None, cache_config=None, quant_config=None):
self._model = model
self.cache_config = cache_config
self.quant_config = quant_config
def build_model(self, device: str):
return self._model.to(device=device) if self._model else None
def _build_model(self, device: str):
return
def _load_checkpoint(self, model, device):
return
def get_cache_config(self):
return self.cache_config
def get_quant_config(self):
return self.quant_config
class SequenceEmbeddingInfo(SequenceInfo):
hidden_size: int
dtype: torch.dtype
def set_example_sequence(self) -> None:
super().set_example_sequence()
# set input ids to a 3D tensor (actually input embeddings)
self.input_ids = torch.rand(
*self.input_ids.shape,
self.hidden_size,
device=self.input_ids.device,
dtype=self.dtype,
)
def count_parameters(model: torch.nn.Module):
@ -32,6 +65,79 @@ def count_buffers(model: torch.nn.Module):
return sum(np.prod(b.shape) for b in model.buffers())
def run_test_transformed_gm(
model: nn.Module,
x: torch.Tensor,
gm_transformed: GraphModule,
check_transformed_graph: Callable[[GraphModule], bool],
_get_expected_num_params: Callable[[int], int],
atol: float = 1e-3,
rtol: float = 1e-3,
test_load_hook: bool = True,
strict_loading: bool = True,
dynamic_shapes: Dict = None,
skip_output_assert: bool = False,
*args, # Additional arguments for transform
) -> GraphModule:
# run model once
y_model = model(x)
# num params
num_params_model = count_parameters(model)
print(num_params_model)
# export + check (we clone the state dict to have a bit more freedom in testing below)
gm_ref = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
print(gm_ref)
y_gm = gm_ref(x)
num_params_gm = count_parameters(gm_ref)
assert num_params_model == num_params_gm
if not skip_output_assert:
torch.testing.assert_close(y_model, y_gm, atol=atol, rtol=rtol)
print(gm_transformed)
# in case buffers or other tensors were added during the transform
gm_transformed = gm_transformed.to("cuda")
y_transformed = gm_transformed(x)
n_p_transformed = count_parameters(gm_transformed)
n_p_t_expected = _get_expected_num_params(num_params_model)
assert n_p_transformed == n_p_t_expected, (
f"actual params {n_p_transformed} != expected params {n_p_t_expected}"
)
# check if the transformation worked
assert check_transformed_graph(gm_transformed)
if strict_loading and not skip_output_assert:
# check if output equals without loading state dict
torch.testing.assert_close(y_model, y_transformed, atol=atol, rtol=rtol)
if test_load_hook and not skip_output_assert:
# check if loading hook works from original state dict
reset_parameters(gm_transformed)
y_random = gm_transformed(x)
assert not all_close(y_model, y_random), f"{y_model=}, {y_random=}"
gm_transformed.load_state_dict(model.state_dict(), strict=True if strict_loading else False)
y_loaded_from_original = gm_transformed(x)
torch.testing.assert_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol)
# check if loading hook works from state_dict of a transformed model
state_dict_sharded = copy.deepcopy(gm_transformed.state_dict())
reset_parameters(gm_transformed)
y_random2 = gm_transformed(x)
assert not all_close(y_model, y_random2), f"{y_model=}, {y_random2=}"
gm_transformed.load_state_dict(state_dict_sharded, strict=True if strict_loading else False)
y_loaded_from_transformed = gm_transformed(x)
torch.testing.assert_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol)
# check if we can still export the model as expected
export(gm_transformed, args=(x,))
def run_test(
model: nn.Module,
x: torch.Tensor,

View File

@ -19,9 +19,6 @@ from build_and_run_ad import ExperimentConfig, main
],
)
def test_build_ad(world_size: int, experiment_config: Dict):
if world_size > 1:
pytest.skip("https://nvbugspro.nvidia.com/bug/5331013")
experiment_config["args"]["world_size"] = world_size
experiment_config["args"]["runtime"] = "trtllm" # Default runtime set to trtllm
experiment_config = ExperimentConfig(**experiment_config)

View File

@ -7,14 +7,8 @@ from torch.export import Dim
from torch.fx import GraphModule
from transformers.integrations.sdpa_attention import repeat_kv as hf_repeat_kv
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.transformations.library.attention import (
match_attention_layout,
match_causal_attn_mask,
match_eager_attention,
match_grouped_attention,
match_repeat_kv,
)
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
torch.manual_seed(0)
@ -164,16 +158,15 @@ class EagerAttentionModel(torch.nn.Module):
# Multiplication pattern
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scaling
# Add attention mask if enabled
# Add causal attention mask if enabled
if self.has_mask:
# Create a simple causal mask for testing - make sure all tensors are on the same device
mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device),
diagonal=1,
# [1, 1, seq_len, seq_len] causal mask with -inf in the upper triangle
attn_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1
)
mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
attn_mask = torch.zeros_like(attn_weights, device=device)
attn_mask = attn_mask.masked_fill(mask, float("-inf"))
attn_mask = (
attn_mask.unsqueeze(0).unsqueeze(0).to(x.dtype)
) # shape: [1, 1, seq_len, seq_len]
attn_weights = attn_weights + attn_mask
# Apply softmax, dtype conversion, and dropout
@ -249,13 +242,13 @@ class ComplexEagerAttentionModel(torch.nn.Module):
# Add attention mask if enabled
if self.has_mask:
mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device),
diagonal=1,
# [1, 1, seq_len, seq_len] causal mask with -inf in the upper triangle
attn_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1
)
mask = mask.unsqueeze(0).unsqueeze(0)
attn_mask = torch.zeros_like(attn_weights, device=device)
attn_mask = attn_mask.masked_fill(mask, float("-inf"))
attn_mask = (
attn_mask.unsqueeze(0).unsqueeze(0).to(x.dtype)
) # shape: [1, 1, seq_len, seq_len]
attn_weights = attn_weights + attn_mask
# Add a to_dtype node before softmax to match pattern in the graph
@ -366,8 +359,6 @@ class GroupedAttentionModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
device = x.device
dtype = x.dtype
# Generate q, k, v
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
@ -385,28 +376,26 @@ class GroupedAttentionModel(torch.nn.Module):
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, self.n_rep)
# Create attention mask if needed
attn_mask = None
if self.has_mask:
# Simple causal mask
mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device),
diagonal=1,
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout,
is_causal=True,
scale=1.0 / (self.head_dim**0.5),
)
else:
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout,
is_causal=False,
scale=1.0 / (self.head_dim**0.5),
)
mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
attn_mask = torch.zeros(
(batch_size, 1, seq_len, seq_len), device=device, dtype=dtype
).masked_fill(mask, float("-inf"))
# Apply scaled dot product attention
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.dropout,
is_causal=False,
scale=1.0 / (self.head_dim**0.5),
)
# Reshape output for the linear projection
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
@ -423,11 +412,47 @@ def _get_match_repeat_kv_optimizer() -> Callable:
"cleanup_noop_slice": {
"stage": "post_export",
},
"match_repeat_kv": {
"stage": "pattern_matcher",
},
}
def _transform(gm: GraphModule) -> GraphModule:
gm = InferenceOptimizer(None, config)(None, gm)
return gm
return _transform
def _get_match_eager_attention_optimizer() -> Callable:
config = {
"cleanup_noop_slice": {
"stage": "post_export",
},
"match_eager_attention": {
"stage": "pattern_matcher",
},
}
def _transform(gm: GraphModule) -> GraphModule:
gm = InferenceOptimizer(None, config)(None, gm)
return gm
return _transform
def _get_match_grouped_attention_optimizer() -> Callable:
config = {
"cleanup_noop_slice": {
"stage": "post_export",
},
"match_grouped_attention": {
"stage": "pattern_matcher",
},
}
def _transform(gm: GraphModule) -> GraphModule:
gm = InferenceOptimizer(None, config)(None, gm)
match_repeat_kv(gm)
return gm
return _transform
@ -516,8 +541,8 @@ def test_match_repeat_kv(num_heads, num_kv_heads, model_cls):
)
@pytest.mark.parametrize("has_mask", [True, False])
@pytest.mark.parametrize("use_division", [False, True])
@pytest.mark.parametrize("has_mask", [False, True])
@pytest.mark.parametrize("use_division", [True, False])
@pytest.mark.parametrize(
"dropout, skip_output_assert",
[
@ -537,8 +562,10 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
# Create different model types based on the parameter
if model_type == "standard":
model = EagerAttentionModel(hidden_size, num_heads, has_mask, dropout, use_division).to(
"cuda", dtype=torch.float16
model = (
EagerAttentionModel(hidden_size, num_heads, has_mask, dropout, use_division)
.to("cuda", dtype=torch.float16)
.eval()
)
# Print the original scaling approach and value
if use_division:
@ -549,8 +576,10 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
expected_scale = model.scaling
else: # complex
# Complex model only uses division for scaling
model = ComplexEagerAttentionModel(hidden_size, num_heads, has_mask, dropout).to(
"cuda", dtype=torch.float16
model = (
ComplexEagerAttentionModel(hidden_size, num_heads, has_mask, dropout)
.to("cuda", dtype=torch.float16)
.eval()
)
expected_scale = 1.0 / model.scale_divisor
# Override use_division and only run test once (ignore the parameterization)
@ -567,6 +596,7 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
expected_matches = 1
def verify_matcher(gm):
# torch_attention_sdpa is replaced with torch_attention_sdpa after the transformation
sdpa_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
]
@ -636,13 +666,15 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
# Check mask handling for masked attention
if has_mask:
has_mask_arg = "attn_mask" in kwargs
if not has_mask_arg and len(node.args) >= 4:
has_mask_arg = node.args[3] is not None
is_causal = kwargs.get("is_causal", None)
if is_causal is None and len(node.args) >= 6:
is_causal = node.args[5]
if not has_mask_arg:
print("❌ Missing mask information in SDPA node")
if is_causal is not True:
print(f"❌ Expected is_causal=True for masked attention, got {is_causal}")
valid = False
else:
print("✅ is_causal correctly set to True")
print("Graph verification successful" if valid else "Graph verification failed")
return valid
@ -651,7 +683,7 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
run_test(
model,
x,
match_eager_attention,
_get_match_eager_attention_optimizer(),
verify_matcher,
lambda num_p_og: num_p_og,
atol=1e-3,
@ -685,7 +717,7 @@ def test_counter_example():
_ = run_test(
model,
x,
match_repeat_kv,
_get_match_eager_attention_optimizer(),
verify_no_matches,
lambda num_p_og: num_p_og,
atol=1e-3,
@ -709,9 +741,8 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
dynamic_shapes = model.get_dynamic_shapes()
# We should find 1 instance of the pattern if num_heads != num_kv_heads
# Otherwise, no pattern should be matched (no grouped attention)
expected_matches = 1 if num_heads != num_kv_heads else 0
# We should find 1 instance of torch_attention_grouped_sdpa
expected_matches = 1
def verify_matcher(gm):
grouped_sdpa_nodes = [
@ -727,10 +758,6 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
)
return False
# If we don't expect any matches, we're done
if expected_matches == 0:
return True
# Otherwise, check the node properties
for node in grouped_sdpa_nodes:
# Basic checks: should have at least 3 positional args (q, k, v)
@ -743,16 +770,14 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
# Mask handling should be preserved
if has_mask:
# Check if attn_mask is in kwargs or provided via args
has_mask_arg = "attn_mask" in kwargs
if (
not has_mask_arg and len(node.args) >= 4
): # Assuming attn_mask is the 4th positional arg
has_mask_arg = node.args[3] is not None
is_causal = kwargs.get("is_causal", None)
if is_causal is None and len(node.args) >= 6:
is_causal = node.args[5]
if not has_mask_arg:
print("❌ Expected attn_mask in args or kwargs but not found")
return False
if is_causal is not True:
print(f"❌ Expected is_causal=True for masked attention, got {is_causal}")
else:
print("✅ is_causal correctly set to True")
return True
@ -760,7 +785,7 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
_ = run_test(
model,
x,
match_grouped_attention,
_get_match_grouped_attention_optimizer(),
verify_matcher,
lambda num_p_og: num_p_og,
atol=1e-3,
@ -884,98 +909,6 @@ class CausalAttentionModel(torch.nn.Module):
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
@pytest.mark.parametrize("mask_type", ["triu", "negative_fill", "non_causal"])
@pytest.mark.parametrize("use_grouped_sdpa", [False, True])
@torch.inference_mode()
def test_match_causal_attention(mask_type, use_grouped_sdpa):
batch_size, seq_len = 4, 12
hidden_size = 512
num_heads = 8
num_kv_heads = 4 if use_grouped_sdpa else num_heads
model = CausalAttentionModel(
hidden_size,
num_heads,
mask_type=mask_type,
use_grouped_sdpa=use_grouped_sdpa,
num_kv_heads=num_kv_heads,
).to("cuda", dtype=torch.float16)
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
dynamic_shapes = model.get_dynamic_shapes()
# We expect optimization (None mask + is_causal=True) when using causal masks
should_optimize = mask_type in ["triu", "negative_fill"]
def verify_matcher(gm):
# Find attention operations
if use_grouped_sdpa:
attn_nodes = [
n
for n in gm.graph.nodes
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
]
else:
attn_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
]
if len(attn_nodes) != 1:
print(f"Expected 1 attention node, found {len(attn_nodes)}")
return False
node = attn_nodes[0]
# Check if attention mask was set to None and is_causal was set to True
if should_optimize:
# Attention mask (4th arg) should be None
has_mask = (
node.args[3] is not None if len(node.args) > 3 else "attn_mask" in node.kwargs
)
# is_causal (6th arg) should be True
is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False)
# Check if optimization was correctly applied
if has_mask or not is_causal:
print("❌ Expected optimization: mask=None, is_causal=True")
print(
f" Got: mask={node.args[3] if len(node.args) > 3 else 'not in args'}, "
f"is_causal={is_causal}"
)
return False
print("✅ Successfully optimized causal mask: mask=None, is_causal=True")
else:
# Non-causal masks should remain as is
has_mask = (
node.args[3] is not None if len(node.args) > 3 else "attn_mask" in node.kwargs
)
# Check if non-optimization was correctly preserved
if not has_mask:
print("❌ Expected non-causal mask to be preserved")
return False
print("✅ Successfully preserved non-causal mask")
return True
# Run the test
_ = run_test(
model,
x,
match_causal_attn_mask,
verify_matcher,
lambda num_p_og: num_p_og,
atol=1e-3,
rtol=1e-3,
test_load_hook=True,
strict_loading=True,
dynamic_shapes=dynamic_shapes,
)
class Llama3CausalAttentionModel(torch.nn.Module):
"""Model that creates a causal attention mask mimicking the llama-3.1 pattern."""
@ -1082,78 +1015,7 @@ class Llama3CausalAttentionModel(torch.nn.Module):
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
@pytest.mark.parametrize("use_grouped_sdpa", [False, True])
@pytest.mark.skip(reason="Skip until we have more robust attention masking handling, see #4783")
@torch.inference_mode()
def test_match_llama3_causal_attention(use_grouped_sdpa):
batch_size, seq_len = 4, 12
hidden_size = 512
num_heads = 8
num_kv_heads = 4 if use_grouped_sdpa else num_heads
model = Llama3CausalAttentionModel(
hidden_size,
num_heads,
use_grouped_sdpa=use_grouped_sdpa,
num_kv_heads=num_kv_heads,
).to("cuda", dtype=torch.float32)
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float32)
dynamic_shapes = model.get_dynamic_shapes()
def verify_matcher(gm):
# Find attention operations
if use_grouped_sdpa:
attn_nodes = [
n
for n in gm.graph.nodes
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
]
else:
attn_nodes = [
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
]
if len(attn_nodes) != 1:
print(f"Expected 1 attention node, found {len(attn_nodes)}")
return False
node = attn_nodes[0]
# Attention mask (4th arg) should be None
has_mask = node.args[3] is not None if len(node.args) > 3 else "attn_mask" in node.kwargs
# is_causal (6th arg) should be True
is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False)
# Check if optimization was correctly applied
if has_mask or not is_causal:
print("❌ Expected optimization: mask=None, is_causal=True")
print(
f" Got: mask={node.args[3] if len(node.args) > 3 else 'not in args'}, "
f"is_causal={is_causal}"
)
return False
print("✅ Successfully optimized llama-3.1 causal mask: mask=None, is_causal=True")
return True
# Run the test
run_test(
model,
x,
match_causal_attn_mask,
verify_matcher,
lambda num_p_og: num_p_og,
atol=1e-3,
rtol=1e-3,
test_load_hook=True,
strict_loading=True,
dynamic_shapes=dynamic_shapes,
)
class MockAttentionDescriptor:
class MockAttentionDescriptor(AttentionDescriptor):
"""A mock class that mimics the AttentionDescriptor interface for testing."""
layout: str = "bnsd"
@ -1458,7 +1320,15 @@ def test_match_attention_layout(layout, model_config, has_mask):
run_test(
model,
x,
lambda gm: match_attention_layout(gm, MockAttentionDescriptor),
lambda gm: InferenceOptimizer(
None,
{
"match_attention_layout": {
"stage": "pattern_matcher",
"attention_op": MockAttentionDescriptor,
},
},
)(None, gm),
verify_matcher,
lambda num_p_og: num_p_og,
atol=1e-3,

View File

@ -1,26 +1,26 @@
"""Test that the attention matcher works with HF's SDPA backends."""
import copy
from typing import Any, Callable, Dict
import pytest
import torch
import torch.nn as nn
from _graph_test_helpers import run_test
from accelerate import init_empty_weights
from torch.export import Dim
from torch.fx import GraphModule
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from tensorrt_llm._torch.auto_deploy.transformations.library import (
match_attention_layout,
match_causal_attn_mask,
match_eager_attention,
match_grouped_attention,
match_repeat_kv,
)
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
torch.manual_seed(0)
class MockAttentionDescriptor:
class MockAttentionDescriptor(AttentionDescriptor):
"""A mock class that mimics the AttentionDescriptor interface for testing."""
layout: str = "bsnd"
@ -45,11 +45,24 @@ class HFWrapper(nn.Module):
def _joint_transform(gm: GraphModule) -> None:
match_repeat_kv(gm)
match_eager_attention(gm)
match_grouped_attention(gm)
match_causal_attn_mask(gm)
match_attention_layout(gm, MockAttentionDescriptor())
gm = InferenceOptimizer(
None,
{
"match_repeat_kv": {
"stage": "pattern_matcher",
},
"match_eager_attention": {
"stage": "pattern_matcher",
},
"match_grouped_attention": {
"stage": "pattern_matcher",
},
"match_attention_layout": {
"stage": "pattern_matcher",
"attention_op": MockAttentionDescriptor,
},
},
)(None, gm)
@pytest.mark.parametrize(
@ -65,23 +78,6 @@ def _joint_transform(gm: GraphModule) -> None:
["eager", "sdpa"],
)
def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str):
batch_size, seq_len = 4, 12
full_config = {
"num_hidden_layers": 1,
"vocab_size": 256,
"hidden_size": 128,
"intermediate_size": 128,
"attn_implementation": attn_implementation,
**config,
}
dynamic_shapes = {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
model = HFWrapper(LlamaModel(LlamaConfig(**full_config))).to("cuda")
model.eval()
x = torch.randint(
0, full_config["vocab_size"], (batch_size, seq_len), dtype=torch.long, device="cuda"
)
def verify_matcher(gm: GraphModule):
"""Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
call in the graph. Also check that there is no repeat_kv pattern left.
@ -106,18 +102,69 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str)
op="call_function", target=torch.ops.auto_deploy.torch_attention_repeat_kv
)
assert len(nodes) == 0, "Found repeat_kv pattern in the graph"
attn_nodes = gm.graph.find_nodes(
op="call_function", target=torch.ops.auto_deploy.torch_attention_sdpa
)
assert len(attn_nodes) == 0, "Found torch_attention_sdpa node in the graph"
return True
_ = run_test(
model,
x,
_joint_transform,
verify_matcher,
lambda num_p_og: num_p_og,
atol=1e-3,
rtol=5e-2,
test_load_hook=True,
strict_loading=True,
dynamic_shapes=dynamic_shapes,
batch_size, seq_len = 2, 4
full_config = {
"num_hidden_layers": 1,
"vocab_size": 256,
"hidden_size": 128,
"intermediate_size": 128,
"attn_implementation": attn_implementation,
**config,
}
dynamic_shapes = {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=2, max=8)}
# Build and export model on meta device
with init_empty_weights():
model = HFWrapper(LlamaModel(LlamaConfig(**full_config))).eval()
x = torch.randint(
0, full_config["vocab_size"], (batch_size, seq_len), dtype=torch.long, device="cuda"
)
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
print("Exported gm", gm)
gm_exported = copy.deepcopy(gm)
# Move model to cuda
device = "cuda"
model._apply(
lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype)
if t.device == torch.device("meta")
else t.to(device)
)
y_model = model(x)
gm_exported._apply(
lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype)
if t.device == torch.device("meta")
else t.to(device)
)
gm_exported.load_state_dict(model.state_dict())
move_to_device(gm_exported, "cuda")
y_gm_exported = gm_exported(x)
torch.testing.assert_close(y_gm_exported, y_model, atol=5e-3, rtol=5e-3)
# Apply transformation
_joint_transform(gm)
assert verify_matcher(gm)
print("Transformed gm", gm)
# Move gm to cuda
gm._apply(
lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype)
if t.device == torch.device("meta")
else t.to(device)
)
gm.load_state_dict(model.state_dict())
move_to_device(gm, "cuda")
# Verify output
y_gm = gm(x)
torch.testing.assert_close(y_gm_exported, y_gm, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(y_model, y_gm, atol=5e-2, rtol=5e-2)

View File

@ -1,10 +1,11 @@
import pytest
import torch
from _graph_test_helpers import run_test
from _graph_test_helpers import FakeFactory, run_test_transformed_gm
from _model_test_utils import MoEOpModel
from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available
from tensorrt_llm._torch.auto_deploy.transformations.library import quantize_moe
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
@ -62,13 +63,20 @@ def test_quantize_moe_transformation(quant_algo, expected_op):
quant_config = {"quant_algo": quant_algo}
def _transform(gm, *args):
return quantize_moe(gm, quant_config)
gm = torch_export_to_gm(model, args=(x,), clone=True)
gm_transformed = InferenceOptimizer(
FakeFactory(quant_config=quant_config),
{
"quantize_moe": {
"stage": "pattern_matcher",
},
},
)(None, gm)
_ = run_test(
run_test_transformed_gm(
model=model,
x=x,
transform=_transform,
gm_transformed=gm_transformed,
check_transformed_graph=_check_transformed_graph,
_get_expected_num_params=_expected_num_params,
atol=0.5,

View File

@ -4,13 +4,14 @@ Tests for basic graph sharding.
import pytest
import torch
from _graph_test_helpers import run_test
from _graph_test_helpers import run_test_transformed_gm
from _model_test_utils import MLP, BMMDynamicModel, BMMModel
from _torch_test_utils import fp4_compatible, fp8_compatible
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import QUANT_OPS
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transformations.library import quantize
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp8_scale
@ -19,6 +20,22 @@ def check_quantized(gm):
return any(is_op(n, QUANT_OPS) for n in gm.graph.nodes)
class DummyFactory(ModelFactory):
"""Dummy factory to pass quant_config for testing."""
def __init__(self, quant_config):
self.quant_config = quant_config
def _build_model(self, device: str):
return
def _load_checkpoint(self, model, device):
return
def get_quant_config(self):
return self.quant_config
@pytest.mark.parametrize(
"quant_config,atol,rtol,num_p_og",
[
@ -51,11 +68,22 @@ def test_quantization(quant_config, atol, rtol, num_p_og):
model.linear2.register_buffer(
"input_scale", torch.tensor([1.0], device=model.linear2.weight.device)
)
# set up sequence+cache objects
gm = torch_export_to_gm(model, args=(x,), clone=True)
gm_transformed = InferenceOptimizer(
DummyFactory(quant_config),
{
"quantize": {
"stage": "pattern_matcher",
},
},
)(None, gm)
gm_transformed.to("cuda")
gm_transformed = run_test(
run_test_transformed_gm(
model,
x,
quantize,
gm_transformed,
check_quantized,
num_p_og,
atol,
@ -122,10 +150,22 @@ def test_bmm_quantization(quant_config, atol, rtol, num_p_og, model_class):
model.register_buffer("bmm_dynamic_input_scale", fp8_scale(x))
model.register_buffer("bmm_dynamic_weight_scale", fp8_scale(dummy_weight))
gm_transformed = run_test(
# set up sequence+cache objects
gm = torch_export_to_gm(model, args=(x,), clone=True)
gm_transformed = InferenceOptimizer(
DummyFactory(quant_config),
{
"quantize": {
"stage": "pattern_matcher",
},
},
)(None, gm)
gm_transformed.to("cuda")
run_test_transformed_gm(
model,
x,
quantize,
gm_transformed,
check_quantized,
num_p_og,
atol,