From 0ead17bb85db6baf23ddbdf8d876fb70f4853568 Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Tue, 27 Jan 2026 20:47:48 +0800 Subject: [PATCH] =?UTF-8?q?[https://nvbugs/5800646][fix]=20Fix=20hang=20is?= =?UTF-8?q?sue=20by=20avoid=20exposing=20UB=20buf=E2=80=A6=20(#10842)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- .../compilation/patterns/ar_residual_norm.py | 62 +++++++++++++++++++ tensorrt_llm/_torch/modules/linear.py | 6 +- tests/integration/test_lists/waives.txt | 1 - .../_torch/multi_gpu/test_user_buffers.py | 45 ++++++++++---- 4 files changed, 98 insertions(+), 16 deletions(-) diff --git a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py index 55e79f72a1..3d2b2e9624 100644 --- a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py +++ b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py @@ -713,6 +713,65 @@ def register_ub_patterns(custom_passes: List[PatternMatcherPass], search_fn_pattern=trtllm_allreduce_default, ) + def insert_copy_for_graph_output(custom_pass: PatternMatcherPass): + trtllm_allreduce_default = CallFunction( + torch.ops.trtllm.allreduce.default, KeywordArg("input"), + KeywordArg("residual"), KeywordArg("gamma"), KeywordArg("scale"), + None, None, mapping.tp_group, int(AllReduceStrategy.UB), + KeywordArg("fusion_op"), KeywordArg("eps"), Ignored()) + + def empty_copy_for_graph_output_pattern( + input: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + scale: Optional[torch.Tensor], + fusion_op: int, + eps: float, + ): + return + + def target_copy_for_graph_output_pattern( + input: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + scale: Optional[torch.Tensor], + fusion_op: int, + eps: float, + ): + allreduce_output = torch.ops.trtllm.allreduce( + input, residual, gamma, scale, None, None, mapping.tp_group, + int(AllReduceStrategy.UB), fusion_op, eps, False) + non_ub_tensor = torch.empty_like(allreduce_output[0]) + non_ub_tensor.copy_(allreduce_output[0]) + allreduce_output[0] = non_ub_tensor + return allreduce_output + + def extra_check(match: Match) -> bool: + ar_node = match.ctx.pattern_to_node[trtllm_allreduce_default] + assert isinstance(ar_node, torch.fx.graph.Node) + for user_node in ar_node.users: + if not isinstance(user_node, torch.fx.graph.Node): + continue + if user_node.op == "call_function" and user_node.target == getitem and user_node.args[ + 1] == 0: + # Check whether the getitem is connected to output + for getitem_user in user_node.users: + if not isinstance(getitem_user, torch.fx.graph.Node): + continue + if getitem_user.op == "output": + return True + return False + + register_replacement( + empty_copy_for_graph_output_pattern, + target_copy_for_graph_output_pattern, + [], + fwd_only, + custom_pass, + search_fn_pattern=trtllm_allreduce_default, + extra_check=extra_check, + ) + custom_passes.append(PatternMatcherPass()) register_convert_supported_ar_to_ub(custom_passes[-1]) @@ -722,6 +781,9 @@ def register_ub_patterns(custom_passes: List[PatternMatcherPass], custom_passes.append(PatternMatcherPass()) register_ub_finalize_patterns(custom_passes[-1]) + custom_passes.append(PatternMatcherPass()) + insert_copy_for_graph_output(custom_passes[-1]) + def register_ar_fusions(custom_passes: List[PatternMatcherPass], mapping: Mapping, enable_ub: bool): diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 61ce2de2b7..1bf48e1821 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -2353,6 +2353,7 @@ class Linear(nn.Module): disable_deep_gemm: bool = False, fused_weight_shard_indices_mapping: Optional[dict] = None, nvfp4_allowed_backends: Optional[List[str]] = None, + enable_gemm_allreduce_fusion: bool = True, ): """ Args: @@ -2430,13 +2431,14 @@ class Linear(nn.Module): ) device_supported = get_sm_version() >= 100 - enable_gemm_allreduce_fusion = (os.environ.get( + enable_gemm_allreduce_fusion_env = (os.environ.get( "TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "0") == "1") self.use_fused_gemm_allreduce = all([ self.reduce_output, mpi_enabled, dtype_supported, in_features_aligned, out_features_aligned, tp_valid, quant_valid, - device_supported, enable_gemm_allreduce_fusion + device_supported, enable_gemm_allreduce_fusion, + enable_gemm_allreduce_fusion_env ]) if self.use_fused_gemm_allreduce: self.use_fused_gemm_allreduce = ipc_nvls_supported() diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 00d88b7686..4aa5bee47d 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -257,7 +257,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[ accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=False] SKIP (https://nvbugs/5787892) accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_prefill[tp8ep8-cuda_graph=False] SKIP (https://nvbugs/5795918) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5800591) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5800646) full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5800672) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5853997) examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (https://nvbugs/5802248) diff --git a/tests/unittest/_torch/multi_gpu/test_user_buffers.py b/tests/unittest/_torch/multi_gpu/test_user_buffers.py index 00abc2229b..eb5051afcb 100644 --- a/tests/unittest/_torch/multi_gpu/test_user_buffers.py +++ b/tests/unittest/_torch/multi_gpu/test_user_buffers.py @@ -330,35 +330,40 @@ class UBTestModel(nn.Module): dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.l1 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.l2 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.l3 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.l4 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda() self.norm1 = RMSNorm(hidden_size=hidden_size, eps=eps, @@ -460,7 +465,10 @@ def run_single_rank_ub_pass( # 3 AR_NORM replacement # 3 Scaled MM Prologue # 2 UB Finalize Removal - assert backend.match_count == [3, 0, 2, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0] + # 1 Insert copy for graph output + assert backend.match_count == [ + 3, 0, 2, 0, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0, 1, 0 + ] torch.cuda.synchronize() if rank == 0: @@ -759,7 +767,10 @@ def run_single_rank_ub_mm_add_pass(tensor_parallel_size, num_tokens, # 3 AR_NORM replacement # 3 Prologue # 1 UB Finalize Removal - assert backend.match_count == [3, 0, 0, 0, 0, 0, 3, 0, 3, 0, 1, 0] + # 1 Insert copy for graph output + assert backend.match_count == [ + 3, 0, 0, 0, 0, 0, 0, 3, 0, 3, 0, 1, 0, 1, 0 + ] torch.cuda.synchronize() if rank == 0: @@ -819,35 +830,40 @@ class UBFp4TestModel(nn.Module): dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.l1 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.l2 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.l3 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.l4 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, - quant_config=quant_config).cuda() + quant_config=quant_config, + enable_gemm_allreduce_fusion=False).cuda() self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda() self.norm1 = RMSNorm(hidden_size=hidden_size, eps=eps, @@ -993,7 +1009,10 @@ def run_single_rank_ub_pass_fp4( # 3 AR_NORM replacement # 3 Scaled MM Prologue # 2 UB Finalize Removal - assert backend.match_count == [3, 0, 2, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0] + # 1 Insert copy for graph output + assert backend.match_count == [ + 3, 0, 2, 0, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0, 1, 0 + ] torch.cuda.synchronize() torch.testing.assert_close(output_fused, output_ref,