diff --git a/cpp/tensorrt_llm/nanobind/runtime/hostfunc.cpp b/cpp/tensorrt_llm/nanobind/runtime/hostfunc.cpp index 9319b58e5a..b6f9f47c4a 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/hostfunc.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/hostfunc.cpp @@ -78,9 +78,13 @@ std::optional launchHostFunc( { auto const stream = reinterpret_cast(streamPtr); + nb::gil_scoped_acquire gil; + auto hostFuncUserData = std::make_unique(freeUserData, pyHostFunc, nb::tuple(pyArgs), nb::dict(pyKwargs)); + nb::gil_scoped_release release; + cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get()); if (err != cudaSuccess) { @@ -110,6 +114,7 @@ void initHostFuncBindings(nb::module_& m) { m.def("launch_hostfunc", &launchHostFunc, "Launch a Python host function to a CUDA stream", nb::call_guard()); - m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function"); + m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function", + nb::call_guard()); } } // namespace tensorrt_llm::nanobind::runtime diff --git a/cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp b/cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp index 7704ff2fd1..8839e9b8b6 100644 --- a/cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/hostfunc.cpp @@ -78,9 +78,13 @@ std::optional launchHostFunc( { auto const stream = reinterpret_cast(streamPtr); + py::gil_scoped_acquire gil; + auto hostFuncUserData = std::make_unique(freeUserData, pyHostFunc, py::tuple(pyArgs), py::dict(pyKwargs)); + py::gil_scoped_release release; + cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get()); if (err != cudaSuccess) { @@ -110,6 +114,7 @@ void initHostFuncBindings(pybind11::module_& m) { m.def("launch_hostfunc", &launchHostFunc, "Launch a Python host function to a CUDA stream", py::call_guard()); - m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function"); + m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function", + py::call_guard()); } } // namespace tensorrt_llm::pybind::runtime diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 2405d3e5fc..fbde925a21 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -2178,32 +2178,6 @@ if IS_CUTLASS_DSL_AVAILABLE: device=input_scale.device) return output, output_scale - class FusedMoEInputsHelper: - - def __init__(self, num_experts: int, top_k: int, num_local_experts: int, - local_expert_offset: int): - self.num_experts = num_experts - self.top_k = top_k - self.num_local_experts = num_local_experts - self.local_expert_offset = local_expert_offset - - def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int: - return input_shapes[0][0] - - def inputs_pre_hook(self, - inputs: List[torch.Tensor]) -> List[torch.Tensor]: - x, x_sf, token_selected_experts, token_final_scales, *others = inputs - num_tokens = token_selected_experts.size(0) - new_token_final_scales, new_token_selected_experts = torch.randn( - num_tokens, - self.num_experts, - device=token_selected_experts.device).topk(self.top_k, dim=-1) - new_token_selected_experts = new_token_selected_experts.to( - token_selected_experts.dtype) - new_token_final_scales = new_token_final_scales.softmax(dim=-1).to( - token_final_scales.dtype) - return x, x_sf, new_token_selected_experts, new_token_final_scales, *others - class Sm100BlockScaledFusedMoERunner(TunableRunner): tuning_config_cache = dict() diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index 3540f91550..5dd84c57ef 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -35,44 +35,18 @@ import cutlass.pipeline as pipeline import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils -from cutlass._mlir.dialects import math, nvvm +from cutlass._mlir.dialects import math from cutlass.cute.nvgpu import cpasync, tcgen05 -from cutlass.cute.typing import Float32 -from cutlass.cutlass_dsl import T, dsl_user_op from .custom_pipeline import PipelineCpAsyncUmma -from .utils import is_power_of_2 - - -@dsl_user_op -def fmin( - a: Union[float, Float32], b: Union[float, Float32], *, nan=False, loc=None, ip=None -) -> Float32: - return Float32( - nvvm.fmin( - T.f32(), - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - nan=nan, - loc=loc, - ip=ip, - ) - ) - - -def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]: - """ - Compute the sigmoid of the input tensor. - """ - return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath)) - - -def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]: - """ - Compute the silu of the input tensor. - """ - return a * sigmoid_f32(a, fastmath=fastmath) - +from .utils import ( + TRTLLM_ENABLE_PDL, + fmin, + griddepcontrol_launch_dependents, + griddepcontrol_wait, + is_power_of_2, + silu_f32, +) """ High-performance persistent blockscaled contiguous grouped dense GEMM with gather and SwiGLU fusion @@ -819,6 +793,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel: smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, + use_pdl=TRTLLM_ENABLE_PDL, ) return @@ -1148,6 +1123,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel: else: self.cta_sync_barrier.arrive_and_wait() + griddepcontrol_wait() + # # Specialized Schedule warp # @@ -2282,6 +2259,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel: # c_pipeline.producer_tail() + griddepcontrol_launch_dependents() + def epilog_tmem_copy_and_partition( self, tidx: cutlass.Int32, diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py index b6ea02cf36..be62291cd3 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py @@ -52,7 +52,12 @@ import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils from cutlass.cute.nvgpu import cpasync, tcgen05 -from .utils import is_power_of_2 +from .utils import ( + TRTLLM_ENABLE_PDL, + griddepcontrol_launch_dependents, + griddepcontrol_wait, + is_power_of_2, +) class Sm100BlockScaledContiguousGroupedGemmKernel: @@ -597,6 +602,7 @@ class Sm100BlockScaledContiguousGroupedGemmKernel: smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, + use_pdl=TRTLLM_ENABLE_PDL, ) return @@ -933,6 +939,8 @@ class Sm100BlockScaledContiguousGroupedGemmKernel: else: self.cta_sync_barrier.arrive_and_wait() + griddepcontrol_wait() + # # Specialized Schedule warp # @@ -1597,6 +1605,8 @@ class Sm100BlockScaledContiguousGroupedGemmKernel: # c_pipeline.producer_tail() + griddepcontrol_launch_dependents() + def epilog_tmem_copy_and_partition( self, tidx: cutlass.Int32, diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index 08fe6c91e8..d424e00fd8 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -35,11 +35,17 @@ import cutlass.pipeline as pipeline import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils -from cutlass._mlir.dialects import llvm from cutlass.cute.nvgpu import cpasync, tcgen05 -from cutlass.cutlass_dsl import Int32, T, dsl_user_op -from .utils import is_power_of_2 +from .utils import ( + TRTLLM_ENABLE_PDL, + atomic_add_func, + griddepcontrol_launch_dependents, + griddepcontrol_wait, + is_power_of_2, + vectorized_atomic_add_bf16x8, + vectorized_atomic_add_fp32x2, +) """ High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for @@ -259,8 +265,8 @@ def hooked_PersistentTileSchedulerParams_init( def hooked_get_cluster_work_idx_with_fastdivmod( - self, current_work_linear_idx: Int32, *, loc=None, ip=None -) -> Tuple[Int32, Int32, Int32]: + self, current_work_linear_idx: cutlass.Int32, *, loc=None, ip=None +) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd) if self.params._raster_along_m: @@ -287,69 +293,6 @@ cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmo ) -# TODO(zhichenj): try to move these to NVVM wrapper or helper functions -@dsl_user_op -def vectorized_atomic_add_bf16x8(rOut_epi_packed, scatter_out_offset, loc=None, ip=None): - llvm.inline_asm( - None, - [ - scatter_out_offset.iterator.llvm_ptr, - llvm.bitcast(T.i32(), rOut_epi_packed[0, None].load().ir_value()), - llvm.bitcast(T.i32(), rOut_epi_packed[1, None].load().ir_value()), - llvm.bitcast(T.i32(), rOut_epi_packed[2, None].load().ir_value()), - llvm.bitcast(T.i32(), rOut_epi_packed[3, None].load().ir_value()), - ], - "red.global.v4.bf16x2.add.noftz [$0], {$1, $2, $3, $4};", - "l,r,r,r,r", - has_side_effects=True, - ) - - -@dsl_user_op -def vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset, loc=None, ip=None): - llvm.inline_asm( - None, - [ - scatter_out_offset.iterator.llvm_ptr, - rOut_epi_packed[0].ir_value(), - rOut_epi_packed[1].ir_value(), - ], - "red.global.v2.f32.add [$0], {$1, $2};", - "l,f,f", - has_side_effects=True, - ) - - -@dsl_user_op -def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None): - if cutlass.const_expr(rOut_epi_packed.dtype == cutlass.Float32): - llvm.inline_asm( - None, - [ - scatter_out_offset.iterator.llvm_ptr, - rOut_epi_packed.ir_value(), - ], - "red.global.add.f32 [$0], $1;", - "l,f", - has_side_effects=True, - loc=loc, - ip=ip, - ) - elif cutlass.const_expr(rOut_epi_packed.dtype == cutlass.BFloat16): - llvm.inline_asm( - None, - [ - scatter_out_offset.iterator.llvm_ptr, - llvm.bitcast(T.i16(), rOut_epi_packed.ir_value()), - ], - "red.add.noftz.bf16 [$0], $1;", - "l,h", - has_side_effects=True, - loc=loc, - ip=ip, - ) - - class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. @@ -931,6 +874,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, + use_pdl=TRTLLM_ENABLE_PDL, ) return @@ -1286,6 +1230,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: else: self.cta_sync_barrier.arrive_and_wait() + griddepcontrol_wait() + # # Specialized Schedule warp # @@ -1940,6 +1886,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: self.epilog_sync_barrier.arrive_and_wait() tmem.free(tmem_ptr) + griddepcontrol_launch_dependents() + def epilog_tmem_copy_and_partition( self, tidx: cutlass.Int32, diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py index 12a37c31b8..815996707a 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py @@ -35,43 +35,17 @@ import cutlass.pipeline as pipeline import cutlass.utils as utils import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils -from cutlass._mlir.dialects import math, nvvm +from cutlass._mlir.dialects import math from cutlass.cute.nvgpu import cpasync, tcgen05 -from cutlass.cute.typing import Float32 -from cutlass.cutlass_dsl import T, dsl_user_op - -from .utils import is_power_of_2 - - -@dsl_user_op -def fmin( - a: Union[float, Float32], b: Union[float, Float32], *, nan=False, loc=None, ip=None -) -> Float32: - return Float32( - nvvm.fmin( - T.f32(), - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - nan=nan, - loc=loc, - ip=ip, - ) - ) - - -def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]: - """ - Compute the sigmoid of the input tensor. - """ - return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath)) - - -def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]: - """ - Compute the silu of the input tensor. - """ - return a * sigmoid_f32(a, fastmath=fastmath) +from .utils import ( + TRTLLM_ENABLE_PDL, + fmin, + griddepcontrol_launch_dependents, + griddepcontrol_wait, + is_power_of_2, + silu_f32, +) """ High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for @@ -749,6 +723,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel: smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, + use_pdl=TRTLLM_ENABLE_PDL, ) return @@ -1087,6 +1062,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel: else: self.cta_sync_barrier.arrive_and_wait() + griddepcontrol_wait() + # # Specialized Schedule warp # @@ -1949,6 +1926,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel: # c_pipeline.producer_tail() + griddepcontrol_launch_dependents() + def epilog_tmem_copy_and_partition( self, tidx: cutlass.Int32, diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py index 913473cf20..e143812105 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py @@ -55,7 +55,8 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils from cutlass.cute.nvgpu import cpasync, tcgen05 from .custom_pipeline import PipelineTmaUmma, PipelineUmmaAsync -from .utils import is_power_of_2 +from .utils import (TRTLLM_ENABLE_PDL, griddepcontrol_launch_dependents, + griddepcontrol_wait, is_power_of_2) class Sm100BlockScaledPersistentDenseGemmKernel: @@ -578,6 +579,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: smem=self.shared_storage.size_in_bytes(), min_blocks_per_mp=1, stream=stream, + use_pdl=TRTLLM_ENABLE_PDL, ) return @@ -869,6 +871,8 @@ class Sm100BlockScaledPersistentDenseGemmKernel: cute.arch.barrier(barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta) + griddepcontrol_wait() + # # Specialized TMA load warp # @@ -1473,6 +1477,8 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # c_pipeline.producer_tail() + griddepcontrol_launch_dependents() + def mainloop_s2t_copy_and_partition( self, sSF: cute.Tensor, diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py index 347eece9d9..afc3430875 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py @@ -44,11 +44,18 @@ # This file is copied and modified from cutlass https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/core.py import ctypes +import os from typing import Union +import cutlass import cutlass._mlir.dialects.cute as _cute_ir +import cutlass.cute as cute from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm, nvvm from cutlass.cute.typing import AddressSpace, Numeric, Pointer, Type +from cutlass.cutlass_dsl import T, dsl_user_op + +TRTLLM_ENABLE_PDL = os.environ.get("TRTLLM_ENABLE_PDL", "0") == "1" # WAR for CuTeDSL make_ptr implementation @@ -190,3 +197,150 @@ def make_ptr( def is_power_of_2(x: int) -> bool: return x > 0 and (x & (x - 1)) == 0 + + +@dsl_user_op +def fmin(a: Union[float, cutlass.Float32], + b: Union[float, cutlass.Float32], + *, + nan=False, + loc=None, + ip=None) -> cutlass.Float32: + return cutlass.Float32( + nvvm.fmin( + T.f32(), + cutlass.Float32(a).ir_value(loc=loc, ip=ip), + cutlass.Float32(b).ir_value(loc=loc, ip=ip), + nan=nan, + loc=loc, + ip=ip, + )) + + +def sigmoid_f32(a: Union[float, cutlass.Float32], + fastmath: bool = False) -> Union[float, cutlass.Float32]: + """ + Compute the sigmoid of the input tensor. + """ + return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath)) + + +def silu_f32(a: Union[float, cutlass.Float32], + fastmath: bool = False) -> Union[float, cutlass.Float32]: + """ + Compute the silu of the input tensor. + """ + return a * sigmoid_f32(a, fastmath=fastmath) + + +# TODO(zhichenj): try to move these to NVVM wrapper or helper functions +@dsl_user_op +def vectorized_atomic_add_bf16x8(rOut_epi_packed, + scatter_out_offset, + loc=None, + ip=None): + llvm.inline_asm( + None, + [ + scatter_out_offset.iterator.llvm_ptr, + llvm.bitcast(T.i32(), rOut_epi_packed[0, None].load().ir_value()), + llvm.bitcast(T.i32(), rOut_epi_packed[1, None].load().ir_value()), + llvm.bitcast(T.i32(), rOut_epi_packed[2, None].load().ir_value()), + llvm.bitcast(T.i32(), rOut_epi_packed[3, None].load().ir_value()), + ], + "red.global.v4.bf16x2.add.noftz [$0], {$1, $2, $3, $4};", + "l,r,r,r,r", + has_side_effects=True, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def vectorized_atomic_add_fp32x2(rOut_epi_packed, + scatter_out_offset, + loc=None, + ip=None): + llvm.inline_asm( + None, + [ + scatter_out_offset.iterator.llvm_ptr, + rOut_epi_packed[0].ir_value(), + rOut_epi_packed[1].ir_value(), + ], + "red.global.v2.f32.add [$0], {$1, $2};", + "l,f,f", + has_side_effects=True, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None): + if cutlass.const_expr(rOut_epi_packed.dtype == cutlass.Float32): + llvm.inline_asm( + None, + [ + scatter_out_offset.iterator.llvm_ptr, + rOut_epi_packed.ir_value(), + ], + "red.global.add.f32 [$0], $1;", + "l,f", + has_side_effects=True, + loc=loc, + ip=ip, + ) + elif cutlass.const_expr(rOut_epi_packed.dtype == cutlass.BFloat16): + llvm.inline_asm( + None, + [ + scatter_out_offset.iterator.llvm_ptr, + llvm.bitcast(T.i16(), rOut_epi_packed.ir_value()), + ], + "red.add.noftz.bf16 [$0], $1;", + "l,h", + has_side_effects=True, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def griddepcontrol_wait(*, loc=None, ip=None) -> None: + """ + This instruction is used to wait for the previous kernel's grid ending + (all blocks of the previous kernel have finished and memflushed), i.e., + the instruction after this instruction will not be issued until the previous + grid has finished. + """ + llvm.inline_asm( + res=None, + operands_=[], + asm_string="griddepcontrol.wait;", + constraints="", + has_side_effects=True, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def griddepcontrol_launch_dependents(*, loc=None, ip=None) -> None: + """ + Issuing the launch_dependents instruction hints a dependent kernel to launch earlier. + launch_dependents doesn't impact the functionality but the performance: + Launching a dependent kernel too early can compete with current kernels, + while launching too late can lead to a long latency. + """ + llvm.inline_asm( + res=None, + operands_=[], + asm_string="griddepcontrol.launch_dependents;", + constraints="", + has_side_effects=True, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 605972ab5c..11b92dbb8d 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1452,12 +1452,13 @@ class DeepseekV3Model(DecoderModel): config = model_config.pretrained_config self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers - aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + aux_stream_list = [torch.cuda.Stream() for _ in range(4)] self.aux_stream_dict = { AuxStreamType.Attention: aux_stream_list[0], AuxStreamType.MoeShared: aux_stream_list[0], AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], AuxStreamType.MoeBalancer: aux_stream_list[2], + AuxStreamType.MoeOutputMemset: aux_stream_list[3], } self.embed_tokens = Embedding( diff --git a/tensorrt_llm/_torch/models/modeling_glm.py b/tensorrt_llm/_torch/models/modeling_glm.py index 868e43195b..94ae57cef9 100644 --- a/tensorrt_llm/_torch/models/modeling_glm.py +++ b/tensorrt_llm/_torch/models/modeling_glm.py @@ -890,12 +890,13 @@ class Glm4Model(DecoderModel): config = model_config.pretrained_config self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers - aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + aux_stream_list = [torch.cuda.Stream() for _ in range(4)] self.aux_stream_dict = { AuxStreamType.Attention: aux_stream_list[0], AuxStreamType.MoeShared: aux_stream_list[0], AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], AuxStreamType.MoeBalancer: aux_stream_list[2], + AuxStreamType.MoeOutputMemset: aux_stream_list[3], } self.embed_tokens = Embedding( diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index e05ad149bd..04190199d8 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -327,6 +327,7 @@ class Qwen3MoEModel(DecoderModel): self.aux_stream_dict = { AuxStreamType.MoeChunkingOverlap: torch.cuda.Stream(), AuxStreamType.MoeBalancer: torch.cuda.Stream(), + AuxStreamType.MoeOutputMemset: torch.cuda.Stream(), } self.preload_weight_modules = [] if config.moe_backend == "TRTLLM": diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 8adb412d01..b4e1dc7503 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -487,11 +487,13 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel, def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.model_config = model_config - aux_stream_list = [torch.cuda.Stream() for _ in range(2)] + aux_stream_list = [torch.cuda.Stream() for _ in range(4)] self.aux_stream_dict = { AuxStreamType.Attention: aux_stream_list[0], AuxStreamType.MoeShared: aux_stream_list[0], AuxStreamType.MoeChunkingOverlap: aux_stream_list[1], + AuxStreamType.MoeBalancer: aux_stream_list[2], + AuxStreamType.MoeOutputMemset: aux_stream_list[3], } super().__init__( MTPDraftModel(self.model_config, diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index 7aa51a938d..a43ae9fa25 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -175,7 +175,6 @@ class ConfigurableMoE(MoE): assert not torch.compiler.is_compiling(), ( "Backend should not be none if not in torch.compile" ) - self.backend.aux_stream_dict = self.aux_stream_dict self.backend.layer_idx = self.layer_idx self.backend.layer_idx_str = self.layer_idx_str self.backend.num_slots = self.num_slots diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index 2cec8a269e..2480ff22e8 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -8,7 +8,7 @@ from tensorrt_llm._utils import is_sm_100f from ...distributed import allgather from ...model_config import ModelConfig -from ...utils import AuxStreamType, Fp4QuantizedTensor +from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE from .interface import AlltoallMethodType from .quantization import MoEWeightLoadingMode, NVFP4CuteDslFusedMoEMethod @@ -183,7 +183,6 @@ class CuteDslFusedMoE(CutlassFusedMoE): init_load_balancer: bool = True, without_comm: bool = False, ): - super().__init__( routing_method=routing_method, num_experts=num_experts, @@ -199,6 +198,16 @@ class CuteDslFusedMoE(CutlassFusedMoE): init_load_balancer=init_load_balancer, without_comm=without_comm, ) + if self.aux_stream_dict is None: + self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {} + if AuxStreamType.MoeOutputMemset not in self.aux_stream_dict: + self.aux_stream_dict[ + AuxStreamType.MoeOutputMemset] = torch.cuda.Stream() + if self.event_dict is None: + self.event_dict = {} + for key in [EventType.Main, EventType.MoeOutputMemset]: + if key not in self.event_dict: + self.event_dict[key] = torch.cuda.Event() def select_alltoall_method_type(self) -> AlltoallMethodType: return AlltoallMethodType.NotEnabled @@ -288,6 +297,11 @@ class CuteDslFusedMoE(CutlassFusedMoE): self.hidden_size) assert moe_output.dtype == output_dtype + if self.use_fused_finalize: + self.event_dict[EventType.Main].record() + moe_output.record_stream( + self.aux_stream_dict[AuxStreamType.MoeOutputMemset]) + x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( input=x.view(torch.float4_e2m1fn_x2), weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2), @@ -307,17 +321,24 @@ class CuteDslFusedMoE(CutlassFusedMoE): ) if self.use_fused_finalize: - torch.ops.trtllm.moe_output_memset_inplace( - input=moe_output, - tile_idx_to_mn_limit=tile_idx_to_mn_limit, - expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, - permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, - num_non_exiting_tiles=num_non_exiting_tiles, - tile_tokens_dim=tile_size, - top_k=self.routing_method.experts_per_token, - ep_size=self.mapping.moe_ep_size, - enable_alltoall=enable_alltoall, - ) + with torch.cuda.stream( + self.aux_stream_dict[AuxStreamType.MoeOutputMemset]): + self.event_dict[EventType.Main].wait() + torch.ops.trtllm.moe_output_memset_inplace( + input=moe_output, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, + num_non_exiting_tiles=num_non_exiting_tiles, + tile_tokens_dim=tile_size, + top_k=self.routing_method.experts_per_token, + ep_size=self.mapping.moe_ep_size, + enable_alltoall=enable_alltoall, + ) + self.event_dict[EventType.MoeOutputMemset].record() + + self.event_dict[EventType.MoeOutputMemset].wait() + torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell( input=x.view(torch.float4_e2m1fn_x2), weight=self.w2_weight.view(torch.float4_e2m1fn_x2), diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index dac655b1c3..1c3c02ca34 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -21,6 +21,7 @@ aux_stream_name_list = [ 'MoeShared', 'MoeChunkingOverlap', 'MoeBalancer', + 'MoeOutputMemset', ] AuxStreamType = Enum( 'AuxStreamType', diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 20daae513a..55430f62fb 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -330,7 +330,6 @@ accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False] SKIP test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] SKIP (https://nvbugs/5648560) test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False] SKIP (https://nvbugs/5648560) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen_adp_lmtp] SKIP (https://nvbugs/5629136) -unittest/bindings/test_hostfunc.py::test_hostfunc SKIP (https://nvbugs/5643631) examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5568052) accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype SKIP (https://nvbugs/5648441) accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype SKIP (https://nvbugs/5648441) diff --git a/tests/unittest/bindings/test_hostfunc.py b/tests/unittest/bindings/test_hostfunc.py index e778dc3486..f208ca0ce9 100644 --- a/tests/unittest/bindings/test_hostfunc.py +++ b/tests/unittest/bindings/test_hostfunc.py @@ -15,6 +15,7 @@ def test_hostfunc(): with torch.cuda.stream(stream): for _ in range(5): increase(x) + torch.cuda.synchronize() g = torch.cuda.CUDAGraph() with torch.cuda.graph(g, stream=stream): @@ -25,7 +26,7 @@ def test_hostfunc(): with torch.cuda.stream(stream): for _ in range(10): g.replay() - torch.cuda.synchronize() + assert (x == 25).all().item() assert len(HOSTFUNC_USER_DATA_HANDLES) == 2 diff --git a/tests/unittest/llmapi/apps/_test_openai_responses.py b/tests/unittest/llmapi/apps/_test_openai_responses.py index a5a26f2067..18271f6b76 100644 --- a/tests/unittest/llmapi/apps/_test_openai_responses.py +++ b/tests/unittest/llmapi/apps/_test_openai_responses.py @@ -94,6 +94,7 @@ async def test_reasoning(client: openai.AsyncOpenAI, model: str): check_reponse(response, "test_reasoning: ") +@pytest.mark.skip(reason="https://nvbugs/5753250") @pytest.mark.asyncio(loop_scope="module") async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str): for effort in ["low", "medium", "high"]: @@ -106,6 +107,7 @@ async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str): check_reponse(response, f"test_reasoning_effort_{effort}: ") +@pytest.mark.skip(reason="https://nvbugs/5753250") @pytest.mark.asyncio(loop_scope="module") async def test_chat(client: openai.AsyncOpenAI, model: str): response = await client.responses.create(model=model, @@ -150,6 +152,7 @@ def get_current_weather(location: str, format: str = "celsius") -> dict: return {"sunny": True, "temperature": 20 if format == "celsius" else 68} +@pytest.mark.skip(reason="https://nvbugs/5753250") @pytest.mark.asyncio(loop_scope="module") async def test_tool_calls(client: openai.AsyncOpenAI, model: str): if model.startswith("DeepSeek-R1"): @@ -201,6 +204,7 @@ async def test_tool_calls(client: openai.AsyncOpenAI, model: str): check_tool_calling(response, False, "test_tool_calls: ") +@pytest.mark.skip(reason="https://nvbugs/5753250") @pytest.mark.asyncio(loop_scope="module") async def test_streaming(client: openai.AsyncOpenAI, model: str): stream = await client.responses.create( @@ -222,6 +226,7 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str): assert full_reasoning_response +@pytest.mark.skip(reason="https://nvbugs/5753250") @pytest.mark.asyncio(loop_scope="module") async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str): if model.startswith("DeepSeek-R1"):