mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[AutoDeploy] merge feat/ad-2025-07-22 (#6520)
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Signed-off-by: Gal Agam <ghubaraagam@cw-dfw-cs-001-login-01.cm.cluster> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Signed-off-by: haoguo <67671475+h-guo18@users.noreply.github.com> Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com> Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Co-authored-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Co-authored-by: Gal Agam <ghubaraagam@cw-dfw-h100-004-328-012.cm.cluster> Co-authored-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> Co-authored-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Co-authored-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
This commit is contained in:
parent
16febefee0
commit
5247df6ae2
@ -19,3 +19,15 @@ transforms:
|
||||
stage: post_export
|
||||
cleanup_input_constraints:
|
||||
stage: post_export
|
||||
quantize:
|
||||
stage: pattern_matcher
|
||||
quantize_moe:
|
||||
stage: pattern_matcher
|
||||
match_repeat_kv:
|
||||
stage: pattern_matcher
|
||||
match_eager_attention:
|
||||
stage: pattern_matcher
|
||||
match_grouped_attention:
|
||||
stage: pattern_matcher
|
||||
match_attention_layout:
|
||||
stage: pattern_matcher
|
||||
|
||||
@ -7,7 +7,28 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention.
|
||||
|
||||
def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor:
|
||||
"""Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)"""
|
||||
if logit_cap is not None and logit_cap > 0.0:
|
||||
return logit_cap * torch.tanh(attn_scores / logit_cap)
|
||||
return attn_scores
|
||||
|
||||
|
||||
def _convert_boolean_mask_to_float(attn_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""Convert boolean attention mask to floating point mask.
|
||||
Args:
|
||||
attn_mask: Boolean tensor where True allows attention, False blocks it
|
||||
dtype: Target dtype for the output mask
|
||||
Returns:
|
||||
Floating point mask where True -> 1.0, False -> -inf
|
||||
"""
|
||||
if attn_mask.dtype == torch.bool:
|
||||
float_mask = torch.zeros_like(attn_mask, dtype=dtype)
|
||||
float_mask = float_mask.masked_fill(attn_mask, 1.0) # True -> 1.0
|
||||
float_mask = float_mask.masked_fill(~attn_mask, float("-inf")) # False -> -inf
|
||||
return float_mask
|
||||
return attn_mask
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=())
|
||||
@ -77,19 +98,96 @@ def grouped_sdpa(
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
logit_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""SDPA attention that can handle GQA."""
|
||||
"""SDPA attention that can handle GQA. Expects bnsd format inputs."""
|
||||
b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim]
|
||||
_, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim]
|
||||
|
||||
return F.scaled_dot_product_attention(
|
||||
query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(),
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=True,
|
||||
)
|
||||
# Inputs are already in bnsd format, no need to transpose
|
||||
query_t = query # [b, n_heads, s_q, head_dim]
|
||||
key_t = key # [b, n_kv_heads, s_k, head_dim]
|
||||
value_t = value # [b, n_kv_heads, s_k, v_head_dim]
|
||||
|
||||
# Handle GQA by repeating KV if needed
|
||||
if n_heads != n_kv_heads:
|
||||
n_rep = n_heads // n_kv_heads
|
||||
key_t = repeat_kv(key_t, n_rep)
|
||||
value_t = repeat_kv(value_t, n_rep)
|
||||
|
||||
# Set scale
|
||||
if scale is None:
|
||||
scale = 1.0 / math.sqrt(head_dim)
|
||||
|
||||
# Compute attention scores: Q @ K^T
|
||||
attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k]
|
||||
|
||||
# Apply attention mask if provided
|
||||
if attn_mask is not None:
|
||||
# Convert boolean mask to float if needed
|
||||
attn_mask = _convert_boolean_mask_to_float(attn_mask, attn_scores.dtype)
|
||||
attn_scores = attn_scores + attn_mask
|
||||
|
||||
# Apply causal mask if specified and only during the context phase
|
||||
if is_causal and s_q == s_k: # Only apply causal mask during context processing
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(s_q, s_k, device=query.device, dtype=torch.bool),
|
||||
diagonal=1, # Use diagonal=1 for standard causal masking
|
||||
)
|
||||
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
|
||||
|
||||
# Apply sliding window mask if specified
|
||||
if sliding_window is not None and sliding_window > 0:
|
||||
# Handle position calculation for both context and generation phases
|
||||
if s_q == s_k:
|
||||
# Context phase: standard position calculation
|
||||
query_positions = torch.arange(s_q, device=query.device)
|
||||
key_positions = torch.arange(s_k, device=query.device)
|
||||
else:
|
||||
# Generation phase: query is at position s_k (after the cache)
|
||||
query_positions = torch.arange(s_k, s_k + s_q, device=query.device) # [s_k] for s_q=1
|
||||
key_positions = torch.arange(s_k, device=query.device) # [0,1,2,...,s_k-1]
|
||||
|
||||
# Create position difference matrix: query_pos - key_pos
|
||||
pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k]
|
||||
|
||||
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
|
||||
sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k]
|
||||
attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf"))
|
||||
|
||||
# Apply logit softcapping if enabled
|
||||
attn_scores = _apply_logit_softcapping(attn_scores, logit_cap)
|
||||
|
||||
# Apply sinks if provided
|
||||
if sinks is not None:
|
||||
# Concatenate sinks to attention scores following the reference implementation
|
||||
# sinks should have n_heads elements, each head gets its own sink value
|
||||
# Expand sinks to [b, n_heads, s_q, 1] - one sink column per head
|
||||
sinks_expanded = sinks.reshape(1, -1, 1, 1).expand(
|
||||
b, n_heads, s_q, 1
|
||||
) # [b, n_heads, s_q, 1]
|
||||
|
||||
# Concatenate along the key dimension (last dimension)
|
||||
logits_max = torch.max(attn_scores, dim=-1, keepdim=True).values
|
||||
sinks = torch.exp(sinks_expanded - logits_max)
|
||||
unnormalized_scores = torch.exp(attn_scores - logits_max)
|
||||
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
|
||||
scores = unnormalized_scores / normalizer
|
||||
# Use only the non-sink portion for computing output
|
||||
# We added exactly 1 column, so remove exactly 1 column
|
||||
attn_out = torch.matmul(scores, value_t) # [b, n_heads, s_q, v_head_dim]
|
||||
else:
|
||||
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim]
|
||||
|
||||
# Apply dropout if specified
|
||||
if dropout_p > 0.0:
|
||||
attn_out = F.dropout(attn_out, p=dropout_p, training=False)
|
||||
|
||||
# Return in bnsd format (same as input format)
|
||||
return attn_out
|
||||
|
||||
|
||||
@grouped_sdpa.register_fake
|
||||
@ -101,6 +199,9 @@ def grouped_sdpa_fake(
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
sinks=None,
|
||||
sliding_window=None,
|
||||
logit_cap=None,
|
||||
):
|
||||
"""Fake implementation of grouped SDPA."""
|
||||
return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous()
|
||||
@ -108,9 +209,9 @@ def grouped_sdpa_fake(
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=())
|
||||
def bsnd_grouped_sdpa(
|
||||
query: torch.Tensor, # layout: [b, n, s_q, d]
|
||||
key: torch.Tensor, # layout: [b, n, s_k, d]
|
||||
value: torch.Tensor, # layout: [b, n, s_k, d]
|
||||
query: torch.Tensor, # layout: [b, s_q, n, d]
|
||||
key: torch.Tensor, # layout: [b, s_k, n, d]
|
||||
value: torch.Tensor, # layout: [b, s_k, n, d]
|
||||
attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k]
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
@ -124,14 +225,16 @@ def bsnd_grouped_sdpa(
|
||||
Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the
|
||||
original sdpa op!
|
||||
"""
|
||||
# let's transpose to bnsd so we can use the grouped sdpa
|
||||
query = query.transpose(1, 2).contiguous()
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
# Transpose inputs to bnsd format for grouped_sdpa
|
||||
query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d]
|
||||
key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
|
||||
value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d]
|
||||
|
||||
out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
||||
|
||||
# let's transpose back to bnsd
|
||||
# Call grouped_sdpa with bnsd inputs
|
||||
out = grouped_sdpa(
|
||||
query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap
|
||||
)
|
||||
# Transpose back to bsnd format
|
||||
return out.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
|
||||
@ -103,7 +103,7 @@ def _torch_generate_mha(
|
||||
# Apply sinks if provided (following the model file pattern)
|
||||
if sinks is not None:
|
||||
# Concatenate sinks to attention scores
|
||||
sinks = sinks.reshape(-1, 1, 1).expand(-1, attn_scores.shape[-2], -1)
|
||||
sinks = sinks.reshape(-1, 1, 1)
|
||||
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
# Use only the non-sink portion for computing output (ignore sinks)
|
||||
@ -202,9 +202,7 @@ def _torch_context_mha(
|
||||
) # [seq_len_i, kv_seq_len]
|
||||
|
||||
# Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size
|
||||
sliding_window_mask = (pos_diff < 0) | (
|
||||
pos_diff >= sliding_window_size
|
||||
) # [seq_len_i, kv_seq_len]
|
||||
sliding_window_mask = pos_diff >= sliding_window_size
|
||||
|
||||
# Combine causal and sliding window masks
|
||||
combined_mask = causal_mask | sliding_window_mask
|
||||
@ -219,14 +217,14 @@ def _torch_context_mha(
|
||||
# Apply sinks if provided (following the model file pattern)
|
||||
if sinks is not None:
|
||||
# Concatenate sinks to attention scores
|
||||
sinks = sinks.reshape(1, -1, 1, 1).expand(
|
||||
attn_scores.shape[0], -1, attn_scores.shape[-2], -1
|
||||
new_sinks = sinks.reshape(1, -1, 1, 1).expand(
|
||||
attn_scores.shape[0], -1, attn_scores.shape[2], 1
|
||||
)
|
||||
attn_weights = torch.cat([attn_scores, sinks], dim=-1)
|
||||
attn_weights = torch.cat([attn_scores, new_sinks], dim=-1)
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
# Use only the non-sink portion for computing output (ignore sinks)
|
||||
attn_out = torch.matmul(
|
||||
attn_weights[..., : -sinks.size(-1)], v_seq_t
|
||||
attn_weights[..., : -new_sinks.size(-1)], v_seq_t
|
||||
) # [1, n_heads, seq_len_i, v_head_dim]
|
||||
else:
|
||||
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
|
||||
@ -17,7 +17,8 @@ try:
|
||||
rank, world_size = get_rank_world_size()
|
||||
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
|
||||
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
|
||||
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO)
|
||||
# Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
|
||||
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.NCCL)
|
||||
return torch_op(tensor, all_reduce_params=all_reduce_params)
|
||||
|
||||
@torch.library.custom_op(
|
||||
|
||||
@ -76,6 +76,12 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
"max_position_embeddings": 1024,
|
||||
}
|
||||
|
||||
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
|
||||
"""Get the max position embeddings config for the model."""
|
||||
return {
|
||||
"max_position_embeddings": self.max_seq_len,
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -83,7 +89,11 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
|
||||
# Ingest defaults for tokenizer and model kwargs
|
||||
self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs)
|
||||
self.model_kwargs = deep_merge_dicts(self._model_defaults, self.model_kwargs)
|
||||
self.model_kwargs = deep_merge_dicts(
|
||||
self._model_defaults,
|
||||
self.model_kwargs,
|
||||
self._get_max_position_embeddings_config(),
|
||||
)
|
||||
|
||||
# special handling for torch_dtype in model_kwargs since HF does not correctly update
|
||||
# torch_dtype string to an actual torch.dtype object (only with default)
|
||||
@ -295,7 +305,7 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
# at this point it should be a directory (either the original one or the download dir)
|
||||
assert os.path.isdir(fetched_dir), f"Checkpoint path {fetched_dir} is not a directory."
|
||||
|
||||
self._load_quantization_config()
|
||||
self._load_quantization_config(fetched_dir)
|
||||
|
||||
return fetched_dir
|
||||
|
||||
@ -313,13 +323,13 @@ class AutoModelForCausalLMFactory(ModelFactory):
|
||||
# model-transformed weights,leading to unexpected key mismatches or format issues.
|
||||
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
|
||||
|
||||
def _load_quantization_config(self):
|
||||
def _load_quantization_config(self, fetched_dir: str):
|
||||
"""Load the quantization config from the model directory if not done already."""
|
||||
if self._quant_config is not None:
|
||||
return
|
||||
|
||||
assert self.model
|
||||
hf_quant_config_file = os.path.join(self.model, "hf_quant_config.json")
|
||||
hf_quant_config_file = os.path.join(fetched_dir, "hf_quant_config.json")
|
||||
if os.path.exists(hf_quant_config_file):
|
||||
with open(hf_quant_config_file, "r") as file:
|
||||
quantization_config = json.load(file)
|
||||
@ -344,6 +354,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
|
||||
},
|
||||
}
|
||||
|
||||
def _get_max_position_embeddings_config(self) -> Dict[str, Any]:
|
||||
"""Get the max position embeddings config for the model."""
|
||||
return {
|
||||
"max_position_embeddings": self.max_seq_len,
|
||||
"text_config": {
|
||||
"max_position_embeddings": self.max_seq_len,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def automodel_from_config(self):
|
||||
return AutoModelForImageTextToText.from_config
|
||||
|
||||
@ -227,18 +227,26 @@ class BaseTransform(ABC):
|
||||
# run or skip the transform
|
||||
if self.config.enabled:
|
||||
# run graph pre-cleanup
|
||||
self._run_pre_cleanup(gm, info_last)
|
||||
is_clean_pre, has_valid_shapes_pre = self._run_pre_cleanup(gm, info_last)
|
||||
|
||||
# run the transform in a error-handling wrapper
|
||||
try:
|
||||
gm, info = self._apply(gm, cm, factory)
|
||||
except Exception as e:
|
||||
error_msg = f"Transform {t_name} failed"
|
||||
if self.config.skip_on_error:
|
||||
# run the transform in a error-handling wrapper if desired
|
||||
if self.config.skip_on_error:
|
||||
try:
|
||||
gm, info = self._apply(gm, cm, factory)
|
||||
except Exception as e:
|
||||
error_msg = f"Transform {t_name} failed"
|
||||
ad_logger.warning(f"{error_msg}: {e}")
|
||||
info = TransformInfo(skipped=True, num_matches=0)
|
||||
else:
|
||||
raise TransformError(error_msg) from e
|
||||
else:
|
||||
# handle this here normally to improve debugging and error message
|
||||
gm, info = self._apply(gm, cm, factory)
|
||||
|
||||
# we cannot say it's clean if the previous wasn't clean even if this one is
|
||||
# create new info object with updated cleanup status
|
||||
info_dict = info.model_dump()
|
||||
info_dict["is_clean"] &= is_clean_pre
|
||||
info_dict["has_valid_shapes"] &= has_valid_shapes_pre
|
||||
info = TransformInfo(**info_dict)
|
||||
|
||||
# run graph post-cleanup
|
||||
info = self._run_post_cleanup(gm, info)
|
||||
@ -279,20 +287,36 @@ class BaseTransform(ABC):
|
||||
gm.meta[self._autodeploy_meta_key] = autodeploy_meta
|
||||
|
||||
@final
|
||||
def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> None:
|
||||
def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> Tuple[bool, bool]:
|
||||
"""Run graph cleanup before the transform.
|
||||
|
||||
Args:
|
||||
gm: The graph module to run cleanup on.
|
||||
info: The last transform info.
|
||||
|
||||
Returns:
|
||||
A tuple of (is_clean, has_valid_shapes) indicating the cleanup status after the
|
||||
pre-cleanup.
|
||||
|
||||
This is used to ensure the transform is applied to a clean graph as needed by the transform.
|
||||
"""
|
||||
if not self.config.requires_clean_graph:
|
||||
return
|
||||
return info.is_clean, info.has_valid_shapes
|
||||
|
||||
is_clean = info.is_clean
|
||||
has_valid_shapes = is_clean and info.has_valid_shapes
|
||||
|
||||
# check if run cleanup depending on the config and info
|
||||
if self.config.requires_shape_prop and not (info.is_clean and info.has_valid_shapes):
|
||||
if self.config.requires_shape_prop and not has_valid_shapes:
|
||||
with lift_to_meta(gm):
|
||||
canonicalize_graph(gm, shape_prop=True)
|
||||
elif self.config.requires_clean_graph and not info.is_clean:
|
||||
is_clean = True
|
||||
has_valid_shapes = True
|
||||
elif self.config.requires_clean_graph and not is_clean:
|
||||
canonicalize_graph(gm)
|
||||
is_clean = True
|
||||
|
||||
return is_clean, has_valid_shapes
|
||||
|
||||
@final
|
||||
def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformInfo:
|
||||
|
||||
562
tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
Normal file
562
tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
Normal file
@ -0,0 +1,562 @@
|
||||
"""Pattern matching for detecting repeat_kv, eager, grouped attention patterns from Huggingface models."""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...custom_ops.attention_interface import AttentionDescriptor
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import is_op
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from ..interface import BaseTransform, TransformConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
def _apply_pattern(
|
||||
gm: GraphModule,
|
||||
pattern_name: str,
|
||||
register_fn: Callable[[ADPatternMatcherPass], None],
|
||||
) -> int:
|
||||
"""Utility to register and apply a pattern."""
|
||||
patterns = ADPatternMatcherPass()
|
||||
register_fn(patterns)
|
||||
num_matches = patterns.apply(gm.graph)
|
||||
return num_matches
|
||||
|
||||
|
||||
def _repeat_kv_pattern(hidden_states, n_rep) -> torch.Tensor:
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = torch.unsqueeze(hidden_states, 2)
|
||||
hidden_states = hidden_states.expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def _repeat_kv_repl(hidden_states, n_rep) -> torch.Tensor:
|
||||
return torch.ops.auto_deploy.torch_attention_repeat_kv(hidden_states, n_rep)
|
||||
|
||||
|
||||
# with causal_mask, no division
|
||||
def _sfdp_pattern_1(query, key, value, attention_mask, scaling, dropout):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _sfdp_replacement_1(query, key, value, attention_mask, scaling, dropout):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout,
|
||||
is_causal=True,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
|
||||
# no causal_mask, no division
|
||||
def _sfdp_pattern_2(query, key, value, scaling, dropout):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _sfdp_replacement_2(query, key, value, scaling, dropout):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout,
|
||||
is_causal=False,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
|
||||
# with causal_mask, with division
|
||||
def _sfdp_pattern_3(query, key, value, attention_mask, scaling, dropout):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _sfdp_replacement_3(query, key, value, attention_mask, scaling, dropout):
|
||||
scaling = 1.0 / scaling
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout,
|
||||
is_causal=True,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
|
||||
# no causal_mask, with division
|
||||
def _sfdp_pattern_4(query, key, value, scaling, dropout):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _sfdp_replacement_4(query, key, value, scaling, dropout):
|
||||
scaling = 1.0 / scaling
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout,
|
||||
is_causal=False,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
|
||||
# no causal_mask, with division, explicit casting model
|
||||
def _sfdp_pattern_5(query, key, value, scaling, dropout):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_weights = F.softmax(attn_weights, dim=-1).to(query.dtype)
|
||||
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _sfdp_replacement_5(query, key, value, scaling, dropout):
|
||||
scaling = 1.0 / scaling
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout,
|
||||
is_causal=False,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
|
||||
# with causal_mask, with division, explicit casting model
|
||||
def _sfdp_pattern_6(query, key, value, attention_mask, scaling, dropout):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_weights = F.softmax(attn_weights, dim=-1).to(query.dtype)
|
||||
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _sfdp_replacement_6(query, key, value, attention_mask, scaling, dropout):
|
||||
scaling = 1.0 / scaling
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout,
|
||||
is_causal=True,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
|
||||
# Only pass in causal attention mask in downstream standardized pipeline
|
||||
def _sfdp_pattern_7(query, key, value, attention_mask, scaling, dropout):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=dropout,
|
||||
is_causal=False,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
|
||||
def _sfdp_replacement_7(query, key, value, attention_mask, scaling, dropout):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout,
|
||||
is_causal=True if attention_mask is not None else False,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
|
||||
# with causal_mask, no division, does not cast to fp32 for softmax
|
||||
def _sfdp_pattern_8(query, key, value, attention_mask, scaling, dropout):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output
|
||||
|
||||
|
||||
def _sfdp_replacement_8(query, key, value, attention_mask, scaling, dropout):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=dropout,
|
||||
is_causal=True,
|
||||
scale=scaling,
|
||||
)
|
||||
|
||||
|
||||
def _get_sfdp_patterns() -> List[Dict[str, Any]]:
|
||||
bs, seq_len, n_heads, hidden_size = 8, 16, 8, 512
|
||||
head_dim = hidden_size // n_heads
|
||||
|
||||
def common_tensor():
|
||||
return torch.randn(bs, n_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
def causal_mask():
|
||||
return torch.randn(bs, 1, 1, seq_len, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
configs = [
|
||||
(_sfdp_pattern_1, _sfdp_replacement_1, True, 0.1234743, 0.85849734),
|
||||
(_sfdp_pattern_2, _sfdp_replacement_2, False, 0.234743, 0.5849734),
|
||||
(_sfdp_pattern_3, _sfdp_replacement_3, True, 0.34743, 0.849734),
|
||||
(_sfdp_pattern_4, _sfdp_replacement_4, False, 0.74321, 0.9734),
|
||||
(_sfdp_pattern_5, _sfdp_replacement_5, False, 0.874321, 0.89734),
|
||||
(_sfdp_pattern_6, _sfdp_replacement_6, True, 0.634743, 0.6849734),
|
||||
(_sfdp_pattern_7, _sfdp_replacement_7, True, 0.34743, 0.849734),
|
||||
(_sfdp_pattern_8, _sfdp_replacement_8, True, 0.2234743, 0.95849734),
|
||||
]
|
||||
|
||||
patterns = []
|
||||
for search_fn, replace_fn, has_mask, scale, dropout in configs:
|
||||
dummy_args = [common_tensor(), common_tensor(), common_tensor()]
|
||||
if has_mask:
|
||||
dummy_args.append(causal_mask())
|
||||
dummy_args.extend([scale, dropout])
|
||||
|
||||
patterns.append(
|
||||
{
|
||||
"search_fn": search_fn,
|
||||
"replace_fn": replace_fn,
|
||||
"dummy_args": dummy_args,
|
||||
"scalar_workaround": {"scaling": scale, "dropout": dropout},
|
||||
"op_ignore_types": {torch.ops.aten.to.dtype: (torch.dtype,)},
|
||||
}
|
||||
)
|
||||
|
||||
return patterns
|
||||
|
||||
|
||||
def _grouped_attn_pattern_1(q, k, v, n_rep, attn_mask, dropout_p, scale):
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_1(q, k, v, n_rep, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
# Only expose torch_attention_grouped_sdpa after the transformation
|
||||
def _grouped_attn_pattern_2(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_2(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_pattern_3(q, k, v, n_rep, attn_mask, dropout_p, scale):
|
||||
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_3(q, k, v, n_rep, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
# Only expose torch_attention_grouped_sdpa after the transformation
|
||||
def _grouped_attn_pattern_4(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale):
|
||||
return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("match_repeat_kv")
|
||||
class MatchRepeatKV(BaseTransform):
|
||||
"""
|
||||
Match and replace the repeat_kv pattern with torch.ops.auto_deploy.torch_attention_repeat_kv.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
def register_repeat_kv(patterns: ADPatternMatcherPass):
|
||||
dummy_args = [
|
||||
torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16),
|
||||
7,
|
||||
]
|
||||
register_ad_pattern(
|
||||
search_fn=_repeat_kv_pattern,
|
||||
replace_fn=_repeat_kv_repl,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args,
|
||||
op_ignore_types={
|
||||
torch.ops.aten.reshape.default: (int,),
|
||||
torch.ops.aten.expand.default: (int,),
|
||||
},
|
||||
scalar_workaround={"n_rep": dummy_args[1]},
|
||||
)
|
||||
|
||||
num_kv_patterns = _apply_pattern(gm, "Repeat KV", register_repeat_kv)
|
||||
|
||||
if num_kv_patterns > 0:
|
||||
self.config.run_shape_prop = True
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_kv_patterns,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("match_eager_attention")
|
||||
class MatchEagerAttention(BaseTransform):
|
||||
"""
|
||||
Match and replace the eager attention pattern with torch.ops.auto_deploy.torch_attention_sdpa.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
def register_eager_attention(patterns: ADPatternMatcherPass):
|
||||
for pattern_config in _get_sfdp_patterns():
|
||||
register_ad_pattern(**pattern_config, patterns=patterns)
|
||||
|
||||
num_eager_patterns = _apply_pattern(gm, "Eager Attention", register_eager_attention)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_eager_patterns,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("match_grouped_attention")
|
||||
class MatchGroupedAttention(BaseTransform):
|
||||
"""
|
||||
Match and replace the grouped attention pattern with
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
def register_grouped_attention(patterns: ADPatternMatcherPass):
|
||||
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
|
||||
k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
|
||||
v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
|
||||
attn_mask = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16)
|
||||
dropout = 0.12345
|
||||
scale = 0.56789
|
||||
n_rep = 7
|
||||
|
||||
dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale]
|
||||
dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale]
|
||||
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_1,
|
||||
replace_fn=_grouped_attn_replacement_1,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_1,
|
||||
scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_2,
|
||||
replace_fn=_grouped_attn_replacement_2,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_2,
|
||||
scalar_workaround={
|
||||
"scale": scale,
|
||||
"dropout_p": dropout,
|
||||
},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_3,
|
||||
replace_fn=_grouped_attn_replacement_3,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_1,
|
||||
scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_grouped_attn_pattern_4,
|
||||
replace_fn=_grouped_attn_replacement_4,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args_2,
|
||||
scalar_workaround={
|
||||
"scale": scale,
|
||||
"dropout_p": dropout,
|
||||
},
|
||||
)
|
||||
|
||||
num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_grouped_patterns,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
class MatchAttentionLayoutConfig(TransformConfig):
|
||||
"""Configuration for the insert cached attention transform."""
|
||||
|
||||
attention_op: Type[AttentionDescriptor] = Field(description="The attention descriptor to use.")
|
||||
|
||||
|
||||
@TransformRegistry.register("match_attention_layout")
|
||||
class MatchAttentionLayout(BaseTransform):
|
||||
"""
|
||||
Match and transform attention operations to match the layout expected by the attention backend.
|
||||
|
||||
If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which
|
||||
is the default for SDPA operations, we don't need to transform anything.
|
||||
|
||||
If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert
|
||||
appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa.
|
||||
"""
|
||||
|
||||
config: MatchAttentionLayoutConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return MatchAttentionLayoutConfig
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# Get attention layout from attention_op
|
||||
attention_layout = self.config.attention_op.get_attention_layout()
|
||||
|
||||
# List of SDPA operations to look for
|
||||
sdpa_ops = {
|
||||
torch.ops.auto_deploy.torch_attention_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
|
||||
}
|
||||
|
||||
graph = gm.graph
|
||||
num_bsnd_patterns = 0
|
||||
|
||||
# Look for SDPA operations
|
||||
for sdpa_node in list(graph.nodes):
|
||||
if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops):
|
||||
continue
|
||||
|
||||
ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}")
|
||||
|
||||
# Extract q, k, v inputs
|
||||
q, k, v = sdpa_node.args[:3]
|
||||
|
||||
# Check if we need to transpose the inputs
|
||||
if attention_layout == "bsnd":
|
||||
# Add transposes before the node (from bnsd to bsnd)
|
||||
with graph.inserting_before(sdpa_node):
|
||||
q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2))
|
||||
k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2))
|
||||
v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2))
|
||||
|
||||
# Preserve fake tensor in meta["val"] for the transposed inputs
|
||||
q_updated.meta["val"] = q.meta["val"].transpose(1, 2)
|
||||
k_updated.meta["val"] = k.meta["val"].transpose(1, 2)
|
||||
v_updated.meta["val"] = v.meta["val"].transpose(1, 2)
|
||||
elif attention_layout == "bnsd":
|
||||
# we don't need to do anything...
|
||||
q_updated = q
|
||||
k_updated = k
|
||||
v_updated = v
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention layout: {attention_layout}")
|
||||
|
||||
# Create bsnd_grouped_sdpa node with the same args as the original node
|
||||
# but using the transposed inputs
|
||||
with graph.inserting_before(sdpa_node):
|
||||
source_sdpa_node = graph.call_function(
|
||||
self.config.attention_op.get_source_attention_op(),
|
||||
args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:],
|
||||
kwargs=sdpa_node.kwargs,
|
||||
)
|
||||
|
||||
# Check if need to update the output node to match the layout
|
||||
if attention_layout == "bsnd":
|
||||
# Add transpose for the output (from bsnd back to bnsd)
|
||||
with graph.inserting_after(source_sdpa_node):
|
||||
output_updated = graph.call_function(
|
||||
torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2)
|
||||
)
|
||||
|
||||
# Preserve fake tensor in meta["val"] for the transposed inputs
|
||||
source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous()
|
||||
output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2)
|
||||
elif attention_layout == "bnsd":
|
||||
output_updated = source_sdpa_node
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention layout: {attention_layout}")
|
||||
|
||||
# Replace the old node with the transposed output
|
||||
sdpa_node.replace_all_uses_with(output_updated)
|
||||
|
||||
num_bsnd_patterns += 1
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_bsnd_patterns,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
@ -1,11 +1,12 @@
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Any, Dict
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import (
|
||||
extract_param_names_from_lin_node,
|
||||
get_quantization_params_from_linear_node,
|
||||
@ -20,7 +21,7 @@ from ...utils.quantization_utils import (
|
||||
remove_output_quantizers,
|
||||
should_skip_quantization,
|
||||
)
|
||||
from .._graph import canonicalize_graph
|
||||
from ..interface import BaseTransform, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
def _insert_quantized_linear(
|
||||
@ -138,12 +139,8 @@ def _insert_quantized_bmm(
|
||||
scale_target_module = gm # Register in root module
|
||||
scale_name_prefix = ""
|
||||
|
||||
ad_logger.info(f"Quantized BMM with dynamic weight tensor for node {node}")
|
||||
else:
|
||||
# If we can't determine the shape, skip quantization
|
||||
ad_logger.warning(
|
||||
f"BMM weight is dynamic tensor without shape metadata, skipping quantization for node {node}"
|
||||
)
|
||||
return
|
||||
|
||||
# Common logic for both parameter and dynamic tensor cases
|
||||
@ -169,53 +166,70 @@ def _insert_quantized_bmm(
|
||||
node.args = (*node.args, *scale_values)
|
||||
|
||||
|
||||
def quantize(gm: GraphModule, quant_config: Dict[str, Any]) -> None:
|
||||
"""Quantize the GraphModule and replace linear with quantized linear."""
|
||||
# extract info from quant_config
|
||||
is_quant_graph = is_quantized_graph(gm)
|
||||
quant_algo = quant_config.get("quant_algo")
|
||||
excluded_patterns = quant_config.get("exclude_modules", [])
|
||||
@TransformRegistry.register("quantize")
|
||||
class Quantization(BaseTransform):
|
||||
"""Quantize the GraphModule and replace linear/BMM with quantized linear/BMM."""
|
||||
|
||||
# no quantization to do
|
||||
if not (is_quant_graph or quant_config):
|
||||
ad_logger.info("No quantization to do.")
|
||||
return
|
||||
|
||||
# tracking quantized operations in the graph
|
||||
quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
for n in gm.graph.nodes:
|
||||
if should_skip_quantization(n, excluded_patterns):
|
||||
continue
|
||||
|
||||
# Process linear operations
|
||||
if is_linear_op(n, include_quantization=False):
|
||||
# get per-layer quantization format from the node
|
||||
quant_algo_n: str = (
|
||||
get_quantization_from_linear_node(n) if is_quant_graph else quant_algo
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# extract info from quant_config
|
||||
quant_config = factory.get_quant_config()
|
||||
if not quant_config:
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
if not quant_algo_n:
|
||||
|
||||
is_quant_graph = is_quantized_graph(gm)
|
||||
quant_algo = quant_config.get("quant_algo")
|
||||
excluded_patterns = quant_config.get("exclude_modules", [])
|
||||
if not quant_algo:
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
# tracking quantized operations in the graph
|
||||
quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
for n in gm.graph.nodes:
|
||||
if should_skip_quantization(n, excluded_patterns):
|
||||
continue
|
||||
|
||||
# insert quantized linear node
|
||||
_insert_quantized_linear(gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph)
|
||||
quantized_nodes[quant_algo_n]["linear"] += 1
|
||||
# Process linear operations
|
||||
if is_linear_op(n, include_quantization=False):
|
||||
# get per-layer quantization format from the node
|
||||
quant_algo_n: str = (
|
||||
get_quantization_from_linear_node(n) if is_quant_graph else quant_algo
|
||||
)
|
||||
if not quant_algo_n:
|
||||
continue
|
||||
|
||||
# Process BMM operations
|
||||
elif is_bmm_op(n):
|
||||
if not quant_algo:
|
||||
continue
|
||||
# insert quantized linear node
|
||||
_insert_quantized_linear(
|
||||
gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph
|
||||
)
|
||||
quantized_nodes[quant_algo_n]["linear"] += 1
|
||||
|
||||
# insert quantized bmm node
|
||||
_insert_quantized_bmm(
|
||||
gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph
|
||||
)
|
||||
quantized_nodes[quant_algo]["bmm"] += 1
|
||||
# Process BMM operations
|
||||
elif is_bmm_op(n):
|
||||
if not quant_algo:
|
||||
continue
|
||||
|
||||
if is_quant_graph:
|
||||
remove_output_quantizers(gm)
|
||||
# insert quantized bmm node
|
||||
_insert_quantized_bmm(
|
||||
gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph
|
||||
)
|
||||
quantized_nodes[quant_algo]["bmm"] += 1
|
||||
|
||||
canonicalize_graph(gm)
|
||||
for quant_algo in quantized_nodes:
|
||||
for op_type, count in quantized_nodes[quant_algo].items():
|
||||
ad_logger.info(f"Found {count} {quant_algo} quantized {op_type} nodes.")
|
||||
ad_logger.debug("After quantization: " + str(gm))
|
||||
if is_quant_graph:
|
||||
remove_output_quantizers(gm)
|
||||
|
||||
num_matches = 0
|
||||
for quant_algo in quantized_nodes:
|
||||
for op_type, count in quantized_nodes[quant_algo].items():
|
||||
num_matches += count
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
|
||||
)
|
||||
|
||||
return gm, info
|
||||
@ -1,14 +1,15 @@
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import is_op
|
||||
from ...utils.quantization_utils import QuantizationImpl, should_skip_quantization
|
||||
from .._graph import canonicalize_graph
|
||||
from ..interface import BaseTransform, TransformInfo, TransformRegistry
|
||||
|
||||
quantized_moe_op_map = {
|
||||
"FP8": torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
@ -92,47 +93,10 @@ def _quantize_moe_node(
|
||||
quantized_op,
|
||||
args=tuple(args),
|
||||
)
|
||||
ad_logger.debug(f"Updating {node.name} args to {new_node.args}")
|
||||
node.replace_all_uses_with(new_node)
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
|
||||
def quantize_moe(gm: GraphModule, quant_config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the
|
||||
quantized version using the quant_algo from quant_config.
|
||||
"""
|
||||
quant_algo = quant_config.get("quant_algo")
|
||||
if not quant_algo:
|
||||
ad_logger.info("No quantization to do.")
|
||||
return gm
|
||||
excluded_patterns = quant_config.get("exclude_modules", [])
|
||||
|
||||
quant_impl = QuantizationImpl.create(quant_algo)
|
||||
quantized_op = quantized_moe_op_map[quant_algo]
|
||||
|
||||
count = 0
|
||||
|
||||
for node in list(gm.graph.nodes):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
# Check that all expert weights should be quantized
|
||||
w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node)
|
||||
if any(
|
||||
should_skip_quantization(n, excluded_patterns)
|
||||
for n in w1_names + w2_names + w3_names
|
||||
):
|
||||
continue
|
||||
_quantize_moe_node(gm, node, quant_impl, quantized_op)
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
return gm
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {count} {quant_algo} quantized {quantized_op} nodes.")
|
||||
return
|
||||
|
||||
|
||||
# TODO(Fridah-nv): robust handling similar to `extract_param_names_from_lin_node` or expand it
|
||||
def _extract_moe_weight_param_lists(moe_node: Node) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""
|
||||
@ -165,3 +129,51 @@ def _extract_moe_weight_param_lists(moe_node: Node) -> Tuple[List[str], List[str
|
||||
w3_names = _unwrap_list(w3_list)
|
||||
|
||||
return w1_names, w2_names, w3_names
|
||||
|
||||
|
||||
@TransformRegistry.register("quantize_moe")
|
||||
class QuantizeMOE(BaseTransform):
|
||||
"""
|
||||
Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the
|
||||
quantized version using the quant_algo from quant_config.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
quant_config = factory.get_quant_config()
|
||||
quant_algo = quant_config.get("quant_algo") if quant_config else None
|
||||
|
||||
if not quant_config or not quant_algo:
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
excluded_patterns = quant_config.get("exclude_modules", [])
|
||||
|
||||
quant_impl = QuantizationImpl.create(quant_algo)
|
||||
quantized_op = quantized_moe_op_map[quant_algo]
|
||||
|
||||
count = 0
|
||||
|
||||
for node in list(gm.graph.nodes):
|
||||
if is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
# Check that all expert weights should be quantized
|
||||
w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node)
|
||||
if any(
|
||||
should_skip_quantization(n, excluded_patterns)
|
||||
for n in w1_names + w2_names + w3_names
|
||||
):
|
||||
continue
|
||||
_quantize_moe_node(gm, node, quant_impl, quantized_op)
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
return gm, TransformInfo(
|
||||
skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=count, is_clean=False, has_valid_shapes=False
|
||||
)
|
||||
|
||||
return gm, info
|
||||
@ -96,23 +96,24 @@ def named_graphmodules(gm: fx.GraphModule) -> Iterator[Tuple[str, fx.GraphModule
|
||||
yield name, m
|
||||
|
||||
|
||||
def _move_single_gm_to_device(
|
||||
gm: GraphModule, device: torch.device, recompile_graph: bool = False
|
||||
) -> None:
|
||||
def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None:
|
||||
"""Move one GraphModule and its nodes to the specified device in-place.
|
||||
Partially inspired by https://github.com/pytorch/pytorch/blob/05cb98f91d49df9eadfcb3fc29bbd1b621d88860/torch/export/passes/__init__.py#L11
|
||||
"""
|
||||
# move state dict
|
||||
gm.to(device)
|
||||
recompile_graph = False
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
# move all the nodes kwargs with burnt-in device
|
||||
if "device" in node.kwargs:
|
||||
recompile_graph = True
|
||||
kwargs = node.kwargs.copy()
|
||||
kwargs["device"] = device
|
||||
node.kwargs = kwargs
|
||||
|
||||
if is_op(node, torch.ops.aten.to.device):
|
||||
recompile_graph = True
|
||||
args = list(node.args)
|
||||
args[1] = device
|
||||
node.args = tuple(args)
|
||||
@ -135,7 +136,7 @@ def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> fx.GraphModule
|
||||
|
||||
for _, subgm in reversed(list(named_graphmodules(gm))):
|
||||
# recompile graph to update self generated codes in subgraph
|
||||
_move_single_gm_to_device(subgm, device, subgm is not gm)
|
||||
_move_single_gm_to_device(subgm, device)
|
||||
|
||||
|
||||
def _is_impure_node(node: Node) -> bool:
|
||||
|
||||
@ -1,13 +1,10 @@
|
||||
"""A library of transformation passes."""
|
||||
|
||||
from .attention import *
|
||||
from .collectives import *
|
||||
from .eliminate_redundant_transposes import *
|
||||
from .fused_moe import *
|
||||
from .fusion import *
|
||||
from .kvcache import *
|
||||
from .quantization import *
|
||||
from .quantize_moe import *
|
||||
from .rms_norm import *
|
||||
from .rope import *
|
||||
from .sharding import *
|
||||
|
||||
@ -1,833 +0,0 @@
|
||||
"""Pattern matching for detecting repeat_kv pattern from Huggingface models."""
|
||||
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...custom_ops.attention_interface import AttentionDescriptor
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import is_op
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
def match_repeat_kv(gm: GraphModule) -> None:
|
||||
"""
|
||||
Match and replace the repeat_kv pattern in fx graphs.
|
||||
|
||||
The pattern is:
|
||||
unsqueeze -> expand -> reshape -> [optional] contiguous
|
||||
|
||||
This is replaced with torch.ops.auto_deploy.torch_attention_repeat_kv.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
num_kv_patterns = 0
|
||||
|
||||
# Iterate through nodes in the graph
|
||||
for node in list(graph.nodes):
|
||||
# Look for reshape nodes that could be the end of our pattern
|
||||
if is_op(node, torch.ops.aten.reshape):
|
||||
match_info = _match_repeat_kv_pattern(node)
|
||||
if match_info:
|
||||
ad_logger.debug(f"Found repeat_kv pattern at {node}")
|
||||
_replace_with_repeat_kv(graph, match_info)
|
||||
num_kv_patterns += 1
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_kv_patterns:
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_kv_patterns} repeat_kv patterns")
|
||||
|
||||
|
||||
def match_eager_attention(gm: GraphModule) -> None:
|
||||
"""
|
||||
Match and replace the eager attention pattern in fx graphs.
|
||||
|
||||
The pattern is:
|
||||
transpose -> matmul -> mul -> (optional) add -> softmax -> to -> dropout -> matmul
|
||||
|
||||
This is replaced with torch.ops.auto_deploy.torch_attention_sdpa.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
# Track replacements to avoid processing nodes multiple times
|
||||
num_eager_patterns = 0
|
||||
|
||||
# Iterate through nodes in the graph
|
||||
for node in list(graph.nodes):
|
||||
# Look for the final matmul nodes that could be part of our pattern
|
||||
if is_op(node, torch.ops.aten.matmul):
|
||||
match_info = _match_eager_attention_pattern(node)
|
||||
if match_info:
|
||||
ad_logger.debug(f"Found eager attention pattern at {node}")
|
||||
_replace_with_sdpa(graph, match_info)
|
||||
num_eager_patterns += 1
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_eager_patterns:
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_eager_patterns} eager attention patterns")
|
||||
|
||||
|
||||
def match_grouped_attention(gm: GraphModule) -> None:
|
||||
"""
|
||||
Match and replace the grouped attention pattern in fx graphs.
|
||||
|
||||
The pattern is:
|
||||
repeat_kv(k, n_rep) ->
|
||||
repeat_kv(v, n_rep) ->
|
||||
sdpa(q, repeated_k, repeated_v)
|
||||
|
||||
This is replaced with torch.ops.auto_deploy.torch_attention_grouped_sdpa.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
# Track replacements to avoid processing nodes multiple times
|
||||
num_grouped_patterns = 0
|
||||
|
||||
# Iterate through nodes in the graph
|
||||
for node in list(graph.nodes):
|
||||
# Look for SDPA nodes that could be part of our pattern
|
||||
if is_op(node, torch.ops.auto_deploy.torch_attention_sdpa):
|
||||
match_info = _match_grouped_attention_pattern(node)
|
||||
if match_info:
|
||||
ad_logger.debug(f"Found grouped attention pattern at {node}")
|
||||
_replace_with_grouped_sdpa(graph, match_info)
|
||||
num_grouped_patterns += 1
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_grouped_patterns:
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_grouped_patterns} grouped attention patterns")
|
||||
|
||||
|
||||
def match_causal_attn_mask(gm: GraphModule) -> None:
|
||||
"""
|
||||
Match attention operations with causal attention masks and optimize them.
|
||||
|
||||
For operations that use explicit causal masks, this replaces:
|
||||
- sdpa(q, k, v, causal_mask, dropout_p, False, scale)
|
||||
with:
|
||||
- sdpa(q, k, v, None, dropout_p, True, scale)
|
||||
|
||||
This optimization enables more efficient implementations on supported backends.
|
||||
"""
|
||||
graph = gm.graph
|
||||
|
||||
# Track replacements to avoid processing nodes multiple times
|
||||
num_causal_patterns = 0
|
||||
|
||||
# Iterate through nodes in the graph
|
||||
for node in list(graph.nodes):
|
||||
# Look for SDPA nodes or grouped SDPA nodes
|
||||
if not (
|
||||
is_op(node, torch.ops.auto_deploy.torch_attention_sdpa)
|
||||
or is_op(node, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
):
|
||||
continue
|
||||
|
||||
# Get the attention mask argument (4th argument)
|
||||
if len(node.args) < 4 or node.args[3] is None:
|
||||
continue
|
||||
|
||||
attn_mask = node.args[3]
|
||||
|
||||
# Check if this mask is a causal mask
|
||||
if not _is_causal_mask(attn_mask):
|
||||
ad_logger.debug(f"Found non-causal attention mask at {node=}!")
|
||||
continue
|
||||
|
||||
ad_logger.debug(f"Found causal attention mask at {node}")
|
||||
|
||||
# construct the new args list with args provided to the node and the default values otherwise
|
||||
new_args = []
|
||||
for idx, arg in enumerate(node.target._schema.arguments):
|
||||
# In case arg is provided to the node, use it
|
||||
if idx < len(node.args):
|
||||
new_args.append(node.args[idx])
|
||||
# In case arg is not provided to the node, use the default value
|
||||
elif arg.has_default_value:
|
||||
new_args.append(arg.default_value)
|
||||
else:
|
||||
raise ValueError(f"Missing required argument: {arg.name}")
|
||||
|
||||
# Create new arguments with None mask and is_causal=True
|
||||
new_args[3] = None # Set mask to None
|
||||
new_args[5] = True # Set is_causal to True
|
||||
|
||||
# Create new node with updated arguments
|
||||
with graph.inserting_before(node):
|
||||
new_node = graph.call_function(node.target, args=tuple(new_args), kwargs=node.kwargs)
|
||||
|
||||
# Preserve metadata
|
||||
new_node.meta = node.meta.copy()
|
||||
|
||||
# Replace the old node with the new one
|
||||
node.replace_all_uses_with(new_node)
|
||||
|
||||
num_causal_patterns += 1
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_causal_patterns:
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_causal_patterns} causal mask attention patterns")
|
||||
|
||||
|
||||
def _match_repeat_kv_pattern(reshape_node: Node) -> Optional[Dict[str, Node]]:
|
||||
"""
|
||||
Match the repeat_kv pattern starting from a reshape node.
|
||||
|
||||
The pattern is:
|
||||
unsqueeze -> expand -> reshape -> [optional] contiguous
|
||||
|
||||
Returns a dictionary with information about the match or None if no match.
|
||||
"""
|
||||
# Check that reshape_node is a reshape operation
|
||||
if not is_op(reshape_node, torch.ops.aten.reshape):
|
||||
return None
|
||||
|
||||
# The reshape should have expand as its first argument
|
||||
if len(reshape_node.args) < 1:
|
||||
return None
|
||||
|
||||
expand_node = reshape_node.args[0]
|
||||
if not is_op(expand_node, torch.ops.aten.expand):
|
||||
return None
|
||||
|
||||
# The expand should have unsqueeze as its first argument
|
||||
if len(expand_node.args) < 1:
|
||||
return None
|
||||
|
||||
unsqueeze_node = expand_node.args[0]
|
||||
if not is_op(unsqueeze_node, torch.ops.aten.unsqueeze):
|
||||
return None
|
||||
|
||||
# The unsqueeze should be inserting a dimension at position 2
|
||||
if len(unsqueeze_node.args) < 2 or unsqueeze_node.args[1] != 2:
|
||||
return None
|
||||
|
||||
# Get the input tensor to unsqueeze
|
||||
if len(unsqueeze_node.args) < 1:
|
||||
return None
|
||||
|
||||
input_tensor = unsqueeze_node.args[0]
|
||||
|
||||
# Check input dimensions - should be 4D (batch, num_key_value_heads, seq_len, head_dim)
|
||||
input_val = input_tensor.meta.get("val", None)
|
||||
if input_val is None or len(input_val.shape) != 4:
|
||||
return None
|
||||
|
||||
# Extract batch size, num_kv_heads, seq_len, and head_dim from the input tensor shape
|
||||
batch_size, num_kv_heads, seq_len, head_dim = input_val.shape
|
||||
|
||||
# Check reshape args
|
||||
if len(reshape_node.args) < 2 or not isinstance(reshape_node.args[1], list):
|
||||
return None
|
||||
|
||||
reshape_args = reshape_node.args[1]
|
||||
if len(reshape_args) != 4:
|
||||
return None
|
||||
|
||||
# Check expand args
|
||||
if len(expand_node.args) < 2 or not isinstance(expand_node.args[1], list):
|
||||
return None
|
||||
|
||||
expand_args = expand_node.args[1]
|
||||
if len(expand_args) != 5:
|
||||
return None
|
||||
|
||||
# Determine n_rep by comparing the output and input head dimensions
|
||||
# In the expand args, we should have [batch, num_kv_heads, n_rep, seq_len, head_dim]
|
||||
# In the reshape args, we should have [batch, num_heads, seq_len, head_dim]
|
||||
# where num_heads = num_kv_heads * n_rep
|
||||
_, _, n_rep, _, _ = expand_args
|
||||
_, reshape_num_heads, _, _ = reshape_args
|
||||
|
||||
# Check that n_rep is an integer
|
||||
if not isinstance(n_rep, int):
|
||||
return None
|
||||
|
||||
# Check that num_heads = num_kv_heads * n_rep
|
||||
# This may be a symbolic expression, so we need to compare with caution
|
||||
reshape_out_val = reshape_node.meta.get("val", None)
|
||||
if reshape_out_val is None or len(reshape_out_val.shape) != 4:
|
||||
return None
|
||||
|
||||
# Ensure output shape is correct
|
||||
out_batch, out_heads, out_seq, out_dim = reshape_out_val.shape
|
||||
|
||||
# Check that input batch and seq dimensions match output
|
||||
if out_batch != batch_size or out_seq != seq_len or out_dim != head_dim:
|
||||
return None
|
||||
|
||||
# Check if reshape is followed by a contiguous node
|
||||
contiguous_node = None
|
||||
users = list(reshape_node.users)
|
||||
|
||||
# Only consider contiguous if reshape has exactly one user
|
||||
if len(users) == 1 and is_op(users[0], torch.ops.aten.contiguous):
|
||||
contiguous_node = users[0]
|
||||
|
||||
result = {
|
||||
"input_tensor": input_tensor,
|
||||
"unsqueeze_node": unsqueeze_node,
|
||||
"expand_node": expand_node,
|
||||
"reshape_node": reshape_node,
|
||||
"n_rep": n_rep,
|
||||
}
|
||||
|
||||
if contiguous_node:
|
||||
result["contiguous_node"] = contiguous_node
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _match_eager_attention_pattern(final_matmul_node: Node) -> Optional[Dict[str, Node]]:
|
||||
"""
|
||||
Match the eager attention pattern starting from the final matmul node.
|
||||
|
||||
The pattern is:
|
||||
transpose -> matmul -> mul/div -> (optional) add -> (optional) to -> softmax -> (optional) to -> dropout -> matmul
|
||||
|
||||
Returns a dictionary with information about the match or None if no match.
|
||||
"""
|
||||
# Check that final_matmul_node is a matmul operation
|
||||
if not is_op(final_matmul_node, torch.ops.aten.matmul):
|
||||
return None
|
||||
|
||||
# Check we have two arguments
|
||||
if len(final_matmul_node.args) < 2:
|
||||
return None
|
||||
|
||||
# The first arg of final matmul should be dropout
|
||||
dropout_node = final_matmul_node.args[0]
|
||||
if not is_op(dropout_node, torch.ops.aten.dropout):
|
||||
return None
|
||||
|
||||
# The second arg of final matmul is the value tensor (possibly repeated/transformed)
|
||||
value = final_matmul_node.args[1]
|
||||
|
||||
# The dropout should have a to_dtype node (or directly softmax) as input
|
||||
if len(dropout_node.args) < 1:
|
||||
return None
|
||||
|
||||
# Allow optional to_dtype node after softmax
|
||||
to_dtype_after_softmax = dropout_node.args[0]
|
||||
if is_op(to_dtype_after_softmax, torch.ops.aten.to):
|
||||
if len(to_dtype_after_softmax.args) < 1:
|
||||
return None
|
||||
softmax_node = to_dtype_after_softmax.args[0]
|
||||
else:
|
||||
softmax_node = to_dtype_after_softmax
|
||||
|
||||
# Now we should have a softmax node
|
||||
if not is_op(softmax_node, torch.ops.aten.softmax):
|
||||
return None
|
||||
|
||||
# The softmax should have dim=-1 (may be specified in different ways)
|
||||
if len(softmax_node.args) < 2 or (
|
||||
isinstance(softmax_node.args[1], int) and softmax_node.args[1] != -1
|
||||
):
|
||||
# Check kwargs if not in args
|
||||
if softmax_node.kwargs.get("dim", -1) != -1:
|
||||
return None
|
||||
|
||||
# The softmax node's input can be:
|
||||
# - direct from add/mul/div
|
||||
# - or through a to_dtype node (like to_35 in the example)
|
||||
if len(softmax_node.args) < 1:
|
||||
return None
|
||||
|
||||
# Handle optional to_dtype node before softmax
|
||||
prev_node = softmax_node.args[0]
|
||||
if is_op(prev_node, torch.ops.aten.to):
|
||||
if len(prev_node.args) < 1:
|
||||
return None
|
||||
prev_node = prev_node.args[0]
|
||||
|
||||
# Check for attention mask pattern (add node)
|
||||
if is_op(prev_node, torch.ops.aten.add):
|
||||
add_node = prev_node
|
||||
attn_mask = add_node.args[1] # Second arg is the mask
|
||||
|
||||
# The add should have a mul or div node as its first argument
|
||||
if len(add_node.args) < 1:
|
||||
return None
|
||||
|
||||
scaling_node = add_node.args[0]
|
||||
if not (is_op(scaling_node, torch.ops.aten.mul) or is_op(scaling_node, torch.ops.aten.div)):
|
||||
return None
|
||||
elif is_op(prev_node, torch.ops.aten.mul) or is_op(prev_node, torch.ops.aten.div):
|
||||
# No mask case - the softmax input is directly the mul or div node
|
||||
scaling_node = prev_node
|
||||
attn_mask = None
|
||||
else:
|
||||
return None
|
||||
|
||||
# Check the scaling operation and extract the scaling factor
|
||||
is_division = is_op(scaling_node, torch.ops.aten.div)
|
||||
|
||||
# The mul/div node should have a matmul node as input
|
||||
if len(scaling_node.args) < 2:
|
||||
return None
|
||||
|
||||
# Extract the scaling factor, adjusting for division vs multiplication
|
||||
scale = scaling_node.args[1]
|
||||
# Allow for constant or tensor scale
|
||||
if not isinstance(scale, (float, int, Node)):
|
||||
return None
|
||||
|
||||
# For division, we need to invert the scaling factor if it's a constant
|
||||
if is_division and isinstance(scale, (float, int)):
|
||||
scale = 1.0 / scale
|
||||
|
||||
first_matmul_node = scaling_node.args[0]
|
||||
if not is_op(first_matmul_node, torch.ops.aten.matmul):
|
||||
return None
|
||||
|
||||
# The first matmul should have the query and key transpose as inputs
|
||||
if len(first_matmul_node.args) < 2:
|
||||
return None
|
||||
|
||||
query = first_matmul_node.args[0]
|
||||
transpose_key = first_matmul_node.args[1]
|
||||
|
||||
# Check for transpose, could be any dimensions
|
||||
if not is_op(transpose_key, torch.ops.aten.transpose):
|
||||
return None
|
||||
|
||||
# The transpose should have the key as input
|
||||
if len(transpose_key.args) < 1:
|
||||
return None
|
||||
|
||||
key = transpose_key.args[0]
|
||||
|
||||
# Create the match info dictionary
|
||||
match_info = {
|
||||
"query": query,
|
||||
"key": key,
|
||||
"value": value,
|
||||
"scale": scale,
|
||||
"dropout_p": dropout_node.args[1] if len(dropout_node.args) > 1 else 0.0,
|
||||
"final_matmul": final_matmul_node,
|
||||
}
|
||||
|
||||
# Add the attention mask if it exists
|
||||
if attn_mask is not None:
|
||||
match_info["attn_mask"] = attn_mask
|
||||
|
||||
return match_info
|
||||
|
||||
|
||||
def _match_grouped_attention_pattern(sdpa_node: Node) -> Optional[Dict[str, Node]]:
|
||||
"""
|
||||
Match the grouped attention pattern starting from an SDPA node.
|
||||
|
||||
The pattern is:
|
||||
repeat_kv(k, n_rep) ->
|
||||
repeat_kv(v, n_rep) ->
|
||||
sdpa(q, repeated_k, repeated_v)
|
||||
|
||||
Returns a dictionary with information about the match or None if no match.
|
||||
"""
|
||||
# Check that sdpa_node is an SDPA operation
|
||||
if not is_op(sdpa_node, torch.ops.auto_deploy.torch_attention_sdpa):
|
||||
return None
|
||||
|
||||
# SDPA should have query, key, value as its first three arguments
|
||||
if len(sdpa_node.args) < 3:
|
||||
return None
|
||||
|
||||
query, key_repeated, value_repeated = sdpa_node.args[0:3]
|
||||
|
||||
# Key and value should come from repeat_kv operations
|
||||
if not is_op(key_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv) or not is_op(
|
||||
value_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv
|
||||
):
|
||||
return None
|
||||
|
||||
# Extract the original key, value, and n_rep
|
||||
orig_key = key_repeated.args[0]
|
||||
orig_value = value_repeated.args[0]
|
||||
key_n_rep = key_repeated.args[1]
|
||||
value_n_rep = value_repeated.args[1]
|
||||
|
||||
# Both repeat_kv operations should have the same n_rep
|
||||
if key_n_rep != value_n_rep:
|
||||
return None
|
||||
|
||||
# Return the match information
|
||||
return {
|
||||
"query": query,
|
||||
"key": orig_key,
|
||||
"value": orig_value,
|
||||
"key_repeated": key_repeated,
|
||||
"value_repeated": value_repeated,
|
||||
"n_rep": key_n_rep,
|
||||
"sdpa_node": sdpa_node,
|
||||
}
|
||||
|
||||
|
||||
def _replace_with_repeat_kv(graph, match_info: Dict[str, Node]) -> None:
|
||||
"""
|
||||
Replace the matched repeat_kv pattern with the custom op.
|
||||
"""
|
||||
input_tensor = match_info["input_tensor"]
|
||||
reshape_node = match_info["reshape_node"]
|
||||
n_rep = match_info["n_rep"]
|
||||
|
||||
# Determine the node to replace (either reshape or contiguous if present)
|
||||
node_to_replace = match_info.get("contiguous_node", reshape_node)
|
||||
|
||||
with graph.inserting_before(node_to_replace):
|
||||
repeat_kv_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_attention_repeat_kv, args=(input_tensor, n_rep)
|
||||
)
|
||||
|
||||
# Preserve metadata from the original node
|
||||
repeat_kv_node.meta = node_to_replace.meta.copy()
|
||||
|
||||
# Replace all uses of the node with the repeat_kv node
|
||||
node_to_replace.replace_all_uses_with(repeat_kv_node)
|
||||
|
||||
|
||||
def _replace_with_sdpa(graph, match_info: Dict[str, Node]) -> None:
|
||||
"""
|
||||
Replace the matched eager attention pattern with scaled_dot_product_attention.
|
||||
"""
|
||||
# retrieve the default op for scaled_dot_product_attention
|
||||
sdpa_op = torch.ops.auto_deploy.torch_attention_sdpa.default
|
||||
|
||||
# construct the args for the ops based on the match_info and the op's schema
|
||||
args = []
|
||||
for arg in sdpa_op._schema.arguments:
|
||||
if arg.name in match_info:
|
||||
args.append(match_info[arg.name])
|
||||
elif arg.has_default_value:
|
||||
args.append(arg.default_value)
|
||||
else:
|
||||
raise ValueError(f"Missing required argument: {arg.name}")
|
||||
args = tuple(args)
|
||||
|
||||
# retrieve the final matmul node to know where to insert the sdpa node
|
||||
final_matmul = match_info["final_matmul"]
|
||||
|
||||
with graph.inserting_before(final_matmul):
|
||||
sdpa_node = graph.call_function(sdpa_op, args=args)
|
||||
|
||||
# Preserve metadata from the original node
|
||||
sdpa_node.meta = final_matmul.meta.copy()
|
||||
|
||||
# Replace all uses of the final matmul node with the sdpa node
|
||||
final_matmul.replace_all_uses_with(sdpa_node)
|
||||
|
||||
|
||||
def _replace_with_grouped_sdpa(graph, match_info: Dict[str, Node]) -> None:
|
||||
"""
|
||||
Replace the matched grouped attention pattern with torch.ops.auto_deploy.torch_attention_grouped_sdpa.
|
||||
"""
|
||||
sdpa_node = match_info["sdpa_node"]
|
||||
query = match_info["query"]
|
||||
key = match_info["key"]
|
||||
value = match_info["value"]
|
||||
|
||||
# Construct the new args and kwargs
|
||||
args = (query, key, value) + sdpa_node.args[3:]
|
||||
kwargs = sdpa_node.kwargs.copy()
|
||||
|
||||
with graph.inserting_before(sdpa_node):
|
||||
grouped_sdpa_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa.default, args=args, kwargs=kwargs
|
||||
)
|
||||
|
||||
# Preserve metadata from the original node
|
||||
grouped_sdpa_node.meta = sdpa_node.meta.copy()
|
||||
|
||||
# Replace all uses of the SDPA node with the grouped_sdpa node
|
||||
sdpa_node.replace_all_uses_with(grouped_sdpa_node)
|
||||
|
||||
|
||||
def _is_causal_mask(mask_node: Node) -> bool:
|
||||
"""
|
||||
Determine if a node represents a causal attention mask.
|
||||
|
||||
Causal masks typically involve:
|
||||
1. Creating a matrix with very negative values (e.g., -inf or close to it)
|
||||
2. Using triu with offset 1 to create an upper triangular matrix
|
||||
3. Usually involves comparison operations (gt, lt) with position indices
|
||||
|
||||
Returns True if the node appears to be a causal mask pattern.
|
||||
"""
|
||||
# Direct pattern from the test case: masked_fill with triu(ones,1) and -inf
|
||||
if is_op(mask_node, torch.ops.aten.masked_fill):
|
||||
mask_args = mask_node.args
|
||||
if len(mask_args) >= 2:
|
||||
_ = mask_args[0] # zero tensor
|
||||
mask_tensor = mask_args[1]
|
||||
fill_value = mask_args[2] if len(mask_args) > 2 else mask_node.kwargs.get("value", None)
|
||||
|
||||
# Check if fill value is very negative (e.g., -inf)
|
||||
if fill_value is not None and (
|
||||
fill_value == float("-inf")
|
||||
or (isinstance(fill_value, (int, float)) and fill_value < -1e4)
|
||||
):
|
||||
# Try to trace back to find a triu pattern
|
||||
if _has_triu_ancestor(mask_tensor, offset=1):
|
||||
return True
|
||||
|
||||
# Pattern from negative_fill test case: masked_fill with ~triu(ones,1) and 0.0
|
||||
# The negative_fill pattern has a pre-filled tensor with very negative values
|
||||
# and zeros in the lower triangle
|
||||
if is_op(mask_node, torch.ops.aten.masked_fill):
|
||||
mask_args = mask_node.args
|
||||
if len(mask_args) >= 2:
|
||||
negative_tensor = mask_args[0]
|
||||
mask_tensor = mask_args[1]
|
||||
fill_value = mask_args[2] if len(mask_args) > 2 else mask_node.kwargs.get("value", None)
|
||||
|
||||
# Check if fill value is zero and the tensor is pre-filled with negative values
|
||||
if fill_value == 0.0 or fill_value == 0:
|
||||
# Check for the full tensor with negative values
|
||||
if is_op(negative_tensor, torch.ops.aten.full):
|
||||
fill_args = negative_tensor.args
|
||||
if (
|
||||
len(fill_args) > 1
|
||||
and isinstance(fill_args[1], (int, float))
|
||||
and fill_args[1] < -1e4
|
||||
):
|
||||
# This is likely a negative-filled tensor
|
||||
# Now check if the mask is a bitwise_not of triu
|
||||
if is_op(mask_tensor, torch.ops.aten.bitwise_not):
|
||||
if len(mask_tensor.args) > 0 and _has_triu_ancestor(
|
||||
mask_tensor.args[0], offset=1
|
||||
):
|
||||
return True
|
||||
|
||||
# Pattern for llama-3.1 style causal mask: slice of expand(unsqueeze(unsqueeze(mul_(triu, gt))))
|
||||
if is_op(mask_node, torch.ops.aten.slice):
|
||||
# Follow the chain backward to the source of the slice
|
||||
if len(mask_node.args) == 0:
|
||||
return False
|
||||
slice_source = mask_node.args[0]
|
||||
|
||||
# Check for typical expand pattern
|
||||
if not (slice_source and is_op(slice_source, torch.ops.aten.expand)):
|
||||
return False
|
||||
|
||||
# Continue tracing back through the pattern
|
||||
if len(slice_source.args) == 0:
|
||||
return False
|
||||
expand_source = slice_source.args[0]
|
||||
|
||||
# Check for first unsqueeze operation
|
||||
if not (expand_source and is_op(expand_source, torch.ops.aten.unsqueeze)):
|
||||
return False
|
||||
|
||||
# Look for the source of first unsqueeze
|
||||
if len(expand_source.args) == 0:
|
||||
return False
|
||||
first_unsqueeze_source = expand_source.args[0]
|
||||
|
||||
# Check for second unsqueeze operation
|
||||
if not (first_unsqueeze_source and is_op(first_unsqueeze_source, torch.ops.aten.unsqueeze)):
|
||||
return False
|
||||
|
||||
# Look for the source of the second unsqueeze
|
||||
if len(first_unsqueeze_source.args) == 0:
|
||||
return False
|
||||
second_unsqueeze_source = first_unsqueeze_source.args[0]
|
||||
|
||||
# Check for mul_ operation
|
||||
if is_op(second_unsqueeze_source, torch.ops.aten.mul_):
|
||||
# Check if one of the mul_ arguments is a triu operation
|
||||
has_triu = False
|
||||
for arg in second_unsqueeze_source.args:
|
||||
if is_op(arg, torch.ops.aten.triu):
|
||||
if len(arg.args) > 1 and arg.args[1] == 1:
|
||||
has_triu = True
|
||||
break
|
||||
|
||||
if has_triu:
|
||||
# Check if one of the mul_ arguments involves a full tensor with negative values
|
||||
for arg in second_unsqueeze_source.args:
|
||||
if is_op(arg, torch.ops.aten.full):
|
||||
if (
|
||||
len(arg.args) > 1
|
||||
and isinstance(arg.args[1], (int, float))
|
||||
and arg.args[1] < -1e4
|
||||
):
|
||||
return True
|
||||
|
||||
return has_triu
|
||||
|
||||
# Original implementation for backward compatibility
|
||||
if is_op(mask_node, torch.ops.aten.slice):
|
||||
# Follow the chain backward to the source of the slice
|
||||
if len(mask_node.args) == 0:
|
||||
return False
|
||||
slice_source = mask_node.args[0]
|
||||
|
||||
# Check for typical expand pattern
|
||||
if not (slice_source and is_op(slice_source, torch.ops.aten.expand)):
|
||||
return False
|
||||
|
||||
# Continue tracing back through the pattern
|
||||
if len(slice_source.args) == 0:
|
||||
return False
|
||||
expand_source = slice_source.args[0]
|
||||
|
||||
# Check for unsqueeze operations
|
||||
if not (expand_source and is_op(expand_source, torch.ops.aten.unsqueeze)):
|
||||
return False
|
||||
|
||||
# Look for the source of the unsqueeze
|
||||
if len(expand_source.args) == 0:
|
||||
return False
|
||||
unsqueeze_source = expand_source.args[0]
|
||||
|
||||
if not unsqueeze_source:
|
||||
return False
|
||||
|
||||
# Check for triu pattern which is common in causal masks
|
||||
if is_op(unsqueeze_source, torch.ops.aten.mul_):
|
||||
for arg in unsqueeze_source.args:
|
||||
if not is_op(arg, torch.ops.aten.triu):
|
||||
continue
|
||||
|
||||
if len(arg.args) <= 1:
|
||||
continue
|
||||
|
||||
triu_offset = arg.args[1]
|
||||
# Causal masks typically use triu with offset 1
|
||||
if triu_offset == 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Check if we have a full tensor filled with a very negative number
|
||||
if not is_op(unsqueeze_source, torch.ops.aten.full):
|
||||
return False
|
||||
|
||||
if len(unsqueeze_source.args) <= 1:
|
||||
return False
|
||||
|
||||
fill_value = unsqueeze_source.args[1]
|
||||
# Check if the fill value is very negative (likely -inf or close)
|
||||
if isinstance(fill_value, float) and fill_value < -1e10:
|
||||
return True
|
||||
|
||||
# If we can't definitively identify it as causal, return False
|
||||
return False
|
||||
|
||||
|
||||
def _has_triu_ancestor(node: Node, offset: int = 1, depth: int = 0, max_depth: int = 5) -> bool:
|
||||
"""Helper function to find a triu operation in the ancestry of a node."""
|
||||
if depth > max_depth: # Prevent infinite recursion
|
||||
return False
|
||||
|
||||
if is_op(node, torch.ops.aten.triu):
|
||||
if len(node.args) > 1 and node.args[1] == offset:
|
||||
return True
|
||||
|
||||
# Check if any of the arguments has a triu ancestor
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node) and _has_triu_ancestor(arg, offset, depth + 1, max_depth):
|
||||
return True
|
||||
|
||||
# Check if any of the kwargs has a triu ancestor
|
||||
for value in node.kwargs.values():
|
||||
if isinstance(value, Node) and _has_triu_ancestor(value, offset, depth + 1, max_depth):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def match_attention_layout(gm: GraphModule, attention_op: Type[AttentionDescriptor]) -> None:
|
||||
"""
|
||||
Match and transform attention operations to match the layout expected by the attention backend.
|
||||
|
||||
If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which
|
||||
is the default for SDPA operations, we don't need to transform anything.
|
||||
|
||||
If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert
|
||||
appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa.
|
||||
"""
|
||||
# Get attention layout from attention_op
|
||||
attention_layout = attention_op.get_attention_layout()
|
||||
|
||||
# List of SDPA operations to look for
|
||||
sdpa_ops = {
|
||||
torch.ops.auto_deploy.torch_attention_sdpa,
|
||||
torch.ops.auto_deploy.torch_attention_grouped_sdpa,
|
||||
}
|
||||
|
||||
graph = gm.graph
|
||||
num_bsnd_patterns = 0
|
||||
|
||||
# Look for SDPA operations
|
||||
for sdpa_node in list(graph.nodes):
|
||||
if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops):
|
||||
continue
|
||||
|
||||
ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}")
|
||||
|
||||
# Extract q, k, v inputs
|
||||
q, k, v = sdpa_node.args[:3]
|
||||
|
||||
# Check if we need to transpose the inputs
|
||||
if attention_layout == "bsnd":
|
||||
# Add transposes before the node (from bnsd to bsnd)
|
||||
with graph.inserting_before(sdpa_node):
|
||||
q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2))
|
||||
k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2))
|
||||
v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2))
|
||||
|
||||
# Preserve fake tensor in meta["val"] for the transposed inputs
|
||||
q_updated.meta["val"] = q.meta["val"].transpose(1, 2)
|
||||
k_updated.meta["val"] = k.meta["val"].transpose(1, 2)
|
||||
v_updated.meta["val"] = v.meta["val"].transpose(1, 2)
|
||||
elif attention_layout == "bnsd":
|
||||
# we don't need to do anything...
|
||||
q_updated = q
|
||||
k_updated = k
|
||||
v_updated = v
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention layout: {attention_layout}")
|
||||
|
||||
# Create bsnd_grouped_sdpa node with the same args as the original node
|
||||
# but using the transposed inputs
|
||||
with graph.inserting_before(sdpa_node):
|
||||
source_sdpa_node = graph.call_function(
|
||||
attention_op.get_source_attention_op(),
|
||||
args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:],
|
||||
kwargs=sdpa_node.kwargs,
|
||||
)
|
||||
|
||||
# Check if need to update the output node to match the layout
|
||||
if attention_layout == "bsnd":
|
||||
# Add transpose for the output (from bsnd back to bnsd)
|
||||
with graph.inserting_after(source_sdpa_node):
|
||||
output_updated = graph.call_function(
|
||||
torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2)
|
||||
)
|
||||
|
||||
# Preserve fake tensor in meta["val"] for the transposed inputs
|
||||
source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous()
|
||||
output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2)
|
||||
elif attention_layout == "bnsd":
|
||||
output_updated = source_sdpa_node
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention layout: {attention_layout}")
|
||||
|
||||
# Replace the old node with the transposed output
|
||||
sdpa_node.replace_all_uses_with(output_updated)
|
||||
|
||||
num_bsnd_patterns += 1
|
||||
|
||||
# Clean up the graph if we made any replacements
|
||||
if num_bsnd_patterns:
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.debug(f"Transformed graph for bsnd layout: {gm}")
|
||||
|
||||
ad_logger.info(f"Found and matched {num_bsnd_patterns} attention layouts")
|
||||
@ -174,7 +174,7 @@ def resize_kv_cache(
|
||||
memory_for_forward_pass = free_mem_pre - free_mem_post
|
||||
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
|
||||
|
||||
new_cache_size = free_mem_post * free_mem_ratio + current_cache_size
|
||||
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
|
||||
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
|
||||
|
||||
# Need to sync all the GPUs
|
||||
|
||||
@ -141,6 +141,12 @@ def match_rope_pattern(gm: GraphModule) -> int:
|
||||
torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16),
|
||||
torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float16),
|
||||
]
|
||||
# float32 input can change the graph when there's .float() in pattern
|
||||
dummy_complex_2 = [
|
||||
torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32),
|
||||
torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32),
|
||||
torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float32),
|
||||
]
|
||||
register_ad_pattern(
|
||||
search_fn=_explicit_rope_pattern,
|
||||
replace_fn=_explicit_rope_repl,
|
||||
@ -172,6 +178,16 @@ def match_rope_pattern(gm: GraphModule) -> int:
|
||||
},
|
||||
scalar_workaround={"unsqueeze_dim": 1},
|
||||
)
|
||||
register_ad_pattern(
|
||||
search_fn=_complex_rope_pattern,
|
||||
replace_fn=_complex_rope_repl,
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_complex_2,
|
||||
op_ignore_types={
|
||||
torch.ops.aten.reshape.default: (int,),
|
||||
},
|
||||
scalar_workaround={"unsqueeze_dim": 1},
|
||||
)
|
||||
|
||||
num_matches = patterns.apply(graph)
|
||||
canonicalize_graph(gm)
|
||||
|
||||
@ -24,17 +24,10 @@ from .library import (
|
||||
fuse_collectives,
|
||||
fuse_rmsnorm,
|
||||
insert_cached_attention,
|
||||
match_attention_layout,
|
||||
match_causal_attn_mask,
|
||||
match_eager_attention,
|
||||
match_grouped_attention,
|
||||
match_moe_pattern,
|
||||
match_repeat_kv,
|
||||
match_rope_layout,
|
||||
match_rope_pattern,
|
||||
optimize_rope,
|
||||
quantize,
|
||||
quantize_moe,
|
||||
resize_kv_cache,
|
||||
sharding_transform_executor,
|
||||
update_in_out_nodes,
|
||||
@ -63,6 +56,12 @@ class InferenceOptimizer:
|
||||
############################################################################################
|
||||
# RUN MODULAR INFERENCE OPTIMIZER FOR ALREADY-MIGRATED TRANSFORMS
|
||||
############################################################################################
|
||||
# TODO (hg): default values that are not representable in YAML.
|
||||
if "match_attention_layout" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"match_attention_layout"
|
||||
].attention_op = AttentionRegistry.get(self.ad_config.attn_backend)
|
||||
|
||||
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
|
||||
egm = new_optimizer(cm)
|
||||
|
||||
@ -71,28 +70,10 @@ class InferenceOptimizer:
|
||||
############################################################################################
|
||||
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
# quantization
|
||||
quantize(egm, self.factory.get_quant_config())
|
||||
quantize_moe(egm, self.factory.get_quant_config())
|
||||
|
||||
# Match MoE pattern
|
||||
match_moe_pattern(egm)
|
||||
|
||||
# Match repeat_kv pattern
|
||||
match_repeat_kv(egm)
|
||||
|
||||
# Match eager attention pattern
|
||||
match_eager_attention(egm)
|
||||
|
||||
# Match grouped attention pattern
|
||||
match_grouped_attention(egm)
|
||||
|
||||
# Match and optimize causal attention masks
|
||||
match_causal_attn_mask(egm)
|
||||
|
||||
# Match attention layout expected by our backend
|
||||
match_attention_layout(egm, AttentionRegistry.get(self.ad_config.attn_backend))
|
||||
|
||||
# Match rope
|
||||
match_rope_pattern(egm)
|
||||
|
||||
|
||||
@ -153,6 +153,8 @@ def register_ad_pattern(
|
||||
5. register_replacement can auto-generate `search_fn_pattern` if you input `example_inputs`,
|
||||
but that approach will fail when symbolic shapes are involved. Here
|
||||
we explicitly trace & convert via `fx_to_pattern`.
|
||||
6. The PatternMatcherPass would check num_users of the nodes, meaning that the pattern is required
|
||||
to be functionally isolated, no intermediate nodes are shared with the rest of the graph.
|
||||
|
||||
"""
|
||||
argnames = list(inspect.signature(search_fn).parameters.keys())
|
||||
|
||||
@ -8,16 +8,49 @@ from _torch_test_utils import all_close, reset_parameters
|
||||
from torch.export import export
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import ShardingTransformInfo
|
||||
|
||||
|
||||
class FakeFactory:
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = model
|
||||
class FakeFactory(ModelFactory):
|
||||
"""Dummy factory to pass cache_config for testing."""
|
||||
|
||||
def build_model(self, device: str) -> nn.Module:
|
||||
return self.model.to(device=device)
|
||||
def __init__(self, model=None, cache_config=None, quant_config=None):
|
||||
self._model = model
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
def build_model(self, device: str):
|
||||
return self._model.to(device=device) if self._model else None
|
||||
|
||||
def _build_model(self, device: str):
|
||||
return
|
||||
|
||||
def _load_checkpoint(self, model, device):
|
||||
return
|
||||
|
||||
def get_cache_config(self):
|
||||
return self.cache_config
|
||||
|
||||
def get_quant_config(self):
|
||||
return self.quant_config
|
||||
|
||||
|
||||
class SequenceEmbeddingInfo(SequenceInfo):
|
||||
hidden_size: int
|
||||
dtype: torch.dtype
|
||||
|
||||
def set_example_sequence(self) -> None:
|
||||
super().set_example_sequence()
|
||||
# set input ids to a 3D tensor (actually input embeddings)
|
||||
self.input_ids = torch.rand(
|
||||
*self.input_ids.shape,
|
||||
self.hidden_size,
|
||||
device=self.input_ids.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module):
|
||||
@ -32,6 +65,79 @@ def count_buffers(model: torch.nn.Module):
|
||||
return sum(np.prod(b.shape) for b in model.buffers())
|
||||
|
||||
|
||||
def run_test_transformed_gm(
|
||||
model: nn.Module,
|
||||
x: torch.Tensor,
|
||||
gm_transformed: GraphModule,
|
||||
check_transformed_graph: Callable[[GraphModule], bool],
|
||||
_get_expected_num_params: Callable[[int], int],
|
||||
atol: float = 1e-3,
|
||||
rtol: float = 1e-3,
|
||||
test_load_hook: bool = True,
|
||||
strict_loading: bool = True,
|
||||
dynamic_shapes: Dict = None,
|
||||
skip_output_assert: bool = False,
|
||||
*args, # Additional arguments for transform
|
||||
) -> GraphModule:
|
||||
# run model once
|
||||
y_model = model(x)
|
||||
|
||||
# num params
|
||||
num_params_model = count_parameters(model)
|
||||
print(num_params_model)
|
||||
|
||||
# export + check (we clone the state dict to have a bit more freedom in testing below)
|
||||
gm_ref = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
|
||||
print(gm_ref)
|
||||
y_gm = gm_ref(x)
|
||||
num_params_gm = count_parameters(gm_ref)
|
||||
|
||||
assert num_params_model == num_params_gm
|
||||
if not skip_output_assert:
|
||||
torch.testing.assert_close(y_model, y_gm, atol=atol, rtol=rtol)
|
||||
|
||||
print(gm_transformed)
|
||||
# in case buffers or other tensors were added during the transform
|
||||
gm_transformed = gm_transformed.to("cuda")
|
||||
y_transformed = gm_transformed(x)
|
||||
n_p_transformed = count_parameters(gm_transformed)
|
||||
|
||||
n_p_t_expected = _get_expected_num_params(num_params_model)
|
||||
assert n_p_transformed == n_p_t_expected, (
|
||||
f"actual params {n_p_transformed} != expected params {n_p_t_expected}"
|
||||
)
|
||||
|
||||
# check if the transformation worked
|
||||
assert check_transformed_graph(gm_transformed)
|
||||
|
||||
if strict_loading and not skip_output_assert:
|
||||
# check if output equals without loading state dict
|
||||
torch.testing.assert_close(y_model, y_transformed, atol=atol, rtol=rtol)
|
||||
|
||||
if test_load_hook and not skip_output_assert:
|
||||
# check if loading hook works from original state dict
|
||||
reset_parameters(gm_transformed)
|
||||
y_random = gm_transformed(x)
|
||||
assert not all_close(y_model, y_random), f"{y_model=}, {y_random=}"
|
||||
|
||||
gm_transformed.load_state_dict(model.state_dict(), strict=True if strict_loading else False)
|
||||
y_loaded_from_original = gm_transformed(x)
|
||||
torch.testing.assert_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol)
|
||||
|
||||
# check if loading hook works from state_dict of a transformed model
|
||||
state_dict_sharded = copy.deepcopy(gm_transformed.state_dict())
|
||||
reset_parameters(gm_transformed)
|
||||
y_random2 = gm_transformed(x)
|
||||
assert not all_close(y_model, y_random2), f"{y_model=}, {y_random2=}"
|
||||
|
||||
gm_transformed.load_state_dict(state_dict_sharded, strict=True if strict_loading else False)
|
||||
y_loaded_from_transformed = gm_transformed(x)
|
||||
torch.testing.assert_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol)
|
||||
|
||||
# check if we can still export the model as expected
|
||||
export(gm_transformed, args=(x,))
|
||||
|
||||
|
||||
def run_test(
|
||||
model: nn.Module,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -19,9 +19,6 @@ from build_and_run_ad import ExperimentConfig, main
|
||||
],
|
||||
)
|
||||
def test_build_ad(world_size: int, experiment_config: Dict):
|
||||
if world_size > 1:
|
||||
pytest.skip("https://nvbugspro.nvidia.com/bug/5331013")
|
||||
|
||||
experiment_config["args"]["world_size"] = world_size
|
||||
experiment_config["args"]["runtime"] = "trtllm" # Default runtime set to trtllm
|
||||
experiment_config = ExperimentConfig(**experiment_config)
|
||||
|
||||
@ -7,14 +7,8 @@ from torch.export import Dim
|
||||
from torch.fx import GraphModule
|
||||
from transformers.integrations.sdpa_attention import repeat_kv as hf_repeat_kv
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.attention import (
|
||||
match_attention_layout,
|
||||
match_causal_attn_mask,
|
||||
match_eager_attention,
|
||||
match_grouped_attention,
|
||||
match_repeat_kv,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
torch.manual_seed(0)
|
||||
@ -164,16 +158,15 @@ class EagerAttentionModel(torch.nn.Module):
|
||||
# Multiplication pattern
|
||||
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scaling
|
||||
|
||||
# Add attention mask if enabled
|
||||
# Add causal attention mask if enabled
|
||||
if self.has_mask:
|
||||
# Create a simple causal mask for testing - make sure all tensors are on the same device
|
||||
mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device),
|
||||
diagonal=1,
|
||||
# [1, 1, seq_len, seq_len] causal mask with -inf in the upper triangle
|
||||
attn_mask = torch.triu(
|
||||
torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1
|
||||
)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
|
||||
attn_mask = torch.zeros_like(attn_weights, device=device)
|
||||
attn_mask = attn_mask.masked_fill(mask, float("-inf"))
|
||||
attn_mask = (
|
||||
attn_mask.unsqueeze(0).unsqueeze(0).to(x.dtype)
|
||||
) # shape: [1, 1, seq_len, seq_len]
|
||||
attn_weights = attn_weights + attn_mask
|
||||
|
||||
# Apply softmax, dtype conversion, and dropout
|
||||
@ -249,13 +242,13 @@ class ComplexEagerAttentionModel(torch.nn.Module):
|
||||
|
||||
# Add attention mask if enabled
|
||||
if self.has_mask:
|
||||
mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device),
|
||||
diagonal=1,
|
||||
# [1, 1, seq_len, seq_len] causal mask with -inf in the upper triangle
|
||||
attn_mask = torch.triu(
|
||||
torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1
|
||||
)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
attn_mask = torch.zeros_like(attn_weights, device=device)
|
||||
attn_mask = attn_mask.masked_fill(mask, float("-inf"))
|
||||
attn_mask = (
|
||||
attn_mask.unsqueeze(0).unsqueeze(0).to(x.dtype)
|
||||
) # shape: [1, 1, seq_len, seq_len]
|
||||
attn_weights = attn_weights + attn_mask
|
||||
|
||||
# Add a to_dtype node before softmax to match pattern in the graph
|
||||
@ -366,8 +359,6 @@ class GroupedAttentionModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
|
||||
# Generate q, k, v
|
||||
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
@ -385,28 +376,26 @@ class GroupedAttentionModel(torch.nn.Module):
|
||||
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, self.n_rep)
|
||||
|
||||
# Create attention mask if needed
|
||||
attn_mask = None
|
||||
if self.has_mask:
|
||||
# Simple causal mask
|
||||
mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device),
|
||||
diagonal=1,
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=self.dropout,
|
||||
is_causal=True,
|
||||
scale=1.0 / (self.head_dim**0.5),
|
||||
)
|
||||
else:
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
dropout_p=self.dropout,
|
||||
is_causal=False,
|
||||
scale=1.0 / (self.head_dim**0.5),
|
||||
)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
|
||||
attn_mask = torch.zeros(
|
||||
(batch_size, 1, seq_len, seq_len), device=device, dtype=dtype
|
||||
).masked_fill(mask, float("-inf"))
|
||||
|
||||
# Apply scaled dot product attention
|
||||
attn_output = torch.ops.auto_deploy.torch_attention_sdpa(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.dropout,
|
||||
is_causal=False,
|
||||
scale=1.0 / (self.head_dim**0.5),
|
||||
)
|
||||
|
||||
# Reshape output for the linear projection
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
||||
@ -423,11 +412,47 @@ def _get_match_repeat_kv_optimizer() -> Callable:
|
||||
"cleanup_noop_slice": {
|
||||
"stage": "post_export",
|
||||
},
|
||||
"match_repeat_kv": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
}
|
||||
|
||||
def _transform(gm: GraphModule) -> GraphModule:
|
||||
gm = InferenceOptimizer(None, config)(None, gm)
|
||||
return gm
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
def _get_match_eager_attention_optimizer() -> Callable:
|
||||
config = {
|
||||
"cleanup_noop_slice": {
|
||||
"stage": "post_export",
|
||||
},
|
||||
"match_eager_attention": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
}
|
||||
|
||||
def _transform(gm: GraphModule) -> GraphModule:
|
||||
gm = InferenceOptimizer(None, config)(None, gm)
|
||||
return gm
|
||||
|
||||
return _transform
|
||||
|
||||
|
||||
def _get_match_grouped_attention_optimizer() -> Callable:
|
||||
config = {
|
||||
"cleanup_noop_slice": {
|
||||
"stage": "post_export",
|
||||
},
|
||||
"match_grouped_attention": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
}
|
||||
|
||||
def _transform(gm: GraphModule) -> GraphModule:
|
||||
gm = InferenceOptimizer(None, config)(None, gm)
|
||||
match_repeat_kv(gm)
|
||||
return gm
|
||||
|
||||
return _transform
|
||||
@ -516,8 +541,8 @@ def test_match_repeat_kv(num_heads, num_kv_heads, model_cls):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_mask", [True, False])
|
||||
@pytest.mark.parametrize("use_division", [False, True])
|
||||
@pytest.mark.parametrize("has_mask", [False, True])
|
||||
@pytest.mark.parametrize("use_division", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"dropout, skip_output_assert",
|
||||
[
|
||||
@ -537,8 +562,10 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
|
||||
|
||||
# Create different model types based on the parameter
|
||||
if model_type == "standard":
|
||||
model = EagerAttentionModel(hidden_size, num_heads, has_mask, dropout, use_division).to(
|
||||
"cuda", dtype=torch.float16
|
||||
model = (
|
||||
EagerAttentionModel(hidden_size, num_heads, has_mask, dropout, use_division)
|
||||
.to("cuda", dtype=torch.float16)
|
||||
.eval()
|
||||
)
|
||||
# Print the original scaling approach and value
|
||||
if use_division:
|
||||
@ -549,8 +576,10 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
|
||||
expected_scale = model.scaling
|
||||
else: # complex
|
||||
# Complex model only uses division for scaling
|
||||
model = ComplexEagerAttentionModel(hidden_size, num_heads, has_mask, dropout).to(
|
||||
"cuda", dtype=torch.float16
|
||||
model = (
|
||||
ComplexEagerAttentionModel(hidden_size, num_heads, has_mask, dropout)
|
||||
.to("cuda", dtype=torch.float16)
|
||||
.eval()
|
||||
)
|
||||
expected_scale = 1.0 / model.scale_divisor
|
||||
# Override use_division and only run test once (ignore the parameterization)
|
||||
@ -567,6 +596,7 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
|
||||
expected_matches = 1
|
||||
|
||||
def verify_matcher(gm):
|
||||
# torch_attention_sdpa is replaced with torch_attention_sdpa after the transformation
|
||||
sdpa_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
|
||||
]
|
||||
@ -636,13 +666,15 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
|
||||
|
||||
# Check mask handling for masked attention
|
||||
if has_mask:
|
||||
has_mask_arg = "attn_mask" in kwargs
|
||||
if not has_mask_arg and len(node.args) >= 4:
|
||||
has_mask_arg = node.args[3] is not None
|
||||
is_causal = kwargs.get("is_causal", None)
|
||||
if is_causal is None and len(node.args) >= 6:
|
||||
is_causal = node.args[5]
|
||||
|
||||
if not has_mask_arg:
|
||||
print("❌ Missing mask information in SDPA node")
|
||||
if is_causal is not True:
|
||||
print(f"❌ Expected is_causal=True for masked attention, got {is_causal}")
|
||||
valid = False
|
||||
else:
|
||||
print("✅ is_causal correctly set to True")
|
||||
|
||||
print("Graph verification successful" if valid else "Graph verification failed")
|
||||
return valid
|
||||
@ -651,7 +683,7 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse
|
||||
run_test(
|
||||
model,
|
||||
x,
|
||||
match_eager_attention,
|
||||
_get_match_eager_attention_optimizer(),
|
||||
verify_matcher,
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
@ -685,7 +717,7 @@ def test_counter_example():
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
match_repeat_kv,
|
||||
_get_match_eager_attention_optimizer(),
|
||||
verify_no_matches,
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
@ -709,9 +741,8 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
# We should find 1 instance of the pattern if num_heads != num_kv_heads
|
||||
# Otherwise, no pattern should be matched (no grouped attention)
|
||||
expected_matches = 1 if num_heads != num_kv_heads else 0
|
||||
# We should find 1 instance of torch_attention_grouped_sdpa
|
||||
expected_matches = 1
|
||||
|
||||
def verify_matcher(gm):
|
||||
grouped_sdpa_nodes = [
|
||||
@ -727,10 +758,6 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
|
||||
)
|
||||
return False
|
||||
|
||||
# If we don't expect any matches, we're done
|
||||
if expected_matches == 0:
|
||||
return True
|
||||
|
||||
# Otherwise, check the node properties
|
||||
for node in grouped_sdpa_nodes:
|
||||
# Basic checks: should have at least 3 positional args (q, k, v)
|
||||
@ -743,16 +770,14 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
|
||||
|
||||
# Mask handling should be preserved
|
||||
if has_mask:
|
||||
# Check if attn_mask is in kwargs or provided via args
|
||||
has_mask_arg = "attn_mask" in kwargs
|
||||
if (
|
||||
not has_mask_arg and len(node.args) >= 4
|
||||
): # Assuming attn_mask is the 4th positional arg
|
||||
has_mask_arg = node.args[3] is not None
|
||||
is_causal = kwargs.get("is_causal", None)
|
||||
if is_causal is None and len(node.args) >= 6:
|
||||
is_causal = node.args[5]
|
||||
|
||||
if not has_mask_arg:
|
||||
print("❌ Expected attn_mask in args or kwargs but not found")
|
||||
return False
|
||||
if is_causal is not True:
|
||||
print(f"❌ Expected is_causal=True for masked attention, got {is_causal}")
|
||||
else:
|
||||
print("✅ is_causal correctly set to True")
|
||||
|
||||
return True
|
||||
|
||||
@ -760,7 +785,7 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask):
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
match_grouped_attention,
|
||||
_get_match_grouped_attention_optimizer(),
|
||||
verify_matcher,
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
@ -884,98 +909,6 @@ class CausalAttentionModel(torch.nn.Module):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mask_type", ["triu", "negative_fill", "non_causal"])
|
||||
@pytest.mark.parametrize("use_grouped_sdpa", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_match_causal_attention(mask_type, use_grouped_sdpa):
|
||||
batch_size, seq_len = 4, 12
|
||||
hidden_size = 512
|
||||
num_heads = 8
|
||||
num_kv_heads = 4 if use_grouped_sdpa else num_heads
|
||||
|
||||
model = CausalAttentionModel(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mask_type=mask_type,
|
||||
use_grouped_sdpa=use_grouped_sdpa,
|
||||
num_kv_heads=num_kv_heads,
|
||||
).to("cuda", dtype=torch.float16)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
# We expect optimization (None mask + is_causal=True) when using causal masks
|
||||
should_optimize = mask_type in ["triu", "negative_fill"]
|
||||
|
||||
def verify_matcher(gm):
|
||||
# Find attention operations
|
||||
if use_grouped_sdpa:
|
||||
attn_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
]
|
||||
else:
|
||||
attn_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
|
||||
]
|
||||
|
||||
if len(attn_nodes) != 1:
|
||||
print(f"Expected 1 attention node, found {len(attn_nodes)}")
|
||||
return False
|
||||
|
||||
node = attn_nodes[0]
|
||||
|
||||
# Check if attention mask was set to None and is_causal was set to True
|
||||
if should_optimize:
|
||||
# Attention mask (4th arg) should be None
|
||||
has_mask = (
|
||||
node.args[3] is not None if len(node.args) > 3 else "attn_mask" in node.kwargs
|
||||
)
|
||||
|
||||
# is_causal (6th arg) should be True
|
||||
is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False)
|
||||
|
||||
# Check if optimization was correctly applied
|
||||
if has_mask or not is_causal:
|
||||
print("❌ Expected optimization: mask=None, is_causal=True")
|
||||
print(
|
||||
f" Got: mask={node.args[3] if len(node.args) > 3 else 'not in args'}, "
|
||||
f"is_causal={is_causal}"
|
||||
)
|
||||
return False
|
||||
|
||||
print("✅ Successfully optimized causal mask: mask=None, is_causal=True")
|
||||
else:
|
||||
# Non-causal masks should remain as is
|
||||
has_mask = (
|
||||
node.args[3] is not None if len(node.args) > 3 else "attn_mask" in node.kwargs
|
||||
)
|
||||
|
||||
# Check if non-optimization was correctly preserved
|
||||
if not has_mask:
|
||||
print("❌ Expected non-causal mask to be preserved")
|
||||
return False
|
||||
|
||||
print("✅ Successfully preserved non-causal mask")
|
||||
|
||||
return True
|
||||
|
||||
# Run the test
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
match_causal_attn_mask,
|
||||
verify_matcher,
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
test_load_hook=True,
|
||||
strict_loading=True,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
|
||||
class Llama3CausalAttentionModel(torch.nn.Module):
|
||||
"""Model that creates a causal attention mask mimicking the llama-3.1 pattern."""
|
||||
|
||||
@ -1082,78 +1015,7 @@ class Llama3CausalAttentionModel(torch.nn.Module):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_grouped_sdpa", [False, True])
|
||||
@pytest.mark.skip(reason="Skip until we have more robust attention masking handling, see #4783")
|
||||
@torch.inference_mode()
|
||||
def test_match_llama3_causal_attention(use_grouped_sdpa):
|
||||
batch_size, seq_len = 4, 12
|
||||
hidden_size = 512
|
||||
num_heads = 8
|
||||
num_kv_heads = 4 if use_grouped_sdpa else num_heads
|
||||
|
||||
model = Llama3CausalAttentionModel(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
use_grouped_sdpa=use_grouped_sdpa,
|
||||
num_kv_heads=num_kv_heads,
|
||||
).to("cuda", dtype=torch.float32)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float32)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
def verify_matcher(gm):
|
||||
# Find attention operations
|
||||
if use_grouped_sdpa:
|
||||
attn_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa)
|
||||
]
|
||||
else:
|
||||
attn_nodes = [
|
||||
n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa)
|
||||
]
|
||||
|
||||
if len(attn_nodes) != 1:
|
||||
print(f"Expected 1 attention node, found {len(attn_nodes)}")
|
||||
return False
|
||||
|
||||
node = attn_nodes[0]
|
||||
|
||||
# Attention mask (4th arg) should be None
|
||||
has_mask = node.args[3] is not None if len(node.args) > 3 else "attn_mask" in node.kwargs
|
||||
|
||||
# is_causal (6th arg) should be True
|
||||
is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False)
|
||||
|
||||
# Check if optimization was correctly applied
|
||||
if has_mask or not is_causal:
|
||||
print("❌ Expected optimization: mask=None, is_causal=True")
|
||||
print(
|
||||
f" Got: mask={node.args[3] if len(node.args) > 3 else 'not in args'}, "
|
||||
f"is_causal={is_causal}"
|
||||
)
|
||||
return False
|
||||
|
||||
print("✅ Successfully optimized llama-3.1 causal mask: mask=None, is_causal=True")
|
||||
return True
|
||||
|
||||
# Run the test
|
||||
run_test(
|
||||
model,
|
||||
x,
|
||||
match_causal_attn_mask,
|
||||
verify_matcher,
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
test_load_hook=True,
|
||||
strict_loading=True,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
|
||||
class MockAttentionDescriptor:
|
||||
class MockAttentionDescriptor(AttentionDescriptor):
|
||||
"""A mock class that mimics the AttentionDescriptor interface for testing."""
|
||||
|
||||
layout: str = "bnsd"
|
||||
@ -1458,7 +1320,15 @@ def test_match_attention_layout(layout, model_config, has_mask):
|
||||
run_test(
|
||||
model,
|
||||
x,
|
||||
lambda gm: match_attention_layout(gm, MockAttentionDescriptor),
|
||||
lambda gm: InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_attention_layout": {
|
||||
"stage": "pattern_matcher",
|
||||
"attention_op": MockAttentionDescriptor,
|
||||
},
|
||||
},
|
||||
)(None, gm),
|
||||
verify_matcher,
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
|
||||
@ -1,26 +1,26 @@
|
||||
"""Test that the attention matcher works with HF's SDPA backends."""
|
||||
|
||||
import copy
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _graph_test_helpers import run_test
|
||||
from accelerate import init_empty_weights
|
||||
from torch.export import Dim
|
||||
from torch.fx import GraphModule
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library import (
|
||||
match_attention_layout,
|
||||
match_causal_attn_mask,
|
||||
match_eager_attention,
|
||||
match_grouped_attention,
|
||||
match_repeat_kv,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class MockAttentionDescriptor:
|
||||
class MockAttentionDescriptor(AttentionDescriptor):
|
||||
"""A mock class that mimics the AttentionDescriptor interface for testing."""
|
||||
|
||||
layout: str = "bsnd"
|
||||
@ -45,11 +45,24 @@ class HFWrapper(nn.Module):
|
||||
|
||||
|
||||
def _joint_transform(gm: GraphModule) -> None:
|
||||
match_repeat_kv(gm)
|
||||
match_eager_attention(gm)
|
||||
match_grouped_attention(gm)
|
||||
match_causal_attn_mask(gm)
|
||||
match_attention_layout(gm, MockAttentionDescriptor())
|
||||
gm = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"match_repeat_kv": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
"match_eager_attention": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
"match_grouped_attention": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
"match_attention_layout": {
|
||||
"stage": "pattern_matcher",
|
||||
"attention_op": MockAttentionDescriptor,
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -65,23 +78,6 @@ def _joint_transform(gm: GraphModule) -> None:
|
||||
["eager", "sdpa"],
|
||||
)
|
||||
def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str):
|
||||
batch_size, seq_len = 4, 12
|
||||
full_config = {
|
||||
"num_hidden_layers": 1,
|
||||
"vocab_size": 256,
|
||||
"hidden_size": 128,
|
||||
"intermediate_size": 128,
|
||||
"attn_implementation": attn_implementation,
|
||||
**config,
|
||||
}
|
||||
dynamic_shapes = {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)}
|
||||
|
||||
model = HFWrapper(LlamaModel(LlamaConfig(**full_config))).to("cuda")
|
||||
model.eval()
|
||||
x = torch.randint(
|
||||
0, full_config["vocab_size"], (batch_size, seq_len), dtype=torch.long, device="cuda"
|
||||
)
|
||||
|
||||
def verify_matcher(gm: GraphModule):
|
||||
"""Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa
|
||||
call in the graph. Also check that there is no repeat_kv pattern left.
|
||||
@ -106,18 +102,69 @@ def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str)
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_attention_repeat_kv
|
||||
)
|
||||
assert len(nodes) == 0, "Found repeat_kv pattern in the graph"
|
||||
attn_nodes = gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.auto_deploy.torch_attention_sdpa
|
||||
)
|
||||
assert len(attn_nodes) == 0, "Found torch_attention_sdpa node in the graph"
|
||||
|
||||
return True
|
||||
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
_joint_transform,
|
||||
verify_matcher,
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
rtol=5e-2,
|
||||
test_load_hook=True,
|
||||
strict_loading=True,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
batch_size, seq_len = 2, 4
|
||||
full_config = {
|
||||
"num_hidden_layers": 1,
|
||||
"vocab_size": 256,
|
||||
"hidden_size": 128,
|
||||
"intermediate_size": 128,
|
||||
"attn_implementation": attn_implementation,
|
||||
**config,
|
||||
}
|
||||
dynamic_shapes = {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=2, max=8)}
|
||||
|
||||
# Build and export model on meta device
|
||||
with init_empty_weights():
|
||||
model = HFWrapper(LlamaModel(LlamaConfig(**full_config))).eval()
|
||||
x = torch.randint(
|
||||
0, full_config["vocab_size"], (batch_size, seq_len), dtype=torch.long, device="cuda"
|
||||
)
|
||||
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
|
||||
|
||||
print("Exported gm", gm)
|
||||
gm_exported = copy.deepcopy(gm)
|
||||
|
||||
# Move model to cuda
|
||||
device = "cuda"
|
||||
model._apply(
|
||||
lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype)
|
||||
if t.device == torch.device("meta")
|
||||
else t.to(device)
|
||||
)
|
||||
y_model = model(x)
|
||||
|
||||
gm_exported._apply(
|
||||
lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype)
|
||||
if t.device == torch.device("meta")
|
||||
else t.to(device)
|
||||
)
|
||||
gm_exported.load_state_dict(model.state_dict())
|
||||
move_to_device(gm_exported, "cuda")
|
||||
y_gm_exported = gm_exported(x)
|
||||
torch.testing.assert_close(y_gm_exported, y_model, atol=5e-3, rtol=5e-3)
|
||||
|
||||
# Apply transformation
|
||||
_joint_transform(gm)
|
||||
assert verify_matcher(gm)
|
||||
print("Transformed gm", gm)
|
||||
|
||||
# Move gm to cuda
|
||||
gm._apply(
|
||||
lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype)
|
||||
if t.device == torch.device("meta")
|
||||
else t.to(device)
|
||||
)
|
||||
gm.load_state_dict(model.state_dict())
|
||||
move_to_device(gm, "cuda")
|
||||
|
||||
# Verify output
|
||||
y_gm = gm(x)
|
||||
torch.testing.assert_close(y_gm_exported, y_gm, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(y_model, y_gm, atol=5e-2, rtol=5e-2)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import run_test
|
||||
from _graph_test_helpers import FakeFactory, run_test_transformed_gm
|
||||
from _model_test_utils import MoEOpModel
|
||||
from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library import quantize_moe
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
|
||||
@ -62,13 +63,20 @@ def test_quantize_moe_transformation(quant_algo, expected_op):
|
||||
|
||||
quant_config = {"quant_algo": quant_algo}
|
||||
|
||||
def _transform(gm, *args):
|
||||
return quantize_moe(gm, quant_config)
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
gm_transformed = InferenceOptimizer(
|
||||
FakeFactory(quant_config=quant_config),
|
||||
{
|
||||
"quantize_moe": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
_ = run_test(
|
||||
run_test_transformed_gm(
|
||||
model=model,
|
||||
x=x,
|
||||
transform=_transform,
|
||||
gm_transformed=gm_transformed,
|
||||
check_transformed_graph=_check_transformed_graph,
|
||||
_get_expected_num_params=_expected_num_params,
|
||||
atol=0.5,
|
||||
|
||||
@ -4,13 +4,14 @@ Tests for basic graph sharding.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import run_test
|
||||
from _graph_test_helpers import run_test_transformed_gm
|
||||
from _model_test_utils import MLP, BMMDynamicModel, BMMModel
|
||||
from _torch_test_utils import fp4_compatible, fp8_compatible
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import QUANT_OPS
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library import quantize
|
||||
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp8_scale
|
||||
|
||||
@ -19,6 +20,22 @@ def check_quantized(gm):
|
||||
return any(is_op(n, QUANT_OPS) for n in gm.graph.nodes)
|
||||
|
||||
|
||||
class DummyFactory(ModelFactory):
|
||||
"""Dummy factory to pass quant_config for testing."""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def _build_model(self, device: str):
|
||||
return
|
||||
|
||||
def _load_checkpoint(self, model, device):
|
||||
return
|
||||
|
||||
def get_quant_config(self):
|
||||
return self.quant_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"quant_config,atol,rtol,num_p_og",
|
||||
[
|
||||
@ -51,11 +68,22 @@ def test_quantization(quant_config, atol, rtol, num_p_og):
|
||||
model.linear2.register_buffer(
|
||||
"input_scale", torch.tensor([1.0], device=model.linear2.weight.device)
|
||||
)
|
||||
# set up sequence+cache objects
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
gm_transformed = InferenceOptimizer(
|
||||
DummyFactory(quant_config),
|
||||
{
|
||||
"quantize": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
gm_transformed.to("cuda")
|
||||
|
||||
gm_transformed = run_test(
|
||||
run_test_transformed_gm(
|
||||
model,
|
||||
x,
|
||||
quantize,
|
||||
gm_transformed,
|
||||
check_quantized,
|
||||
num_p_og,
|
||||
atol,
|
||||
@ -122,10 +150,22 @@ def test_bmm_quantization(quant_config, atol, rtol, num_p_og, model_class):
|
||||
model.register_buffer("bmm_dynamic_input_scale", fp8_scale(x))
|
||||
model.register_buffer("bmm_dynamic_weight_scale", fp8_scale(dummy_weight))
|
||||
|
||||
gm_transformed = run_test(
|
||||
# set up sequence+cache objects
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
gm_transformed = InferenceOptimizer(
|
||||
DummyFactory(quant_config),
|
||||
{
|
||||
"quantize": {
|
||||
"stage": "pattern_matcher",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
gm_transformed.to("cuda")
|
||||
|
||||
run_test_transformed_gm(
|
||||
model,
|
||||
x,
|
||||
quantize,
|
||||
gm_transformed,
|
||||
check_quantized,
|
||||
num_p_og,
|
||||
atol,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user