mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-12 22:14:03 +08:00
Merge branch 'main' into fix_spec_gate
This commit is contained in:
commit
9ce84b8f3d
@ -78,9 +78,13 @@ std::optional<uintptr_t> launchHostFunc(
|
||||
{
|
||||
auto const stream = reinterpret_cast<cudaStream_t>(streamPtr);
|
||||
|
||||
nb::gil_scoped_acquire gil;
|
||||
|
||||
auto hostFuncUserData
|
||||
= std::make_unique<HostFuncUserData>(freeUserData, pyHostFunc, nb::tuple(pyArgs), nb::dict(pyKwargs));
|
||||
|
||||
nb::gil_scoped_release release;
|
||||
|
||||
cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get());
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
@ -110,6 +114,7 @@ void initHostFuncBindings(nb::module_& m)
|
||||
{
|
||||
m.def("launch_hostfunc", &launchHostFunc, "Launch a Python host function to a CUDA stream",
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function");
|
||||
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function",
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
}
|
||||
} // namespace tensorrt_llm::nanobind::runtime
|
||||
|
||||
@ -78,9 +78,13 @@ std::optional<uintptr_t> launchHostFunc(
|
||||
{
|
||||
auto const stream = reinterpret_cast<cudaStream_t>(streamPtr);
|
||||
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
auto hostFuncUserData
|
||||
= std::make_unique<HostFuncUserData>(freeUserData, pyHostFunc, py::tuple(pyArgs), py::dict(pyKwargs));
|
||||
|
||||
py::gil_scoped_release release;
|
||||
|
||||
cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get());
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
@ -110,6 +114,7 @@ void initHostFuncBindings(pybind11::module_& m)
|
||||
{
|
||||
m.def("launch_hostfunc", &launchHostFunc, "Launch a Python host function to a CUDA stream",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function");
|
||||
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
|
||||
@ -2178,32 +2178,6 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
device=input_scale.device)
|
||||
return output, output_scale
|
||||
|
||||
class FusedMoEInputsHelper:
|
||||
|
||||
def __init__(self, num_experts: int, top_k: int, num_local_experts: int,
|
||||
local_expert_offset: int):
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.num_local_experts = num_local_experts
|
||||
self.local_expert_offset = local_expert_offset
|
||||
|
||||
def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int:
|
||||
return input_shapes[0][0]
|
||||
|
||||
def inputs_pre_hook(self,
|
||||
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
x, x_sf, token_selected_experts, token_final_scales, *others = inputs
|
||||
num_tokens = token_selected_experts.size(0)
|
||||
new_token_final_scales, new_token_selected_experts = torch.randn(
|
||||
num_tokens,
|
||||
self.num_experts,
|
||||
device=token_selected_experts.device).topk(self.top_k, dim=-1)
|
||||
new_token_selected_experts = new_token_selected_experts.to(
|
||||
token_selected_experts.dtype)
|
||||
new_token_final_scales = new_token_final_scales.softmax(dim=-1).to(
|
||||
token_final_scales.dtype)
|
||||
return x, x_sf, new_token_selected_experts, new_token_final_scales, *others
|
||||
|
||||
class Sm100BlockScaledFusedMoERunner(TunableRunner):
|
||||
tuning_config_cache = dict()
|
||||
|
||||
|
||||
@ -35,44 +35,18 @@ import cutlass.pipeline as pipeline
|
||||
import cutlass.utils as utils
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
from cutlass._mlir.dialects import math, nvvm
|
||||
from cutlass._mlir.dialects import math
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass.cute.typing import Float32
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
|
||||
from .custom_pipeline import PipelineCpAsyncUmma
|
||||
from .utils import is_power_of_2
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fmin(
|
||||
a: Union[float, Float32], b: Union[float, Float32], *, nan=False, loc=None, ip=None
|
||||
) -> Float32:
|
||||
return Float32(
|
||||
nvvm.fmin(
|
||||
T.f32(),
|
||||
Float32(a).ir_value(loc=loc, ip=ip),
|
||||
Float32(b).ir_value(loc=loc, ip=ip),
|
||||
nan=nan,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
|
||||
"""
|
||||
Compute the sigmoid of the input tensor.
|
||||
"""
|
||||
return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath))
|
||||
|
||||
|
||||
def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
|
||||
"""
|
||||
Compute the silu of the input tensor.
|
||||
"""
|
||||
return a * sigmoid_f32(a, fastmath=fastmath)
|
||||
|
||||
from .utils import (
|
||||
TRTLLM_ENABLE_PDL,
|
||||
fmin,
|
||||
griddepcontrol_launch_dependents,
|
||||
griddepcontrol_wait,
|
||||
is_power_of_2,
|
||||
silu_f32,
|
||||
)
|
||||
|
||||
"""
|
||||
High-performance persistent blockscaled contiguous grouped dense GEMM with gather and SwiGLU fusion
|
||||
@ -819,6 +793,7 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
min_blocks_per_mp=1,
|
||||
use_pdl=TRTLLM_ENABLE_PDL,
|
||||
)
|
||||
return
|
||||
|
||||
@ -1148,6 +1123,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
else:
|
||||
self.cta_sync_barrier.arrive_and_wait()
|
||||
|
||||
griddepcontrol_wait()
|
||||
|
||||
#
|
||||
# Specialized Schedule warp
|
||||
#
|
||||
@ -2282,6 +2259,8 @@ class BlockScaledContiguousGatherGroupedGemmKernel:
|
||||
#
|
||||
c_pipeline.producer_tail()
|
||||
|
||||
griddepcontrol_launch_dependents()
|
||||
|
||||
def epilog_tmem_copy_and_partition(
|
||||
self,
|
||||
tidx: cutlass.Int32,
|
||||
|
||||
@ -52,7 +52,12 @@ import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
|
||||
from .utils import is_power_of_2
|
||||
from .utils import (
|
||||
TRTLLM_ENABLE_PDL,
|
||||
griddepcontrol_launch_dependents,
|
||||
griddepcontrol_wait,
|
||||
is_power_of_2,
|
||||
)
|
||||
|
||||
|
||||
class Sm100BlockScaledContiguousGroupedGemmKernel:
|
||||
@ -597,6 +602,7 @@ class Sm100BlockScaledContiguousGroupedGemmKernel:
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
min_blocks_per_mp=1,
|
||||
use_pdl=TRTLLM_ENABLE_PDL,
|
||||
)
|
||||
return
|
||||
|
||||
@ -933,6 +939,8 @@ class Sm100BlockScaledContiguousGroupedGemmKernel:
|
||||
else:
|
||||
self.cta_sync_barrier.arrive_and_wait()
|
||||
|
||||
griddepcontrol_wait()
|
||||
|
||||
#
|
||||
# Specialized Schedule warp
|
||||
#
|
||||
@ -1597,6 +1605,8 @@ class Sm100BlockScaledContiguousGroupedGemmKernel:
|
||||
#
|
||||
c_pipeline.producer_tail()
|
||||
|
||||
griddepcontrol_launch_dependents()
|
||||
|
||||
def epilog_tmem_copy_and_partition(
|
||||
self,
|
||||
tidx: cutlass.Int32,
|
||||
|
||||
@ -35,11 +35,17 @@ import cutlass.pipeline as pipeline
|
||||
import cutlass.utils as utils
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass.cutlass_dsl import Int32, T, dsl_user_op
|
||||
|
||||
from .utils import is_power_of_2
|
||||
from .utils import (
|
||||
TRTLLM_ENABLE_PDL,
|
||||
atomic_add_func,
|
||||
griddepcontrol_launch_dependents,
|
||||
griddepcontrol_wait,
|
||||
is_power_of_2,
|
||||
vectorized_atomic_add_bf16x8,
|
||||
vectorized_atomic_add_fp32x2,
|
||||
)
|
||||
|
||||
"""
|
||||
High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for
|
||||
@ -259,8 +265,8 @@ def hooked_PersistentTileSchedulerParams_init(
|
||||
|
||||
|
||||
def hooked_get_cluster_work_idx_with_fastdivmod(
|
||||
self, current_work_linear_idx: Int32, *, loc=None, ip=None
|
||||
) -> Tuple[Int32, Int32, Int32]:
|
||||
self, current_work_linear_idx: cutlass.Int32, *, loc=None, ip=None
|
||||
) -> Tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]:
|
||||
work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd)
|
||||
|
||||
if self.params._raster_along_m:
|
||||
@ -287,69 +293,6 @@ cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmo
|
||||
)
|
||||
|
||||
|
||||
# TODO(zhichenj): try to move these to NVVM wrapper or helper functions
|
||||
@dsl_user_op
|
||||
def vectorized_atomic_add_bf16x8(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[
|
||||
scatter_out_offset.iterator.llvm_ptr,
|
||||
llvm.bitcast(T.i32(), rOut_epi_packed[0, None].load().ir_value()),
|
||||
llvm.bitcast(T.i32(), rOut_epi_packed[1, None].load().ir_value()),
|
||||
llvm.bitcast(T.i32(), rOut_epi_packed[2, None].load().ir_value()),
|
||||
llvm.bitcast(T.i32(), rOut_epi_packed[3, None].load().ir_value()),
|
||||
],
|
||||
"red.global.v4.bf16x2.add.noftz [$0], {$1, $2, $3, $4};",
|
||||
"l,r,r,r,r",
|
||||
has_side_effects=True,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[
|
||||
scatter_out_offset.iterator.llvm_ptr,
|
||||
rOut_epi_packed[0].ir_value(),
|
||||
rOut_epi_packed[1].ir_value(),
|
||||
],
|
||||
"red.global.v2.f32.add [$0], {$1, $2};",
|
||||
"l,f,f",
|
||||
has_side_effects=True,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
|
||||
if cutlass.const_expr(rOut_epi_packed.dtype == cutlass.Float32):
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[
|
||||
scatter_out_offset.iterator.llvm_ptr,
|
||||
rOut_epi_packed.ir_value(),
|
||||
],
|
||||
"red.global.add.f32 [$0], $1;",
|
||||
"l,f",
|
||||
has_side_effects=True,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
elif cutlass.const_expr(rOut_epi_packed.dtype == cutlass.BFloat16):
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[
|
||||
scatter_out_offset.iterator.llvm_ptr,
|
||||
llvm.bitcast(T.i16(), rOut_epi_packed.ir_value()),
|
||||
],
|
||||
"red.add.noftz.bf16 [$0], $1;",
|
||||
"l,h",
|
||||
has_side_effects=True,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
|
||||
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
|
||||
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
|
||||
@ -931,6 +874,7 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
min_blocks_per_mp=1,
|
||||
use_pdl=TRTLLM_ENABLE_PDL,
|
||||
)
|
||||
return
|
||||
|
||||
@ -1286,6 +1230,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
|
||||
else:
|
||||
self.cta_sync_barrier.arrive_and_wait()
|
||||
|
||||
griddepcontrol_wait()
|
||||
|
||||
#
|
||||
# Specialized Schedule warp
|
||||
#
|
||||
@ -1940,6 +1886,8 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
|
||||
self.epilog_sync_barrier.arrive_and_wait()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
griddepcontrol_launch_dependents()
|
||||
|
||||
def epilog_tmem_copy_and_partition(
|
||||
self,
|
||||
tidx: cutlass.Int32,
|
||||
|
||||
@ -35,43 +35,17 @@ import cutlass.pipeline as pipeline
|
||||
import cutlass.utils as utils
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
from cutlass._mlir.dialects import math, nvvm
|
||||
from cutlass._mlir.dialects import math
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass.cute.typing import Float32
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
|
||||
from .utils import is_power_of_2
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fmin(
|
||||
a: Union[float, Float32], b: Union[float, Float32], *, nan=False, loc=None, ip=None
|
||||
) -> Float32:
|
||||
return Float32(
|
||||
nvvm.fmin(
|
||||
T.f32(),
|
||||
Float32(a).ir_value(loc=loc, ip=ip),
|
||||
Float32(b).ir_value(loc=loc, ip=ip),
|
||||
nan=nan,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
|
||||
"""
|
||||
Compute the sigmoid of the input tensor.
|
||||
"""
|
||||
return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath))
|
||||
|
||||
|
||||
def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]:
|
||||
"""
|
||||
Compute the silu of the input tensor.
|
||||
"""
|
||||
return a * sigmoid_f32(a, fastmath=fastmath)
|
||||
|
||||
from .utils import (
|
||||
TRTLLM_ENABLE_PDL,
|
||||
fmin,
|
||||
griddepcontrol_launch_dependents,
|
||||
griddepcontrol_wait,
|
||||
is_power_of_2,
|
||||
silu_f32,
|
||||
)
|
||||
|
||||
"""
|
||||
High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for
|
||||
@ -749,6 +723,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
min_blocks_per_mp=1,
|
||||
use_pdl=TRTLLM_ENABLE_PDL,
|
||||
)
|
||||
return
|
||||
|
||||
@ -1087,6 +1062,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
else:
|
||||
self.cta_sync_barrier.arrive_and_wait()
|
||||
|
||||
griddepcontrol_wait()
|
||||
|
||||
#
|
||||
# Specialized Schedule warp
|
||||
#
|
||||
@ -1949,6 +1926,8 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
#
|
||||
c_pipeline.producer_tail()
|
||||
|
||||
griddepcontrol_launch_dependents()
|
||||
|
||||
def epilog_tmem_copy_and_partition(
|
||||
self,
|
||||
tidx: cutlass.Int32,
|
||||
|
||||
@ -55,7 +55,8 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
|
||||
from .custom_pipeline import PipelineTmaUmma, PipelineUmmaAsync
|
||||
from .utils import is_power_of_2
|
||||
from .utils import (TRTLLM_ENABLE_PDL, griddepcontrol_launch_dependents,
|
||||
griddepcontrol_wait, is_power_of_2)
|
||||
|
||||
|
||||
class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
@ -578,6 +579,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
min_blocks_per_mp=1,
|
||||
stream=stream,
|
||||
use_pdl=TRTLLM_ENABLE_PDL,
|
||||
)
|
||||
return
|
||||
|
||||
@ -869,6 +871,8 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
cute.arch.barrier(barrier_id=self.cta_sync_bar_id,
|
||||
number_of_threads=self.threads_per_cta)
|
||||
|
||||
griddepcontrol_wait()
|
||||
|
||||
#
|
||||
# Specialized TMA load warp
|
||||
#
|
||||
@ -1473,6 +1477,8 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
#
|
||||
c_pipeline.producer_tail()
|
||||
|
||||
griddepcontrol_launch_dependents()
|
||||
|
||||
def mainloop_s2t_copy_and_partition(
|
||||
self,
|
||||
sSF: cute.Tensor,
|
||||
|
||||
@ -44,11 +44,18 @@
|
||||
# This file is copied and modified from cutlass https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/core.py
|
||||
|
||||
import ctypes
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import cutlass
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass.cute as cute
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import llvm, nvvm
|
||||
from cutlass.cute.typing import AddressSpace, Numeric, Pointer, Type
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
|
||||
TRTLLM_ENABLE_PDL = os.environ.get("TRTLLM_ENABLE_PDL", "0") == "1"
|
||||
|
||||
|
||||
# WAR for CuTeDSL make_ptr implementation
|
||||
@ -190,3 +197,150 @@ def make_ptr(
|
||||
|
||||
def is_power_of_2(x: int) -> bool:
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fmin(a: Union[float, cutlass.Float32],
|
||||
b: Union[float, cutlass.Float32],
|
||||
*,
|
||||
nan=False,
|
||||
loc=None,
|
||||
ip=None) -> cutlass.Float32:
|
||||
return cutlass.Float32(
|
||||
nvvm.fmin(
|
||||
T.f32(),
|
||||
cutlass.Float32(a).ir_value(loc=loc, ip=ip),
|
||||
cutlass.Float32(b).ir_value(loc=loc, ip=ip),
|
||||
nan=nan,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
))
|
||||
|
||||
|
||||
def sigmoid_f32(a: Union[float, cutlass.Float32],
|
||||
fastmath: bool = False) -> Union[float, cutlass.Float32]:
|
||||
"""
|
||||
Compute the sigmoid of the input tensor.
|
||||
"""
|
||||
return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath))
|
||||
|
||||
|
||||
def silu_f32(a: Union[float, cutlass.Float32],
|
||||
fastmath: bool = False) -> Union[float, cutlass.Float32]:
|
||||
"""
|
||||
Compute the silu of the input tensor.
|
||||
"""
|
||||
return a * sigmoid_f32(a, fastmath=fastmath)
|
||||
|
||||
|
||||
# TODO(zhichenj): try to move these to NVVM wrapper or helper functions
|
||||
@dsl_user_op
|
||||
def vectorized_atomic_add_bf16x8(rOut_epi_packed,
|
||||
scatter_out_offset,
|
||||
loc=None,
|
||||
ip=None):
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[
|
||||
scatter_out_offset.iterator.llvm_ptr,
|
||||
llvm.bitcast(T.i32(), rOut_epi_packed[0, None].load().ir_value()),
|
||||
llvm.bitcast(T.i32(), rOut_epi_packed[1, None].load().ir_value()),
|
||||
llvm.bitcast(T.i32(), rOut_epi_packed[2, None].load().ir_value()),
|
||||
llvm.bitcast(T.i32(), rOut_epi_packed[3, None].load().ir_value()),
|
||||
],
|
||||
"red.global.v4.bf16x2.add.noftz [$0], {$1, $2, $3, $4};",
|
||||
"l,r,r,r,r",
|
||||
has_side_effects=True,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def vectorized_atomic_add_fp32x2(rOut_epi_packed,
|
||||
scatter_out_offset,
|
||||
loc=None,
|
||||
ip=None):
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[
|
||||
scatter_out_offset.iterator.llvm_ptr,
|
||||
rOut_epi_packed[0].ir_value(),
|
||||
rOut_epi_packed[1].ir_value(),
|
||||
],
|
||||
"red.global.v2.f32.add [$0], {$1, $2};",
|
||||
"l,f,f",
|
||||
has_side_effects=True,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
|
||||
if cutlass.const_expr(rOut_epi_packed.dtype == cutlass.Float32):
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[
|
||||
scatter_out_offset.iterator.llvm_ptr,
|
||||
rOut_epi_packed.ir_value(),
|
||||
],
|
||||
"red.global.add.f32 [$0], $1;",
|
||||
"l,f",
|
||||
has_side_effects=True,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
elif cutlass.const_expr(rOut_epi_packed.dtype == cutlass.BFloat16):
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[
|
||||
scatter_out_offset.iterator.llvm_ptr,
|
||||
llvm.bitcast(T.i16(), rOut_epi_packed.ir_value()),
|
||||
],
|
||||
"red.add.noftz.bf16 [$0], $1;",
|
||||
"l,h",
|
||||
has_side_effects=True,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def griddepcontrol_wait(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
This instruction is used to wait for the previous kernel's grid ending
|
||||
(all blocks of the previous kernel have finished and memflushed), i.e.,
|
||||
the instruction after this instruction will not be issued until the previous
|
||||
grid has finished.
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
res=None,
|
||||
operands_=[],
|
||||
asm_string="griddepcontrol.wait;",
|
||||
constraints="",
|
||||
has_side_effects=True,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def griddepcontrol_launch_dependents(*, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Issuing the launch_dependents instruction hints a dependent kernel to launch earlier.
|
||||
launch_dependents doesn't impact the functionality but the performance:
|
||||
Launching a dependent kernel too early can compete with current kernels,
|
||||
while launching too late can lead to a long latency.
|
||||
"""
|
||||
llvm.inline_asm(
|
||||
res=None,
|
||||
operands_=[],
|
||||
asm_string="griddepcontrol.launch_dependents;",
|
||||
constraints="",
|
||||
has_side_effects=True,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@ -1452,12 +1452,13 @@ class DeepseekV3Model(DecoderModel):
|
||||
config = model_config.pretrained_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(4)]
|
||||
self.aux_stream_dict = {
|
||||
AuxStreamType.Attention: aux_stream_list[0],
|
||||
AuxStreamType.MoeShared: aux_stream_list[0],
|
||||
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
|
||||
AuxStreamType.MoeBalancer: aux_stream_list[2],
|
||||
AuxStreamType.MoeOutputMemset: aux_stream_list[3],
|
||||
}
|
||||
|
||||
self.embed_tokens = Embedding(
|
||||
|
||||
@ -890,12 +890,13 @@ class Glm4Model(DecoderModel):
|
||||
config = model_config.pretrained_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(4)]
|
||||
self.aux_stream_dict = {
|
||||
AuxStreamType.Attention: aux_stream_list[0],
|
||||
AuxStreamType.MoeShared: aux_stream_list[0],
|
||||
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
|
||||
AuxStreamType.MoeBalancer: aux_stream_list[2],
|
||||
AuxStreamType.MoeOutputMemset: aux_stream_list[3],
|
||||
}
|
||||
|
||||
self.embed_tokens = Embedding(
|
||||
|
||||
@ -327,6 +327,7 @@ class Qwen3MoEModel(DecoderModel):
|
||||
self.aux_stream_dict = {
|
||||
AuxStreamType.MoeChunkingOverlap: torch.cuda.Stream(),
|
||||
AuxStreamType.MoeBalancer: torch.cuda.Stream(),
|
||||
AuxStreamType.MoeOutputMemset: torch.cuda.Stream(),
|
||||
}
|
||||
self.preload_weight_modules = []
|
||||
if config.moe_backend == "TRTLLM":
|
||||
|
||||
@ -487,11 +487,13 @@ class MTPDraftModelForCausalLM(DecoderModelForCausalLM[MTPDraftModel,
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
|
||||
self.model_config = model_config
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(2)]
|
||||
aux_stream_list = [torch.cuda.Stream() for _ in range(4)]
|
||||
self.aux_stream_dict = {
|
||||
AuxStreamType.Attention: aux_stream_list[0],
|
||||
AuxStreamType.MoeShared: aux_stream_list[0],
|
||||
AuxStreamType.MoeChunkingOverlap: aux_stream_list[1],
|
||||
AuxStreamType.MoeBalancer: aux_stream_list[2],
|
||||
AuxStreamType.MoeOutputMemset: aux_stream_list[3],
|
||||
}
|
||||
super().__init__(
|
||||
MTPDraftModel(self.model_config,
|
||||
|
||||
@ -175,7 +175,6 @@ class ConfigurableMoE(MoE):
|
||||
assert not torch.compiler.is_compiling(), (
|
||||
"Backend should not be none if not in torch.compile"
|
||||
)
|
||||
self.backend.aux_stream_dict = self.aux_stream_dict
|
||||
self.backend.layer_idx = self.layer_idx
|
||||
self.backend.layer_idx_str = self.layer_idx_str
|
||||
self.backend.num_slots = self.num_slots
|
||||
|
||||
@ -8,7 +8,7 @@ from tensorrt_llm._utils import is_sm_100f
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor
|
||||
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .interface import AlltoallMethodType
|
||||
from .quantization import MoEWeightLoadingMode, NVFP4CuteDslFusedMoEMethod
|
||||
@ -183,7 +183,6 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
init_load_balancer: bool = True,
|
||||
without_comm: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
routing_method=routing_method,
|
||||
num_experts=num_experts,
|
||||
@ -199,6 +198,16 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
init_load_balancer=init_load_balancer,
|
||||
without_comm=without_comm,
|
||||
)
|
||||
if self.aux_stream_dict is None:
|
||||
self.aux_stream_dict = aux_stream_dict if aux_stream_dict is not None else {}
|
||||
if AuxStreamType.MoeOutputMemset not in self.aux_stream_dict:
|
||||
self.aux_stream_dict[
|
||||
AuxStreamType.MoeOutputMemset] = torch.cuda.Stream()
|
||||
if self.event_dict is None:
|
||||
self.event_dict = {}
|
||||
for key in [EventType.Main, EventType.MoeOutputMemset]:
|
||||
if key not in self.event_dict:
|
||||
self.event_dict[key] = torch.cuda.Event()
|
||||
|
||||
def select_alltoall_method_type(self) -> AlltoallMethodType:
|
||||
return AlltoallMethodType.NotEnabled
|
||||
@ -288,6 +297,11 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
self.hidden_size)
|
||||
assert moe_output.dtype == output_dtype
|
||||
|
||||
if self.use_fused_finalize:
|
||||
self.event_dict[EventType.Main].record()
|
||||
moe_output.record_stream(
|
||||
self.aux_stream_dict[AuxStreamType.MoeOutputMemset])
|
||||
|
||||
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
|
||||
@ -307,17 +321,24 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
)
|
||||
|
||||
if self.use_fused_finalize:
|
||||
torch.ops.trtllm.moe_output_memset_inplace(
|
||||
input=moe_output,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
tile_tokens_dim=tile_size,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
ep_size=self.mapping.moe_ep_size,
|
||||
enable_alltoall=enable_alltoall,
|
||||
)
|
||||
with torch.cuda.stream(
|
||||
self.aux_stream_dict[AuxStreamType.MoeOutputMemset]):
|
||||
self.event_dict[EventType.Main].wait()
|
||||
torch.ops.trtllm.moe_output_memset_inplace(
|
||||
input=moe_output,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
tile_tokens_dim=tile_size,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
ep_size=self.mapping.moe_ep_size,
|
||||
enable_alltoall=enable_alltoall,
|
||||
)
|
||||
self.event_dict[EventType.MoeOutputMemset].record()
|
||||
|
||||
self.event_dict[EventType.MoeOutputMemset].wait()
|
||||
|
||||
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
|
||||
|
||||
@ -21,6 +21,7 @@ aux_stream_name_list = [
|
||||
'MoeShared',
|
||||
'MoeChunkingOverlap',
|
||||
'MoeBalancer',
|
||||
'MoeOutputMemset',
|
||||
]
|
||||
AuxStreamType = Enum(
|
||||
'AuxStreamType',
|
||||
|
||||
@ -330,7 +330,6 @@ accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False] SKIP
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] SKIP (https://nvbugs/5648560)
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False] SKIP (https://nvbugs/5648560)
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_trtllmgen_adp_lmtp] SKIP (https://nvbugs/5629136)
|
||||
unittest/bindings/test_hostfunc.py::test_hostfunc SKIP (https://nvbugs/5643631)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5568052)
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestNVILA_8B::test_auto_dtype SKIP (https://nvbugs/5648441)
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype SKIP (https://nvbugs/5648441)
|
||||
|
||||
@ -15,6 +15,7 @@ def test_hostfunc():
|
||||
with torch.cuda.stream(stream):
|
||||
for _ in range(5):
|
||||
increase(x)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g, stream=stream):
|
||||
@ -25,7 +26,7 @@ def test_hostfunc():
|
||||
with torch.cuda.stream(stream):
|
||||
for _ in range(10):
|
||||
g.replay()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
assert (x == 25).all().item()
|
||||
assert len(HOSTFUNC_USER_DATA_HANDLES) == 2
|
||||
|
||||
@ -94,6 +94,7 @@ async def test_reasoning(client: openai.AsyncOpenAI, model: str):
|
||||
check_reponse(response, "test_reasoning: ")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5753250")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
|
||||
for effort in ["low", "medium", "high"]:
|
||||
@ -106,6 +107,7 @@ async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str):
|
||||
check_reponse(response, f"test_reasoning_effort_{effort}: ")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5753250")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_chat(client: openai.AsyncOpenAI, model: str):
|
||||
response = await client.responses.create(model=model,
|
||||
@ -150,6 +152,7 @@ def get_current_weather(location: str, format: str = "celsius") -> dict:
|
||||
return {"sunny": True, "temperature": 20 if format == "celsius" else 68}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5753250")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
|
||||
if model.startswith("DeepSeek-R1"):
|
||||
@ -201,6 +204,7 @@ async def test_tool_calls(client: openai.AsyncOpenAI, model: str):
|
||||
check_tool_calling(response, False, "test_tool_calls: ")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5753250")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_streaming(client: openai.AsyncOpenAI, model: str):
|
||||
stream = await client.responses.create(
|
||||
@ -222,6 +226,7 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
|
||||
assert full_reasoning_response
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5753250")
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str):
|
||||
if model.startswith("DeepSeek-R1"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user