mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[https://nvbugs/5800646][fix] Fix hang issue by avoid exposing UB buf… (#10842)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
d348dd95a7
commit
0ead17bb85
@ -713,6 +713,65 @@ def register_ub_patterns(custom_passes: List[PatternMatcherPass],
|
||||
search_fn_pattern=trtllm_allreduce_default,
|
||||
)
|
||||
|
||||
def insert_copy_for_graph_output(custom_pass: PatternMatcherPass):
|
||||
trtllm_allreduce_default = CallFunction(
|
||||
torch.ops.trtllm.allreduce.default, KeywordArg("input"),
|
||||
KeywordArg("residual"), KeywordArg("gamma"), KeywordArg("scale"),
|
||||
None, None, mapping.tp_group, int(AllReduceStrategy.UB),
|
||||
KeywordArg("fusion_op"), KeywordArg("eps"), Ignored())
|
||||
|
||||
def empty_copy_for_graph_output_pattern(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
scale: Optional[torch.Tensor],
|
||||
fusion_op: int,
|
||||
eps: float,
|
||||
):
|
||||
return
|
||||
|
||||
def target_copy_for_graph_output_pattern(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
gamma: torch.Tensor,
|
||||
scale: Optional[torch.Tensor],
|
||||
fusion_op: int,
|
||||
eps: float,
|
||||
):
|
||||
allreduce_output = torch.ops.trtllm.allreduce(
|
||||
input, residual, gamma, scale, None, None, mapping.tp_group,
|
||||
int(AllReduceStrategy.UB), fusion_op, eps, False)
|
||||
non_ub_tensor = torch.empty_like(allreduce_output[0])
|
||||
non_ub_tensor.copy_(allreduce_output[0])
|
||||
allreduce_output[0] = non_ub_tensor
|
||||
return allreduce_output
|
||||
|
||||
def extra_check(match: Match) -> bool:
|
||||
ar_node = match.ctx.pattern_to_node[trtllm_allreduce_default]
|
||||
assert isinstance(ar_node, torch.fx.graph.Node)
|
||||
for user_node in ar_node.users:
|
||||
if not isinstance(user_node, torch.fx.graph.Node):
|
||||
continue
|
||||
if user_node.op == "call_function" and user_node.target == getitem and user_node.args[
|
||||
1] == 0:
|
||||
# Check whether the getitem is connected to output
|
||||
for getitem_user in user_node.users:
|
||||
if not isinstance(getitem_user, torch.fx.graph.Node):
|
||||
continue
|
||||
if getitem_user.op == "output":
|
||||
return True
|
||||
return False
|
||||
|
||||
register_replacement(
|
||||
empty_copy_for_graph_output_pattern,
|
||||
target_copy_for_graph_output_pattern,
|
||||
[],
|
||||
fwd_only,
|
||||
custom_pass,
|
||||
search_fn_pattern=trtllm_allreduce_default,
|
||||
extra_check=extra_check,
|
||||
)
|
||||
|
||||
custom_passes.append(PatternMatcherPass())
|
||||
register_convert_supported_ar_to_ub(custom_passes[-1])
|
||||
|
||||
@ -722,6 +781,9 @@ def register_ub_patterns(custom_passes: List[PatternMatcherPass],
|
||||
custom_passes.append(PatternMatcherPass())
|
||||
register_ub_finalize_patterns(custom_passes[-1])
|
||||
|
||||
custom_passes.append(PatternMatcherPass())
|
||||
insert_copy_for_graph_output(custom_passes[-1])
|
||||
|
||||
|
||||
def register_ar_fusions(custom_passes: List[PatternMatcherPass],
|
||||
mapping: Mapping, enable_ub: bool):
|
||||
|
||||
@ -2353,6 +2353,7 @@ class Linear(nn.Module):
|
||||
disable_deep_gemm: bool = False,
|
||||
fused_weight_shard_indices_mapping: Optional[dict] = None,
|
||||
nvfp4_allowed_backends: Optional[List[str]] = None,
|
||||
enable_gemm_allreduce_fusion: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -2430,13 +2431,14 @@ class Linear(nn.Module):
|
||||
)
|
||||
|
||||
device_supported = get_sm_version() >= 100
|
||||
enable_gemm_allreduce_fusion = (os.environ.get(
|
||||
enable_gemm_allreduce_fusion_env = (os.environ.get(
|
||||
"TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "0") == "1")
|
||||
|
||||
self.use_fused_gemm_allreduce = all([
|
||||
self.reduce_output, mpi_enabled, dtype_supported,
|
||||
in_features_aligned, out_features_aligned, tp_valid, quant_valid,
|
||||
device_supported, enable_gemm_allreduce_fusion
|
||||
device_supported, enable_gemm_allreduce_fusion,
|
||||
enable_gemm_allreduce_fusion_env
|
||||
])
|
||||
if self.use_fused_gemm_allreduce:
|
||||
self.use_fused_gemm_allreduce = ipc_nvls_supported()
|
||||
|
||||
@ -257,7 +257,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=False] SKIP (https://nvbugs/5787892)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_chunked_prefill[tp8ep8-cuda_graph=False] SKIP (https://nvbugs/5795918)
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5800591)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5800646)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5800672)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] SKIP (https://nvbugs/5853997)
|
||||
examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (https://nvbugs/5802248)
|
||||
|
||||
@ -330,35 +330,40 @@ class UBTestModel(nn.Module):
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.l1 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.l2 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.l3 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.l4 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps,
|
||||
dtype=dtype).cuda()
|
||||
self.norm1 = RMSNorm(hidden_size=hidden_size, eps=eps,
|
||||
@ -460,7 +465,10 @@ def run_single_rank_ub_pass(
|
||||
# 3 AR_NORM replacement
|
||||
# 3 Scaled MM Prologue
|
||||
# 2 UB Finalize Removal
|
||||
assert backend.match_count == [3, 0, 2, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0]
|
||||
# 1 Insert copy for graph output
|
||||
assert backend.match_count == [
|
||||
3, 0, 2, 0, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0, 1, 0
|
||||
]
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if rank == 0:
|
||||
@ -759,7 +767,10 @@ def run_single_rank_ub_mm_add_pass(tensor_parallel_size, num_tokens,
|
||||
# 3 AR_NORM replacement
|
||||
# 3 Prologue
|
||||
# 1 UB Finalize Removal
|
||||
assert backend.match_count == [3, 0, 0, 0, 0, 0, 3, 0, 3, 0, 1, 0]
|
||||
# 1 Insert copy for graph output
|
||||
assert backend.match_count == [
|
||||
3, 0, 0, 0, 0, 0, 0, 3, 0, 3, 0, 1, 0, 1, 0
|
||||
]
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if rank == 0:
|
||||
@ -819,35 +830,40 @@ class UBFp4TestModel(nn.Module):
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.l1 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.l2 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.l3 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.l4 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config).cuda()
|
||||
quant_config=quant_config,
|
||||
enable_gemm_allreduce_fusion=False).cuda()
|
||||
self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps,
|
||||
dtype=dtype).cuda()
|
||||
self.norm1 = RMSNorm(hidden_size=hidden_size, eps=eps,
|
||||
@ -993,7 +1009,10 @@ def run_single_rank_ub_pass_fp4(
|
||||
# 3 AR_NORM replacement
|
||||
# 3 Scaled MM Prologue
|
||||
# 2 UB Finalize Removal
|
||||
assert backend.match_count == [3, 0, 2, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0]
|
||||
# 1 Insert copy for graph output
|
||||
assert backend.match_count == [
|
||||
3, 0, 2, 0, 0, 0, 0, 0, 3, 0, 3, 0, 2, 0, 1, 0
|
||||
]
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output_fused,
|
||||
output_ref,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user