mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][fix] fix Qwen2/3 export for AutoDeploy (#11007)
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
4e10bf8950
commit
f03908cf9e
@ -22,7 +22,7 @@ HF_ATTN_KWARGS_MAPPING = {
|
||||
|
||||
|
||||
def torch_attention_hf_wrapper(
|
||||
self: torch.nn.Module,
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@ -40,6 +40,13 @@ def torch_attention_hf_wrapper(
|
||||
HF_ATTN_KWARGS_MAPPING[k]: v for k, v in kwargs.items() if k in HF_ATTN_KWARGS_MAPPING
|
||||
}
|
||||
|
||||
# Handle is_causal logic to match HF's SDPA behavior exactly.
|
||||
# See: transformers.integrations.sdpa_attention.sdpa_attention_forward
|
||||
is_causal = kwargs.get("is_causal", None)
|
||||
if is_causal is None:
|
||||
is_causal = getattr(module, "is_causal", True)
|
||||
ad_attn_kwargs["is_causal"] = is_causal
|
||||
|
||||
attn_output = torch.ops.auto_deploy.torch_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
@ -70,6 +77,8 @@ class UnifiedAttnPatch(BaseExportPatch):
|
||||
# torch_export_to_gm is called at both export stage and attn matching stage
|
||||
# we only patch attn implementation for export stage
|
||||
if hasattr(model, "config") and hasattr(model.config, "_attn_implementation"):
|
||||
self.original_values["model_config"] = model.config
|
||||
self.original_values["_attn_implementation"] = model.config._attn_implementation
|
||||
model.config._attn_implementation = "ad_unified_attn"
|
||||
return self.original_values["te.export"](model, *args, **kwargs)
|
||||
|
||||
@ -79,3 +88,12 @@ class UnifiedAttnPatch(BaseExportPatch):
|
||||
def _revert_patch(self):
|
||||
"""Revert the te.export patch."""
|
||||
te.export = self.original_values["te.export"]
|
||||
|
||||
# Restore original _attn_implementation if we modified it
|
||||
if (
|
||||
"model_config" in self.original_values
|
||||
and "_attn_implementation" in self.original_values
|
||||
):
|
||||
self.original_values["model_config"]._attn_implementation = self.original_values[
|
||||
"_attn_implementation"
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user