[None][fix] default disable gemm+allreduce fusion (#10656)

This commit is contained in:
benzh-2025 2026-01-20 12:31:17 +08:00 committed by GitHub
parent 26bc16842e
commit 4c8468c5d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 42 additions and 24 deletions

View File

@ -434,7 +434,7 @@ bool ipcNvlsSupported()
TLLM_CUDA_CHECK(cudaDriverGetVersion(&cuda_driver_version));
if (cuda_driver_version < 12010)
{
TLLM_LOG_ERROR("CUDA Driver version < 12010");
TLLM_LOG_DEBUG("CUDA Driver version < 12010");
return false;
}
@ -448,7 +448,7 @@ bool ipcNvlsSupported()
CUCHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, current_dev));
if (!multicast_supported)
{
TLLM_LOG_ERROR("CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED not supported on GPU%d.", cuda_dev);
TLLM_LOG_DEBUG("CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED not supported on GPU%d.", cuda_dev);
return false;
}
}

View File

@ -692,19 +692,19 @@ class LlamaDecoderLayer(DecoderLayer):
self.enable_fusion &= config.hidden_size > 4096
enable_gemm_allreduce_fusion = (os.environ.get(
"TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "1") == "1")
"TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "0") == "1")
mpi_enabled = not mpi_disabled()
dtype_supported = config.torch_dtype in (torch.float16, torch.bfloat16)
tp_valid = self.mapping.tp_size > 1
quant_valid = self.is_nvfp4 is not None and self.is_nvfp4
device_supported = get_sm_version() >= 100
nvls_supported = ipc_nvls_supported()
use_fused_gemm_allreduce = all([
enable_gemm_allreduce_fusion, mpi_enabled, dtype_supported,
tp_valid, quant_valid, device_supported, nvls_supported
tp_valid, quant_valid, device_supported
])
if use_fused_gemm_allreduce:
use_fused_gemm_allreduce = ipc_nvls_supported()
def check_in_out_features(in_features, out_features):
in_feature_valid = in_features % 128 == 0 and in_features >= 1024

View File

@ -2166,13 +2166,16 @@ class Linear(nn.Module):
)
device_supported = get_sm_version() >= 100
nvls_supported = ipc_nvls_supported()
enable_gemm_allreduce_fusion = (os.environ.get(
"TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "0") == "1")
self.use_fused_gemm_allreduce = all([
self.reduce_output, mpi_enabled, dtype_supported,
in_features_aligned, out_features_aligned, tp_valid, quant_valid,
device_supported, nvls_supported
device_supported, enable_gemm_allreduce_fusion
])
if self.use_fused_gemm_allreduce:
self.use_fused_gemm_allreduce = ipc_nvls_supported()
self.enable_cuda_core = False
if torch.cuda.is_available():

View File

@ -14,6 +14,7 @@
# limitations under the License.
import os
import sys
from unittest import mock
import pytest
import torch
@ -705,17 +706,24 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device(4)
@skip_pre_blackwell
@parametrize_with_ids("enable_gemm_allreduce_fusion", [False, True])
@parametrize_with_ids("torch_compile", [False, True])
def test_fp4_tp2pp2(self, torch_compile):
def test_fp4_tp2pp2(self, enable_gemm_allreduce_fusion, torch_compile):
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP4"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5)
torch_compile_config = _get_default_torch_compile_config(torch_compile)
with LLM(model_path,
tensor_parallel_size=2,
pipeline_parallel_size=2,
max_batch_size=32,
kv_cache_config=kv_cache_config,
torch_compile_config=torch_compile_config) as llm:
with (mock.patch.dict(
os.environ, {
"TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED":
str(int(enable_gemm_allreduce_fusion))
}),
LLM(model_path,
tensor_parallel_size=2,
pipeline_parallel_size=2,
max_batch_size=32,
kv_cache_config=kv_cache_config,
torch_compile_config=torch_compile_config) as llm):
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
sampling_params = SamplingParams(
max_tokens=256,

View File

@ -37,8 +37,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_fp8_prequantized
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False]

View File

@ -120,8 +120,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized
accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_fp8_prequantized
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False]

View File

@ -89,8 +89,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False]

View File

@ -199,7 +199,7 @@ l0_dgx_b200:
- unittest/_torch/multi_gpu_modeling -k "deepseek"
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False]
# ------------- AutoDeploy Backend Stages ---------------
- condition:
ranges:

View File

@ -52,7 +52,7 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=False]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[fp8]

View File

@ -433,7 +433,8 @@ def fp4_row_linear_allreduce_run_single_rank(func, tp_size, seq_len,
indirect=True,
ids=lambda x: f"tp_size:{x}")
def test_fp4_row_linear_allreduce(seq_len, output_size, hidden_size, dtype,
mpi_pool_executor):
mpi_pool_executor, monkeypatch):
monkeypatch.setenv("TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "1")
torch.manual_seed(42)
tp_size = mpi_pool_executor.num_workers