[#4585][feat] Replace unified attention before export (#8303)

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
This commit is contained in:
h-guo18 2025-10-23 15:02:04 -07:00 committed by GitHub
parent 32e1ad68e1
commit 23920223ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 84 additions and 126 deletions

View File

@ -0,0 +1,81 @@
"""Patch for torch.export.export to detect and replace hf attention_interface with unified attention."""
from typing import Optional
import torch
import torch.export as te
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from ..interface import BaseExportPatch, ExportPatchRegistry
# Kwargs mapping for HF attention_interface to auto_deploy::torch_attention
HF_ATTN_KWARGS_MAPPING = {
"dropout": "dropout_p",
"is_causal": "is_causal",
"scaling": "scale",
"scale": "scale",
"s_aux": "sinks",
"sinks": "sinks",
"sliding_window": "sliding_window",
"logit_cap": "logit_cap",
}
def torch_attention_hf_wrapper(
self: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
**kwargs,
):
"""Wrapper of auto_deploy::torch_attention with HF attention_interface signature."""
# Convert from [batch, num_heads, seq_len, head_dim] to [batch, seq_len, num_heads, head_dim]
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
value_states = value.transpose(1, 2)
ad_attn_kwargs = {
HF_ATTN_KWARGS_MAPPING[k]: v for k, v in kwargs.items() if k in HF_ATTN_KWARGS_MAPPING
}
attn_output = torch.ops.auto_deploy.torch_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
layout="bsnd",
**ad_attn_kwargs,
)
return attn_output, None
@ExportPatchRegistry.register("unified_attn")
class UnifiedAttnPatch(BaseExportPatch):
"""
Patch on torch.export.export to replace attention_interface with torch.ops.auto_deploy.torch_attention.
"""
def _apply_patch(self):
"""Apply the te.export patch."""
# Store original torch.export.export
self.original_values["te.export"] = te.export
# Register the wrapper function
ALL_ATTENTION_FUNCTIONS.register("ad_unified_attn", torch_attention_hf_wrapper)
def _export_with_unified_attn(model, *args, **kwargs):
# torch_export_to_gm is called at both export stage and attn matching stage
# we only patch attn implementation for export stage
if hasattr(model, "config") and hasattr(model.config, "_attn_implementation"):
model.config._attn_implementation = "ad_unified_attn"
return self.original_values["te.export"](model, *args, **kwargs)
# Apply patch
te.export = _export_with_unified_attn
def _revert_patch(self):
"""Revert the te.export patch."""
te.export = self.original_values["te.export"]

View File

@ -1,16 +0,0 @@
import re
from transformers import AutoConfig
_orig_from_pretrained = AutoConfig.from_pretrained
def _from_pretrained_patched(pretrained_model_name_or_path, **kwargs):
print(str(pretrained_model_name_or_path))
if re.search(r"Llama-3_(?:1|3)-Nemotron-(?:Ultra|Super)", str(pretrained_model_name_or_path)):
kwargs["attn_implementation"] = "eager"
return _orig_from_pretrained(pretrained_model_name_or_path, **kwargs)
# TODO: figure out how this can be incorporated into the export patch system
AutoConfig.from_pretrained = _from_pretrained_patched

View File

@ -1,100 +0,0 @@
import types
from typing import Callable, Dict, Optional
import torch
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
def gpt_oss_attention(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
"""GPT OSS Attention forward function rewritten to wrap attention as a custom op."""
from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb
# Add new parameters
sliding_window = getattr(self, "sliding_window", -1) # Default to -1 if not present
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# Apply Q, K, V projections (same as original)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# Use original rope implementation
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Handle KV cache properly
if past_key_value is not None:
# Update KV cache - check if it has update method (modern cache objects)
if hasattr(past_key_value, "update"):
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
else:
# Handle legacy tuple-based cache
if isinstance(past_key_value, tuple) and len(past_key_value) == 2:
past_key, past_value = past_key_value
key_states = torch.cat([past_key, key_states], dim=2)
value_states = torch.cat([past_value, value_states], dim=2)
# Convert from [batch, num_heads, seq_len, head_dim] to [batch, seq_len, num_heads, head_dim]
query_states = query_states.transpose(1, 2).contiguous()
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
# Get sinks parameter from model if available
sinks = None
if hasattr(self, "sinks"):
# If sinks is a model parameter, use it directly
sinks = self.sinks
# Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim)
attn_output = torch.ops.auto_deploy.torch_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=True,
scale=self.scaling,
sinks=sinks,
sliding_window=sliding_window,
layout="bsnd",
)
# Reshape back to original input shape
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, past_key_value
_from_config_original = AutoModelForCausalLM.from_config
CUSTOM_MODULE_PATCHES: Dict[str, Callable] = {
"GptOssAttention": gpt_oss_attention,
}
def get_model_from_config_patched(config, **kwargs):
model = _from_config_original(config, **kwargs)
# Patch modules
for _, module in model.named_modules():
if type(module).__name__ in CUSTOM_MODULE_PATCHES.keys():
# Replace the forward method
module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module)
return model
AutoModelForCausalLM.from_config = get_model_from_config_patched

View File

@ -12,6 +12,7 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...custom_ops.attention_interface import AttentionDescriptor, Constant
from ...export.library.unified_attn import HF_ATTN_KWARGS_MAPPING
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
@ -39,16 +40,8 @@ def fake_profiler_mha(
# construct kwargs for bsnd_grouped_sdpa
node_kwargs = {"attn_mask": attention_mask, "is_causal": is_causal}
kwargs_to_op = {
"dropout": "dropout_p",
"scaling": "scale",
"scale": "scale",
"s_aux": "sinks",
"sinks": "sinks",
"sliding_window": "sliding_window",
"logit_cap": "logit_cap",
}
for k_kwargs, k_op_kwargs in kwargs_to_op.items():
for k_kwargs, k_op_kwargs in HF_ATTN_KWARGS_MAPPING.items():
if k_kwargs in kwargs:
node_kwargs[k_op_kwargs] = kwargs[k_kwargs]