mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
This commit is contained in:
parent
32e1ad68e1
commit
23920223ab
@ -0,0 +1,81 @@
|
||||
"""Patch for torch.export.export to detect and replace hf attention_interface with unified attention."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.export as te
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from ..interface import BaseExportPatch, ExportPatchRegistry
|
||||
|
||||
# Kwargs mapping for HF attention_interface to auto_deploy::torch_attention
|
||||
HF_ATTN_KWARGS_MAPPING = {
|
||||
"dropout": "dropout_p",
|
||||
"is_causal": "is_causal",
|
||||
"scaling": "scale",
|
||||
"scale": "scale",
|
||||
"s_aux": "sinks",
|
||||
"sinks": "sinks",
|
||||
"sliding_window": "sliding_window",
|
||||
"logit_cap": "logit_cap",
|
||||
}
|
||||
|
||||
|
||||
def torch_attention_hf_wrapper(
|
||||
self: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
"""Wrapper of auto_deploy::torch_attention with HF attention_interface signature."""
|
||||
|
||||
# Convert from [batch, num_heads, seq_len, head_dim] to [batch, seq_len, num_heads, head_dim]
|
||||
query_states = query.transpose(1, 2)
|
||||
key_states = key.transpose(1, 2)
|
||||
value_states = value.transpose(1, 2)
|
||||
|
||||
ad_attn_kwargs = {
|
||||
HF_ATTN_KWARGS_MAPPING[k]: v for k, v in kwargs.items() if k in HF_ATTN_KWARGS_MAPPING
|
||||
}
|
||||
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
layout="bsnd",
|
||||
**ad_attn_kwargs,
|
||||
)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
@ExportPatchRegistry.register("unified_attn")
|
||||
class UnifiedAttnPatch(BaseExportPatch):
|
||||
"""
|
||||
Patch on torch.export.export to replace attention_interface with torch.ops.auto_deploy.torch_attention.
|
||||
"""
|
||||
|
||||
def _apply_patch(self):
|
||||
"""Apply the te.export patch."""
|
||||
# Store original torch.export.export
|
||||
self.original_values["te.export"] = te.export
|
||||
|
||||
# Register the wrapper function
|
||||
ALL_ATTENTION_FUNCTIONS.register("ad_unified_attn", torch_attention_hf_wrapper)
|
||||
|
||||
def _export_with_unified_attn(model, *args, **kwargs):
|
||||
# torch_export_to_gm is called at both export stage and attn matching stage
|
||||
# we only patch attn implementation for export stage
|
||||
if hasattr(model, "config") and hasattr(model.config, "_attn_implementation"):
|
||||
model.config._attn_implementation = "ad_unified_attn"
|
||||
return self.original_values["te.export"](model, *args, **kwargs)
|
||||
|
||||
# Apply patch
|
||||
te.export = _export_with_unified_attn
|
||||
|
||||
def _revert_patch(self):
|
||||
"""Revert the te.export patch."""
|
||||
te.export = self.original_values["te.export"]
|
||||
@ -1,16 +0,0 @@
|
||||
import re
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
_orig_from_pretrained = AutoConfig.from_pretrained
|
||||
|
||||
|
||||
def _from_pretrained_patched(pretrained_model_name_or_path, **kwargs):
|
||||
print(str(pretrained_model_name_or_path))
|
||||
if re.search(r"Llama-3_(?:1|3)-Nemotron-(?:Ultra|Super)", str(pretrained_model_name_or_path)):
|
||||
kwargs["attn_implementation"] = "eager"
|
||||
return _orig_from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
|
||||
# TODO: figure out how this can be incorporated into the export patch system
|
||||
AutoConfig.from_pretrained = _from_pretrained_patched
|
||||
@ -1,100 +0,0 @@
|
||||
import types
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
||||
|
||||
|
||||
def gpt_oss_attention(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""GPT OSS Attention forward function rewritten to wrap attention as a custom op."""
|
||||
from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb
|
||||
|
||||
# Add new parameters
|
||||
sliding_window = getattr(self, "sliding_window", -1) # Default to -1 if not present
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
# Apply Q, K, V projections (same as original)
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Use original rope implementation
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Handle KV cache properly
|
||||
if past_key_value is not None:
|
||||
# Update KV cache - check if it has update method (modern cache objects)
|
||||
if hasattr(past_key_value, "update"):
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
else:
|
||||
# Handle legacy tuple-based cache
|
||||
if isinstance(past_key_value, tuple) and len(past_key_value) == 2:
|
||||
past_key, past_value = past_key_value
|
||||
key_states = torch.cat([past_key, key_states], dim=2)
|
||||
value_states = torch.cat([past_value, value_states], dim=2)
|
||||
|
||||
# Convert from [batch, num_heads, seq_len, head_dim] to [batch, seq_len, num_heads, head_dim]
|
||||
query_states = query_states.transpose(1, 2).contiguous()
|
||||
key_states = key_states.transpose(1, 2).contiguous()
|
||||
value_states = value_states.transpose(1, 2).contiguous()
|
||||
|
||||
# Get sinks parameter from model if available
|
||||
sinks = None
|
||||
if hasattr(self, "sinks"):
|
||||
# If sinks is a model parameter, use it directly
|
||||
sinks = self.sinks
|
||||
|
||||
# Use custom op to capture attention. This layout is bsnd (batch, seq, num_heads, head_dim)
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=True,
|
||||
scale=self.scaling,
|
||||
sinks=sinks,
|
||||
sliding_window=sliding_window,
|
||||
layout="bsnd",
|
||||
)
|
||||
|
||||
# Reshape back to original input shape
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, past_key_value
|
||||
|
||||
|
||||
_from_config_original = AutoModelForCausalLM.from_config
|
||||
|
||||
CUSTOM_MODULE_PATCHES: Dict[str, Callable] = {
|
||||
"GptOssAttention": gpt_oss_attention,
|
||||
}
|
||||
|
||||
|
||||
def get_model_from_config_patched(config, **kwargs):
|
||||
model = _from_config_original(config, **kwargs)
|
||||
# Patch modules
|
||||
for _, module in model.named_modules():
|
||||
if type(module).__name__ in CUSTOM_MODULE_PATCHES.keys():
|
||||
# Replace the forward method
|
||||
module.forward = types.MethodType(CUSTOM_MODULE_PATCHES[type(module).__name__], module)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
AutoModelForCausalLM.from_config = get_model_from_config_patched
|
||||
@ -12,6 +12,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from ...custom_ops.attention_interface import AttentionDescriptor, Constant
|
||||
from ...export.library.unified_attn import HF_ATTN_KWARGS_MAPPING
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
@ -39,16 +40,8 @@ def fake_profiler_mha(
|
||||
|
||||
# construct kwargs for bsnd_grouped_sdpa
|
||||
node_kwargs = {"attn_mask": attention_mask, "is_causal": is_causal}
|
||||
kwargs_to_op = {
|
||||
"dropout": "dropout_p",
|
||||
"scaling": "scale",
|
||||
"scale": "scale",
|
||||
"s_aux": "sinks",
|
||||
"sinks": "sinks",
|
||||
"sliding_window": "sliding_window",
|
||||
"logit_cap": "logit_cap",
|
||||
}
|
||||
for k_kwargs, k_op_kwargs in kwargs_to_op.items():
|
||||
|
||||
for k_kwargs, k_op_kwargs in HF_ATTN_KWARGS_MAPPING.items():
|
||||
if k_kwargs in kwargs:
|
||||
node_kwargs[k_op_kwargs] = kwargs[k_kwargs]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user