[https://nvbugs/5800646][fix] Fix hang issue by avoid exposing UB buf… (#10842)

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Jin Li 2026-01-27 20:47:48 +08:00 committed by Yanchao Lu
parent d348dd95a7
commit 0ead17bb85
4 changed files with 98 additions and 16 deletions

View File

@ -713,6 +713,65 @@ def register_ub_patterns(custom_passes: List[PatternMatcherPass],
search_fn_pattern=trtllm_allreduce_default,
)
def insert_copy_for_graph_output(custom_pass: PatternMatcherPass):
trtllm_allreduce_default = CallFunction(
torch.ops.trtllm.allreduce.default, KeywordArg("input"),
KeywordArg("residual"), KeywordArg("gamma"), KeywordArg("scale"),
None, None, mapping.tp_group, int(AllReduceStrategy.UB),
KeywordArg("fusion_op"), KeywordArg("eps"), Ignored())
def empty_copy_for_graph_output_pattern(
input: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
scale: Optional[torch.Tensor],
fusion_op: int,
eps: float,
):
return
def target_copy_for_graph_output_pattern(
input: torch.Tensor,
residual: torch.Tensor,
gamma: torch.Tensor,
scale: Optional[torch.Tensor],
fusion_op: int,
eps: float,
):
allreduce_output = torch.ops.trtllm.allreduce(
input, residual, gamma, scale, None, None, mapping.tp_group,
int(AllReduceStrategy.UB), fusion_op, eps, False)
non_ub_tensor = torch.empty_like(allreduce_output[0])
non_ub_tensor.copy_(allreduce_output[0])
allreduce_output[0] = non_ub_tensor
return allreduce_output
def extra_check(match: Match) -> bool:
ar_node = match.ctx.pattern_to_node[trtllm_allreduce_default]
assert isinstance(ar_node, torch.fx.graph.Node)
for user_node in ar_node.users:
if not isinstance(user_node, torch.fx.graph.Node):
continue
if user_node.op == "call_function" and user_node.target == getitem and user_node.args[
1] == 0:
# Check whether the getitem is connected to output
for getitem_user in user_node.users:
if not isinstance(getitem_user, torch.fx.graph.Node):
continue
if getitem_user.op == "output":
return True
return False
register_replacement(
empty_copy_for_graph_output_pattern,
target_copy_for_graph_output_pattern,
[],
fwd_only,
custom_pass,
search_fn_pattern=trtllm_allreduce_default,
extra_check=extra_check,
)
custom_passes.append(PatternMatcherPass())
register_convert_supported_ar_to_ub(custom_passes[-1])
@ -722,6 +781,9 @@ def register_ub_patterns(custom_passes: List[PatternMatcherPass],
custom_passes.append(PatternMatcherPass())
register_ub_finalize_patterns(custom_passes[-1])
custom_passes.append(PatternMatcherPass())
insert_copy_for_graph_output(custom_passes[-1])
def register_ar_fusions(custom_passes: List[PatternMatcherPass],
mapping: Mapping, enable_ub: bool):

View File

@ -2353,6 +2353,7 @@ class Linear(nn.Module):
disable_deep_gemm: bool = False,
fused_weight_shard_indices_mapping: Optional[dict] = None,
nvfp4_allowed_backends: Optional[List[str]] = None,
enable_gemm_allreduce_fusion: bool = True,
):
"""
Args:
@ -2430,13 +2431,14 @@ class Linear(nn.Module):
)
device_supported = get_sm_version() >= 100
enable_gemm_allreduce_fusion = (os.environ.get(
enable_gemm_allreduce_fusion_env = (os.environ.get(
"TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "0") == "1")
self.use_fused_gemm_allreduce = all([
self.reduce_output, mpi_enabled, dtype_supported,
in_features_aligned, out_features_aligned, tp_valid, quant_valid,
device_supported, enable_gemm_allreduce_fusion
device_supported, enable_gemm_allreduce_fusion,
enable_gemm_allreduce_fusion_env
])
if self.use_fused_gemm_allreduce:
self.use_fused_gemm_allreduce = ipc_nvls_supported()

View File

@ -257,7 +257,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=False] SKIP (https://nvbugs/5787892)
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_prefill[tp8ep8-cuda_graph=False] SKIP (https://nvbugs/5795918)
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5800591)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5800646)
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5800672)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5853997)
examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (https://nvbugs/5802248)

View File

@ -330,35 +330,40 @@ class UBTestModel(nn.Module):
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.l1 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.l2 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.l3 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.l4 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps,
dtype=dtype).cuda()
self.norm1 = RMSNorm(hidden_size=hidden_size, eps=eps,
@ -460,7 +465,10 @@ def run_single_rank_ub_pass(
# 3 AR_NORM replacement
# 3 Scaled MM Prologue
# 2 UB Finalize Removal
assert backend.match_count == [3, 0, 2, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0]
# 1 Insert copy for graph output
assert backend.match_count == [
3, 0, 2, 0, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0, 1, 0
]
torch.cuda.synchronize()
if rank == 0:
@ -759,7 +767,10 @@ def run_single_rank_ub_mm_add_pass(tensor_parallel_size, num_tokens,
# 3 AR_NORM replacement
# 3 Prologue
# 1 UB Finalize Removal
assert backend.match_count == [3, 0, 0, 0, 0, 0, 3, 0, 3, 0, 1, 0]
# 1 Insert copy for graph output
assert backend.match_count == [
3, 0, 0, 0, 0, 0, 0, 3, 0, 3, 0, 1, 0, 1, 0
]
torch.cuda.synchronize()
if rank == 0:
@ -819,35 +830,40 @@ class UBFp4TestModel(nn.Module):
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.l1 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.l2 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.l3 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.l4 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config).cuda()
quant_config=quant_config,
enable_gemm_allreduce_fusion=False).cuda()
self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps,
dtype=dtype).cuda()
self.norm1 = RMSNorm(hidden_size=hidden_size, eps=eps,
@ -993,7 +1009,10 @@ def run_single_rank_ub_pass_fp4(
# 3 AR_NORM replacement
# 3 Scaled MM Prologue
# 2 UB Finalize Removal
assert backend.match_count == [3, 0, 2, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0]
# 1 Insert copy for graph output
assert backend.match_count == [
3, 0, 2, 0, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0, 1, 0
]
torch.cuda.synchronize()
torch.testing.assert_close(output_fused,
output_ref,