/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/kernels/cuteDslKernels/moeUtils.h" #include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh" #include "tensorrt_llm/kernels/quantization.cuh" #include "tensorrt_llm/kernels/quantization.h" #include #include TRTLLM_NAMESPACE_BEGIN namespace kernels::cute_dsl { namespace { using ElemCopyType = uint4; using SFCopyType = uint32_t; using ActivationType = tensorrt_llm::kernels::cutlass_kernels::ActivationType; template auto constexpr bitsPerElem() { #ifdef ENABLE_FP4 return std::is_same_v ? 4 : cute::sizeof_bits_v; #else return cute::sizeof_bits_v; #endif } template auto constexpr elemPerCopy() { return bitsPerElem() / bitsPerElem(); } template auto constexpr sfElemPerCopy() { return bitsPerElem() / bitsPerElem(); } } // namespace template __global__ void moePermuteKernel(InputType const* input, InputType* permuted_output, SFType const* input_sf, SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size) { int32_t constexpr kElemPerCopy = elemPerCopy(); int32_t constexpr kSFElemPerCopy = sfElemPerCopy(); // Need int64_t to prevent overflow when computing pointer offsets. int64_t const kCopyPerToken = hidden_size / kElemPerCopy; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x) { int32_t const tile_idx = permuted_idx / tile_size; if (permuted_idx >= tile_idx_to_mn_limit[tile_idx]) { continue; } int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx]; int32_t const token_idx = expanded_idx / top_k; auto const* src_ptr = reinterpret_cast(input) + token_idx * kCopyPerToken; auto* dst_ptr = reinterpret_cast(permuted_output) + permuted_idx * kCopyPerToken; for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) { dst_ptr[i] = src_ptr[i]; } #ifdef ENABLE_FP4 if constexpr (std::is_same_v) { int32_t const sf_hidden_size = hidden_size / kSFVecSize; int64_t const kSFCopyPerToken = sf_hidden_size / kSFElemPerCopy; auto const* sf_src_ptr = reinterpret_cast(input_sf); auto* sf_dst_ptr = reinterpret_cast(permuted_sf); for (int32_t i = threadIdx.x; i < kSFCopyPerToken; i += kThreadsPerBlock) { // input_sf is not swizzled, while permuted_sf is swizzled. int64_t const src_offset = token_idx * kSFCopyPerToken + i; int64_t const dst_offset = get_sf_out_offset_128x4(/* batchIdx= */ std::nullopt, permuted_idx, i * kSFElemPerCopy, /* numRows= */ std::nullopt, sf_hidden_size) / kSFElemPerCopy; sf_dst_ptr[dst_offset] = sf_src_ptr[src_offset]; } } #endif } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif } template void moePermute(InputType const* input, InputType* permuted_output, SFType const* input_sf, SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size, cudaStream_t stream) { int32_t constexpr kThreadsPerBlock = 256; int32_t constexpr kSFVecSize = 16; int32_t constexpr kElemPerCopy = elemPerCopy(); TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy); #ifdef ENABLE_FP4 if constexpr (std::is_same_v) { int32_t constexpr kSFMAlignment = 128; int32_t constexpr kSFKAlignment = 4; int32_t constexpr kSFElemPerCopy = sfElemPerCopy(); static_assert(kSFElemPerCopy == kSFKAlignment); TLLM_CHECK_WITH_INFO(max_num_permuted_tokens % kSFMAlignment == 0, "max_num_permuted_tokens must be divisible by %d.", kSFMAlignment); TLLM_CHECK_WITH_INFO(hidden_size % (kSFVecSize * kSFKAlignment) == 0, "hidden_size must be divisible by %d.", kSFVecSize * kSFKAlignment); TLLM_CHECK_WITH_INFO(input_sf != nullptr, "input_sf is required for NVFP4."); TLLM_CHECK_WITH_INFO(permuted_sf != nullptr, "permuted_sf is required for NVFP4."); } #endif auto kernel = &moePermuteKernel; static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); int32_t const threads = kThreadsPerBlock; cudaLaunchConfig_t config; config.gridDim = blocks; config.blockDim = threads; config.dynamicSmemBytes = 0; config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernel, input, permuted_output, input_sf, permuted_sf, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size); } #define INSTANTIATE_MOE_PERMUTE(InputType, SFType) \ template void moePermute(InputType const* input, InputType* permuted_output, \ SFType const* input_sf, SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, \ int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, \ int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, \ int32_t const tile_size, cudaStream_t stream) INSTANTIATE_MOE_PERMUTE(half, uint8_t); #ifdef ENABLE_BF16 INSTANTIATE_MOE_PERMUTE(__nv_bfloat16, uint8_t); #endif #ifdef ENABLE_FP8 INSTANTIATE_MOE_PERMUTE(__nv_fp8_e4m3, uint8_t); #endif #ifdef ENABLE_FP4 INSTANTIATE_MOE_PERMUTE(__nv_fp4_e2m1, uint8_t); #endif #undef INSTANTIATE_MOE_PERMUTE template __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* output, int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, int32_t const hidden_size, int32_t const top_k) { using AccumType = float; int32_t constexpr kElemPerCopy = elemPerCopy(); // Need int64_t to prevent overflow when computing pointer offsets. int64_t const kCopyPerToken = hidden_size / kElemPerCopy; InputType rmem[kElemPerCopy]; AccumType rmemAccum[kElemPerCopy]; int32_t const token_idx = blockIdx.x; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif auto* dst_ptr = reinterpret_cast(output) + token_idx * kCopyPerToken; for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) { #pragma unroll for (int32_t j = 0; j < kElemPerCopy; j++) { rmemAccum[j] = 0; } for (int32_t k = 0; k < top_k; k++) { int32_t const permuted_idx = expanded_idx_to_permuted_idx[token_idx * top_k + k]; if (permuted_idx < 0) { continue; } auto const* src_ptr = reinterpret_cast(permuted_input) + permuted_idx * kCopyPerToken; *reinterpret_cast(rmem) = src_ptr[i]; TopKScaleType const scale = topk_scales[token_idx * top_k + k]; #pragma unroll for (int32_t j = 0; j < kElemPerCopy; j++) { rmemAccum[j] += static_cast(rmem[j]) * static_cast(scale); } } #pragma unroll for (int32_t j = 0; j < kElemPerCopy; j++) { rmem[j] = static_cast(rmemAccum[j]); } dst_ptr[i] = *reinterpret_cast(rmem); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif } template void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k, cudaStream_t stream) { int32_t constexpr kThreadsPerBlock = 256; int32_t constexpr kElemPerCopy = elemPerCopy(); TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy); int32_t const blocks = num_tokens; int32_t const threads = kThreadsPerBlock; auto kernel = &moeUnpermuteKernel; cudaLaunchConfig_t config; config.gridDim = blocks; config.blockDim = threads; config.dynamicSmemBytes = 0; config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx( &config, kernel, permuted_input, output, expanded_idx_to_permuted_idx, topk_scales, hidden_size, top_k); } #define INSTANTIATE_MOE_UNPERMUTE(InputType, TopKScaleType) \ template void moeUnpermute(InputType const* permuted_input, InputType* output, \ int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, int32_t const num_tokens, \ int32_t const hidden_size, int32_t const top_k, cudaStream_t stream) INSTANTIATE_MOE_UNPERMUTE(half, float); INSTANTIATE_MOE_UNPERMUTE(half, half); #ifdef ENABLE_BF16 INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, float); INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16); #endif #undef INSTANTIATE_MOE_UNPERMUTE template __global__ void moeOutputMemsetKernel(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size) { int32_t constexpr kElemPerCopy = elemPerCopy(); int64_t const kCopyPerToken = hidden_size / kElemPerCopy; InputType rmem[kElemPerCopy]; #pragma unroll for (int32_t j = 0; j < kElemPerCopy; j++) { rmem[j] = 0; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x) { int32_t const tile_idx = permuted_idx / tile_size; if (permuted_idx >= tile_idx_to_mn_limit[tile_idx]) { continue; } int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx]; int32_t const token_idx = expanded_idx / top_k; int32_t const topk_idx = expanded_idx % top_k; bool is_first_in_topk = true; for (int32_t k = 0; k < topk_idx; k++) { if (expanded_idx_to_permuted_idx[token_idx * top_k + k] >= 0) { is_first_in_topk = false; break; } } if (!is_first_in_topk) { continue; } auto* dst_ptr = reinterpret_cast(input) + token_idx * kCopyPerToken; for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) { dst_ptr[i] = *reinterpret_cast(rmem); } } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif } template void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size, cudaStream_t stream) { int32_t constexpr kThreadsPerBlock = 256; int32_t constexpr kElemPerCopy = elemPerCopy(); TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy); auto kernel = &moeOutputMemsetKernel; static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); int32_t const threads = kThreadsPerBlock; cudaLaunchConfig_t config; config.gridDim = blocks; config.blockDim = threads; config.dynamicSmemBytes = 0; config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernel, input, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size); } #define INSTANTIATE_MOE_OUTPUT_MEMSET(InputType) \ template void moeOutputMemset(InputType * input, int32_t const* tile_idx_to_mn_limit, \ int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, \ int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, \ int32_t const top_k, int32_t const tile_size, cudaStream_t stream) INSTANTIATE_MOE_OUTPUT_MEMSET(half); #ifdef ENABLE_BF16 INSTANTIATE_MOE_OUTPUT_MEMSET(__nv_bfloat16); #endif #undef INSTANTIATE_MOE_OUTPUT_MEMSET template __global__ void moeActivationKernel(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles, int32_t const interm_size, int32_t const tile_size) { using ComputeType = float; #ifdef ENABLE_FP4 using ElemOutputCopyType = std::conditional_t, uint32_t, ElemCopyType>; #else using ElemOutputCopyType = ElemCopyType; #endif int32_t constexpr kElemPerCopy = elemPerCopy(); // Need int64_t to prevent overflow when computing pointer offsets. int64_t const kCopyPerToken = interm_size / kElemPerCopy; InputType rmem[kElemPerCopy]; InputType rmemGate[kElemPerCopy]; ActFn act{}; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif float global_sf_val = global_sf == nullptr ? 1.0f : global_sf[0]; int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x) { int32_t const tile_idx = permuted_idx / tile_size; if (permuted_idx >= tile_idx_to_mn_limit[tile_idx]) { continue; } auto const* src_ptr = reinterpret_cast(input) + permuted_idx * kCopyPerToken * (ActFn::IS_GLU ? 2 : 1); auto* dst_ptr = reinterpret_cast(output) + permuted_idx * kCopyPerToken; for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) { *reinterpret_cast(rmem) = src_ptr[i]; if constexpr (ActFn::IS_GLU) { *reinterpret_cast(rmemGate) = src_ptr[i + kCopyPerToken]; #pragma unroll for (int32_t j = 0; j < kElemPerCopy; j++) { rmem[j] = static_cast( act(static_cast(rmemGate[j]), static_cast(rmem[j]))); } } else { #pragma unroll for (int32_t j = 0; j < kElemPerCopy; j++) { rmem[j] = static_cast(act(static_cast(rmem[j]))); } } #ifdef ENABLE_FP4 if constexpr (std::is_same_v) { auto* sf_dst_ptr = cvt_quant_get_sf_out_offset( /* batchIdx= */ std::nullopt, permuted_idx, i, /*numRows=*/std::nullopt, interm_size / kSFVecSize, output_sf, QuantizationSFLayout::SWIZZLED); dst_ptr[i] = cvt_warp_fp16_to_fp4( *reinterpret_cast*>(rmem), global_sf_val, sf_dst_ptr); } else #endif { dst_ptr[i] = *reinterpret_cast(rmem); } } } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif } template void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles, cutlass_kernels::ActivationParams activation_params, int32_t const max_num_permuted_tokens, int32_t const interm_size, int32_t const tile_size, cudaStream_t stream) { int32_t constexpr kThreadsPerBlock = 256; int32_t constexpr kSFVecSize = 16; int32_t constexpr kElemPerCopy = elemPerCopy(); TLLM_CHECK_WITH_INFO(interm_size % kElemPerCopy == 0, "interm_size must be divisible by %d.", kElemPerCopy); #ifdef ENABLE_FP4 if constexpr (std::is_same_v) { int32_t constexpr kSFMAlignment = 128; int32_t constexpr kSFKAlignment = 4; TLLM_CHECK_WITH_INFO(max_num_permuted_tokens % kSFMAlignment == 0, "max_num_permuted_tokens must be divisible by %d.", kSFMAlignment); TLLM_CHECK_WITH_INFO(interm_size % (kSFVecSize * kSFKAlignment) == 0, "interm_size must be divisible by %d.", kSFVecSize * kSFKAlignment); TLLM_CHECK_WITH_INFO(global_sf != nullptr, "global_sf is required for NVFP4."); TLLM_CHECK_WITH_INFO(output_sf != nullptr, "output_sf is required for NVFP4."); } #endif auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles, int32_t const interm_size, int32_t const tile_size) { switch (activation_type) { case ActivationType::Identity: return &moeActivationKernel, kThreadsPerBlock>; case ActivationType::Gelu: return &moeActivationKernel, kThreadsPerBlock>; case ActivationType::Geglu: return &moeActivationKernel, kThreadsPerBlock>; case ActivationType::Relu: return &moeActivationKernel, kThreadsPerBlock>; case ActivationType::Silu: return &moeActivationKernel, kThreadsPerBlock>; case ActivationType::Swiglu: return &moeActivationKernel, kThreadsPerBlock>; case ActivationType::SwigluBias: return &moeActivationKernel; case ActivationType::Relu2: // Unsupported activation type break; } TLLM_CHECK_WITH_INFO(false, "Unsupported activation type: %d", int(activation_type)); return nullptr; }; auto kernel = get_act_kernel(activation_params.activation_type); static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0); int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens); int32_t const threads = kThreadsPerBlock; cudaLaunchConfig_t config; config.gridDim = blocks; config.blockDim = threads; config.dynamicSmemBytes = 0; config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernel, input, output, global_sf, output_sf, tile_idx_to_mn_limit, num_non_exiting_tiles, interm_size, tile_size); } #define INSTANTIATE_MOE_ACTIVATION(InputType, OutputType, SFType) \ template void moeActivation(InputType const* input, OutputType* output, \ float const* global_sf, SFType* output_sf, int32_t const* tile_idx_to_mn_limit, \ int32_t const* num_non_exiting_tiles, cutlass_kernels::ActivationParams activation_params, \ int32_t const max_num_permuted_tokens, int32_t const interm_size, int32_t const tile_size, \ cudaStream_t stream) INSTANTIATE_MOE_ACTIVATION(half, half, uint8_t); #ifdef ENABLE_BF16 INSTANTIATE_MOE_ACTIVATION(__nv_bfloat16, __nv_bfloat16, uint8_t); #endif #ifdef ENABLE_FP4 INSTANTIATE_MOE_ACTIVATION(half, __nv_fp4_e2m1, uint8_t); #ifdef ENABLE_BF16 INSTANTIATE_MOE_ACTIVATION(__nv_bfloat16, __nv_fp4_e2m1, uint8_t); #endif #endif #undef INSTANTIATE_MOE_ACTIVATION } // namespace kernels::cute_dsl TRTLLM_NAMESPACE_END