From 83ee91e17b85a34ea72a263dd18f3703f75a8ea8 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Wed, 6 Aug 2025 09:18:09 +0300 Subject: [PATCH 1/5] [None][fix] Fix 6522 mpi.pkl5.intracomm.Request has wait not Wait (#6646) Signed-off-by: Netanel Haber --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 3227e18718..d87dbef4e7 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -800,7 +800,7 @@ class PyExecutor: # Second last rank does not need to since last rank has original decoded tokens if not self.dist.is_second_last_pp_rank: if self.send_handles[prev_microbatch_id] is not None: - self.send_handles[prev_microbatch_id].Wait() + self.send_handles[prev_microbatch_id].wait() needs_logits = ( self._need_return_logits(scheduled_batch) or (self._need_return_log_probs(scheduled_batch) From 43bd861ce17fc22d0a8dee7f355c52ffd2e9b3d0 Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com> Date: Tue, 5 Aug 2025 23:25:23 -0700 Subject: [PATCH 2/5] Update allreduce benchmark for torch (#6271) Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> --- tests/microbenchmarks/all_reduce.py | 180 +++++++++++++++++----------- 1 file changed, 111 insertions(+), 69 deletions(-) diff --git a/tests/microbenchmarks/all_reduce.py b/tests/microbenchmarks/all_reduce.py index 13cb4271bb..8fdcc0a827 100644 --- a/tests/microbenchmarks/all_reduce.py +++ b/tests/microbenchmarks/all_reduce.py @@ -18,21 +18,21 @@ from argparse import ArgumentParser # isort: off import torch # isort: on -from cuda import cuda, cudart +from cuda import cudart import tensorrt_llm as tllm -from tensorrt_llm import Mapping, Tensor +from tensorrt_llm import Mapping +from tensorrt_llm._torch.distributed import AllReduce, AllReduceFusionOp +from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm._utils import local_mpi_rank, local_mpi_size -from tensorrt_llm.functional import (AllReduceParams, AllReduceStrategy, - allreduce) -from tensorrt_llm.plugin.plugin import (current_all_reduce_helper, - init_all_reduce_helper) -from tensorrt_llm.runtime import Session +from tensorrt_llm.bindings.internal.runtime import delay_kernel +from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy def allreduce_benchmark(dtype: str, - test_range: str = "10,10000000,10", - no_header: bool = False): + test_range: str = "1,10000000,10", + no_header: bool = False, + enable_cudagraph: bool = False): tllm.logger.set_level('error') world_size = tllm.mpi_world_size() rank = tllm.mpi_rank() @@ -49,80 +49,120 @@ def allreduce_benchmark(dtype: str, torch_dtype = tllm._utils.str_dtype_to_torch(dtype) min_size, max_size, ratio = [int(i) for i in test_range.split(",")] - inner_loop = 1000 + inner_loop = 1200 + outer_loop = 10 size = min_size - dtype_size = torch.finfo(torch_dtype).bits // 8 + hidden_size = size + bs = 1 if mapping.rank == 0 and not no_header: print( - f"{'world_size':<15}, {'dtype':<10}, {'message size':<15}, {'strategy':<15}, {'duration (ms)':<10}" + f"{'world_size':<15}, {'dtype':<10}, {'message size':<15}, {'strategy':<10}, {'fusion':<20}, {'version':<10}, {'duration (ms)':<10}" ) while size < max_size: - input = torch.ones(size, dtype=torch_dtype, device="cuda") + input = torch.ones((bs, hidden_size), dtype=torch_dtype, device="cuda") - for strategy in [ - AllReduceStrategy.AUTO, - AllReduceStrategy.NCCL, - AllReduceStrategy.ONESHOT, - AllReduceStrategy.TWOSHOT, - ]: - builder = tllm.Builder() - net = builder.create_network() - net.plugin_config.set_nccl_plugin(dtype) - init_all_reduce_helper() - _buffers, workspace = current_all_reduce_helper( - ).allocate_workspace(mapping, size * dtype_size) + for version in ["v1"]: + for fusion in [ + AllReduceFusionOp.RESIDUAL_RMS_NORM, AllReduceFusionOp.NONE + ]: + for strategy in [ + AllReduceStrategy.NCCL, + AllReduceStrategy.ONESHOT, + AllReduceStrategy.TWOSHOT, + ]: + if size >= 25600000 and fusion != AllReduceFusionOp.NONE: + continue + allreduce = AllReduce(mapping=mapping, strategy=strategy) + if fusion == AllReduceFusionOp.RESIDUAL_RMS_NORM: + norm_weight = torch.randn((hidden_size, ), + dtype=torch_dtype, + device="cuda") + norm = RMSNorm(hidden_size=hidden_size, + dtype=torch_dtype, + eps=1e-5).cuda() + norm.weight.data.copy_(norm_weight) + if version == "v1": + params = { + "all_reduce_params": + AllReduceParams(fusion_op=fusion, + residual=input, + norm_weight=norm.weight, + eps=norm.variance_epsilon) + } + else: + params = { + "reduce_fusion_inputs": [input, norm.weight], + "eps": norm.variance_epsilon, + "fusion_op": fusion + } + else: + if version == "v1": + params = { + "all_reduce_params": + AllReduceParams(fusion_op=fusion) + } + else: + continue - with tllm.net_guard(net): - tllm.default_trtnet() + def func(input): + for _ in range(inner_loop): + input = allreduce(input, **params) + if fusion == AllReduceFusionOp.RESIDUAL_RMS_NORM: + input = input[0] + return input - x = Tensor(name='x', - shape=input.shape, - dtype=tllm.str_dtype_to_trt(dtype)) + start = [ + torch.cuda.Event(enable_timing=True) + for _ in range(outer_loop) + ] + stop = [ + torch.cuda.Event(enable_timing=True) + for _ in range(outer_loop) + ] + graph = torch.cuda.CUDAGraph() - current_all_reduce_helper().set_workspace_tensor(mapping) + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + if enable_cudagraph: + for _ in range(2): + func(input) + with torch.cuda.graph(graph, stream=stream): + output = func(input) + tllm.mpi_barrier() + delay_kernel(2000000, stream) + torch.cuda.profiler.start() + for i in range(outer_loop): + start[i].record(stream) + if enable_cudagraph: + graph.replay() + else: + output = func(input) + stop[i].record(stream) - current = x - for _ in range(inner_loop): - current = allreduce( - current, - mapping.tp_group, - all_reduce_params=AllReduceParams(strategy=strategy)) - current.mark_output('output', dtype) - feed_dict = {'x': input, 'all_reduce_workspace': workspace} - builder_config = builder.create_builder_config(precision=dtype) - engine = builder.build_engine(net, builder_config) - assert engine is not None, "Failed to build engine" - session = Session.from_serialized_engine(engine) + torch.cuda.synchronize() + torch.cuda.profiler.stop() + runtimes = [ + start[i].elapsed_time(stop[i]) + for i in range(outer_loop) + ] + median_ms = sorted(runtimes)[len(runtimes) // 2] - _, start = cuda.cuEventCreate(0) - _, stop = cuda.cuEventCreate(0) - runtimes = [] + if fusion == AllReduceFusionOp.NONE: + allreduce_ref = (input * world_size)**inner_loop + torch.testing.assert_close(output, allreduce_ref) - tllm.mpi_barrier() - output = torch.empty(input.shape, dtype=torch_dtype, device='cuda') - stream = torch.cuda.current_stream() - for _ in range(10): - cuda.cuEventRecord(start, stream.cuda_stream) - session.run(inputs=feed_dict, - outputs={"output": output}, - stream=stream.cuda_stream) - cuda.cuEventRecord(stop, stream.cuda_stream) - torch.cuda.synchronize() - _, ms = cuda.cuEventElapsedTime(start, stop) - runtimes.append(ms) - - median_ms = sorted(runtimes)[len(runtimes) // 2] - - allreduce_ref = (input * world_size)**inner_loop - assert torch.allclose(output, allreduce_ref) - - if mapping.rank == 0: - print( - f"{mapping.world_size:<15}, {dtype:<10}, {size:<15}, {strategy.name:<15}, {median_ms:<10.2f}" - ) + if mapping.rank == 0: + print( + f"{mapping.world_size:<15}, {dtype:<10}, {size:<15}, {strategy.name:<10}, {fusion.name:<20}, {version:<10}, {median_ms:<10.2f}" + ) size *= ratio + if hidden_size * ratio > 4096: + bs *= ratio + else: + hidden_size *= ratio + assert size == bs * hidden_size if __name__ == "__main__": @@ -134,6 +174,8 @@ if __name__ == "__main__": default="256,256000000,10", # 256 to 256M help="min_size,max_size,multiplicative_ratio") parser.add_argument("--no-header", action="store_true") + parser.add_argument("--enable-cudagraph", action="store_true") args = parser.parse_args() - allreduce_benchmark(args.dtype, args.range, args.no_header) + allreduce_benchmark(args.dtype, args.range, args.no_header, + args.enable_cudagraph) From 907c180eb27aa4e78ef5eaf6fa862ce0efee1a9c Mon Sep 17 00:00:00 2001 From: ruodil <200874449+ruodil@users.noreply.github.com> Date: Wed, 6 Aug 2025 14:25:57 +0800 Subject: [PATCH 3/5] [None][test] align kv_frac in perf test with perflab and add more cases for 4 gpus GB200 (#6632) Signed-off-by: ruodil <200874449+ruodil@users.noreply.github.com> --- tests/integration/defs/perf/test_perf.py | 2 + .../test_lists/qa/llm_perf_cluster.yml | 47 ++++++++++++++----- .../test_lists/qa/llm_perf_full.yml | 41 +++++++++------- 3 files changed, 60 insertions(+), 30 deletions(-) diff --git a/tests/integration/defs/perf/test_perf.py b/tests/integration/defs/perf/test_perf.py index 566cbbef28..f443ca1035 100644 --- a/tests/integration/defs/perf/test_perf.py +++ b/tests/integration/defs/perf/test_perf.py @@ -95,6 +95,8 @@ MODEL_PATH_DICT = { "gemma_3_1b_it": "gemma/gemma-3-1b-it", "deepseek_r1_fp8": "DeepSeek-R1/DeepSeek-R1", "deepseek_r1_nvfp4": "DeepSeek-R1/DeepSeek-R1-FP4", + "deepseek_r1_0528_fp8": "DeepSeek-R1/DeepSeek-R1-0528/", + "deepseek_r1_0528_fp4": "DeepSeek-R1/DeepSeek-R1-0528-FP4/", "deepseek_v3_lite_fp8": "DeepSeek-V3-Lite/fp8", "deepseek_v3_lite_nvfp4": "DeepSeek-V3-Lite/nvfp4_moe_only", "qwen2_7b_instruct": "Qwen2-7B-Instruct", diff --git a/tests/integration/test_lists/qa/llm_perf_cluster.yml b/tests/integration/test_lists/qa/llm_perf_cluster.yml index 29045d35ca..47877a3fcc 100644 --- a/tests/integration/test_lists/qa/llm_perf_cluster.yml +++ b/tests/integration/test_lists/qa/llm_perf_cluster.yml @@ -68,8 +68,26 @@ trt_llm_release_perf_cluster_test: - perf/test_perf.py::test_perf[starcoder_15b-bench-float16-input_output_len:512,200-gpus:4] - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:512-input_output_len:128,128-ep:4-tp:4-gpus:4] - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-streaming-float4-maxbs:512-input_output_len:128,128-ep:4-tp:4-gpus:4] + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,1000-reqs:20000-ep:4-tp:4-gpus:4] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,2000-reqs:3000-ep:4-tp:4-gpus:4] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:1000-maxnt:5000-kv_frac:0.85-input_output_len:5000,500-reqs:20000-ep:4-tp:4-gpus:4] TIMEOUT(120) - perf/test_perf.py::test_perf[qwen3_235b_a22b_fp4-bench-pytorch-float4-input_output_len:1000,2000-con:512-ep:4-gpus:4] - #- perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-input_output_len:128,128-gpus:4] + #llama_v3.1_405b_instruct_fp4 + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,1000-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1024,2048-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-maxbs:4096-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:1000-gpus:4] TIMEOUT(120) + #llama_v3.3_70b_instruct_fp4 + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-input_output_len:512,32-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:3000-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:4096-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:1000-gpus:4] TIMEOUT(120) + #llama_v4_scout_17b_16e_instruct_fp4 + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-input_output_len:128,128-gpus:4] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:500,2000-reqs:1500-gpus:4] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:1500-gpus:4] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:4096-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:1000-gpus:4] TIMEOUT(120) + # Tests for systems with 8+ GPUs - condition: @@ -83,20 +101,19 @@ trt_llm_release_perf_cluster_test: - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b-bench-bfloat16-input_output_len:500,2000-quant:fp8-con:250-gpus:8] - perf/test_perf.py::test_perf[llama_v3.3_nemotron_super_49b_fp8-bench-pytorch-bfloat16-input_output_len:500,2000-con:250-gpus:8] #llama_v3.3_70b_instruct_fp4 - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-input_output_len:128,128-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-input_output_len:512,32-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-input_output_len:500,2000-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-input_output_len:1000,1000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:500,2000-reqs:3000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:3000-tp:8-gpus:8] #llama_v3.1_405b_instruct_fp4 - - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-input_output_len:128,128-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-input_output_len:512,32-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-input_output_len:500,2000-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-input_output_len:1000,1000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:128,128-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:512,32-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:500,2000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp4-bench-pytorch-float4-maxbs:4096-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-tp:8-gpus:8] TIMEOUT(120) #llama_v4_scout_17b_16e_instruct_fp4 - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-input_output_len:128,128-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-input_output_len:512,32-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-input_output_len:500,2000-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-input_output_len:1000,1000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:128,128-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:512,32-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:500,2000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp4-bench-pytorch-float4-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-tp:8-gpus:8] #mixtral_8x22b_v0.1 - perf/test_perf.py::test_perf[mixtral_8x22b_v0.1-bench-float16-input_output_len:512,512-quant:fp8-tp:8] - perf/test_perf.py::test_perf[gpt_20b-bench-float16-maxbs:8-input_output_len:128,128-reqs:80-gpus:8] @@ -111,6 +128,10 @@ trt_llm_release_perf_cluster_test: - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-streaming-float4-maxbs:1-input_output_len:1000,2000-reqs:10-ep:4-tp:8-gpus:8] #min latency test - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-float4-maxbs:384-maxnt:1536-input_output_len:1000,2000-reqs:49152-con:3072-ep:8-tp:8-gpus:8] TIMEOUT (120) #max throughput test - perf/test_perf.py::test_perf[deepseek_r1_nvfp4-bench-pytorch-streaming-float4-maxbs:384-maxnt:1536-input_output_len:1000,2000-reqs:49152-con:3072-ep:8-tp:8-gpus:8] #max throughput test + #deepseek_r1_0528_fp4 + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,1000-reqs:20000-ep:8-tp:8-gpus:8] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp4-bench-pytorch-float4-kv_frac:0.85-input_output_len:1000,2000-reqs:3000-ep:8-tp:8-gpus:8] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp4-bench-pytorch-float4-maxbs:1000-maxnt:5000-kv_frac:0.85-input_output_len:5000,500-reqs:20000-ep:4-tp:4-gpus:4] TIMEOUT(120) - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct-bench-pytorch-bfloat16-input_output_len:128,128-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct-bench-pytorch-bfloat16-input_output_len:500,2000-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct-bench-pytorch-bfloat16-input_output_len:2000,500-ep:8-tp:8-gpus:8] diff --git a/tests/integration/test_lists/qa/llm_perf_full.yml b/tests/integration/test_lists/qa/llm_perf_full.yml index 6796527701..c4778586b5 100644 --- a/tests/integration/test_lists/qa/llm_perf_full.yml +++ b/tests/integration/test_lists/qa/llm_perf_full.yml @@ -465,29 +465,34 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.1_nemotron_ultra_253b_fp8-bench-pytorch-float8-maxbs:256-input_output_len:500,2000-reqs:250-con:250-tp:8-gpus:8] # llama_v3.1_405b_fp8 #pytorch backend - - perf/test_perf.py::test_perf[llama_v3.1_405b_fp8-bench-pytorch-float8-maxbs:1-input_output_len:2000,500-reqs:8-con:1-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_405b_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-input_output_len:500,2000-reqs:3000-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_405b_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-input_output_len:1000,1000-reqs:3000-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_405b_fp8-bench-pytorch-float8-input_output_len:128,128-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_405b_fp8-bench-pytorch-float8-input_output_len:512,32-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp8-bench-pytorch-float8-maxbs:1-input_output_len:2000,500-reqs:8-con:1-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:5000-kv_frac:0.85-input_output_len:5000,500-reqs:1500-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:3000-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp8-bench-pytorch-float8-maxbs:4096-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:1000-tp:8-gpus:8] TIMEOUT(120) + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_405b_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-tp:8-gpus:8] + #llama_v3.3_70b_instruct_fp8 + #pytorch backend + - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-maxbs:4096-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:1000-tp:8-gpus:8] TIMEOUT(120) #llama_v4_maverick_17b_128e_instruct_fp8 #pytorch backend - - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.6-input_output_len:2000,500-reqs:3000-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.6-input_output_len:500,2000-reqs:3000-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.6-input_output_len:1000,1000-reqs:3000-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-kv_frac:0.6-input_output_len:128,128-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-kv_frac:0.6-input_output_len:512,32-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:2000,500-reqs:3000-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:500,2000-reqs:3000-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:3000-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-input_output_len:512,32-ep:8-tp:8-gpus:8] #rcca case - - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-kv_frac:0.6-input_output_len:20000,2000-reqs:1000-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_maverick_17b_128e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:20000-kv_frac:0.6-input_output_len:20000,2000-reqs:1000-ep:8-tp:8-gpus:8] #llama_v4_scout_17b_16e_instruct_fp8 #pytorch backend - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.6-input_output_len:2000,500-reqs:3000-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.6-input_output_len:500,2000-reqs:3000-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.6-input_output_len:1000,1000-reqs:3000-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-kv_frac:0.6-input_output_len:128,128-ep:8-tp:8-gpus:8] - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-kv_frac:0.6-input_output_len:512,32-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:5000-kv_frac:0.85-input_output_len:5000,500-reqs:3000-ep:8-tp:8-gpus:8] TIMEOUT(100) + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:500,2000-reqs:3000-ep:8-tp:8-gpus:8] TIMEOUT(100) + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:4096-kv_frac:0.85-input_output_len:1000,1000-reqs:3000-ep:8-tp:8-gpus:8] TIMEOUT(100) + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-maxbs:1024-maxnt:20000-kv_frac:0.85-input_output_len:20000,2000-reqs:1000-ep:8-tp:8-gpus:8] TIMEOUT(100) + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-kv_frac:0.85-input_output_len:128,128-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct_fp8-bench-pytorch-float8-kv_frac:0.85-input_output_len:512,32-ep:8-tp:8-gpus:8] #deepseek_r1_fp8 #pytorch backend @@ -509,7 +514,9 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-streaming-float8-maxbs:32-input_output_len:128,128-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:1-input_output_len:1000,2000-reqs:10-con:1-ep:4-tp:8-gpus:8] TIMEOUT(40)#min latency test - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:128-maxnt:1127-input_output_len:1000,2000-reqs:5120-con:1024-ep:8-tp:8-gpus:8] TIMEOUT(80) #max throughput test - - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-bfloat16-input_output_len:128,128-ep:8-tp:8-gpus:8] TIMEOUT(20) + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp8-bench-pytorch-float8-input_output_len:1000,1000-reqs:20000-ep:8-tp:8-gpus:8] TIMEOUT(120) + - perf/test_perf.py::test_perf[deepseek_r1_0528_fp8-bench-pytorch-float8-input_output_len:1000,2000-reqs:3000-ep:8-tp:8-gpus:8] TIMEOUT(100) + - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-bfloat16-input_output_len:128,128-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-streaming-bfloat16-input_output_len:128,128-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-bfloat16-input_output_len:500,2000-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-bfloat16-input_output_len:2000,500-ep:8-tp:8-gpus:8] From 0ff8df95b7ccf0412b32be7befddbec3503115b6 Mon Sep 17 00:00:00 2001 From: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Date: Wed, 6 Aug 2025 16:44:21 +0800 Subject: [PATCH 4/5] [https://nvbugs/5433581][fix] DeepGEMM installation on SBSA (#6588) Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> --- .gitignore | 3 + .gitmodules | 3 + 3rdparty/DeepGEMM | 1 + cpp/CMakeLists.txt | 9 +- cpp/tensorrt_llm/CMakeLists.txt | 4 + cpp/tensorrt_llm/deep_gemm/CMakeLists.txt | 126 ++++++++++++++++++ .../deep_gemm/deep_gemm_cpp_tllm.version | 4 + requirements.txt | 1 - scripts/build_wheel.py | 25 +++- setup.py | 3 +- .../modules/fused_moe/fused_moe_deepgemm.py | 2 +- tensorrt_llm/_torch/modules/linear.py | 2 +- .../_torch/thop/test_fp8_block_scale_gemm.py | 2 +- 13 files changed, 176 insertions(+), 9 deletions(-) create mode 160000 3rdparty/DeepGEMM create mode 100644 cpp/tensorrt_llm/deep_gemm/CMakeLists.txt create mode 100644 cpp/tensorrt_llm/deep_gemm/deep_gemm_cpp_tllm.version diff --git a/.gitignore b/.gitignore index a4c45ea691..7ae724e708 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,9 @@ tensorrt_llm/bindings/**/*.pyi tensorrt_llm/deep_ep/ tensorrt_llm/deep_ep_cpp_tllm.*.so tensorrt_llm/deep_ep_cpp_tllm.pyi +tensorrt_llm/deep_gemm/ +tensorrt_llm/deep_gemm_cpp_tllm.*.so +tensorrt_llm/deep_gemm_cpp_tllm.pyi *docs/cpp_docs* *docs/source/_cpp_gen* docs/source/**/*.rst diff --git a/.gitmodules b/.gitmodules index 00ff73d136..45d99f8fe4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -26,3 +26,6 @@ [submodule "3rdparty/cppzmq"] path = 3rdparty/cppzmq url = https://github.com/zeromq/cppzmq.git +[submodule "3rdparty/DeepGEMM"] + path = 3rdparty/DeepGEMM + url = https://github.com/deepseek-ai/DeepGEMM.git diff --git a/3rdparty/DeepGEMM b/3rdparty/DeepGEMM new file mode 160000 index 0000000000..7b6b5563b9 --- /dev/null +++ b/3rdparty/DeepGEMM @@ -0,0 +1 @@ +Subproject commit 7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6732db6eaa..43b86df937 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -31,6 +31,7 @@ option(BUILD_PYT "Build in PyTorch TorchScript class mode" ON) option(BUILD_TESTS "Build Google tests" ON) option(BUILD_BENCHMARKS "Build benchmarks" ON) option(BUILD_DEEP_EP "Build the Deep EP module" ON) +option(BUILD_DEEP_GEMM "Build the DeepGEMM module" ON) option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF) option(NVTX_DISABLE "Disable all NVTX features" ON) option(WARNING_IS_ERROR "Treat all warnings as errors" OFF) @@ -199,7 +200,9 @@ set(TRT_LIB TensorRT::NvInfer) get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH) set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty) -if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP) +if(BINDING_TYPE STREQUAL "pybind" + OR BUILD_DEEP_EP + OR BUILD_DEEP_GEMM) add_subdirectory(${3RDPARTY_DIR}/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/pybind11) endif() @@ -218,7 +221,9 @@ include_directories( ${3RDPARTY_DIR}/cutlass/tools/util/include ${3RDPARTY_DIR}/NVTX/include ${3RDPARTY_DIR}/json/include) -if(BINDING_TYPE STREQUAL "pybind" OR BUILD_DEEP_EP) +if(BINDING_TYPE STREQUAL "pybind" + OR BUILD_DEEP_EP + OR BUILD_DEEP_GEMM) include_directories(${3RDPARTY_DIR}/pybind11/include) endif() if(BINDING_TYPE STREQUAL "nanobind") diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt index f8e3aaad9e..c4814c1d4e 100644 --- a/cpp/tensorrt_llm/CMakeLists.txt +++ b/cpp/tensorrt_llm/CMakeLists.txt @@ -314,4 +314,8 @@ if(BUILD_DEEP_EP) add_subdirectory(deep_ep) endif() +if(BUILD_DEEP_GEMM) + add_subdirectory(deep_gemm) +endif() + add_subdirectory(plugins) diff --git a/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt b/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt new file mode 100644 index 0000000000..631e5029c7 --- /dev/null +++ b/cpp/tensorrt_llm/deep_gemm/CMakeLists.txt @@ -0,0 +1,126 @@ +add_custom_target(deep_gemm) + +if(WIN32) + return() +endif() + +# Prepare files +# ============= + +# Use DeepGEMM submodule +set(DEEP_GEMM_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../3rdparty/DeepGEMM) +get_filename_component(DEEP_GEMM_SOURCE_DIR ${DEEP_GEMM_SOURCE_DIR} ABSOLUTE) + +if(NOT EXISTS ${DEEP_GEMM_SOURCE_DIR}) + message( + FATAL_ERROR + "DeepGEMM submodule not found at ${DEEP_GEMM_SOURCE_DIR}. Please run: git submodule update --init --recursive" + ) +endif() + +# Check if submodules are initialized +if(NOT EXISTS ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include) + message( + FATAL_ERROR + "DeepGEMM submodules not initialized. Please run: git submodule update --init --recursive" + ) +endif() + +# Copy and update python files +set(DEEP_GEMM_PYTHON_DEST ${CMAKE_CURRENT_BINARY_DIR}/python/deep_gemm) +file(REMOVE_RECURSE ${DEEP_GEMM_PYTHON_DEST}) +file(MAKE_DIRECTORY ${DEEP_GEMM_PYTHON_DEST}) + +# Copy all files from deep_gemm directory +file(GLOB_RECURSE DEEP_GEMM_ALL_FILES ${DEEP_GEMM_SOURCE_DIR}/deep_gemm/*) +configure_file(${DEEP_GEMM_SOURCE_DIR}/LICENSE ${DEEP_GEMM_PYTHON_DEST}/LICENSE + COPYONLY) +foreach(SOURCE_FILE ${DEEP_GEMM_ALL_FILES}) + file(RELATIVE_PATH REL_PATH ${DEEP_GEMM_SOURCE_DIR}/deep_gemm ${SOURCE_FILE}) + get_filename_component(REL_DIR ${REL_PATH} DIRECTORY) + file(MAKE_DIRECTORY ${DEEP_GEMM_PYTHON_DEST}/${REL_DIR}) + + # Check if it's a Python file that needs import renaming + get_filename_component(FILE_EXT ${SOURCE_FILE} EXT) + if(FILE_EXT STREQUAL ".py") + # Read file content and replace module imports for Python files + file(READ ${SOURCE_FILE} _content) + string(REPLACE "deep_gemm_cpp" "tensorrt_llm.deep_gemm_cpp_tllm" _content + "${_content}") + + # Add adaptation header + string( + PREPEND + _content + "# Adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/${REL_PATH}\n" + ) + + # Write modified content + set(_dst "${DEEP_GEMM_PYTHON_DEST}/${REL_PATH}") + file(WRITE ${_dst} "${_content}") + else() + # Copy non-Python files as-is + set(_dst "${DEEP_GEMM_PYTHON_DEST}/${REL_PATH}") + file(COPY ${SOURCE_FILE} DESTINATION ${DEEP_GEMM_PYTHON_DEST}/${REL_DIR}) + endif() + + # Add dependency tracking + set_property( + DIRECTORY + APPEND + PROPERTY CMAKE_CONFIGURE_DEPENDS ${SOURCE_FILE}) +endforeach() + +# Copy third-party includes (cutlass and fmt) to the include directory +set(DEEP_GEMM_INCLUDE_DEST ${DEEP_GEMM_PYTHON_DEST}/include) +file(MAKE_DIRECTORY ${DEEP_GEMM_INCLUDE_DEST}) +file(COPY ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include/cute + DESTINATION ${DEEP_GEMM_INCLUDE_DEST}) +file(COPY ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include/cutlass + DESTINATION ${DEEP_GEMM_INCLUDE_DEST}) + +# Find torch_python +find_library(TORCH_PYTHON_LIB torch_python REQUIRED + HINTS ${TORCH_INSTALL_PREFIX}/lib) + +# Build deep_gemm_cpp_tllm extension (matching deep_gemm's setup.py) +set(DEEP_GEMM_SOURCES ${DEEP_GEMM_SOURCE_DIR}/csrc/python_api.cpp) + +pybind11_add_module(deep_gemm_cpp_tllm ${DEEP_GEMM_SOURCES}) +set_target_properties( + deep_gemm_cpp_tllm + PROPERTIES CXX_STANDARD_REQUIRED ON + CXX_STANDARD 17 + CXX_SCAN_FOR_MODULES OFF + CUDA_SEPARABLE_COMPILATION ON + LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/deep_gemm_cpp_tllm.version + INSTALL_RPATH "${TORCH_INSTALL_PREFIX}/lib" + BUILD_WITH_INSTALL_RPATH TRUE) + +target_compile_options(deep_gemm_cpp_tllm PRIVATE ${TORCH_CXX_FLAGS} -std=c++17 + -O3 -fPIC -Wno-psabi) + +# Extension name definition +target_compile_definitions(deep_gemm_cpp_tllm + PRIVATE TORCH_EXTENSION_NAME=deep_gemm_cpp_tllm) + +# Include directories matching deep_gemm setup.py +target_include_directories( + deep_gemm_cpp_tllm + PRIVATE ${CUDA_INCLUDE_DIRS} ${DEEP_GEMM_SOURCE_DIR}/deep_gemm/include + ${DEEP_GEMM_SOURCE_DIR}/third-party/cutlass/include + ${DEEP_GEMM_SOURCE_DIR}/third-party/fmt/include) + +# Link libraries (matching deep_gemm setup.py: cuda, cudart + torch) +target_link_libraries( + deep_gemm_cpp_tllm PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIB} + CUDA::cuda_driver CUDA::cudart) + +# Link directories +target_link_directories( + deep_gemm_cpp_tllm PRIVATE ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) + +# Set targets +# =========== +add_dependencies(deep_gemm deep_gemm_cpp_tllm) diff --git a/cpp/tensorrt_llm/deep_gemm/deep_gemm_cpp_tllm.version b/cpp/tensorrt_llm/deep_gemm/deep_gemm_cpp_tllm.version new file mode 100644 index 0000000000..24e3b31952 --- /dev/null +++ b/cpp/tensorrt_llm/deep_gemm/deep_gemm_cpp_tllm.version @@ -0,0 +1,4 @@ +{ + global: PyInit_deep_gemm_cpp_tllm; + local: *; +}; diff --git a/requirements.txt b/requirements.txt index d5a3a8ecb9..2ecd9ef5c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -61,6 +61,5 @@ etcd3 blake3 llguidance==0.7.29 soundfile -deep_gemm @ git+https://github.com/zongfeijing/DeepGEMM.git@a9d538ef4dff0326fe521c6ca0bfde115703b56a triton==3.3.1; platform_machine == "x86_64" blobfile diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index 8138995273..52abdbcb84 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -448,10 +448,12 @@ def main(*, if cpp_only: build_pyt = "OFF" build_deep_ep = "OFF" + build_deep_gemm = "OFF" else: - targets.extend(["th_common", "bindings", "deep_ep"]) + targets.extend(["th_common", "bindings", "deep_ep", "deep_gemm"]) build_pyt = "ON" build_deep_ep = "ON" + build_deep_gemm = "ON" if benchmarks: targets.append("benchmarks") @@ -490,7 +492,7 @@ def main(*, ) cmake_def_args = " ".join(cmake_def_args) cmake_configure_command = ( - f'cmake -DCMAKE_BUILD_TYPE="{build_type}" -DBUILD_PYT="{build_pyt}" -DBINDING_TYPE="{binding_type}" -DBUILD_DEEP_EP="{build_deep_ep}"' + f'cmake -DCMAKE_BUILD_TYPE="{build_type}" -DBUILD_PYT="{build_pyt}" -DBINDING_TYPE="{binding_type}" -DBUILD_DEEP_EP="{build_deep_ep}" -DBUILD_DEEP_GEMM="{build_deep_gemm}"' f' -DNVTX_DISABLE="{disable_nvtx}" -DBUILD_MICRO_BENCHMARKS={build_micro_benchmarks}' f' -DBUILD_WHEEL_TARGETS="{";".join(targets)}"' f' -DPython_EXECUTABLE={venv_python} -DPython3_EXECUTABLE={venv_python}' @@ -637,6 +639,14 @@ def main(*, clear_folder(deep_ep_dir) deep_ep_dir.rmdir() + # Handle deep_gemm installation + deep_gemm_dir = pkg_dir / "deep_gemm" + if deep_gemm_dir.is_symlink(): + deep_gemm_dir.unlink() + elif deep_gemm_dir.is_dir(): + clear_folder(deep_gemm_dir) + deep_gemm_dir.rmdir() + bin_dir = pkg_dir / "bin" if bin_dir.exists(): clear_folder(bin_dir) @@ -684,6 +694,14 @@ def main(*, build_dir / "tensorrt_llm/deep_ep/nvshmem-build/src/lib/nvshmem_transport_ibgda.so.103", lib_dir / "nvshmem") + + install_file(get_binding_lib("deep_gemm", "deep_gemm_cpp_tllm"), + pkg_dir) + install_tree(build_dir / "tensorrt_llm" / "deep_gemm" / "python" / + "deep_gemm", + deep_gemm_dir, + dirs_exist_ok=True) + if not skip_stubs: with working_directory(project_dir): if binding_type == "nanobind": @@ -757,6 +775,9 @@ def main(*, build_run( f"\"{venv_python}\" -m pybind11_stubgen -o . deep_ep_cpp_tllm --exit-code", env=env_ld) + build_run( + f"\"{venv_python}\" -m pybind11_stubgen -o . deep_gemm_cpp_tllm --exit-code", + env=env_ld) if not skip_building_wheel: if dist_dir is None: diff --git a/setup.py b/setup.py index c436dfd834..d3293c4bee 100644 --- a/setup.py +++ b/setup.py @@ -107,7 +107,8 @@ else: 'libs/libdecoder_attention_1.so', 'libs/nvshmem/License.txt', 'libs/nvshmem/nvshmem_bootstrap_uid.so.3', 'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so', - 'deep_ep/LICENSE', 'deep_ep_cpp_tllm.*.so', "include/**/*" + 'deep_ep/LICENSE', 'deep_ep_cpp_tllm.*.so', "include/**/*", + 'deep_gemm/LICENSE', 'deep_gemm/include/**/*', 'deep_gemm_cpp_tllm.*.so' ] package_data += [ diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 3721a5d2af..bcdf8d4415 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -1,12 +1,12 @@ from typing import List, Optional, Union -import deep_gemm import torch import torch.nn.functional as F import triton import triton.language as tl import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils +from tensorrt_llm import deep_gemm from tensorrt_llm._utils import nvtx_range from ...distributed import allgather diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 9653a3530e..44d69076fc 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -573,7 +573,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase): assert input.dtype == torch.bfloat16 if get_sm_version() == 100: - import deep_gemm + from tensorrt_llm import deep_gemm a, a_sf = fp8_utils.per_token_quant_and_transform(input) output = torch.empty((input.shape[0], module.weight.shape[0]), device=input.device, diff --git a/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py b/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py index 0fb78c01f3..9d8dfa14f3 100644 --- a/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py +++ b/tests/unittest/_torch/thop/test_fp8_block_scale_gemm.py @@ -50,7 +50,7 @@ def test_fp8_block_scale_deep_gemm(dtype, m, k, n): act_b_fp8, act_b_sf = per_block_cast_to_fp8_e8m0(b) output_expected = a @ b.t() - import deep_gemm + from tensorrt_llm import deep_gemm output = torch.empty((act_a_fp8.shape[0], act_b_fp8.shape[0]), device=act_a_fp8.device, dtype=torch.bfloat16) From 80f918cc22cb7ba801fce4374963545953cb322c Mon Sep 17 00:00:00 2001 From: Hanjun Cho Date: Wed, 6 Aug 2025 18:02:35 +0900 Subject: [PATCH 5/5] [None][feat] Add Qwen3 MoE support to TensorRT backend (#6470) Signed-off-by: gkswns0531 Signed-off-by: hanjuncho Co-authored-by: bhsueh_NV <11360707+byshiue@users.noreply.github.com> --- tensorrt_llm/models/__init__.py | 1 + tensorrt_llm/models/qwen/config.py | 15 ++++- tensorrt_llm/models/qwen/convert.py | 98 +++++++++++++++-------------- tensorrt_llm/models/qwen/model.py | 39 ++++++++---- 4 files changed, 92 insertions(+), 61 deletions(-) diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index 2ca427ab47..96bd4eff96 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -196,6 +196,7 @@ MODEL_MAP = { 'Qwen2VLForConditionalGeneration': QWenForCausalLM, 'Qwen2VLModel': QWenForCausalLM, 'Qwen3ForCausalLM': QWenForCausalLM, + 'Qwen3MoeForCausalLM': QWenForCausalLM, 'WhisperEncoder': WhisperEncoder, 'EncoderModel': EncoderModel, 'DecoderModel': DecoderModel, diff --git a/tensorrt_llm/models/qwen/config.py b/tensorrt_llm/models/qwen/config.py index c9e57ecdf6..e2c2290953 100644 --- a/tensorrt_llm/models/qwen/config.py +++ b/tensorrt_llm/models/qwen/config.py @@ -32,6 +32,8 @@ class QWenConfig(PretrainedConfig): use_logn_attn: bool = False, moe: Optional[Union[MoeConfig, dict]] = None, num_labels: int = 1, + mlp_only_layers: Optional[list] = None, + decoder_sparse_step: int = 1, **kwargs): self.mlp_bias = mlp_bias self.attn_bias = attn_bias @@ -40,6 +42,8 @@ class QWenConfig(PretrainedConfig): self.disable_weight_only_quant_plugin = disable_weight_only_quant_plugin self.num_labels = num_labels self.use_logn_attn = use_logn_attn + self.mlp_only_layers = mlp_only_layers or [] + self.decoder_sparse_step = decoder_sparse_step if moe is None: # Legacy MOE config fields moe = MoeConfig(num_experts=kwargs.pop('moe_num_experts', 0), @@ -64,6 +68,8 @@ class QWenConfig(PretrainedConfig): output[ 'disable_weight_only_quant_plugin'] = self.disable_weight_only_quant_plugin output['use_logn_attn'] = self.use_logn_attn + output['mlp_only_layers'] = self.mlp_only_layers + output['decoder_sparse_step'] = self.decoder_sparse_step output['moe'] = self.moe.to_dict() return output @@ -114,7 +120,7 @@ class QWenConfig(PretrainedConfig): hf_config.hidden_size // hf_config.num_attention_heads) head_size = getattr(hf_config, "kv_channels", head_dim) hidden_act = getattr(hf_config, "hidden_act", "silu") - if qwen_type == "qwen2_moe": + if qwen_type in ("qwen2_moe", "qwen3_moe"): hidden_act = "swiglu" # Qwen3 models have no attention bias, while legacy models have bias @@ -144,6 +150,11 @@ class QWenConfig(PretrainedConfig): moe_shared_expert_intermediate_size = getattr( hf_config, "shared_expert_intermediate_size", 0) moe_normalization_mode = MoeConfig.ExpertScaleNormalizationMode.NONE + + # Add support for mlp_only_layers and decoder_sparse_step (Qwen3 MoE) + mlp_only_layers = getattr(hf_config, "mlp_only_layers", []) + decoder_sparse_step = getattr(hf_config, "decoder_sparse_step", 1) + moe_config = MoeConfig(num_experts=moe_num_experts, top_k=moe_top_k, normalization_mode=moe_normalization_mode) @@ -189,6 +200,8 @@ class QWenConfig(PretrainedConfig): moe_intermediate_size=moe_intermediate_size, moe_shared_expert_intermediate_size= moe_shared_expert_intermediate_size, + mlp_only_layers=mlp_only_layers, + decoder_sparse_step=decoder_sparse_step, moe=moe_config, mapping=mapping, quantization=quant_config, diff --git a/tensorrt_llm/models/qwen/convert.py b/tensorrt_llm/models/qwen/convert.py index 1e9f045304..0188b4c0d5 100644 --- a/tensorrt_llm/models/qwen/convert.py +++ b/tensorrt_llm/models/qwen/convert.py @@ -714,57 +714,58 @@ def convert_hf_qwen(hf_model, dtype, use_gemm_woq_plugin)) - if qwen_type == "qwen2_moe" and moe_config and moe_config.has_moe(): + if moe_config and moe_config.has_moe(): + if qwen_type == "qwen2_moe": + # shared_expert for qwen2_moe + shared_expert_up_proj = model_params[ + f'model.layers.{l}.mlp.shared_expert.up_proj.weight'] + shared_expert_down_proj = model_params[ + f'model.layers.{l}.mlp.shared_expert.down_proj.weight'] + shared_expert_gate = model_params[ + f'model.layers.{l}.mlp.shared_expert.gate_proj.weight'] + shared_expert_up_proj = split(shared_expert_up_proj, + mapping.tp_size, + mapping.tp_rank, + dim=0) + shared_expert_down_proj = split(shared_expert_down_proj, + mapping.tp_size, + mapping.tp_rank, + dim=1) + shared_expert_gate = split(shared_expert_gate, + mapping.tp_size, + mapping.tp_rank, + dim=0) + shared_expert_gate_up_proj = torch.concat( + [shared_expert_up_proj, shared_expert_gate], + dim=-2).to(dtype) - # shared_expert for qwen2_moe - shared_expert_up_proj = model_params[ - f'model.layers.{l}.mlp.shared_expert.up_proj.weight'] - shared_expert_down_proj = model_params[ - f'model.layers.{l}.mlp.shared_expert.down_proj.weight'] - shared_expert_gate = model_params[ - f'model.layers.{l}.mlp.shared_expert.gate_proj.weight'] - shared_expert_up_proj = split(shared_expert_up_proj, - mapping.tp_size, - mapping.tp_rank, - dim=0) - shared_expert_down_proj = split(shared_expert_down_proj, - mapping.tp_size, - mapping.tp_rank, - dim=1) - shared_expert_gate = split(shared_expert_gate, - mapping.tp_size, - mapping.tp_rank, - dim=0) - shared_expert_gate_up_proj = torch.concat( - [shared_expert_up_proj, shared_expert_gate], dim=-2).to(dtype) + ## mlp.shared_expert.gate_up_proj.weight + weights.update( + get_tllm_linear_weight(shared_expert_gate_up_proj, + tllm_prex + 'mlp.shared_expert.fc.', + None, use_weight_only, + plugin_weight_only_quant_type, dtype, + use_gemm_woq_plugin)) - ## mlp.shared_expert.gate_up_proj.weight - weights.update( - get_tllm_linear_weight(shared_expert_gate_up_proj, - tllm_prex + 'mlp.shared_expert.fc.', - None, use_weight_only, - plugin_weight_only_quant_type, dtype, - use_gemm_woq_plugin)) + ## mlp.shared_expert.down_proj.weight + weights.update( + get_tllm_linear_weight( + shared_expert_down_proj.to(dtype), + tllm_prex + 'mlp.shared_expert.proj.', None, + use_weight_only, plugin_weight_only_quant_type, dtype, + use_gemm_woq_plugin)) - ## mlp.shared_expert.down_proj.weight - weights.update( - get_tllm_linear_weight(shared_expert_down_proj.to(dtype), - tllm_prex + 'mlp.shared_expert.proj.', - None, use_weight_only, - plugin_weight_only_quant_type, dtype, - use_gemm_woq_plugin)) - - moe_shared_expert_gate_weights = get_weight( - model_params, prefix + 'mlp.shared_expert_gate', dtype) - weights.update( - get_tllm_linear_weight( - moe_shared_expert_gate_weights, - tllm_prex + 'mlp.shared_expert_gate.', - None, - False, # Router should never be quantized - plugin_weight_only_quant_type, - dtype, - use_gemm_woq_plugin)) + moe_shared_expert_gate_weights = get_weight( + model_params, prefix + 'mlp.shared_expert_gate', dtype) + weights.update( + get_tllm_linear_weight( + moe_shared_expert_gate_weights, + tllm_prex + 'mlp.shared_expert_gate.', + None, + False, # Router should never be quantized + plugin_weight_only_quant_type, + dtype, + use_gemm_woq_plugin)) ## fine-grained experts rank_experts = list(range(moe_config.num_experts)) @@ -811,6 +812,7 @@ def convert_hf_qwen(hf_model, plugin_weight_only_quant_type, dtype, use_gemm_woq_plugin)) + else: mlp_gate_weight = get_weight(model_params, prefix + key_list[2], dtype) diff --git a/tensorrt_llm/models/qwen/model.py b/tensorrt_llm/models/qwen/model.py index 0fb003a90c..0eb6e8ac44 100644 --- a/tensorrt_llm/models/qwen/model.py +++ b/tensorrt_llm/models/qwen/model.py @@ -90,11 +90,15 @@ class QWenDecoderLayer(Module): if config.moe.has_moe(): mlp_kwargs = {'moe_config': config.moe, 'mapping': config.mapping} if config.qwen_type == 'qwen2_moe': + # Qwen2 MoE uses SharedMoE with shared expert ClsMLP = SharedMoE mlp_kwargs['use_shared_gate'] = True mlp_kwargs['use_side_stream'] = True mlp_kwargs['moe_config'].shared_expert_intermediate_size = \ config.moe_shared_expert_intermediate_size + elif config.qwen_type == 'qwen3_moe': + # Qwen3 MoE uses standard MOE without shared expert + ClsMLP = MOE else: ClsMLP = MOE else: @@ -104,7 +108,7 @@ class QWenDecoderLayer(Module): # Qwen's real inter_size depends on qwen_type if self.config.qwen_type == 'qwen': intermediate_size = config.intermediate_size // 2 - elif self.config.qwen_type == 'qwen2_moe': + elif self.config.qwen_type in ('qwen2_moe', 'qwen3_moe'): intermediate_size = config.moe_intermediate_size else: intermediate_size = config.intermediate_size @@ -264,18 +268,11 @@ class QWenForCausalLM(DecoderModelForCausalLM): "mlp_4h_to_h": "mlp.c_proj", "mlp_gate": "w1", } - elif config.qwen_type == 'qwen2_moe': + elif config.qwen_type in ('qwen2_moe', 'qwen3_moe'): self.trtllm_modules_to_hf_modules = copy.copy( get_default_trtllm_modules_to_hf_modules()) + # Common MoE expert mappings for both Qwen2 and Qwen3 MoE self.trtllm_modules_to_hf_modules.update({ - "mlp_h_to_4h": - "mlp.shared_expert.gate_proj", - "mlp_4h_to_h": - "mlp.shared_expert.down_proj", - "mlp_gate": - "mlp.shared_expert.up_proj", - "mlp_router": - "mlp.shared_expert_gate", "moe_h_to_4h": "mlp.experts.gate_proj", "moe_4h_to_h": @@ -283,6 +280,18 @@ class QWenForCausalLM(DecoderModelForCausalLM): "moe_gate": "mlp.experts.up_proj", }) + # Qwen2 MoE additionally has shared expert + if config.qwen_type == 'qwen2_moe': + self.trtllm_modules_to_hf_modules.update({ + "mlp_h_to_4h": + "mlp.shared_expert.gate_proj", + "mlp_4h_to_h": + "mlp.shared_expert.down_proj", + "mlp_gate": + "mlp.shared_expert.up_proj", + "mlp_router": + "mlp.shared_expert_gate", + }) else: self.trtllm_modules_to_hf_modules = None super().__init__(config, transformer, lm_head) @@ -343,6 +352,12 @@ class QWenForCausalLM(DecoderModelForCausalLM): "mlp.shared_expert_gate": "mlp.shared_expert_gate", "fc": ["up_proj", "gate_proj"], } + elif config.qwen_type == "qwen3_moe": + custom_dict = { + "fc": ["up_proj", "gate_proj"], + "q_layernorm": "q_norm", + "k_layernorm": "k_norm", + } elif config.qwen_type in {"qwen2", "qwen2_vl" } and config.tie_word_embeddings: custom_dict = {"lm_head": "model.embed_tokens"} @@ -360,7 +375,7 @@ class QWenForCausalLM(DecoderModelForCausalLM): "transformer": "language_model.model", "lm_head": "language_model.lm_head", } - elif config.qwen_type in ("qwen3", "qwen3_moe"): + elif config.qwen_type == "qwen3": custom_dict = { "q_layernorm": "q_norm", "k_layernorm": "k_norm", @@ -412,7 +427,7 @@ class QWenForCausalLM(DecoderModelForCausalLM): loader.load(tllm_key, custom_postprocess_kwargs=arg_dict)) loader.fill(tllm_weights) - elif config.qwen_type == "qwen2_moe": + elif config.qwen_type in ("qwen2_moe", "qwen3_moe"): for tllm_key, _ in model.named_parameters(): sub_module = model for attr in tllm_key.split(".")[:-1]: