Merge branch 'main' into fix_spec_gate

This commit is contained in:
Zheyu Fu 2025-12-20 15:39:49 -08:00 committed by GitHub
commit 9ce84b8f3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 278 additions and 187 deletions

View File

@ -78,9 +78,13 @@ std::optional<uintptr_t> launchHostFunc(
{
auto const stream = reinterpret_cast<cudaStream_t>(streamPtr);
nb::gil_scoped_acquire gil;
auto hostFuncUserData
= std::make_unique<HostFuncUserData>(freeUserData, pyHostFunc, nb::tuple(pyArgs), nb::dict(pyKwargs));
nb::gil_scoped_release release;
cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get());
if (err != cudaSuccess)
{
@ -110,6 +114,7 @@ void initHostFuncBindings(nb::module_& m)
{
m.def("launch_hostfunc", &launchHostFunc, "Launch a Python host function to a CUDA stream",
nb::call_guard<nb::gil_scoped_release>());
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function");
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function",
nb::call_guard<nb::gil_scoped_release>());
}
} // namespace tensorrt_llm::nanobind::runtime

View File

@ -78,9 +78,13 @@ std::optional<uintptr_t> launchHostFunc(
{
auto const stream = reinterpret_cast<cudaStream_t>(streamPtr);
py::gil_scoped_acquire gil;
auto hostFuncUserData
= std::make_unique<HostFuncUserData>(freeUserData, pyHostFunc, py::tuple(pyArgs), py::dict(pyKwargs));
py::gil_scoped_release release;
cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get());
if (err != cudaSuccess)
{
@ -110,6 +114,7 @@ void initHostFuncBindings(pybind11::module_& m)
{
m.def("launch_hostfunc", &launchHostFunc, "Launch a Python host function to a CUDA stream",
py::call_guard<py::gil_scoped_release>());
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function");
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function",
py::call_guard<py::gil_scoped_release>());
}
} // namespace tensorrt_llm::pybind::runtime

View File

@ -2178,32 +2178,6 @@ if IS_CUTLASS_DSL_AVAILABLE:
device=input_scale.device)
return output, output_scale
class FusedMoEInputsHelper:
def __init__(self, num_experts: int, top_k: int, num_local_experts: int,
local_expert_offset: int):
self.num_experts = num_experts
self.top_k = top_k
self.num_local_experts = num_local_experts
self.local_expert_offset = local_expert_offset
def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int:
return input_shapes[0][0]
def inputs_pre_hook(self,
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
x, x_sf, token_selected_experts, token_final_scales, *others = inputs
num_tokens = token_selected_experts.size(0)
new_token_final_scales, new_token_selected_experts = torch.randn(
num_tokens,
self.num_experts,
device=token_selected_experts.device).topk(self.top_k, dim=-1)
new_token_selected_experts = new_token_selected_experts.to(
token_selected_experts.dtype)
new_token_final_scales = new_token_final_scales.softmax(dim=-1).to(
token_final_scales.dtype)
return x, x_sf, new_token_selected_experts, new_token_final_scales, *others
class Sm100BlockScaledFusedMoERunner(TunableRunner):
tuning_config_cache = dict()

View File

@ -35,44 +35,18 @@ import cutlass.pipeline as pipeline
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass._mlir.dialects import math, nvvm
from cutlass._mlir.dialects import math
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.typing import Float32
from cutlass.cutlass_dsl import T, dsl_user_op
from .custom_pipeline import PipelineCpAsyncUmma
from .utils import is_power_of_2
@dsl_user_op
def fmin(
a: Union[float, Float32], b: Union[float, Float32], *, nan=False, loc=None, ip=None
) -> Float32:
return Float32(
nvvm.fmin(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
nan=nan,
loc=loc,
ip=ip,
)
)
def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
"""
Compute the sigmoid of the input tensor.
"""
return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath))
def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
"""
Compute the silu of the input tensor.
"""
return a * sigmoid_f32(a, fastmath=fastmath)
from .utils import (
TRTLLM_ENABLE_PDL,
fmin,
griddepcontrol_launch_dependents,
griddepcontrol_wait,
is_power_of_2,
silu_f32,
)
"""
High-performance persistent blockscaled contiguous grouped dense GEMM with gather and SwiGLU fusion
@ -819,6 +793,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
smem=self.shared_storage.size_in_bytes(),
stream=stream,
min_blocks_per_mp=1,
use_pdl=TRTLLM_ENABLE_PDL,
)
return
@ -1148,6 +1123,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
else:
self.cta_sync_barrier.arrive_and_wait()
griddepcontrol_wait()
#
# Specialized Schedule warp
#
@ -2282,6 +2259,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
#
c_pipeline.producer_tail()
griddepcontrol_launch_dependents()
def epilog_tmem_copy_and_partition(
self,
tidx: cutlass.Int32,

View File

@ -52,7 +52,12 @@ import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from .utils import is_power_of_2
from .utils import (
TRTLLM_ENABLE_PDL,
griddepcontrol_launch_dependents,
griddepcontrol_wait,
is_power_of_2,
)
class Sm100BlockScaledContiguousGroupedGemmKernel:
@ -597,6 +602,7 @@ class Sm100BlockScaledContiguousGroupedGemmKernel:
smem=self.shared_storage.size_in_bytes(),
stream=stream,
min_blocks_per_mp=1,
use_pdl=TRTLLM_ENABLE_PDL,
)
return
@ -933,6 +939,8 @@ class Sm100BlockScaledContiguousGroupedGemmKernel:
else:
self.cta_sync_barrier.arrive_and_wait()
griddepcontrol_wait()
#
# Specialized Schedule warp
#
@ -1597,6 +1605,8 @@ class Sm100BlockScaledContiguousGroupedGemmKernel:
#
c_pipeline.producer_tail()
griddepcontrol_launch_dependents()
def epilog_tmem_copy_and_partition(
self,
tidx: cutlass.Int32,

View File

@ -35,11 +35,17 @@ import cutlass.pipeline as pipeline
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass._mlir.dialects import llvm
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cutlass_dsl import Int32, T, dsl_user_op
from .utils import is_power_of_2
from .utils import (
TRTLLM_ENABLE_PDL,
atomic_add_func,
griddepcontrol_launch_dependents,
griddepcontrol_wait,
is_power_of_2,
vectorized_atomic_add_bf16x8,
vectorized_atomic_add_fp32x2,
)
"""
High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for
@ -259,8 +265,8 @@ def hooked_PersistentTileSchedulerParams_init(
def hooked_get_cluster_work_idx_with_fastdivmod(
self, current_work_linear_idx: Int32, *, loc=None, ip=None
) -> Tuple[Int32, Int32, Int32]:
self, current_work_linear_idx: cutlass.Int32, *, loc=None, ip=None
) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]:
work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd)
if self.params._raster_along_m:
@ -287,69 +293,6 @@ cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmo
)
# TODO(zhichenj): try to move these to NVVM wrapper or helper functions
@dsl_user_op
def vectorized_atomic_add_bf16x8(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
llvm.inline_asm(
None,
[
scatter_out_offset.iterator.llvm_ptr,
llvm.bitcast(T.i32(), rOut_epi_packed[0, None].load().ir_value()),
llvm.bitcast(T.i32(), rOut_epi_packed[1, None].load().ir_value()),
llvm.bitcast(T.i32(), rOut_epi_packed[2, None].load().ir_value()),
llvm.bitcast(T.i32(), rOut_epi_packed[3, None].load().ir_value()),
],
"red.global.v4.bf16x2.add.noftz [$0], {$1, $2, $3, $4};",
"l,r,r,r,r",
has_side_effects=True,
)
@dsl_user_op
def vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
llvm.inline_asm(
None,
[
scatter_out_offset.iterator.llvm_ptr,
rOut_epi_packed[0].ir_value(),
rOut_epi_packed[1].ir_value(),
],
"red.global.v2.f32.add [$0], {$1, $2};",
"l,f,f",
has_side_effects=True,
)
@dsl_user_op
def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
if cutlass.const_expr(rOut_epi_packed.dtype == cutlass.Float32):
llvm.inline_asm(
None,
[
scatter_out_offset.iterator.llvm_ptr,
rOut_epi_packed.ir_value(),
],
"red.global.add.f32 [$0], $1;",
"l,f",
has_side_effects=True,
loc=loc,
ip=ip,
)
elif cutlass.const_expr(rOut_epi_packed.dtype == cutlass.BFloat16):
llvm.inline_asm(
None,
[
scatter_out_offset.iterator.llvm_ptr,
llvm.bitcast(T.i16(), rOut_epi_packed.ir_value()),
],
"red.add.noftz.bf16 [$0], $1;",
"l,h",
has_side_effects=True,
loc=loc,
ip=ip,
)
class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
@ -931,6 +874,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
smem=self.shared_storage.size_in_bytes(),
stream=stream,
min_blocks_per_mp=1,
use_pdl=TRTLLM_ENABLE_PDL,
)
return
@ -1286,6 +1230,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
else:
self.cta_sync_barrier.arrive_and_wait()
griddepcontrol_wait()
#
# Specialized Schedule warp
#
@ -1940,6 +1886,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
self.epilog_sync_barrier.arrive_and_wait()
tmem.free(tmem_ptr)
griddepcontrol_launch_dependents()
def epilog_tmem_copy_and_partition(
self,
tidx: cutlass.Int32,

View File

@ -35,43 +35,17 @@ import cutlass.pipeline as pipeline
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass._mlir.dialects import math, nvvm
from cutlass._mlir.dialects import math
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.typing import Float32
from cutlass.cutlass_dsl import T, dsl_user_op
from .utils import is_power_of_2
@dsl_user_op
def fmin(
a: Union[float, Float32], b: Union[float, Float32], *, nan=False, loc=None, ip=None
) -> Float32:
return Float32(
nvvm.fmin(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
nan=nan,
loc=loc,
ip=ip,
)
)
def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
"""
Compute the sigmoid of the input tensor.
"""
return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath))
def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
"""
Compute the silu of the input tensor.
"""
return a * sigmoid_f32(a, fastmath=fastmath)
from .utils import (
TRTLLM_ENABLE_PDL,
fmin,
griddepcontrol_launch_dependents,
griddepcontrol_wait,
is_power_of_2,
silu_f32,
)
"""
High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for
@ -749,6 +723,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
smem=self.shared_storage.size_in_bytes(),
stream=stream,
min_blocks_per_mp=1,
use_pdl=TRTLLM_ENABLE_PDL,
)
return
@ -1087,6 +1062,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
else:
self.cta_sync_barrier.arrive_and_wait()
griddepcontrol_wait()
#
# Specialized Schedule warp
#
@ -1949,6 +1926,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
#
c_pipeline.producer_tail()
griddepcontrol_launch_dependents()
def epilog_tmem_copy_and_partition(
self,
tidx: cutlass.Int32,

View File

@ -55,7 +55,8 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from .custom_pipeline import PipelineTmaUmma, PipelineUmmaAsync
from .utils import is_power_of_2
from .utils import (TRTLLM_ENABLE_PDL, griddepcontrol_launch_dependents,
griddepcontrol_wait, is_power_of_2)
class Sm100BlockScaledPersistentDenseGemmKernel:
@ -578,6 +579,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
smem=self.shared_storage.size_in_bytes(),
min_blocks_per_mp=1,
stream=stream,
use_pdl=TRTLLM_ENABLE_PDL,
)
return
@ -869,6 +871,8 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
cute.arch.barrier(barrier_id=self.cta_sync_bar_id,
number_of_threads=self.threads_per_cta)
griddepcontrol_wait()
#
# Specialized TMA load warp
#
@ -1473,6 +1477,8 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
#
c_pipeline.producer_tail()
griddepcontrol_launch_dependents()
def mainloop_s2t_copy_and_partition(
self,
sSF: cute.Tensor,

View File

@ -44,11 +44,18 @@
# This file is copied and modified from cutlass https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/core.py
import ctypes
import os
from typing import Union
import cutlass
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass.cute as cute
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm, nvvm
from cutlass.cute.typing import AddressSpace, Numeric, Pointer, Type
from cutlass.cutlass_dsl import T, dsl_user_op
TRTLLM_ENABLE_PDL = os.environ.get("TRTLLM_ENABLE_PDL", "0") == "1"
# WAR for CuTeDSL make_ptr implementation
@ -190,3 +197,150 @@ def make_ptr(
def is_power_of_2(x: int) -> bool:
return x > 0 and (x & (x - 1)) == 0
@dsl_user_op
def fmin(a: Union[float, cutlass.Float32],
b: Union[float, cutlass.Float32],
*,
nan=False,
loc=None,
ip=None) -> cutlass.Float32:
return cutlass.Float32(
nvvm.fmin(
T.f32(),
cutlass.Float32(a).ir_value(loc=loc, ip=ip),
cutlass.Float32(b).ir_value(loc=loc, ip=ip),
nan=nan,
loc=loc,
ip=ip,
))
def sigmoid_f32(a: Union[float, cutlass.Float32],
fastmath: bool = False) -> Union[float, cutlass.Float32]:
"""
Compute the sigmoid of the input tensor.
"""
return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath))
def silu_f32(a: Union[float, cutlass.Float32],
fastmath: bool = False) -> Union[float, cutlass.Float32]:
"""
Compute the silu of the input tensor.
"""
return a * sigmoid_f32(a, fastmath=fastmath)
# TODO(zhichenj): try to move these to NVVM wrapper or helper functions
@dsl_user_op
def vectorized_atomic_add_bf16x8(rOut_epi_packed,
scatter_out_offset,
loc=None,
ip=None):
llvm.inline_asm(
None,
[
scatter_out_offset.iterator.llvm_ptr,
llvm.bitcast(T.i32(), rOut_epi_packed[0, None].load().ir_value()),
llvm.bitcast(T.i32(), rOut_epi_packed[1, None].load().ir_value()),
llvm.bitcast(T.i32(), rOut_epi_packed[2, None].load().ir_value()),
llvm.bitcast(T.i32(), rOut_epi_packed[3, None].load().ir_value()),
],
"red.global.v4.bf16x2.add.noftz [$0], {$1, $2, $3, $4};",
"l,r,r,r,r",
has_side_effects=True,
loc=loc,
ip=ip,
)
@dsl_user_op
def vectorized_atomic_add_fp32x2(rOut_epi_packed,
scatter_out_offset,
loc=None,
ip=None):
llvm.inline_asm(
None,
[
scatter_out_offset.iterator.llvm_ptr,
rOut_epi_packed[0].ir_value(),
rOut_epi_packed[1].ir_value(),
],
"red.global.v2.f32.add [$0], {$1, $2};",
"l,f,f",
has_side_effects=True,
loc=loc,
ip=ip,
)
@dsl_user_op
def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
if cutlass.const_expr(rOut_epi_packed.dtype == cutlass.Float32):
llvm.inline_asm(
None,
[
scatter_out_offset.iterator.llvm_ptr,
rOut_epi_packed.ir_value(),
],
"red.global.add.f32 [$0], $1;",
"l,f",
has_side_effects=True,
loc=loc,
ip=ip,
)
elif cutlass.const_expr(rOut_epi_packed.dtype == cutlass.BFloat16):
llvm.inline_asm(
None,
[
scatter_out_offset.iterator.llvm_ptr,
llvm.bitcast(T.i16(), rOut_epi_packed.ir_value()),
],
"red.add.noftz.bf16 [$0], $1;",
"l,h",
has_side_effects=True,
loc=loc,
ip=ip,
)
@dsl_user_op
def griddepcontrol_wait(*, loc=None, ip=None) -> None:
"""
This instruction is used to wait for the previous kernel's grid ending
(all blocks of the previous kernel have finished and memflushed), i.e.,
the instruction after this instruction will not be issued until the previous
grid has finished.
"""
llvm.inline_asm(
res=None,
operands_=[],
asm_string="griddepcontrol.wait;",
constraints="",
has_side_effects=True,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
@dsl_user_op
def griddepcontrol_launch_dependents(*, loc=None, ip=None) -> None:
"""
Issuing the launch_dependents instruction hints a dependent kernel to launch earlier.
launch_dependents doesn't impact the functionality but the performance:
Launching a dependent kernel too early can compete with current kernels,
while launching too late can lead to a long latency.
"""
llvm.inline_asm(
res=None,
operands_=[],
asm_string="griddepcontrol.launch_dependents;",
constraints="",
has_side_effects=True,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)

View File

@ -1452,12 +1452,13 @@ class DeepseekV3Model(DecoderModel):
config = model_config.pretrained_config
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
aux_stream_list = [torch.cuda.Stream() for _ in range(4)]
self.aux_stream_dict = {
AuxStreamType.Attention: aux_stream_list[0],
AuxStreamType.MoeShared: aux_stream_list[0],
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
AuxStreamType.MoeBalancer: aux_stream_list[2],
AuxStreamType.MoeOutputMemset: aux_stream_list[3],
}
self.embed_tokens = Embedding(

View File

@ -890,12 +890,13 @@ class Glm4Model(DecoderModel):
config = model_config.pretrained_config
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
aux_stream_list = [torch.cuda.Stream() for _ in range(4)]
self.aux_stream_dict = {
AuxStreamType.Attention: aux_stream_list[0],
AuxStreamType.MoeShared: aux_stream_list[0],
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
AuxStreamType.MoeBalancer: aux_stream_list[2],
AuxStreamType.MoeOutputMemset: aux_stream_list[3],
}
self.embed_tokens = Embedding(

View File

@ -327,6 +327,7 @@ class Qwen3MoEModel(DecoderModel):
self.aux_stream_dict = {
AuxStreamType.MoeChunkingOverlap: torch.cuda.Stream(),
AuxStreamType.MoeBalancer: torch.cuda.Stream(),
AuxStreamType.MoeOutputMemset: torch.cuda.Stream(),
}
self.preload_weight_modules = []
if config.moe_backend == "TRTLLM":

View File

@ -487,11 +487,13 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
self.model_config = model_config
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
aux_stream_list = [torch.cuda.Stream() for _ in range(4)]
self.aux_stream_dict = {
AuxStreamType.Attention: aux_stream_list[0],
AuxStreamType.MoeShared: aux_stream_list[0],
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
AuxStreamType.MoeBalancer: aux_stream_list[2],
AuxStreamType.MoeOutputMemset: aux_stream_list[3],
}
super().__init__(
MTPDraftModel(self.model_config,

View File

@ -175,7 +175,6 @@ class ConfigurableMoE(MoE):
assert not torch.compiler.is_compiling(), (
"Backend should not be none if not in torch.compile"
)
self.backend.aux_stream_dict = self.aux_stream_dict
self.backend.layer_idx = self.layer_idx
self.backend.layer_idx_str = self.layer_idx_str
self.backend.num_slots = self.num_slots

View File

@ -8,7 +8,7 @@ from tensorrt_llm._utils import is_sm_100f
from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import AuxStreamType, Fp4QuantizedTensor
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .interface import AlltoallMethodType
from .quantization import MoEWeightLoadingMode, NVFP4CuteDslFusedMoEMethod
@ -183,7 +183,6 @@ class CuteDslFusedMoE(CutlassFusedMoE):
init_load_balancer: bool = True,
without_comm: bool = False,
):
super().__init__(
routing_method=routing_method,
num_experts=num_experts,
@ -199,6 +198,16 @@ class CuteDslFusedMoE(CutlassFusedMoE):
init_load_balancer=init_load_balancer,
without_comm=without_comm,
)
if self.aux_stream_dict is None:
self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {}
if AuxStreamType.MoeOutputMemset not in self.aux_stream_dict:
self.aux_stream_dict[
AuxStreamType.MoeOutputMemset] = torch.cuda.Stream()
if self.event_dict is None:
self.event_dict = {}
for key in [EventType.Main, EventType.MoeOutputMemset]:
if key not in self.event_dict:
self.event_dict[key] = torch.cuda.Event()
def select_alltoall_method_type(self) -> AlltoallMethodType:
return AlltoallMethodType.NotEnabled
@ -288,6 +297,11 @@ class CuteDslFusedMoE(CutlassFusedMoE):
self.hidden_size)
assert moe_output.dtype == output_dtype
if self.use_fused_finalize:
self.event_dict[EventType.Main].record()
moe_output.record_stream(
self.aux_stream_dict[AuxStreamType.MoeOutputMemset])
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
@ -307,17 +321,24 @@ class CuteDslFusedMoE(CutlassFusedMoE):
)
if self.use_fused_finalize:
torch.ops.trtllm.moe_output_memset_inplace(
input=moe_output,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
tile_tokens_dim=tile_size,
top_k=self.routing_method.experts_per_token,
ep_size=self.mapping.moe_ep_size,
enable_alltoall=enable_alltoall,
)
with torch.cuda.stream(
self.aux_stream_dict[AuxStreamType.MoeOutputMemset]):
self.event_dict[EventType.Main].wait()
torch.ops.trtllm.moe_output_memset_inplace(
input=moe_output,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
tile_tokens_dim=tile_size,
top_k=self.routing_method.experts_per_token,
ep_size=self.mapping.moe_ep_size,
enable_alltoall=enable_alltoall,
)
self.event_dict[EventType.MoeOutputMemset].record()
self.event_dict[EventType.MoeOutputMemset].wait()
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),

View File

@ -21,6 +21,7 @@ aux_stream_name_list = [
'MoeShared',
'MoeChunkingOverlap',
'MoeBalancer',
'MoeOutputMemset',
]
AuxStreamType = Enum(
'AuxStreamType',

View File

@ -330,7 +330,6 @@ accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False] SKIP
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] SKIP (https://nvbugs/5648560)
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False] SKIP (https://nvbugs/5648560)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen_adp_lmtp] SKIP (https://nvbugs/5629136)
unittest/bindings/test_hostfunc.py::test_hostfunc SKIP (https://nvbugs/5643631)
examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5568052)
accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype SKIP (https://nvbugs/5648441)
accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype SKIP (https://nvbugs/5648441)

View File

@ -15,6 +15,7 @@ def test_hostfunc():
with torch.cuda.stream(stream):
for _ in range(5):
increase(x)
torch.cuda.synchronize()
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=stream):
@ -25,7 +26,7 @@ def test_hostfunc():
with torch.cuda.stream(stream):
for _ in range(10):
g.replay()
torch.cuda.synchronize()
assert (x == 25).all().item()
assert len(HOSTFUNC_USER_DATA_HANDLES) == 2

View File

@ -94,6 +94,7 @@ async def test_reasoning(client: openai.AsyncOpenAI, model: str):
check_reponse(response, "test_reasoning: ")
@pytest.mark.skip(reason="https://nvbugs/5753250")
@pytest.mark.asyncio(loop_scope="module")
async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
for effort in ["low", "medium", "high"]:
@ -106,6 +107,7 @@ async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
check_reponse(response, f"test_reasoning_effort_{effort}: ")
@pytest.mark.skip(reason="https://nvbugs/5753250")
@pytest.mark.asyncio(loop_scope="module")
async def test_chat(client: openai.AsyncOpenAI, model: str):
response = await client.responses.create(model=model,
@ -150,6 +152,7 @@ def get_current_weather(location: str, format: str = "celsius") -> dict:
return {"sunny": True, "temperature": 20 if format == "celsius" else 68}
@pytest.mark.skip(reason="https://nvbugs/5753250")
@pytest.mark.asyncio(loop_scope="module")
async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
if model.startswith("DeepSeek-R1"):
@ -201,6 +204,7 @@ async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
check_tool_calling(response, False, "test_tool_calls: ")
@pytest.mark.skip(reason="https://nvbugs/5753250")
@pytest.mark.asyncio(loop_scope="module")
async def test_streaming(client: openai.AsyncOpenAI, model: str):
stream = await client.responses.create(
@ -222,6 +226,7 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
assert full_reasoning_response
@pytest.mark.skip(reason="https://nvbugs/5753250")
@pytest.mark.asyncio(loop_scope="module")
async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
if model.startswith("DeepSeek-R1"):