From f03908cf9e5c70504059892001a903e2199dfb0b Mon Sep 17 00:00:00 2001 From: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 16:53:21 -0800 Subject: [PATCH] [None][fix] fix Qwen2/3 export for AutoDeploy (#11007) Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- .../export/library/unified_attn.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/export/library/unified_attn.py b/tensorrt_llm/_torch/auto_deploy/export/library/unified_attn.py index cb8ffbe68c..f0999951af 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/library/unified_attn.py +++ b/tensorrt_llm/_torch/auto_deploy/export/library/unified_attn.py @@ -22,7 +22,7 @@ HF_ATTN_KWARGS_MAPPING = { def torch_attention_hf_wrapper( - self: torch.nn.Module, + module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -40,6 +40,13 @@ def torch_attention_hf_wrapper( HF_ATTN_KWARGS_MAPPING[k]: v for k, v in kwargs.items() if k in HF_ATTN_KWARGS_MAPPING } + # Handle is_causal logic to match HF's SDPA behavior exactly. + # See: transformers.integrations.sdpa_attention.sdpa_attention_forward + is_causal = kwargs.get("is_causal", None) + if is_causal is None: + is_causal = getattr(module, "is_causal", True) + ad_attn_kwargs["is_causal"] = is_causal + attn_output = torch.ops.auto_deploy.torch_attention( query_states, key_states, @@ -70,6 +77,8 @@ class UnifiedAttnPatch(BaseExportPatch): # torch_export_to_gm is called at both export stage and attn matching stage # we only patch attn implementation for export stage if hasattr(model, "config") and hasattr(model.config, "_attn_implementation"): + self.original_values["model_config"] = model.config + self.original_values["_attn_implementation"] = model.config._attn_implementation model.config._attn_implementation = "ad_unified_attn" return self.original_values["te.export"](model, *args, **kwargs) @@ -79,3 +88,12 @@ class UnifiedAttnPatch(BaseExportPatch): def _revert_patch(self): """Revert the te.export patch.""" te.export = self.original_values["te.export"] + + # Restore original _attn_implementation if we modified it + if ( + "model_config" in self.original_values + and "_attn_implementation" in self.original_values + ): + self.original_values["model_config"]._attn_implementation = self.original_values[ + "_attn_implementation" + ]