From 4c8468c5d3cdcfa64761af15dac868207bb02e28 Mon Sep 17 00:00:00 2001 From: benzh-2025 Date: Tue, 20 Jan 2026 12:31:17 +0800 Subject: [PATCH] [None][fix] default disable gemm+allreduce fusion (#10656) --- cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu | 4 ++-- tensorrt_llm/_torch/models/modeling_llama.py | 8 +++---- tensorrt_llm/_torch/modules/linear.py | 7 ++++-- .../defs/accuracy/test_llm_api_pytorch.py | 22 +++++++++++++------ .../test_lists/qa/llm_function_core.txt | 6 +++-- .../qa/llm_function_core_sanity.txt | 6 +++-- .../test_lists/qa/llm_function_rtx6k.txt | 6 +++-- .../test_lists/test-db/l0_dgx_b200.yml | 2 +- .../test-db/l0_gb200_multi_gpus.yml | 2 +- .../unittest/_torch/multi_gpu/test_linear.py | 3 ++- 10 files changed, 42 insertions(+), 24 deletions(-) diff --git a/cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu b/cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu index 031ac92168..345930ab37 100644 --- a/cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu +++ b/cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu @@ -434,7 +434,7 @@ bool ipcNvlsSupported() TLLM_CUDA_CHECK(cudaDriverGetVersion(&cuda_driver_version)); if (cuda_driver_version < 12010) { - TLLM_LOG_ERROR("CUDA Driver version < 12010"); + TLLM_LOG_DEBUG("CUDA Driver version < 12010"); return false; } @@ -448,7 +448,7 @@ bool ipcNvlsSupported() CUCHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, current_dev)); if (!multicast_supported) { - TLLM_LOG_ERROR("CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED not supported on GPU%d.", cuda_dev); + TLLM_LOG_DEBUG("CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED not supported on GPU%d.", cuda_dev); return false; } } diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index e4a57c3fc7..dab230ec86 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -692,19 +692,19 @@ class LlamaDecoderLayer(DecoderLayer): self.enable_fusion &= config.hidden_size > 4096 enable_gemm_allreduce_fusion = (os.environ.get( - "TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "1") == "1") + "TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "0") == "1") mpi_enabled = not mpi_disabled() dtype_supported = config.torch_dtype in (torch.float16, torch.bfloat16) tp_valid = self.mapping.tp_size > 1 quant_valid = self.is_nvfp4 is not None and self.is_nvfp4 - device_supported = get_sm_version() >= 100 - nvls_supported = ipc_nvls_supported() use_fused_gemm_allreduce = all([ enable_gemm_allreduce_fusion, mpi_enabled, dtype_supported, - tp_valid, quant_valid, device_supported, nvls_supported + tp_valid, quant_valid, device_supported ]) + if use_fused_gemm_allreduce: + use_fused_gemm_allreduce = ipc_nvls_supported() def check_in_out_features(in_features, out_features): in_feature_valid = in_features % 128 == 0 and in_features >= 1024 diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 210fd9dd6a..3c0ecf8716 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -2166,13 +2166,16 @@ class Linear(nn.Module): ) device_supported = get_sm_version() >= 100 - nvls_supported = ipc_nvls_supported() + enable_gemm_allreduce_fusion = (os.environ.get( + "TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "0") == "1") self.use_fused_gemm_allreduce = all([ self.reduce_output, mpi_enabled, dtype_supported, in_features_aligned, out_features_aligned, tp_valid, quant_valid, - device_supported, nvls_supported + device_supported, enable_gemm_allreduce_fusion ]) + if self.use_fused_gemm_allreduce: + self.use_fused_gemm_allreduce = ipc_nvls_supported() self.enable_cuda_core = False if torch.cuda.is_available(): diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index cc98054316..3abf542338 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -14,6 +14,7 @@ # limitations under the License. import os import sys +from unittest import mock import pytest import torch @@ -705,17 +706,24 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness): @pytest.mark.skip_less_device(4) @skip_pre_blackwell + @parametrize_with_ids("enable_gemm_allreduce_fusion", [False, True]) @parametrize_with_ids("torch_compile", [False, True]) - def test_fp4_tp2pp2(self, torch_compile): + def test_fp4_tp2pp2(self, enable_gemm_allreduce_fusion, torch_compile): model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP4" kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5) torch_compile_config = _get_default_torch_compile_config(torch_compile) - with LLM(model_path, - tensor_parallel_size=2, - pipeline_parallel_size=2, - max_batch_size=32, - kv_cache_config=kv_cache_config, - torch_compile_config=torch_compile_config) as llm: + + with (mock.patch.dict( + os.environ, { + "TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED": + str(int(enable_gemm_allreduce_fusion)) + }), + LLM(model_path, + tensor_parallel_size=2, + pipeline_parallel_size=2, + max_batch_size=32, + kv_cache_config=kv_cache_config, + torch_compile_config=torch_compile_config) as llm): assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 sampling_params = SamplingParams( max_tokens=256, diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 98dcabc075..cf7157dc69 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -37,8 +37,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=True] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index eaac92b76e..06ae327870 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -120,8 +120,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_fp8_prequantized -accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=True] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False-torch_compile=True] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False] diff --git a/tests/integration/test_lists/qa/llm_function_rtx6k.txt b/tests/integration/test_lists/qa/llm_function_rtx6k.txt index 750011c726..395f3f2a5e 100644 --- a/tests/integration/test_lists/qa/llm_function_rtx6k.txt +++ b/tests/integration/test_lists/qa/llm_function_rtx6k.txt @@ -89,8 +89,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False-sampler_async_worker=True] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] -accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=True] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=False] +accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=True] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True] accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True-torch_compile=False] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 69972e5000..493f4d354f 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -199,7 +199,7 @@ l0_dgx_b200: - unittest/_torch/multi_gpu_modeling -k "deepseek" - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False] # ------------- AutoDeploy Backend Stages --------------- - condition: ranges: diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index 62c0af24f8..7b79187604 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -52,7 +52,7 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=False] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[fp8] diff --git a/tests/unittest/_torch/multi_gpu/test_linear.py b/tests/unittest/_torch/multi_gpu/test_linear.py index 11466818cf..8452baa19e 100644 --- a/tests/unittest/_torch/multi_gpu/test_linear.py +++ b/tests/unittest/_torch/multi_gpu/test_linear.py @@ -433,7 +433,8 @@ def fp4_row_linear_allreduce_run_single_rank(func, tp_size, seq_len, indirect=True, ids=lambda x: f"tp_size:{x}") def test_fp4_row_linear_allreduce(seq_len, output_size, hidden_size, dtype, - mpi_pool_executor): + mpi_pool_executor, monkeypatch): + monkeypatch.setenv("TRTLLM_GEMM_ALLREDUCE_FUSION_ENABLED", "1") torch.manual_seed(42) tp_size = mpi_pool_executor.num_workers