[None][fix] fix Qwen2/3 export for AutoDeploy (#11007)

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Frida Hou 2026-01-28 16:53:21 -08:00 committed by GitHub
parent 4e10bf8950
commit f03908cf9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,7 +22,7 @@ HF_ATTN_KWARGS_MAPPING = {
def torch_attention_hf_wrapper(
self: torch.nn.Module,
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@ -40,6 +40,13 @@ def torch_attention_hf_wrapper(
HF_ATTN_KWARGS_MAPPING[k]: v for k, v in kwargs.items() if k in HF_ATTN_KWARGS_MAPPING
}
# Handle is_causal logic to match HF's SDPA behavior exactly.
# See: transformers.integrations.sdpa_attention.sdpa_attention_forward
is_causal = kwargs.get("is_causal", None)
if is_causal is None:
is_causal = getattr(module, "is_causal", True)
ad_attn_kwargs["is_causal"] = is_causal
attn_output = torch.ops.auto_deploy.torch_attention(
query_states,
key_states,
@ -70,6 +77,8 @@ class UnifiedAttnPatch(BaseExportPatch):
# torch_export_to_gm is called at both export stage and attn matching stage
# we only patch attn implementation for export stage
if hasattr(model, "config") and hasattr(model.config, "_attn_implementation"):
self.original_values["model_config"] = model.config
self.original_values["_attn_implementation"] = model.config._attn_implementation
model.config._attn_implementation = "ad_unified_attn"
return self.original_values["te.export"](model, *args, **kwargs)
@ -79,3 +88,12 @@ class UnifiedAttnPatch(BaseExportPatch):
def _revert_patch(self):
"""Revert the te.export patch."""
te.export = self.original_values["te.export"]
# Restore original _attn_implementation if we modified it
if (
"model_config" in self.original_values
and "_attn_implementation" in self.original_values
):
self.original_values["model_config"]._attn_implementation = self.original_values[
"_attn_implementation"
]