[fix] Fix Llama4 allgather error due to None tensor (#4511)

* [fix] Fix Llama4 allgather error due to None tensor

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>

* Refactor modifications

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>

* Minor modification

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>

* Minor fix

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>

---------

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
Jinyang Yuan 2025-05-24 19:12:12 +08:00 committed by GitHub
parent ad4d947b24
commit f9a9a1af2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 51 additions and 19 deletions

View File

@ -1,6 +1,7 @@
import math
import os
import threading
from itertools import accumulate
from typing import List, Optional, Tuple, Union
import torch
@ -62,6 +63,24 @@ def get_output_info(input: torch.Tensor, dim: int) -> List[int]:
return {'output_shape': output_shape, 'numel_base': numel_base}
def filter_valid_input(
input_list: List[torch.Tensor]
) -> Tuple[List[torch.Tensor], List[bool]]:
func_valid = lambda x: x is not None
valid_list = list(map(func_valid, input_list))
input_list = list(filter(func_valid, input_list))
return input_list, valid_list
def restore_full_output(output_list: List[torch.Tensor],
valid_list: List[bool]) -> List[torch.Tensor]:
index_list = list(accumulate(map(int, valid_list)))
output_list = list(
map(lambda valid, index: output_list[index - 1]
if valid else None, valid_list, index_list))
return output_list
def allgather(
input: Union[torch.Tensor, List[torch.Tensor]],
mapping: Mapping,
@ -101,8 +120,10 @@ def allgather(
if isinstance(input, torch.Tensor):
assert input.shape[dim] == sizes[mapping.tp_rank]
else:
assert all(
[val.shape[dim] == sizes[mapping.tp_rank] for val in input])
assert all([
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
# 'sizes' is not needed if all inputs in the same TP group have the same shape
for split_size in sizes[1:]:
if split_size != sizes[0]:
@ -116,6 +137,7 @@ def allgather(
output_info = get_output_info(input, dim)
input = input.contiguous().view(-1, output_info['numel_base'])
else:
input, valid = filter_valid_input(input)
torch_op = torch.ops.trtllm.allgather_list
output_info = [get_output_info(val, dim) for val in input]
input = [
@ -148,6 +170,7 @@ def allgather(
convert_output(val, val_info)
for val, val_info in zip(output, output_info)
]
output = restore_full_output(output, valid)
return output
@ -166,7 +189,10 @@ def reducescatter(
if isinstance(input, torch.Tensor):
assert input.shape[dim] == sum_split_size
else:
assert all([val.shape[dim] == sum_split_size for val in input])
assert all([
val.shape[dim] == sum_split_size for val in input
if val is not None
])
# 'sizes' is not needed if all outputs in the same TP group have the same shape
for split_size in sizes[1:]:
if split_size != sizes[0]:
@ -191,6 +217,7 @@ def reducescatter(
output_info = get_output_info(input, dim)
input = convert_input(input, output_info)
else:
input, valid = filter_valid_input(input)
torch_op = torch.ops.trtllm.reducescatter_list
output_info = [get_output_info(val, dim) for val in input]
input = [
@ -211,6 +238,7 @@ def reducescatter(
val.view(val_info['output_shape'])
for val, val_info in zip(output, output_info)
]
output = restore_full_output(output, valid)
return output

View File

@ -821,19 +821,13 @@ class FusedMoE(nn.Module):
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
) and not self.enable_alltoall:
if x_sf is None:
x, token_selected_experts, token_final_scales = allgather(
[x, token_selected_experts, token_final_scales],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
else:
# Fp4 gemm has extra scaling factor
x, x_sf, token_selected_experts, token_final_scales = allgather(
[x, x_sf, token_selected_experts, token_final_scales],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
x, x_sf, token_selected_experts, token_final_scales = allgather(
[x, x_sf, token_selected_experts, token_final_scales],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
# Fp4 gemm has extra scaling factor
if x_sf is not None:
x_sf = reswizzle_sf(x_sf, x_row, x_col,
self.scaling_vector_size)

View File

@ -19,5 +19,6 @@ l0_dgx_h200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] # 1h
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-enable_graph-tp8-trtllm-scout]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout]
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora

View File

@ -432,7 +432,8 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5247271)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugspro.nvidia.com/bug/5273945)
disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5279438)
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin] SKIP (https://nvbugs/5247786)
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)

View File

@ -19,15 +19,22 @@ from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@pytest.mark.parametrize("tp_size", [1, 8], ids=["tp1", "tp8"])
@pytest.mark.parametrize("use_cuda_graph", [True, False],
ids=["enable_graph", "disable_graph"])
@pytest.mark.parametrize("enable_attention_dp", [True, False],
ids=["enable_adp", "disable_adp"])
@pytest.mark.parametrize("ep_size", [4, 1], ids=["ep4", "ep1"])
@pytest.mark.parametrize("pp_size", [1, 8], ids=["pp1", "pp8"])
def test_llama4(model_name, backend, tp_size, use_cuda_graph, ep_size, pp_size):
def test_llama4(model_name, backend, tp_size, use_cuda_graph,
enable_attention_dp, ep_size, pp_size):
if pp_size > 1 and (ep_size > 1 or tp_size > 1):
return
if pp_size == 1 and tp_size == 1:
return
if enable_attention_dp and not (tp_size == 8 and ep_size == 4
and pp_size == 1):
pytest.skip("Skip this attention DP test case to avoid too many tests")
prompts = [{
"prompt": "The president of the United States is"
}, {
@ -52,6 +59,7 @@ def test_llama4(model_name, backend, tp_size, use_cuda_graph, ep_size, pp_size):
moe_tensor_parallel_size=tp_size // ep_size,
pytorch_backend_config=pytorch_config,
pipeline_parallel_size=pp_size,
enable_attention_dp=enable_attention_dp,
)
with llm:
outputs = llm.generate(