mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[fix] Fix Llama4 allgather error due to None tensor (#4511)
* [fix] Fix Llama4 allgather error due to None tensor Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> * Refactor modifications Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> * Minor modification Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> * Minor fix Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> --------- Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
parent
ad4d947b24
commit
f9a9a1af2e
@ -1,6 +1,7 @@
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
from itertools import accumulate
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -62,6 +63,24 @@ def get_output_info(input: torch.Tensor, dim: int) -> List[int]:
|
||||
return {'output_shape': output_shape, 'numel_base': numel_base}
|
||||
|
||||
|
||||
def filter_valid_input(
|
||||
input_list: List[torch.Tensor]
|
||||
) -> Tuple[List[torch.Tensor], List[bool]]:
|
||||
func_valid = lambda x: x is not None
|
||||
valid_list = list(map(func_valid, input_list))
|
||||
input_list = list(filter(func_valid, input_list))
|
||||
return input_list, valid_list
|
||||
|
||||
|
||||
def restore_full_output(output_list: List[torch.Tensor],
|
||||
valid_list: List[bool]) -> List[torch.Tensor]:
|
||||
index_list = list(accumulate(map(int, valid_list)))
|
||||
output_list = list(
|
||||
map(lambda valid, index: output_list[index - 1]
|
||||
if valid else None, valid_list, index_list))
|
||||
return output_list
|
||||
|
||||
|
||||
def allgather(
|
||||
input: Union[torch.Tensor, List[torch.Tensor]],
|
||||
mapping: Mapping,
|
||||
@ -101,8 +120,10 @@ def allgather(
|
||||
if isinstance(input, torch.Tensor):
|
||||
assert input.shape[dim] == sizes[mapping.tp_rank]
|
||||
else:
|
||||
assert all(
|
||||
[val.shape[dim] == sizes[mapping.tp_rank] for val in input])
|
||||
assert all([
|
||||
val.shape[dim] == sizes[mapping.tp_rank] for val in input
|
||||
if val is not None
|
||||
])
|
||||
# 'sizes' is not needed if all inputs in the same TP group have the same shape
|
||||
for split_size in sizes[1:]:
|
||||
if split_size != sizes[0]:
|
||||
@ -116,6 +137,7 @@ def allgather(
|
||||
output_info = get_output_info(input, dim)
|
||||
input = input.contiguous().view(-1, output_info['numel_base'])
|
||||
else:
|
||||
input, valid = filter_valid_input(input)
|
||||
torch_op = torch.ops.trtllm.allgather_list
|
||||
output_info = [get_output_info(val, dim) for val in input]
|
||||
input = [
|
||||
@ -148,6 +170,7 @@ def allgather(
|
||||
convert_output(val, val_info)
|
||||
for val, val_info in zip(output, output_info)
|
||||
]
|
||||
output = restore_full_output(output, valid)
|
||||
return output
|
||||
|
||||
|
||||
@ -166,7 +189,10 @@ def reducescatter(
|
||||
if isinstance(input, torch.Tensor):
|
||||
assert input.shape[dim] == sum_split_size
|
||||
else:
|
||||
assert all([val.shape[dim] == sum_split_size for val in input])
|
||||
assert all([
|
||||
val.shape[dim] == sum_split_size for val in input
|
||||
if val is not None
|
||||
])
|
||||
# 'sizes' is not needed if all outputs in the same TP group have the same shape
|
||||
for split_size in sizes[1:]:
|
||||
if split_size != sizes[0]:
|
||||
@ -191,6 +217,7 @@ def reducescatter(
|
||||
output_info = get_output_info(input, dim)
|
||||
input = convert_input(input, output_info)
|
||||
else:
|
||||
input, valid = filter_valid_input(input)
|
||||
torch_op = torch.ops.trtllm.reducescatter_list
|
||||
output_info = [get_output_info(val, dim) for val in input]
|
||||
input = [
|
||||
@ -211,6 +238,7 @@ def reducescatter(
|
||||
val.view(val_info['output_shape'])
|
||||
for val, val_info in zip(output, output_info)
|
||||
]
|
||||
output = restore_full_output(output, valid)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
@ -821,19 +821,13 @@ class FusedMoE(nn.Module):
|
||||
|
||||
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
|
||||
) and not self.enable_alltoall:
|
||||
if x_sf is None:
|
||||
x, token_selected_experts, token_final_scales = allgather(
|
||||
[x, token_selected_experts, token_final_scales],
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
else:
|
||||
# Fp4 gemm has extra scaling factor
|
||||
x, x_sf, token_selected_experts, token_final_scales = allgather(
|
||||
[x, x_sf, token_selected_experts, token_final_scales],
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
x, x_sf, token_selected_experts, token_final_scales = allgather(
|
||||
[x, x_sf, token_selected_experts, token_final_scales],
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
# Fp4 gemm has extra scaling factor
|
||||
if x_sf is not None:
|
||||
x_sf = reswizzle_sf(x_sf, x_row, x_col,
|
||||
self.scaling_vector_size)
|
||||
|
||||
|
||||
@ -19,5 +19,6 @@ l0_dgx_h200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] # 1h
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
|
||||
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-enable_graph-tp8-trtllm-scout]
|
||||
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout]
|
||||
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout]
|
||||
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora
|
||||
|
||||
@ -432,7 +432,8 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5247271)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugspro.nvidia.com/bug/5273945)
|
||||
disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5279438)
|
||||
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
|
||||
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
|
||||
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin] SKIP (https://nvbugs/5247786)
|
||||
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
|
||||
@ -19,15 +19,22 @@ from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
@pytest.mark.parametrize("tp_size", [1, 8], ids=["tp1", "tp8"])
|
||||
@pytest.mark.parametrize("use_cuda_graph", [True, False],
|
||||
ids=["enable_graph", "disable_graph"])
|
||||
@pytest.mark.parametrize("enable_attention_dp", [True, False],
|
||||
ids=["enable_adp", "disable_adp"])
|
||||
@pytest.mark.parametrize("ep_size", [4, 1], ids=["ep4", "ep1"])
|
||||
@pytest.mark.parametrize("pp_size", [1, 8], ids=["pp1", "pp8"])
|
||||
def test_llama4(model_name, backend, tp_size, use_cuda_graph, ep_size, pp_size):
|
||||
def test_llama4(model_name, backend, tp_size, use_cuda_graph,
|
||||
enable_attention_dp, ep_size, pp_size):
|
||||
if pp_size > 1 and (ep_size > 1 or tp_size > 1):
|
||||
return
|
||||
|
||||
if pp_size == 1 and tp_size == 1:
|
||||
return
|
||||
|
||||
if enable_attention_dp and not (tp_size == 8 and ep_size == 4
|
||||
and pp_size == 1):
|
||||
pytest.skip("Skip this attention DP test case to avoid too many tests")
|
||||
|
||||
prompts = [{
|
||||
"prompt": "The president of the United States is"
|
||||
}, {
|
||||
@ -52,6 +59,7 @@ def test_llama4(model_name, backend, tp_size, use_cuda_graph, ep_size, pp_size):
|
||||
moe_tensor_parallel_size=tp_size // ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
pipeline_parallel_size=pp_size,
|
||||
enable_attention_dp=enable_attention_dp,
|
||||
)
|
||||
with llm:
|
||||
outputs = llm.generate(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user