From df8be0c50c5058b7b31878e87d8f03b60ed9b791 Mon Sep 17 00:00:00 2001 From: ameynaik-hub <212485788+ameynaik-hub@users.noreply.github.com> Date: Mon, 26 Jan 2026 19:08:47 -0800 Subject: [PATCH] [TRTLLM-10276][feat] Integrate cutedsl argmax kernel (#10476) Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com> Signed-off-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com> Co-authored-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com> --- LICENSE | 8 + .../_torch/cute_dsl_kernels/argmax.py | 642 ++++++++++++++++++ .../_torch/cute_dsl_kernels/test_argmax.py | 240 +++++++ tensorrt_llm/_torch/speculative/interface.py | 7 +- 4 files changed, 895 insertions(+), 2 deletions(-) create mode 100644 tensorrt_llm/_torch/cute_dsl_kernels/argmax.py create mode 100644 tensorrt_llm/_torch/cute_dsl_kernels/test_argmax.py diff --git a/LICENSE b/LICENSE index 350926f256..e6dc9b3096 100644 --- a/LICENSE +++ b/LICENSE @@ -41,6 +41,14 @@ Original Source: https://github.com/state-spaces/mamba Copyright 2023 Tri Dao, Albert Gu Licensed under the Apache License 2.0 +-------------------------------------------------------------------------------- +Quack +-------------------------------------------------------------------------------- +Original Source: https://github.com/Dao-AILab/quack +Copyright (c) 2025, Tri Dao. +Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. +Licensed under the Apache License 2.0 + -------------------------------------------------------------------------------- SGLang -------------------------------------------------------------------------------- diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/argmax.py b/tensorrt_llm/_torch/cute_dsl_kernels/argmax.py new file mode 100644 index 0000000000..12719acd00 --- /dev/null +++ b/tensorrt_llm/_torch/cute_dsl_kernels/argmax.py @@ -0,0 +1,642 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, Tri Dao. +# SPDX-FileCopyrightText: Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file contains code derived from the Quack library: +# https://github.com/Dao-AILab/quack +# +# Argmax kernel using CuTE DSL for TensorRT-LLM speculative decoding. + +from typing import Optional, Tuple, Type + +import torch + +from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE + +if IS_CUTLASS_DSL_AVAILABLE: + import cuda.bindings.driver as cuda + import cutlass + import cutlass.cute as cute + from cutlass._mlir.dialects import llvm + from cutlass.cute.arch.nvvm_wrappers import FULL_MASK + from cutlass.cute.runtime import from_dlpack + from cutlass.cute.typing import Float32, Int, Int32 + from cutlass.cutlass_dsl import T, dsl_user_op + + # ============================================================================ + # Torch to CuTE dtype mapping + # ============================================================================ + torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + } + + # ============================================================================ + # CUDA Graph compatibility wrapper + # ============================================================================ + class CUDAGraphCompatibleWrapper: + """Wrapper to make tensors compatible with CUDA graph capture for DLPack export.""" + + def __init__(self, tensor): + self._tensor = tensor + + def __dlpack__(self, stream=None): + return self._tensor.__dlpack__(stream=-1) + + def __dlpack_device__(self): + return self._tensor.__dlpack_device__() + + # ============================================================================ + # Utility functions from quack/utils.py + # ============================================================================ + @dsl_user_op + def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + @dsl_user_op + def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: cute.Int32, *, loc=None, ip=None + ) -> cutlass.Int32: + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op + def store_shared_remote( + val: float | Float32 | cutlass.Int64, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: cute.typing.Int, + *, + loc=None, + ip=None, + ) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + if cutlass.const_expr(isinstance(val, float)): + val = Float32(val) + assert isinstance(val, (Float32, Int32, cutlass.Int64)) + suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)] + constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)] + llvm.inline_asm( + None, + [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32], + f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];", + f"r,{constraint},r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + @cute.jit + def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + tApA = cute.make_fragment( + cute.make_layout( + ( + cute.size(tAcA, mode=[0, 1]), + cute.size(tAcA, mode=[1]), + cute.size(tAcA, mode=[2]), + ), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + @cute.jit + def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None: + tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0]) + tXrX_fill.fill(fill_value) + for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]): + for rest_k in cutlass.range_constexpr(tXsX.shape[2]): + if cutlass.const_expr(tXpX is not None): + if not tXpX[rest_v, 0, rest_k]: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + else: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + + @dsl_user_op + def domain_offset_i64( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None + ) -> cute.Tensor: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(tensor.stride) + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + # ============================================================================ + # Inline PTX for redux.sync operations + # ============================================================================ + @dsl_user_op + def ptx_redux_sync_max_f32( + value: Float32, mask: Int = FULL_MASK, *, loc=None, ip=None + ) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(value).ir_value(loc=loc, ip=ip), Int32(mask).ir_value(loc=loc, ip=ip)], + """redux.sync.max.f32 $0, $1, $2;""", + "=f,f,i", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op + def ptx_redux_sync_min_u32(value: Int32, mask: Int = FULL_MASK, *, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [Int32(value).ir_value(loc=loc, ip=ip), Int32(mask).ir_value(loc=loc, ip=ip)], + """redux.sync.min.u32 $0, $1, $2;""", + "=r,r,i", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op + def ptx_select_argmax_candidate( + current_max: Float32, warp_max: Float32, current_argmax: Int32, *, loc=None, ip=None + ) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [ + Float32(current_max).ir_value(loc=loc, ip=ip), + Float32(warp_max).ir_value(loc=loc, ip=ip), + Int32(current_argmax).ir_value(loc=loc, ip=ip), + ], + """{ + .reg .pred p; + setp.eq.f32 p, $1, $2; + selp.s32 $0, $3, 0xffffffff, p; + }""", + "=r,f,f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @cute.jit + def warp_argmax_redux(current_max: Float32, current_argmax: Int32): + """Redux-based warp argmax - only works on sm_100+ (Blackwell).""" + warp_max = ptx_redux_sync_max_f32(current_max) + candidate_idx = ptx_select_argmax_candidate(current_max, warp_max, current_argmax) + winning_idx = ptx_redux_sync_min_u32(candidate_idx) + return warp_max, winning_idx + + @cute.jit + def warp_reduce_argmax(current_max: Float32, current_argmax: Int32): + """Shuffle-based warp argmax - works on all architectures (Hopper+).""" + warp_max = current_max + warp_argmax = current_argmax + + # Use butterfly shuffle pattern for warp reduction + for i in cutlass.range_constexpr(int(5)): # log2(32) = 5 iterations + # Get values from other lanes using butterfly pattern + other_max = cute.arch.shuffle_sync_bfly(warp_max, offset=1 << i) + other_argmax = cute.arch.shuffle_sync_bfly(warp_argmax, offset=1 << i) + + # Inline argmax comparison + if other_max > warp_max: + warp_max = other_max + warp_argmax = other_argmax + + return warp_max, warp_argmax + + # ============================================================================ + # Reduction Base class + # ============================================================================ + class ReductionBase: + def __init__( + self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=cutlass.Float32 + ): + self.dtype = dtype + self.N = N + self.stage = stage + self.reduction_dtype = reduction_dtype + + def _calculate_threads_per_row(self): + raise NotImplementedError() + + def _set_cluster_n(self): + self.cluster_n = 1 + + def _get_num_threads(self): + return 128 if self.N <= 16384 else 256 + + def _get_tv_layout(self, num_copy_bits=128): + vecsize = num_copy_bits // self.dtype.width + num_threads = self._get_num_threads() + threads_per_row = self._calculate_threads_per_row() + num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n) + cols_per_block = num_threads // threads_per_row + tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) + tv_layout = cute.make_layout( + ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)), + stride=( + (vecsize * cols_per_block, 1), + (cols_per_block, cols_per_block * vecsize * threads_per_row), + ), + ) + return tiler_mn, tv_layout + + def _smem_size_in_bytes(self, tiler_mn, num_warps): + return ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) + + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) + + self.stage * (cutlass.Int64.width // 8) + ) + + def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int): + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1) + return cute.make_ordered_layout( + (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), + order=(1, 0, 2), + ) + + def _allocate_reduction_buffer_and_mbar( + self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout + ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]: + reduction_buffer = smem.allocate_tensor( + self.reduction_dtype, + self._get_reduction_buffer_layout(tv_layout, self.cluster_n), + byte_alignment=4, + ) + if cutlass.const_expr(self.cluster_n > 1): + mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage) + else: + mbar_ptr = None + return reduction_buffer, mbar_ptr + + @cute.jit + def _initialize_cluster(self, tidx: cutlass.Int32, mbar_ptr: cute.Pointer, num_warps: int): + if cutlass.const_expr(self.cluster_n > 1): + if tidx < self.stage: + cute.arch.mbarrier_init(mbar_ptr + tidx, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + + # ============================================================================ + # Argmax Kernel class + # ============================================================================ + class ArgmaxKernel(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int, use_redux: bool = False): + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32) + # use_redux=True for Blackwell (sm_100+), False for Hopper (sm_90) + self.use_redux = use_redux + + def _calculate_threads_per_row(self): + N = self.N + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))) + ) + ) + + def _set_cluster_n(self): + N = self.N + if cutlass.const_expr(self.dtype.width == 16): + self.cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + self.cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + + def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int): + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1) + return cute.make_ordered_layout( + (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage, 2), + order=(1, 0, 2, 3), + ) + + def _smem_size_in_bytes(self, tiler_mn, num_warps): + return ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) + + 2 * self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) + + self.stage * (cutlass.Int64.width // 8) + ) + + @cute.jit + def __call__(self, mX: cute.Tensor, mO: cute.Tensor, stream: cuda.CUstream): + self._set_cluster_n() + tiler_mn, tv_layout = self._get_tv_layout() + num_threads = cute.size(tv_layout, mode=[0]) + num_warps = num_threads // cute.arch.WARP_SIZE + + self.kernel(mX, mO, tv_layout, tiler_mn).launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if cutlass.const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.kernel + def kernel( + self, mX: cute.Tensor, mO: cute.Tensor, tv_layout: cute.Layout, tiler_mn: cute.Shape + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + + if cutlass.const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = cutlass.const_expr(0) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + mX, mO = [domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)] + gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)] + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16 + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + copy_atom_load_X = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, num_bits_per_copy=128 + ) + thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx) + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + + tvlayout_cX = cute.composition(cX, tv_layout) + thr_coord = (tidx, (None, None)) + thr_cX = tvlayout_cX[thr_coord] + + tXrX = cute.make_fragment_like(tXgX) + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + is_even_N = cutlass.const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) if not is_even_N else None + ) + + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + if cutlass.const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(cute.Float32) + + current_max = -tXsX.element_type.inf + current_argmax = Int32(0xFFFFFFFF) + + for i in cutlass.range_constexpr(thr_cX.shape[0]): + for j in cutlass.range_constexpr(thr_cX.shape[1]): + col_idx = thr_cX[i, j][1] + linear_idx = i + j * thr_cX.shape[0] + element_value1 = x[linear_idx] + if element_value1 > current_max: + current_max = element_value1 + current_argmax = Int32(col_idx) + + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + if cutlass.const_expr(self.use_redux): + warp_max, warp_argmax = warp_argmax_redux(current_max, current_argmax) + else: + warp_max, warp_argmax = warp_reduce_argmax(current_max, current_argmax) + + if cutlass.const_expr(self.cluster_n == 1): + warps_per_row = cute.size(reduction_buffer.shape[1]) + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + + if lane_idx == 0: + reduction_buffer[row_idx, col_idx, 0, 0] = warp_max + reduction_buffer[row_idx, col_idx, 0, 1] = warp_argmax.to(cutlass.Float32) + + cute.arch.barrier() + block_reduce_max = -tXsX.element_type.inf + block_reduce_argmax = Int32(0xFFFFFFFF) + + if lane_idx < warps_per_row: + block_reduce_max = reduction_buffer[row_idx, lane_idx, 0, 0] + block_reduce_argmax = reduction_buffer[row_idx, lane_idx, 0, 1].to( + cutlass.Int32 + ) + + if cutlass.const_expr(self.use_redux): + warp_max, warp_argmax = warp_argmax_redux(block_reduce_max, block_reduce_argmax) + else: + warp_max, warp_argmax = warp_reduce_argmax( + block_reduce_max, block_reduce_argmax + ) + else: + cute.arch.cluster_wait() + warps_per_row, cluster_n = reduction_buffer.shape[1] + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + rows_per_block, (warps_per_row, cluster_n), _, _ = reduction_buffer.shape + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + + if warp_idx == 0: + with cute.arch.elect_one(): + num_warps = rows_per_block * warps_per_row + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_ptr, + num_warps * cluster_n * 2 * reduction_buffer.element_type.width // 8, + ) + + if lane_idx < cluster_n: + store_shared_remote( + warp_max, + elem_pointer( + reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster), 0, 0) + ), + mbar_ptr, + peer_cta_rank_in_cluster=lane_idx, + ) + store_shared_remote( + warp_argmax.to(cutlass.Float32), + elem_pointer( + reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster), 0, 1) + ), + mbar_ptr, + peer_cta_rank_in_cluster=lane_idx, + ) + + cute.arch.mbarrier_wait(mbar_ptr, phase=0) + block_reduce_val = -tXsX.element_type.inf + block_reduce_argmax = Int32(0xFFFFFFFF) + num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE) + + for i in cutlass.range_constexpr(num_iter): + idx = lane_idx + i * cute.arch.WARP_SIZE + if idx < cute.size(reduction_buffer, mode=[1]): + element_max = reduction_buffer[row_idx, idx, 0, 0] + element_argmax = reduction_buffer[row_idx, idx, 0, 1].to(cutlass.Int32) + if element_max > block_reduce_val: + block_reduce_val = element_max + block_reduce_argmax = element_argmax + + if cutlass.const_expr(self.use_redux): + warp_max, warp_argmax = warp_argmax_redux(block_reduce_val, block_reduce_argmax) + else: + warp_max, warp_argmax = warp_reduce_argmax( + block_reduce_val, block_reduce_argmax + ) + + row_idx = tXcX[0][0] + warps_per_row = tv_layout.shape[0][0] // cute.arch.WARP_SIZE + local_row_idx = row_idx - (bidx * tiler_mn[0]) + first_warp_for_row = local_row_idx * warps_per_row + first_thread_for_row = first_warp_for_row * cute.arch.WARP_SIZE + + if ( + tidx == first_thread_for_row + and row_idx < shape[0] + and local_row_idx >= 0 + and local_row_idx < tiler_mn[0] + and (self.cluster_n == 1 or bidy == 0) + ): + mO[local_row_idx, 0] = warp_max.to(mO.element_type) + mO[local_row_idx, 1] = warp_argmax.to(mO.element_type) + + # ============================================================================ + # Compiled kernel cache and forward function + # ============================================================================ + _argmax_compile_cache = {} + + # Minimum vocab size for the CuTE tiled kernel. + _MIN_VOCAB_SIZE_FOR_CUTE_KERNEL = 256 + + # The async copy requires 128-byte alignment: + # Since we only support float32 currently, use 32. + _VOCAB_SIZE_ALIGNMENT = 32 + + def _should_use_torch_fallback(N: int, dtype: torch.dtype) -> bool: + """Check if we should fall back to torch.max instead of CuTE kernel.""" + if dtype != torch.float32: + return True + if N < _MIN_VOCAB_SIZE_FOR_CUTE_KERNEL: + return True + if N % _VOCAB_SIZE_ALIGNMENT != 0: + return True + return False + + def argmax(x: torch.Tensor) -> torch.Tensor: + """ + Compute argmax along the last dimension of the input tensor. + + Args: + x: Input tensor of shape (M, N) + + Returns: + Output tensor of shape (M, 2) where: + - Column 0: Maximum value in each row + - Column 1: Index of maximum value in each row (argmax) + """ + assert x.dim() == 2, "Input must be 2D" + assert x.is_cuda, "Tensor must be on CUDA device" + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype" + + M, N = x.shape + + if _should_use_torch_fallback(N, x.dtype): + max_vals, max_indices = torch.max(x, dim=-1, keepdim=True) + return torch.cat([max_vals, max_indices.to(x.dtype)], dim=-1) + + out = torch.empty((M, 2), dtype=x.dtype, device=x.device) + dtype = torch2cute_dtype_map[x.dtype] + + def convert_from_dlpack(tensor): + return from_dlpack( + CUDAGraphCompatibleWrapper(tensor.detach()), assumed_align=16 + ).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) + + x_tensor = convert_from_dlpack(x) + out_tensor = convert_from_dlpack(out) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Detect compute capability: use redux instructions only on Blackwell (sm_100+) + # redux.sync.max.f32 is only supported on sm_100+ + from ..._utils import get_sm_version + + use_redux = get_sm_version() >= 100 # sm_100+ (Blackwell) + + compile_key = (dtype, N, use_redux) + if compile_key not in _argmax_compile_cache: + argmax_kernel = ArgmaxKernel(dtype, N, use_redux=use_redux) + _argmax_compile_cache[compile_key] = cute.compile( + argmax_kernel, x_tensor, out_tensor, current_stream + ) + + _argmax_compile_cache[compile_key](x_tensor, out_tensor, current_stream) + return out + +else: + # Fallback if CUTLASS DSL is not available + def argmax(x: torch.Tensor) -> torch.Tensor: + """Fallback argmax using PyTorch when CUTLASS DSL is not available.""" + max_vals, max_indices = torch.max(x, dim=-1, keepdim=True) + return torch.cat([max_vals, max_indices.to(x.dtype)], dim=-1) diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/test_argmax.py b/tensorrt_llm/_torch/cute_dsl_kernels/test_argmax.py new file mode 100644 index 0000000000..878c209896 --- /dev/null +++ b/tensorrt_llm/_torch/cute_dsl_kernels/test_argmax.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test cases for the CuTE DSL argmax kernel. + +The kernel uses CuTE for N >= 256 (aligned to 32), otherwise falls back to torch.max. +Only float32 uses the CuTE kernel; float16/bfloat16 use torch.max fallback. +""" + +import pytest +import torch + +from tensorrt_llm._torch.cute_dsl_kernels.argmax import argmax + +# Increase dynamo cache for parameterized tests +torch._dynamo.config.cache_size_limit = 1024 +torch._dynamo.config.accumulated_cache_size_limit = 1024 + +# ============================================================================ +# Constants for test configurations +# ============================================================================ +# N values where CuTE kernel is used (N >= 256, aligned to 32) +LARGE_N_VALUES = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 201088, 262144] + +# N values that use torch.max fallback (N < 256) +SMALL_N_VALUES = [8, 16, 32, 64, 128] + +# Typical LLM vocab sizes for performance testing +VOCAB_SIZES = [32000, 32768, 65536, 128256, 131072, 201088, 262144] + + +# ============================================================================ +# Correctness Tests +# ============================================================================ +@pytest.mark.parametrize("input_dtype", [torch.float32]) +@pytest.mark.parametrize("N", LARGE_N_VALUES) +@pytest.mark.parametrize("M", [1, 4, 37, 199, 1024]) +def test_argmax_large_n(M, N, input_dtype): + """Test argmax with CuTE kernel (N >= 256, aligned).""" + device = "cuda" + atol, rtol = 1e-4, 1e-4 + + torch.random.manual_seed(42) + x = 0.1 * torch.randn(M, N, device=device, dtype=input_dtype) + + out = argmax(x) + expected_max, expected_idx = torch.max(x, dim=-1, keepdim=True) + + assert out.shape == (M, 2) + assert out.dtype == input_dtype + + torch.testing.assert_close(out[:, 0:1], expected_max, atol=atol, rtol=rtol) + torch.testing.assert_close(out[:, 1:2].long(), expected_idx, atol=0, rtol=0) + + +@pytest.mark.parametrize("input_dtype", [torch.float32]) +@pytest.mark.parametrize("N", SMALL_N_VALUES) +def test_argmax_small_n(N, input_dtype): + """Test argmax with torch.max fallback (N < 256).""" + device = "cuda" + M = 4 + + for max_pos in [0, N // 4, N // 2, 3 * N // 4, N - 1]: + x = torch.full((M, N), -100.0, dtype=input_dtype, device=device) + x[:, max_pos] = 0.0 + + out = argmax(x) + + for row in range(M): + assert out[row, 0].item() == 0.0, f"Row {row}: max value should be 0.0" + assert out[row, 1].item() == float(max_pos), f"Row {row}: argmax wrong" + + +@pytest.mark.parametrize("input_dtype", [torch.float32]) +def test_argmax_mtp_case(input_dtype): + """Test the specific MTP test case (N=8, max at index 1).""" + x = torch.tensor( + [[-100.0, 0.0, -100.0, -100.0, -100.0, -100.0, -100.0, -100.0]], + dtype=input_dtype, + device="cuda", + ) + out = argmax(x) + + assert out.shape == (1, 2) + assert out[0, 0].item() == 0.0, f"Expected max=0, got {out[0, 0].item()}" + assert out[0, 1].item() == 1.0, f"Expected argmax=1, got {out[0, 1].item()}" + + +@pytest.mark.parametrize("input_dtype", [torch.float32]) +@pytest.mark.parametrize("N", [255, 256, 257, 1023, 1024, 1025, 32000, 32001]) +def test_argmax_alignment_fallback(N, input_dtype): + """Test aligned vs unaligned N values (unaligned falls back to torch.max).""" + device = "cuda" + M = 4 + + torch.random.manual_seed(42) + x = torch.randn(M, N, device=device, dtype=input_dtype) + + out = argmax(x) + expected_max, expected_idx = torch.max(x, dim=-1, keepdim=True) + + torch.testing.assert_close(out[:, 0:1], expected_max, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(out[:, 1:2].long(), expected_idx, atol=0, rtol=0) + + +# ============================================================================ +# CUDA Graph Tests +# ============================================================================ +@pytest.mark.parametrize("input_dtype", [torch.float32]) +@pytest.mark.parametrize("N", [1024, 32768, 131072]) +@pytest.mark.parametrize("M", [1, 16, 256]) +def test_argmax_cudagraphs(M, N, input_dtype): + """Test that argmax is CUDA graph capturable.""" + device = "cuda" + + torch.random.manual_seed(0) + x = 0.1 * torch.randn(M, N, device=device, dtype=input_dtype) + + # Warmup + _ = argmax(x) + torch.cuda.synchronize() + + # Capture graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + out = argmax(x) + + graph.replay() + torch.cuda.synchronize() + + # Verify + expected_max, expected_idx = torch.max(x, dim=-1, keepdim=True) + torch.testing.assert_close(out[:, 0:1], expected_max, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(out[:, 1:2].long(), expected_idx, atol=0, rtol=0) + + +# ============================================================================ +# Performance Tests +# ============================================================================ +@pytest.mark.parametrize("N", VOCAB_SIZES) +def test_argmax_performance(N): + """Compare CuTE argmax vs torch.max performance.""" + device = "cuda" + dtype = torch.float32 + num_iters = 100 + M_values = [4, 16, 64, 256, 1024] + + print(f"\n{'=' * 70}") + print(f"N={N:>6} | {'M':>5} | {'CuTE':>10} | {'torch':>10} | {'Speedup':>8}") + print(f"{'-' * 70}") + + for M in M_values: + torch.random.manual_seed(0) + x = 0.1 * torch.randn(M, N, device=device, dtype=dtype) + + # Warmup + _ = argmax(x) + _ = torch.max(x, dim=-1) + torch.cuda.synchronize() + + # CuTE graph + g1 = torch.cuda.CUDAGraph() + with torch.cuda.graph(g1): + for _ in range(num_iters): + out = argmax(x) + + # torch.max graph + g2 = torch.cuda.CUDAGraph() + with torch.cuda.graph(g2): + for _ in range(num_iters): + _, _ = torch.max(x, dim=-1) + + # Time CuTE + torch.cuda.synchronize() + t1 = torch.cuda.Event(enable_timing=True) + t2 = torch.cuda.Event(enable_timing=True) + t1.record() + g1.replay() + t2.record() + torch.cuda.synchronize() + cute_ms = t1.elapsed_time(t2) / num_iters + + # Time torch.max + t1.record() + g2.replay() + t2.record() + torch.cuda.synchronize() + torch_ms = t1.elapsed_time(t2) / num_iters + + speedup = torch_ms / cute_ms if cute_ms > 0 else float("inf") + status = "✓" if speedup > 1 else "✗" + + print( + f" | {M:>5} | {cute_ms:>8.4f}ms | {torch_ms:>8.4f}ms | {speedup:>5.2f}x {status}" + ) + + print(f"{'=' * 70}") + + # Verify correctness + expected_max, expected_idx = torch.max(x, dim=-1, keepdim=True) + torch.testing.assert_close(out[:, 0:1], expected_max, atol=1e-4, rtol=1e-4) + + +# ============================================================================ +# Manual Test Runner +# ============================================================================ +if __name__ == "__main__": + print("Running argmax tests...\n") + + print("1. MTP case (N=8)...") + test_argmax_mtp_case(torch.float32) + print(" ✓ Passed\n") + + print("2. Small N (torch.max fallback)...") + for N in SMALL_N_VALUES: + test_argmax_small_n(N, torch.float32) + print(" ✓ Passed\n") + + print("3. Large N (CuTE kernel)...") + for N in [256, 1024, 32768, 131072]: + test_argmax_large_n(4, N, torch.float32) + print(" ✓ Passed\n") + + print("4. Alignment fallback...") + for N in [255, 256, 257, 1023, 1024, 1025]: + test_argmax_alignment_fallback(N, torch.float32) + print(" ✓ Passed\n") + + print("5. CUDA graphs...") + test_argmax_cudagraphs(16, 32768, torch.float32) + print(" ✓ Passed\n") + + print("6. Performance comparison...") + for N in [1024, 4096, 16384, 32768, 65536, 131072, 201088, 262144]: + test_argmax_performance(N) + + print("\n" + "=" * 70) + print("✓ All tests passed!") + print("=" * 70) diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 6c52fa813e..6bc0846650 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -12,6 +12,7 @@ from tensorrt_llm.logger import logger from ..._utils import get_sm_version from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention +from ..cute_dsl_kernels.argmax import argmax as cute_argmax from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..pyexecutor.resource_manager import BaseResourceManager @@ -534,7 +535,8 @@ class SpecWorkerBase(nn.Module, ABC): Returns: draft_tokens: [num_tokens] - Sampled draft token ids (int32) """ - draft_tokens = torch.argmax(logits, dim=-1) + # cute_argmax returns (M, 2) where col 0 = max value, col 1 = argmax index + draft_tokens = cute_argmax(logits)[:, 1].long() # Apply d2t (offsets between draft and target model dictionaries) if d2t is not None: @@ -636,6 +638,7 @@ class SpecWorkerBase(nn.Module, ABC): seed=self.seed, offset=self.offset) else: - sampled_tokens = torch.argmax(logits, dim=-1) + # cute_argmax returns (M, 2) where col 0 = max value, col 1 = argmax index + sampled_tokens = cute_argmax(logits)[:, 1].long() return sampled_tokens