mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9286][feat] Integration of CuteDSL NVFP4 grouped GEMM (#8880)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
c789000a62
commit
7c4777a571
@ -189,6 +189,7 @@ set(TRTLLM_LINK_LIBS
|
||||
fb_gemm_src
|
||||
gemm_swiglu_sm90_src
|
||||
cutlass_src
|
||||
cute_dsl_src
|
||||
layers_src
|
||||
runtime_src
|
||||
testing_src
|
||||
|
||||
@ -22,6 +22,8 @@ file(GLOB_RECURSE SRC_CU *.cu)
|
||||
# selectiveScan trtllmGenKernels folder
|
||||
list(FILTER SRC_CPP EXCLUDE REGEX "cutlass_kernels/.*")
|
||||
list(FILTER SRC_CU EXCLUDE REGEX "cutlass_kernels/.*")
|
||||
list(FILTER SRC_CPP EXCLUDE REGEX "cuteDslKernels/.*")
|
||||
list(FILTER SRC_CU EXCLUDE REGEX "cuteDslKernels/.*")
|
||||
list(FILTER SRC_CPP EXCLUDE REGEX "flashMLA/.*")
|
||||
list(FILTER SRC_CU EXCLUDE REGEX "flashMLA/.*")
|
||||
list(FILTER SRC_CPP EXCLUDE REGEX "contextFusedMultiHeadAttention/.*")
|
||||
@ -75,6 +77,7 @@ target_include_directories(
|
||||
add_cuda_architectures(kernels_src 89)
|
||||
|
||||
add_subdirectory(cutlass_kernels)
|
||||
add_subdirectory(cuteDslKernels)
|
||||
add_subdirectory(flashMLA)
|
||||
add_subdirectory(contextFusedMultiHeadAttention)
|
||||
add_subdirectory(decoderMaskedMultiheadAttention)
|
||||
|
||||
23
cpp/tensorrt_llm/kernels/cuteDslKernels/CMakeLists.txt
Normal file
23
cpp/tensorrt_llm/kernels/cuteDslKernels/CMakeLists.txt
Normal file
@ -0,0 +1,23 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||
# All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
|
||||
file(GLOB_RECURSE SRC_CPP *.cpp)
|
||||
file(GLOB_RECURSE SRC_CU *.cu)
|
||||
|
||||
add_library(cute_dsl_src OBJECT ${SRC_CPP} ${SRC_CU})
|
||||
set_property(TARGET cute_dsl_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cute_dsl_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
439
cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu
Normal file
439
cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu
Normal file
@ -0,0 +1,439 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/kernels/cuteDslKernels/moeUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh"
|
||||
#include "tensorrt_llm/kernels/quantization.cuh"
|
||||
#include "tensorrt_llm/kernels/quantization.h"
|
||||
|
||||
#include <cuda_fp4.h>
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
|
||||
namespace tensorrt_llm::kernels::cute_dsl
|
||||
{
|
||||
namespace
|
||||
{
|
||||
using ElemCopyType = uint4;
|
||||
using SFCopyType = uint32_t;
|
||||
|
||||
template <typename T>
|
||||
auto constexpr bitsPerElem()
|
||||
{
|
||||
#ifdef ENABLE_FP4
|
||||
return std::is_same_v<T, __nv_fp4_e2m1> ? 4 : cute::sizeof_bits_v<T>;
|
||||
#else
|
||||
return cute::sizeof_bits_v<T>;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto constexpr elemPerCopy()
|
||||
{
|
||||
return bitsPerElem<ElemCopyType>() / bitsPerElem<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto constexpr sfElemPerCopy()
|
||||
{
|
||||
return bitsPerElem<SFCopyType>() / bitsPerElem<T>();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename InputType, typename SFType, int32_t kSFVecSize, int32_t kThreadsPerBlock>
|
||||
__global__ void moePermuteKernel(InputType const* input, InputType* permuted_output, SFType const* input_sf,
|
||||
SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx,
|
||||
int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size)
|
||||
{
|
||||
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
|
||||
int32_t constexpr kSFElemPerCopy = sfElemPerCopy<SFType>();
|
||||
// Need int64_t to prevent overflow when computing pointer offsets.
|
||||
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
|
||||
for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x)
|
||||
{
|
||||
int32_t const tile_idx = permuted_idx / tile_size;
|
||||
if (permuted_idx >= tile_idx_to_mn_limit[tile_idx])
|
||||
{
|
||||
continue;
|
||||
}
|
||||
int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx];
|
||||
int32_t const token_idx = expanded_idx / top_k;
|
||||
|
||||
auto const* src_ptr = reinterpret_cast<ElemCopyType const*>(input) + token_idx * kCopyPerToken;
|
||||
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(permuted_output) + permuted_idx * kCopyPerToken;
|
||||
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
|
||||
{
|
||||
dst_ptr[i] = src_ptr[i];
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP4
|
||||
if constexpr (std::is_same_v<InputType, __nv_fp4_e2m1>)
|
||||
{
|
||||
int32_t const sf_hidden_size = hidden_size / kSFVecSize;
|
||||
int64_t const kSFCopyPerToken = sf_hidden_size / kSFElemPerCopy;
|
||||
auto const* sf_src_ptr = reinterpret_cast<SFCopyType const*>(input_sf);
|
||||
auto* sf_dst_ptr = reinterpret_cast<SFCopyType*>(permuted_sf);
|
||||
for (int32_t i = threadIdx.x; i < kSFCopyPerToken; i += kThreadsPerBlock)
|
||||
{
|
||||
// input_sf is not swizzled, while permuted_sf is swizzled.
|
||||
int64_t const src_offset = token_idx * kSFCopyPerToken + i;
|
||||
int64_t const dst_offset = get_sf_out_offset_128x4(/* batchIdx= */ std::nullopt, permuted_idx,
|
||||
i * kSFElemPerCopy, /* numRows= */ std::nullopt, sf_hidden_size)
|
||||
/ kSFElemPerCopy;
|
||||
|
||||
sf_dst_ptr[dst_offset] = sf_src_ptr[src_offset];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename InputType, typename SFType>
|
||||
void moePermute(InputType const* input, InputType* permuted_output, SFType const* input_sf, SFType* permuted_sf,
|
||||
int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx,
|
||||
int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size,
|
||||
int32_t const top_k, int32_t const tile_size, cudaStream_t stream)
|
||||
{
|
||||
int32_t constexpr kThreadsPerBlock = 256;
|
||||
int32_t constexpr kSFVecSize = 16;
|
||||
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
|
||||
TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy);
|
||||
|
||||
#ifdef ENABLE_FP4
|
||||
if constexpr (std::is_same_v<InputType, __nv_fp4_e2m1>)
|
||||
{
|
||||
int32_t constexpr kSFMAlignment = 128;
|
||||
int32_t constexpr kSFKAlignment = 4;
|
||||
int32_t constexpr kSFElemPerCopy = sfElemPerCopy<SFType>();
|
||||
static_assert(kSFElemPerCopy == kSFKAlignment);
|
||||
TLLM_CHECK_WITH_INFO(max_num_permuted_tokens % kSFMAlignment == 0,
|
||||
"max_num_permuted_tokens must be divisible by %d.", kSFMAlignment);
|
||||
TLLM_CHECK_WITH_INFO(hidden_size % (kSFVecSize * kSFKAlignment) == 0, "hidden_size must be divisible by %d.",
|
||||
kSFVecSize * kSFKAlignment);
|
||||
TLLM_CHECK_WITH_INFO(input_sf != nullptr, "input_sf is required for NVFP4.");
|
||||
TLLM_CHECK_WITH_INFO(permuted_sf != nullptr, "permuted_sf is required for NVFP4.");
|
||||
}
|
||||
#endif
|
||||
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel, input, permuted_output, input_sf, permuted_sf, tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_PERMUTE(InputType, SFType) \
|
||||
template void moePermute<InputType, SFType>(InputType const* input, InputType* permuted_output, \
|
||||
SFType const* input_sf, SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, \
|
||||
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, \
|
||||
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, \
|
||||
int32_t const tile_size, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_MOE_PERMUTE(half, uint8_t);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MOE_PERMUTE(__nv_bfloat16, uint8_t);
|
||||
#endif
|
||||
#ifdef ENABLE_FP8
|
||||
INSTANTIATE_MOE_PERMUTE(__nv_fp8_e4m3, uint8_t);
|
||||
#endif
|
||||
#ifdef ENABLE_FP4
|
||||
INSTANTIATE_MOE_PERMUTE(__nv_fp4_e2m1, uint8_t);
|
||||
#endif
|
||||
#undef INSTANTIATE_MOE_PERMUTE
|
||||
|
||||
template <typename InputType, typename TopKScaleType, int32_t kThreadsPerBlock>
|
||||
__global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* output,
|
||||
int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, int32_t const hidden_size,
|
||||
int32_t const top_k)
|
||||
{
|
||||
using AccumType = float;
|
||||
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
|
||||
// Need int64_t to prevent overflow when computing pointer offsets.
|
||||
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
|
||||
InputType rmem[kElemPerCopy];
|
||||
AccumType rmemAccum[kElemPerCopy];
|
||||
|
||||
int32_t const token_idx = blockIdx.x;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(output) + token_idx * kCopyPerToken;
|
||||
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kElemPerCopy; j++)
|
||||
{
|
||||
rmemAccum[j] = 0;
|
||||
}
|
||||
for (int32_t k = 0; k < top_k; k++)
|
||||
{
|
||||
int32_t const permuted_idx = expanded_idx_to_permuted_idx[token_idx * top_k + k];
|
||||
if (permuted_idx < 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
auto const* src_ptr = reinterpret_cast<ElemCopyType const*>(permuted_input) + permuted_idx * kCopyPerToken;
|
||||
*reinterpret_cast<ElemCopyType*>(rmem) = src_ptr[i];
|
||||
TopKScaleType const scale = topk_scales[token_idx * top_k + k];
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kElemPerCopy; j++)
|
||||
{
|
||||
rmemAccum[j] += static_cast<AccumType>(rmem[j]) * static_cast<AccumType>(scale);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kElemPerCopy; j++)
|
||||
{
|
||||
rmem[j] = static_cast<InputType>(rmemAccum[j]);
|
||||
}
|
||||
dst_ptr[i] = *reinterpret_cast<ElemCopyType*>(rmem);
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename InputType, typename TopKScaleType>
|
||||
void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t const* expanded_idx_to_permuted_idx,
|
||||
TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int32_t constexpr kThreadsPerBlock = 256;
|
||||
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
|
||||
TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy);
|
||||
|
||||
int32_t const blocks = num_tokens;
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
auto kernel = &moeUnpermuteKernel<InputType, TopKScaleType, kThreadsPerBlock>;
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(
|
||||
&config, kernel, permuted_input, output, expanded_idx_to_permuted_idx, topk_scales, hidden_size, top_k);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_UNPERMUTE(InputType, TopKScaleType) \
|
||||
template void moeUnpermute<InputType>(InputType const* permuted_input, InputType* output, \
|
||||
int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, int32_t const num_tokens, \
|
||||
int32_t const hidden_size, int32_t const top_k, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_MOE_UNPERMUTE(half, float);
|
||||
INSTANTIATE_MOE_UNPERMUTE(half, half);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, float);
|
||||
INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16);
|
||||
#endif
|
||||
#undef INSTANTIATE_MOE_UNPERMUTE
|
||||
|
||||
template <typename InputType, typename OutputType, typename SFType, int32_t kSFVecSize, typename ActFn,
|
||||
int32_t kThreadsPerBlock>
|
||||
__global__ void moeActivationKernel(InputType const* input, OutputType* output, float const* global_sf,
|
||||
SFType* output_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles,
|
||||
int32_t const interm_size, int32_t const tile_size)
|
||||
{
|
||||
using ComputeType = float;
|
||||
#ifdef ENABLE_FP4
|
||||
using ElemOutputCopyType = std::conditional_t<std::is_same_v<OutputType, __nv_fp4_e2m1>, uint32_t, ElemCopyType>;
|
||||
#else
|
||||
using ElemOutputCopyType = ElemCopyType;
|
||||
#endif
|
||||
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
|
||||
// Need int64_t to prevent overflow when computing pointer offsets.
|
||||
int64_t const kCopyPerToken = interm_size / kElemPerCopy;
|
||||
InputType rmem[kElemPerCopy];
|
||||
InputType rmemGate[kElemPerCopy];
|
||||
ActFn act{};
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
|
||||
float global_sf_val = global_sf == nullptr ? 1.0f : global_sf[0];
|
||||
|
||||
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
|
||||
for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x)
|
||||
{
|
||||
int32_t const tile_idx = permuted_idx / tile_size;
|
||||
if (permuted_idx >= tile_idx_to_mn_limit[tile_idx])
|
||||
{
|
||||
continue;
|
||||
}
|
||||
auto const* src_ptr
|
||||
= reinterpret_cast<ElemCopyType const*>(input) + permuted_idx * kCopyPerToken * (ActFn::IS_GLU ? 2 : 1);
|
||||
auto* dst_ptr = reinterpret_cast<ElemOutputCopyType*>(output) + permuted_idx * kCopyPerToken;
|
||||
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
|
||||
{
|
||||
*reinterpret_cast<ElemCopyType*>(rmem) = src_ptr[i];
|
||||
if constexpr (ActFn::IS_GLU)
|
||||
{
|
||||
*reinterpret_cast<ElemCopyType*>(rmemGate) = src_ptr[i + kCopyPerToken];
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kElemPerCopy; j++)
|
||||
{
|
||||
rmem[j] = static_cast<InputType>(
|
||||
act(static_cast<ComputeType>(rmemGate[j]), static_cast<ComputeType>(rmem[j])));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kElemPerCopy; j++)
|
||||
{
|
||||
rmem[j] = static_cast<InputType>(act(static_cast<ComputeType>(rmem[j])));
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FP4
|
||||
if constexpr (std::is_same_v<OutputType, __nv_fp4_e2m1>)
|
||||
{
|
||||
auto* sf_dst_ptr = cvt_quant_get_sf_out_offset<SFType, kSFVecSize / kElemPerCopy>(
|
||||
/* batchIdx= */ std::nullopt, permuted_idx, i, /*numRows=*/std::nullopt, interm_size / kSFVecSize,
|
||||
output_sf, QuantizationSFLayout::SWIZZLED);
|
||||
dst_ptr[i] = cvt_warp_fp16_to_fp4<InputType, kSFVecSize, false>(
|
||||
*reinterpret_cast<PackedVec<InputType>*>(rmem), global_sf_val, sf_dst_ptr);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
dst_ptr[i] = *reinterpret_cast<ElemCopyType*>(rmem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType, typename SFType>
|
||||
void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf,
|
||||
int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles,
|
||||
cutlass_kernels::ActivationParams activation_params, int32_t const max_num_permuted_tokens,
|
||||
int32_t const interm_size, int32_t const tile_size, cudaStream_t stream)
|
||||
{
|
||||
int32_t constexpr kThreadsPerBlock = 256;
|
||||
int32_t constexpr kSFVecSize = 16;
|
||||
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
|
||||
TLLM_CHECK_WITH_INFO(interm_size % kElemPerCopy == 0, "interm_size must be divisible by %d.", kElemPerCopy);
|
||||
|
||||
#ifdef ENABLE_FP4
|
||||
if constexpr (std::is_same_v<InputType, __nv_fp4_e2m1>)
|
||||
{
|
||||
int32_t constexpr kSFMAlignment = 128;
|
||||
int32_t constexpr kSFKAlignment = 4;
|
||||
TLLM_CHECK_WITH_INFO(max_num_permuted_tokens % kSFMAlignment == 0,
|
||||
"max_num_permuted_tokens must be divisible by %d.", kSFMAlignment);
|
||||
TLLM_CHECK_WITH_INFO(interm_size % (kSFVecSize * kSFKAlignment) == 0, "interm_size must be divisible by %d.",
|
||||
kSFVecSize * kSFKAlignment);
|
||||
TLLM_CHECK_WITH_INFO(global_sf != nullptr, "global_sf is required for NVFP4.");
|
||||
TLLM_CHECK_WITH_INFO(output_sf != nullptr, "output_sf is required for NVFP4.");
|
||||
}
|
||||
#endif
|
||||
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
auto kernel_array
|
||||
= std::array{&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
|
||||
kThreadsPerBlock>,
|
||||
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
|
||||
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>};
|
||||
|
||||
auto kernel = kernel_array[static_cast<int32_t>(activation_params.activation_type)];
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel, input, output, global_sf, output_sf, tile_idx_to_mn_limit,
|
||||
num_non_exiting_tiles, interm_size, tile_size);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_ACTIVATION(InputType, OutputType, SFType) \
|
||||
template void moeActivation<InputType, OutputType, SFType>(InputType const* input, OutputType* output, \
|
||||
float const* global_sf, SFType* output_sf, int32_t const* tile_idx_to_mn_limit, \
|
||||
int32_t const* num_non_exiting_tiles, cutlass_kernels::ActivationParams activation_params, \
|
||||
int32_t const max_num_permuted_tokens, int32_t const interm_size, int32_t const tile_size, \
|
||||
cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_MOE_ACTIVATION(half, half, uint8_t);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MOE_ACTIVATION(__nv_bfloat16, __nv_bfloat16, uint8_t);
|
||||
#endif
|
||||
#ifdef ENABLE_FP4
|
||||
INSTANTIATE_MOE_ACTIVATION(half, __nv_fp4_e2m1, uint8_t);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MOE_ACTIVATION(__nv_bfloat16, __nv_fp4_e2m1, uint8_t);
|
||||
#endif
|
||||
#endif
|
||||
#undef INSTANTIATE_MOE_ACTIVATION
|
||||
|
||||
} // namespace tensorrt_llm::kernels::cute_dsl
|
||||
41
cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h
Normal file
41
cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h"
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace tensorrt_llm::kernels::cute_dsl
|
||||
{
|
||||
template <typename InputType, typename SFType>
|
||||
void moePermute(InputType const* input, InputType* permuted_output, SFType const* input_sf, SFType* permuted_sf,
|
||||
int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx,
|
||||
int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size,
|
||||
int32_t const top_k, int32_t const tile_size, cudaStream_t stream);
|
||||
|
||||
template <typename InputType, typename TopKScaleType>
|
||||
void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t const* expanded_idx_to_permuted_idx,
|
||||
TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename InputType, typename OutputType, typename SFType>
|
||||
void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf,
|
||||
int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles,
|
||||
cutlass_kernels::ActivationParams activation_params, int32_t const max_num_permuted_tokens,
|
||||
int32_t const interm_size, int32_t const tile_size, cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels::cute_dsl
|
||||
@ -37,7 +37,6 @@
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
@ -52,6 +51,7 @@
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh"
|
||||
#include "tensorrt_llm/kernels/moe_utils.cuh"
|
||||
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
|
||||
#include "tensorrt_llm/kernels/quantization.cuh"
|
||||
@ -1344,7 +1344,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input)
|
||||
return converter(input);
|
||||
}
|
||||
|
||||
// Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing.
|
||||
// Duplicated and permutes rows for MoE.
|
||||
|
||||
// "expanded_x_row" simply means that the number of values is num_rows x k. It is "expanded" since we will have to
|
||||
// duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate
|
||||
@ -1937,56 +1937,6 @@ INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float);
|
||||
INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16);
|
||||
#endif
|
||||
|
||||
// ============================== Activation Adaptors =================================
|
||||
template <template <class> class ActFn>
|
||||
struct IdentityAdaptor
|
||||
{
|
||||
constexpr static bool IS_GLU = false;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
float limit = std::numeric_limits<float>::infinity();
|
||||
|
||||
template <class T>
|
||||
__device__ T operator()(T const& x) const
|
||||
{
|
||||
ActFn<T> fn{};
|
||||
return fn(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <template <class> class ActFn>
|
||||
struct GLUAdaptor
|
||||
{
|
||||
constexpr static bool IS_GLU = true;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
float limit = std::numeric_limits<float>::infinity();
|
||||
|
||||
template <class T>
|
||||
__device__ T operator()(T const& gate, T const& linear) const
|
||||
{
|
||||
ActFn<T> fn{};
|
||||
return fn(gate) * linear;
|
||||
}
|
||||
};
|
||||
|
||||
struct SwigluBiasAdaptor
|
||||
{
|
||||
constexpr static bool IS_GLU = true;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
float limit = std::numeric_limits<float>::infinity();
|
||||
|
||||
template <class T>
|
||||
__device__ T operator()(T const& gate, T const& linear) const
|
||||
{
|
||||
cutlass::epilogue::thread::Sigmoid<T> fn{};
|
||||
T linear_clamped = cutlass::maximum<T>{}(cutlass::minimum<T>{}(linear, limit), -limit);
|
||||
T gate_clamped = cutlass::minimum<T>{}(gate, limit);
|
||||
return gate_clamped * fn(gate_clamped * alpha) * (linear_clamped + beta);
|
||||
}
|
||||
};
|
||||
|
||||
// ============================== Gated Activation =================================
|
||||
constexpr static int ACTIVATION_THREADS_PER_BLOCK = 256;
|
||||
|
||||
|
||||
@ -0,0 +1,75 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <limits>
|
||||
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
|
||||
namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
{
|
||||
// ============================== Activation Adaptors =================================
|
||||
|
||||
template <template <class> class ActFn>
|
||||
struct IdentityAdaptor
|
||||
{
|
||||
constexpr static bool IS_GLU = false;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
float limit = std::numeric_limits<float>::infinity();
|
||||
|
||||
template <class T>
|
||||
__device__ T operator()(T const& x) const
|
||||
{
|
||||
ActFn<T> fn{};
|
||||
return fn(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <template <class> class ActFn>
|
||||
struct GLUAdaptor
|
||||
{
|
||||
constexpr static bool IS_GLU = true;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
float limit = std::numeric_limits<float>::infinity();
|
||||
|
||||
template <class T>
|
||||
__device__ T operator()(T const& gate, T const& linear) const
|
||||
{
|
||||
ActFn<T> fn{};
|
||||
return fn(gate) * linear;
|
||||
}
|
||||
};
|
||||
|
||||
struct SwigluBiasAdaptor
|
||||
{
|
||||
constexpr static bool IS_GLU = true;
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
float limit = std::numeric_limits<float>::infinity();
|
||||
|
||||
template <class T>
|
||||
__device__ T operator()(T const& gate, T const& linear) const
|
||||
{
|
||||
cutlass::epilogue::thread::Sigmoid<T> fn{};
|
||||
T linear_clamped = cutlass::maximum<T>{}(cutlass::minimum<T>{}(linear, limit), -limit);
|
||||
T gate_clamped = cutlass::minimum<T>{}(gate, limit);
|
||||
return gate_clamped * fn(gate_clamped * alpha) * (linear_clamped + beta);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::kernels::cutlass_kernels
|
||||
@ -527,6 +527,10 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesCoo
|
||||
{
|
||||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||||
@ -593,7 +597,8 @@ void run(Data& data, void* stream)
|
||||
TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr,
|
||||
"When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for DeepSeek routing.");
|
||||
}
|
||||
if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr)
|
||||
if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToExpandedIdx != nullptr
|
||||
|| data.mPtrPermutedIdxToTokenIdx != nullptr)
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
(data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize,
|
||||
"If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required");
|
||||
|
||||
@ -474,6 +474,10 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx<BaseType>
|
||||
{
|
||||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||||
@ -840,6 +844,10 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesOff
|
||||
{
|
||||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||||
|
||||
@ -55,6 +55,9 @@ struct DataBase
|
||||
int32_t* mPtrExpandedIdxToPermutedIdx{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mTileTokensDim * mTopK + (mNumExperts × mTileTokensDim) - mNumExperts]
|
||||
int32_t* mPtrPermutedIdxToExpandedIdx{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mTileTokensDim * mTopK + (mNumExperts × mTileTokensDim) - mNumExperts]
|
||||
// Note: this array (mPtrPermutedIdxToTokenIdx) is uninitialized
|
||||
// Any out-of-bounds values are undefined.
|
||||
int32_t* mPtrPermutedIdxToTokenIdx{nullptr};
|
||||
@ -119,6 +122,7 @@ struct KernelParamsBase
|
||||
int32_t* mPtrExpertCounts = nullptr;
|
||||
int32_t* mPtrPermutedIdxSize = nullptr;
|
||||
int32_t* mPtrExpandedIdxToPermutedIdx = nullptr;
|
||||
int32_t* mPtrPermutedIdxToExpandedIdx = nullptr;
|
||||
int32_t* mPtrPermutedIdxToTokenIdx = nullptr;
|
||||
int32_t* mPtrCtaIdxXyToBatchIdx = nullptr;
|
||||
int32_t* mPtrCtaIdxXyToMnLimit = nullptr;
|
||||
@ -144,6 +148,7 @@ struct KernelParamsBase
|
||||
mPtrExpertCounts = data.mPtrExpertCounts;
|
||||
mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
|
||||
mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
|
||||
mPtrPermutedIdxToExpandedIdx = data.mPtrPermutedIdxToExpandedIdx;
|
||||
mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
|
||||
mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
|
||||
mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
|
||||
|
||||
@ -348,6 +348,11 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
|
||||
{
|
||||
params.mPtrExpandedIdxToPermutedIdx[tokenIdx] = permutedIdx;
|
||||
}
|
||||
// write out `mPtrPermutedIdxToExpandedIdx` if required
|
||||
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx;
|
||||
}
|
||||
// write out `mPtrPermutedIdxToTokenIdx` if required
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert && isTokenRouted)
|
||||
{
|
||||
|
||||
@ -276,8 +276,15 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesBlo
|
||||
int const offsetForExpert = expertScanCounts;
|
||||
int const permutedIdx = isLocalExpert ? offsetForExpert + offsetWithinExpert : int32_t{-1};
|
||||
|
||||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||||
if (isLocalExpert)
|
||||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||||
{
|
||||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||||
}
|
||||
|
||||
@ -80,6 +80,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
|
||||
routingData.mPtrExpertCounts = expertCountHistogram;
|
||||
routingData.mPtrPermutedIdxSize = permutedIdxSize;
|
||||
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
|
||||
routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx;
|
||||
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
|
||||
routingData.mPtrTopKWeights = expertWeights;
|
||||
|
||||
@ -122,6 +123,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
|
||||
routingData.mPtrExpertCounts = expertCountHistogram;
|
||||
routingData.mPtrPermutedIdxSize = permutedIdxSize;
|
||||
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
|
||||
routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx;
|
||||
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
|
||||
routingData.mPtrTopKWeights = expertWeights;
|
||||
|
||||
@ -177,6 +179,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
|
||||
routingData.mPtrExpertCounts = expertCountHistogram;
|
||||
routingData.mPtrPermutedIdxSize = permutedIdxSize;
|
||||
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
|
||||
routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx;
|
||||
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
|
||||
routingData.mPtrTopKWeights = expertWeights;
|
||||
routingData.mPtrTopKIds = expertIds;
|
||||
|
||||
@ -44,6 +44,7 @@ add_library(
|
||||
attentionOp.cpp
|
||||
causalConv1dOp.cpp
|
||||
convertSpecDecodingMaskToPackedMaskOp.cpp
|
||||
cuteDslMoeUtilsOp.cpp
|
||||
cutlassScaledMM.cpp
|
||||
cublasScaledMM.cpp
|
||||
cublasFp4ScaledMM.cpp
|
||||
|
||||
444
cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp
Normal file
444
cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp
Normal file
@ -0,0 +1,444 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cuteDslKernels/moeUtils.h"
|
||||
#include "tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#include <cuda_fp4.h>
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
// Sort
|
||||
using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType;
|
||||
|
||||
std::vector<torch::Tensor> moe_topk_sort_impl(torch::optional<torch::Tensor> const& routing_logits,
|
||||
torch::optional<torch::Tensor> const& routing_bias, torch::optional<torch::Tensor> const& token_selected_experts,
|
||||
torch::optional<torch::Tensor> const& token_final_scales, int64_t const num_experts, int64_t const top_k,
|
||||
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const local_expert_offset,
|
||||
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
|
||||
RoutingMethodType const routing_method_type)
|
||||
{
|
||||
int64_t const num_tokens
|
||||
= token_selected_experts.has_value() ? token_selected_experts->size(0) : routing_logits->size(0);
|
||||
int64_t const max_num_padded_tokens
|
||||
= tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::getMaxPermutedPaddedCount(
|
||||
num_tokens, top_k, local_num_experts, tile_tokens_dim);
|
||||
int64_t const max_num_ctas = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::getMaxNumCtasInBatchDim(
|
||||
num_tokens, top_k, local_num_experts, tile_tokens_dim);
|
||||
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
|
||||
auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias->scalar_type() : torch::kBFloat16;
|
||||
|
||||
auto routing_logits_ptr = routing_logits.has_value() ? routing_logits->data_ptr() : nullptr;
|
||||
auto routing_bias_ptr = routing_bias.has_value() ? routing_bias->data_ptr() : nullptr;
|
||||
auto token_selected_experts_ptr
|
||||
= token_selected_experts.has_value() ? token_selected_experts->data_ptr<int32_t>() : nullptr;
|
||||
auto token_final_scales_ptr = token_final_scales.has_value() ? token_final_scales->data_ptr() : nullptr;
|
||||
|
||||
torch::optional<torch::Tensor> new_token_final_scales;
|
||||
if (token_final_scales_ptr == nullptr)
|
||||
{
|
||||
new_token_final_scales
|
||||
= torch::empty({num_tokens, top_k}, torch::dtype(routing_bias_dtype).device(torch::kCUDA));
|
||||
token_final_scales_ptr = new_token_final_scales->data_ptr();
|
||||
}
|
||||
|
||||
auto expert_indexes = torch::empty({num_tokens, top_k}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto expert_count_histogram
|
||||
= torch::empty({size_of_expert_count_histogram}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto total_num_padded_tokens = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto expanded_idx_to_permuted_idx
|
||||
= torch::empty({num_tokens, top_k}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto permuted_idx_to_expanded_idx
|
||||
= torch::empty({max_num_padded_tokens}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto num_tokens_per_expert = torch::empty({num_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto tile_idx_to_expert_idx = torch::empty({max_num_ctas}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto tile_idx_to_mn_limit = torch::empty({max_num_ctas}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
auto num_non_exiting_tiles = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim);
|
||||
auto const& stream = at::cuda::getCurrentCUDAStream(
|
||||
routing_logits.has_value() ? routing_logits->get_device() : token_selected_experts->get_device());
|
||||
routing_runner.run(routing_logits_ptr, routing_bias_ptr, num_tokens, num_experts, top_k, n_group.value_or(0),
|
||||
topk_group.value_or(0), local_expert_offset, local_num_experts, routed_scaling_factor.value_or(1.0),
|
||||
expert_indexes.data_ptr<int>(), expert_count_histogram.data_ptr<int>(), total_num_padded_tokens.data_ptr<int>(),
|
||||
expanded_idx_to_permuted_idx.data_ptr<int>(), permuted_idx_to_expanded_idx.data_ptr<int>(),
|
||||
nullptr /*permuted_idx_to_token_idx.data_ptr<int>()*/, token_final_scales_ptr, token_selected_experts_ptr,
|
||||
num_tokens_per_expert.data_ptr<int>(), tile_idx_to_expert_idx.data_ptr<int>(),
|
||||
tile_idx_to_mn_limit.data_ptr<int>(), num_non_exiting_tiles.data_ptr<int>(),
|
||||
batchedGemm::trtllm::gen::Dtype::Void /* dtypeElt */, false /* use_routing_scales_on_input */,
|
||||
false /* use_deep_seek_fp8 */, routing_method_type, stream);
|
||||
|
||||
std::vector<torch::Tensor> results{tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles};
|
||||
if (new_token_final_scales.has_value())
|
||||
{
|
||||
results.push_back(new_token_final_scales.value());
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_topk_sort(torch::Tensor const& routing_logits,
|
||||
torch::optional<torch::Tensor> const& routing_bias, int64_t const num_experts, int64_t const top_k,
|
||||
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const local_expert_offset,
|
||||
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
|
||||
int64_t const routing_method_type)
|
||||
{
|
||||
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
|
||||
TORCH_CHECK(routing_logits.size(1) == num_experts, "routing_logits.size(1) must be num_experts.");
|
||||
if (routing_bias.has_value())
|
||||
{
|
||||
TORCH_CHECK(routing_bias->dim() == 1, "routing_bias must be 1D.");
|
||||
TORCH_CHECK(routing_bias->size(0) == num_experts, "routing_bias.size(0) must be num_experts.");
|
||||
}
|
||||
return moe_topk_sort_impl(routing_logits, routing_bias, std::nullopt, std::nullopt, num_experts, top_k, n_group,
|
||||
topk_group, local_expert_offset, local_num_experts, routed_scaling_factor, tile_tokens_dim,
|
||||
static_cast<RoutingMethodType>(routing_method_type));
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_sort(torch::Tensor const& token_selected_experts,
|
||||
torch::Tensor const& token_final_scales, int64_t const num_experts, int64_t const top_k,
|
||||
int64_t const local_expert_offset, int64_t const local_num_experts, int64_t const tile_tokens_dim)
|
||||
{
|
||||
TORCH_CHECK(token_selected_experts.dim() == 2, "token_selected_experts must be 2D.");
|
||||
int64_t const num_tokens = token_selected_experts.size(0);
|
||||
TORCH_CHECK(token_selected_experts.size(1) == top_k, "token_selected_experts.size(1) must be top_k.");
|
||||
TORCH_CHECK(token_final_scales.dim() == 2, "token_final_scales must be 2D.");
|
||||
TORCH_CHECK(token_final_scales.size(0) == num_tokens, "token_final_scales.size(0) must be num_tokens.");
|
||||
TORCH_CHECK(token_final_scales.size(1) == top_k, "token_final_scales.size(1) must be top_k.");
|
||||
return moe_topk_sort_impl(std::nullopt, std::nullopt, token_selected_experts, token_final_scales, num_experts,
|
||||
top_k, std::nullopt, std::nullopt, local_expert_offset, local_num_experts, std::nullopt, tile_tokens_dim,
|
||||
RoutingMethodType::Renormalize);
|
||||
}
|
||||
|
||||
// Permute
|
||||
|
||||
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> moe_permute(torch::Tensor const& input,
|
||||
torch::optional<torch::Tensor> const& input_sf, torch::Tensor const& tile_idx_to_mn_limit,
|
||||
torch::Tensor const& permuted_idx_to_expanded_idx, torch::Tensor const& num_non_exiting_tiles,
|
||||
int64_t const tile_tokens_dim, int64_t const top_k)
|
||||
{
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D.");
|
||||
int64_t const num_tokens = input.size(0);
|
||||
int64_t const hidden_size = input.scalar_type() == torch::kFloat4_e2m1fn_x2 ? input.size(1) * 2 : input.size(1);
|
||||
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.dim() == 1, "tile_idx_to_mn_limit must be 1D.");
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32.");
|
||||
int64_t const num_tiles = tile_idx_to_mn_limit.size(0);
|
||||
TORCH_CHECK(permuted_idx_to_expanded_idx.dim() == 1, "permuted_idx_to_expanded_idx must be 1D.");
|
||||
int64_t const max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0);
|
||||
TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles,
|
||||
"max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles.");
|
||||
TORCH_CHECK(max_num_permuted_tokens >= num_tokens * top_k,
|
||||
"max_num_permuted_tokens must be greater than or equal to num_tokens * top_k.");
|
||||
|
||||
TORCH_CHECK(num_non_exiting_tiles.numel() == 1, "num_non_exiting_tiles must have 1 element.");
|
||||
TORCH_CHECK(num_non_exiting_tiles.scalar_type() == torch::kInt32, "num_non_exiting_tiles must be int32.");
|
||||
|
||||
auto permuted_output = torch::empty(
|
||||
{max_num_permuted_tokens, input.size(1)}, torch::dtype(input.scalar_type()).device(torch::kCUDA));
|
||||
|
||||
void* input_sf_ptr = nullptr;
|
||||
void* permuted_sf_ptr = nullptr;
|
||||
torch::optional<torch::Tensor> permuted_sf;
|
||||
if (input.scalar_type() == torch::kFloat4_e2m1fn_x2)
|
||||
{
|
||||
TORCH_CHECK(input_sf.has_value(), "input_sf is required for NVFP4.");
|
||||
input_sf_ptr = input_sf->data_ptr();
|
||||
int64_t constexpr kSFVecSize = 16;
|
||||
permuted_sf = torch::empty({max_num_permuted_tokens * hidden_size / kSFVecSize},
|
||||
torch::dtype(input_sf->scalar_type()).device(torch::kCUDA));
|
||||
permuted_sf_ptr = permuted_sf->data_ptr();
|
||||
}
|
||||
|
||||
auto const& stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
#define DISPATCH_MOE_PERMUTE(InputType, SFType) \
|
||||
tensorrt_llm::kernels::cute_dsl::moePermute<InputType, SFType>(static_cast<InputType*>(input.data_ptr()), \
|
||||
static_cast<InputType*>(permuted_output.data_ptr()), static_cast<SFType*>(input_sf_ptr), \
|
||||
static_cast<SFType*>(permuted_sf_ptr), tile_idx_to_mn_limit.data_ptr<int32_t>(), \
|
||||
permuted_idx_to_expanded_idx.data_ptr<int32_t>(), num_non_exiting_tiles.data_ptr<int32_t>(), \
|
||||
max_num_permuted_tokens, hidden_size, top_k, tile_tokens_dim, stream)
|
||||
|
||||
if (input.scalar_type() == torch::kHalf)
|
||||
{
|
||||
DISPATCH_MOE_PERMUTE(half, uint8_t);
|
||||
}
|
||||
else if (input.scalar_type() == torch::kBFloat16)
|
||||
{
|
||||
DISPATCH_MOE_PERMUTE(__nv_bfloat16, uint8_t);
|
||||
}
|
||||
else if (input.scalar_type() == torch::kFloat8_e4m3fn)
|
||||
{
|
||||
DISPATCH_MOE_PERMUTE(__nv_fp8_e4m3, uint8_t);
|
||||
}
|
||||
else if (input.scalar_type() == torch::kFloat4_e2m1fn_x2)
|
||||
{
|
||||
DISPATCH_MOE_PERMUTE(__nv_fp4_e2m1, uint8_t);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(false, "Unsupported input dtype: ", input.scalar_type());
|
||||
}
|
||||
|
||||
#undef DISPATCH_MOE_PERMUTE
|
||||
|
||||
return {permuted_output, permuted_sf};
|
||||
}
|
||||
|
||||
// Unpermute
|
||||
|
||||
torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor const& expanded_idx_to_permuted_idx,
|
||||
torch::Tensor const& topk_scales)
|
||||
{
|
||||
TORCH_CHECK(permuted_input.dim() == 2, "permuted_input must be 2D.");
|
||||
int64_t const max_num_permuted_tokens = permuted_input.size(0);
|
||||
int64_t const hidden_size = permuted_input.size(1);
|
||||
TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D.");
|
||||
int64_t const num_tokens = expanded_idx_to_permuted_idx.size(0);
|
||||
int64_t const top_k = expanded_idx_to_permuted_idx.size(1);
|
||||
TORCH_CHECK(topk_scales.dim() == 2, "topk_scales must be 2D.");
|
||||
TORCH_CHECK(topk_scales.size(0) == num_tokens, "topk_scales.size(0) must be num_tokens.");
|
||||
TORCH_CHECK(topk_scales.size(1) == top_k, "topk_scales.size(1) must be top_k.");
|
||||
|
||||
TORCH_CHECK(max_num_permuted_tokens >= num_tokens * top_k,
|
||||
"max_num_permuted_tokens must be greater than or equal to num_tokens * top_k.");
|
||||
|
||||
auto output
|
||||
= torch::empty({num_tokens, hidden_size}, torch::dtype(permuted_input.scalar_type()).device(torch::kCUDA));
|
||||
auto const& stream = at::cuda::getCurrentCUDAStream(permuted_input.get_device());
|
||||
|
||||
#define DISPATCH_MOE_UNPERMUTE(InputType, TopKScaleType) \
|
||||
tensorrt_llm::kernels::cute_dsl::moeUnpermute<InputType>(static_cast<InputType*>(permuted_input.data_ptr()), \
|
||||
static_cast<InputType*>(output.data_ptr()), expanded_idx_to_permuted_idx.data_ptr<int32_t>(), \
|
||||
static_cast<TopKScaleType*>(topk_scales.data_ptr()), num_tokens, hidden_size, top_k, stream)
|
||||
|
||||
if (permuted_input.scalar_type() == torch::kHalf && topk_scales.scalar_type() == torch::kFloat)
|
||||
{
|
||||
DISPATCH_MOE_UNPERMUTE(half, float);
|
||||
}
|
||||
else if (permuted_input.scalar_type() == torch::kHalf && topk_scales.scalar_type() == torch::kHalf)
|
||||
{
|
||||
DISPATCH_MOE_UNPERMUTE(half, half);
|
||||
}
|
||||
else if (permuted_input.scalar_type() == torch::kBFloat16 && topk_scales.scalar_type() == torch::kFloat)
|
||||
{
|
||||
DISPATCH_MOE_UNPERMUTE(__nv_bfloat16, float);
|
||||
}
|
||||
else if (permuted_input.scalar_type() == torch::kBFloat16 && topk_scales.scalar_type() == torch::kBFloat16)
|
||||
{
|
||||
DISPATCH_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(false, "Unsupported input dtype: ", permuted_input.scalar_type(),
|
||||
" and/or topk_scales dtype: ", topk_scales.scalar_type());
|
||||
}
|
||||
|
||||
#undef DISPATCH_MOE_UNPERMUTE
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
// Activation
|
||||
|
||||
torch::Tensor moe_swiglu(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit,
|
||||
torch::Tensor const& num_non_exiting_tiles, int64_t const tile_tokens_dim)
|
||||
{
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D.");
|
||||
TORCH_CHECK(input.size(1) % 2 == 0, "input.size(1) must be even.");
|
||||
int64_t const max_num_permuted_tokens = input.size(0);
|
||||
int64_t const interm_size = input.size(1) / 2;
|
||||
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.dim() == 1, "tile_idx_to_mn_limit must be 1D.");
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32.");
|
||||
int64_t const num_tiles = tile_idx_to_mn_limit.size(0);
|
||||
TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles,
|
||||
"max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles.");
|
||||
|
||||
TORCH_CHECK(num_non_exiting_tiles.numel() == 1, "num_non_exiting_tiles must have 1 element.");
|
||||
TORCH_CHECK(num_non_exiting_tiles.scalar_type() == torch::kInt32, "num_non_exiting_tiles must be int32.");
|
||||
|
||||
auto output
|
||||
= torch::empty({max_num_permuted_tokens, interm_size}, torch::dtype(input.scalar_type()).device(torch::kCUDA));
|
||||
tensorrt_llm::kernels::cutlass_kernels::ActivationParams activation_params{
|
||||
tensorrt_llm::kernels::cutlass_kernels::ActivationType::Swiglu};
|
||||
|
||||
auto const& stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
#define DISPATCH_MOE_ACTIVATION(InputType, OutputType, SFType) \
|
||||
tensorrt_llm::kernels::cute_dsl::moeActivation<InputType, OutputType, SFType>( \
|
||||
static_cast<InputType*>(input.data_ptr()), static_cast<OutputType*>(output.data_ptr()), nullptr, nullptr, \
|
||||
tile_idx_to_mn_limit.data_ptr<int32_t>(), num_non_exiting_tiles.data_ptr<int32_t>(), activation_params, \
|
||||
max_num_permuted_tokens, interm_size, tile_tokens_dim, stream)
|
||||
|
||||
if (input.scalar_type() == torch::kHalf)
|
||||
{
|
||||
DISPATCH_MOE_ACTIVATION(half, half, uint8_t);
|
||||
}
|
||||
else if (input.scalar_type() == torch::kBFloat16)
|
||||
{
|
||||
DISPATCH_MOE_ACTIVATION(__nv_bfloat16, __nv_bfloat16, uint8_t);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(false, "Unsupported input dtype: ", input.scalar_type());
|
||||
}
|
||||
|
||||
#undef DISPATCH_MOE_ACTIVATION
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> moe_swiglu_nvfp4_quantize(torch::Tensor const& input,
|
||||
torch::Tensor const& global_sf, torch::Tensor const& tile_idx_to_mn_limit,
|
||||
torch::Tensor const& num_non_exiting_tiles, int64_t const tile_tokens_dim)
|
||||
{
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D.");
|
||||
TORCH_CHECK(input.size(1) % 2 == 0, "input.size(1) must be even.");
|
||||
int64_t const max_num_permuted_tokens = input.size(0);
|
||||
int64_t const interm_size = input.size(1) / 2;
|
||||
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.dim() == 1, "tile_idx_to_mn_limit must be 1D.");
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32.");
|
||||
int64_t const num_tiles = tile_idx_to_mn_limit.size(0);
|
||||
TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles,
|
||||
"max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles.");
|
||||
|
||||
TORCH_CHECK(global_sf.numel() == 1, "global_sf must have 1 element.");
|
||||
TORCH_CHECK(global_sf.scalar_type() == torch::kFloat32, "global_sf must be float32.");
|
||||
TORCH_CHECK(num_non_exiting_tiles.numel() == 1, "num_non_exiting_tiles must have 1 element.");
|
||||
TORCH_CHECK(num_non_exiting_tiles.scalar_type() == torch::kInt32, "num_non_exiting_tiles must be int32.");
|
||||
|
||||
auto output = torch::empty(
|
||||
{max_num_permuted_tokens, interm_size / 2}, torch::dtype(torch::kFloat4_e2m1fn_x2).device(torch::kCUDA));
|
||||
int64_t constexpr kSFVecSize = 16;
|
||||
auto output_sf = torch::empty(
|
||||
{max_num_permuted_tokens * interm_size / kSFVecSize}, torch::dtype(torch::kUInt8).device(torch::kCUDA));
|
||||
|
||||
tensorrt_llm::kernels::cutlass_kernels::ActivationParams activation_params{
|
||||
tensorrt_llm::kernels::cutlass_kernels::ActivationType::Swiglu};
|
||||
|
||||
auto const& stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
#define DISPATCH_MOE_ACTIVATION(InputType, OutputType, SFType) \
|
||||
tensorrt_llm::kernels::cute_dsl::moeActivation<InputType, OutputType, SFType>( \
|
||||
static_cast<InputType*>(input.data_ptr()), static_cast<OutputType*>(output.data_ptr()), \
|
||||
global_sf.data_ptr<float>(), static_cast<SFType*>(output_sf.data_ptr()), \
|
||||
tile_idx_to_mn_limit.data_ptr<int32_t>(), num_non_exiting_tiles.data_ptr<int32_t>(), activation_params, \
|
||||
max_num_permuted_tokens, interm_size, tile_tokens_dim, stream)
|
||||
|
||||
if (input.scalar_type() == torch::kHalf)
|
||||
{
|
||||
DISPATCH_MOE_ACTIVATION(half, __nv_fp4_e2m1, uint8_t);
|
||||
}
|
||||
else if (input.scalar_type() == torch::kBFloat16)
|
||||
{
|
||||
DISPATCH_MOE_ACTIVATION(__nv_bfloat16, __nv_fp4_e2m1, uint8_t);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(false, "Unsupported input dtype: ", input.scalar_type());
|
||||
}
|
||||
|
||||
#undef DISPATCH_MOE_ACTIVATION
|
||||
|
||||
return {output, output_sf};
|
||||
}
|
||||
|
||||
torch::Tensor moe_gelu(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit,
|
||||
torch::Tensor const& num_non_exiting_tiles, int64_t const tile_tokens_dim)
|
||||
{
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D.");
|
||||
int64_t const max_num_permuted_tokens = input.size(0);
|
||||
int64_t const interm_size = input.size(1);
|
||||
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.dim() == 1, "tile_idx_to_mn_limit must be 1D.");
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32.");
|
||||
int64_t const num_tiles = tile_idx_to_mn_limit.size(0);
|
||||
TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles,
|
||||
"max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles.");
|
||||
|
||||
TORCH_CHECK(num_non_exiting_tiles.numel() == 1, "num_non_exiting_tiles must have 1 element.");
|
||||
TORCH_CHECK(num_non_exiting_tiles.scalar_type() == torch::kInt32, "num_non_exiting_tiles must be int32.");
|
||||
|
||||
auto output
|
||||
= torch::empty({max_num_permuted_tokens, interm_size}, torch::dtype(input.scalar_type()).device(torch::kCUDA));
|
||||
tensorrt_llm::kernels::cutlass_kernels::ActivationParams activation_params{
|
||||
tensorrt_llm::kernels::cutlass_kernels::ActivationType::Gelu};
|
||||
|
||||
auto const& stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
#define DISPATCH_MOE_ACTIVATION(InputType, OutputType, SFType) \
|
||||
tensorrt_llm::kernels::cute_dsl::moeActivation<InputType, OutputType, SFType>( \
|
||||
static_cast<InputType*>(input.data_ptr()), static_cast<OutputType*>(output.data_ptr()), nullptr, nullptr, \
|
||||
tile_idx_to_mn_limit.data_ptr<int32_t>(), num_non_exiting_tiles.data_ptr<int32_t>(), activation_params, \
|
||||
max_num_permuted_tokens, interm_size, tile_tokens_dim, stream)
|
||||
|
||||
if (input.scalar_type() == torch::kHalf)
|
||||
{
|
||||
DISPATCH_MOE_ACTIVATION(half, half, uint8_t);
|
||||
}
|
||||
else if (input.scalar_type() == torch::kBFloat16)
|
||||
{
|
||||
DISPATCH_MOE_ACTIVATION(__nv_bfloat16, __nv_bfloat16, uint8_t);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(false, "Unsupported input dtype: ", input.scalar_type());
|
||||
}
|
||||
|
||||
#undef DISPATCH_MOE_ACTIVATION
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"moe_topk_sort(Tensor routing_logits, Tensor? routing_bias, int num_experts, int top_k, int? n_group, "
|
||||
"int? topk_group, int local_expert_offset, int local_num_experts, float? routed_scaling_factor, int "
|
||||
"tile_tokens_dim, int routing_method_type) -> Tensor[]");
|
||||
m.def(
|
||||
"moe_sort(Tensor token_selected_experts, Tensor token_final_scales, int num_experts, int top_k, "
|
||||
"int local_expert_offset, int local_num_experts, int tile_tokens_dim) -> Tensor[]");
|
||||
m.def(
|
||||
"moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, "
|
||||
"Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)");
|
||||
m.def("moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor");
|
||||
m.def(
|
||||
"moe_swiglu(Tensor input, Tensor tile_idx_to_mn_limit, Tensor num_non_exiting_tiles, "
|
||||
"int tile_tokens_dim) -> Tensor");
|
||||
m.def(
|
||||
"moe_swiglu_nvfp4_quantize(Tensor input, Tensor global_sf, Tensor tile_idx_to_mn_limit, Tensor "
|
||||
"num_non_exiting_tiles, int tile_tokens_dim) -> (Tensor, Tensor)");
|
||||
m.def(
|
||||
"moe_gelu(Tensor input, Tensor tile_idx_to_mn_limit, Tensor num_non_exiting_tiles, "
|
||||
"int tile_tokens_dim) -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("moe_topk_sort", &torch_ext::moe_topk_sort);
|
||||
m.impl("moe_sort", &torch_ext::moe_sort);
|
||||
m.impl("moe_permute", &torch_ext::moe_permute);
|
||||
m.impl("moe_unpermute", &torch_ext::moe_unpermute);
|
||||
m.impl("moe_swiglu", &torch_ext::moe_swiglu);
|
||||
m.impl("moe_swiglu_nvfp4_quantize", &torch_ext::moe_swiglu_nvfp4_quantize);
|
||||
m.impl("moe_gelu", &torch_ext::moe_gelu);
|
||||
}
|
||||
@ -236,8 +236,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"fp4_quantize(Tensor input, Tensor? globalScale, int sfVecSize, bool sfUseUE8M0=False, bool "
|
||||
"isSfSwizzledLayout=True) "
|
||||
"-> (Tensor, Tensor)");
|
||||
"isSfSwizzledLayout=True) -> (Tensor, Tensor)");
|
||||
m.def("calculate_nvfp4_global_scale(Tensor input, Tensor? tokensPerBatch) -> Tensor");
|
||||
}
|
||||
|
||||
|
||||
@ -341,10 +341,8 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
"moe_finalize_scale_op(Tensor gemm2_output, Tensor? biases, Tensor unpermuted_final_scales, Tensor "
|
||||
"unpermuted_row_to_permuted_row, Tensor permuted_row_to_unpermuted_row, Tensor token_selected_experts, Tensor "
|
||||
"expert_first_token_offset_tensor, bool enable_alltoall, SymInt num_rows, SymInt hidden_size, SymInt "
|
||||
"unpadded_hidden_size, int "
|
||||
"experts_per_token, int "
|
||||
"num_experts_per_node, int tp_size, int tp_rank, int ep_size, int ep_rank)"
|
||||
"-> (Tensor)");
|
||||
"unpadded_hidden_size, int experts_per_token, int num_experts_per_node, int tp_size, int tp_rank, int ep_size, "
|
||||
"int ep_rank) -> (Tensor)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
|
||||
@ -68,7 +68,7 @@ triton==3.5.0; platform_machine == "x86_64"
|
||||
tiktoken
|
||||
blobfile
|
||||
openai-harmony==0.0.4
|
||||
nvidia-cutlass-dsl==4.2.1; python_version >= "3.10"
|
||||
nvidia-cutlass-dsl==4.3.0.dev0; python_version >= "3.10"
|
||||
plotly
|
||||
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
|
||||
partial_json_parser
|
||||
|
||||
@ -932,9 +932,9 @@ class AutoTuner:
|
||||
dynamic_dims = []
|
||||
|
||||
for spec in tuning_config.dynamic_tensor_specs:
|
||||
assert inspect.isfunction(spec.gen_tuning_buckets) or isinstance(spec.gen_tuning_buckets, (list, tuple)), \
|
||||
assert callable(spec.gen_tuning_buckets) or isinstance(spec.gen_tuning_buckets, (list, tuple)), \
|
||||
"The given dynamic dimension must provide a opt value generation function or a list of opt values"
|
||||
if inspect.isfunction(spec.gen_tuning_buckets):
|
||||
if callable(spec.gen_tuning_buckets):
|
||||
if tuning_config.tune_max_num_tokens is None:
|
||||
# Use the current input size as the opt value
|
||||
opt_shapes = spec.gen_tuning_buckets(
|
||||
@ -1067,7 +1067,11 @@ class AutoTuner:
|
||||
# One solution is to manituplate the tensor content to make it more like the real data
|
||||
# during the tuning process. This can by controlled in the preparation phase by the runner.
|
||||
# It must not use all zero tensors. Otherwise the timing results become unreliable.
|
||||
return torch.randint(-5, 5, shapes, device=device).to(dtype)
|
||||
if dtype == torch.float4_e2m1fn_x2:
|
||||
return torch.randint(-5, 5, shapes,
|
||||
device=device).to(torch.uint8).view(dtype)
|
||||
else:
|
||||
return torch.randint(-5, 5, shapes, device=device).to(dtype)
|
||||
|
||||
def _prepare_input_tensors(
|
||||
self, profile: OptimizationProfile,
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from .cute_dsl_custom_ops import GroupedGemmInputsHelper
|
||||
|
||||
|
||||
def _register_fake():
|
||||
@ -485,6 +486,183 @@ def _register_fake():
|
||||
return gemm2_output.new_empty((num_rows_val, unpadded_hidden_size_val),
|
||||
dtype=gemm2_output.dtype)
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_topk_sort")
|
||||
def _(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
n_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routed_scaling_factor: Optional[float],
|
||||
tile_tokens_dim: int,
|
||||
routing_method_type: int,
|
||||
) -> List[torch.Tensor]:
|
||||
helper = GroupedGemmInputsHelper(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
num_local_experts=local_num_experts,
|
||||
local_expert_offset=local_expert_offset,
|
||||
tile_size=tile_tokens_dim,
|
||||
)
|
||||
num_tokens = routing_logits.size(0)
|
||||
device = routing_logits.device
|
||||
routing_bias_dtype = torch.bfloat16 if routing_bias is None else routing_bias.dtype
|
||||
max_num_tiles = helper.get_max_num_tiles(num_tokens)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
tile_idx_to_expert_idx = torch.empty((max_num_tiles, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
tile_idx_to_mn_limit = torch.empty((max_num_tiles, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
expanded_idx_to_permuted_idx = torch.empty((num_tokens, top_k),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
permuted_idx_to_expanded_idx = torch.empty((max_num_permuted_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
total_num_padded_tokens = torch.empty((1, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
num_non_exiting_tiles = torch.empty((1, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
new_token_final_scales = torch.empty((num_tokens, top_k),
|
||||
dtype=routing_bias_dtype,
|
||||
device=device)
|
||||
return [
|
||||
tile_idx_to_expert_idx, tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx,
|
||||
total_num_padded_tokens, num_non_exiting_tiles,
|
||||
new_token_final_scales
|
||||
]
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_sort")
|
||||
def _(
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
tile_tokens_dim: int,
|
||||
) -> List[torch.Tensor]:
|
||||
helper = GroupedGemmInputsHelper(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
num_local_experts=local_num_experts,
|
||||
local_expert_offset=local_expert_offset,
|
||||
tile_size=tile_tokens_dim,
|
||||
)
|
||||
num_tokens = token_selected_experts.size(0)
|
||||
device = token_selected_experts.device
|
||||
max_num_tiles = helper.get_max_num_tiles(num_tokens)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
tile_idx_to_expert_idx = torch.empty((max_num_tiles, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
tile_idx_to_mn_limit = torch.empty((max_num_tiles, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
expanded_idx_to_permuted_idx = torch.empty((num_tokens, top_k),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
permuted_idx_to_expanded_idx = torch.empty((max_num_permuted_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
total_num_padded_tokens = torch.empty((1, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
num_non_exiting_tiles = torch.empty((1, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
return [
|
||||
tile_idx_to_expert_idx, tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx,
|
||||
total_num_padded_tokens, num_non_exiting_tiles
|
||||
]
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_permute")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
input_sf: Optional[torch.Tensor],
|
||||
tile_idx_to_mn_limit: torch.Tensor,
|
||||
permuted_idx_to_expanded_idx: torch.Tensor,
|
||||
num_non_exiting_tiles: torch.Tensor,
|
||||
tile_tokens_dim: int,
|
||||
top_k: int,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0)
|
||||
permuted_output = torch.empty((max_num_permuted_tokens, input.size(1)),
|
||||
dtype=input.dtype,
|
||||
device=input.device)
|
||||
if input.dtype == torch.float4_e2m1fn_x2:
|
||||
hidden_size = input.size(1) * 2
|
||||
sf_vec_size = 16
|
||||
permuted_sf = torch.empty(
|
||||
(max_num_permuted_tokens * hidden_size // sf_vec_size, ),
|
||||
dtype=input_sf.dtype,
|
||||
device=input.device)
|
||||
else:
|
||||
permuted_sf = None
|
||||
return permuted_output, permuted_sf
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_unpermute")
|
||||
def _(
|
||||
permuted_input: torch.Tensor,
|
||||
expanded_idx_to_permuted_idx: torch.Tensor,
|
||||
topk_scales: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = expanded_idx_to_permuted_idx.size(0)
|
||||
output = torch.empty((num_tokens, permuted_input.size(1)),
|
||||
dtype=permuted_input.dtype,
|
||||
device=permuted_input.device)
|
||||
return output
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_swiglu")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
tile_idx_to_mn_limit: torch.Tensor,
|
||||
num_non_exiting_tiles: torch.Tensor,
|
||||
tile_tokens_dim: int,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty((input.size(0), input.size(1) // 2),
|
||||
dtype=input.dtype,
|
||||
device=input.device)
|
||||
return output
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_swiglu_nvfp4_quantize")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
global_sf: float,
|
||||
tile_idx_to_mn_limit: torch.Tensor,
|
||||
num_non_exiting_tiles: torch.Tensor,
|
||||
tile_tokens_dim: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
max_num_permuted_tokens = input.size(0)
|
||||
interm_size = input.size(1) // 2
|
||||
sf_vec_size = 16
|
||||
output = torch.empty((max_num_permuted_tokens, interm_size // 2),
|
||||
dtype=torch.float4_e2m1fn_x2,
|
||||
device=input.device)
|
||||
output_sf = torch.empty(
|
||||
(max_num_permuted_tokens * interm_size // sf_vec_size, ),
|
||||
dtype=torch.uint8,
|
||||
device=input.device)
|
||||
return output, output_sf
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_gelu")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
tile_idx_to_mn_limit: torch.Tensor,
|
||||
num_non_exiting_tiles: torch.Tensor,
|
||||
tile_tokens_dim: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(input)
|
||||
|
||||
@torch.library.register_fake("trtllm::allgather_list")
|
||||
def allgather_list(input_list, sizes, group):
|
||||
assert len(input_list) > 0
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
from typing import List, Tuple
|
||||
import itertools
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton # type: ignore[import]
|
||||
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.math_utils import pad_up
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from ...math_utils import pad_up
|
||||
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
|
||||
OptimizationProfile, TunableRunner, TuningConfig)
|
||||
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
|
||||
@ -13,20 +12,22 @@ from ..utils import (fp4_scale_infer_shape,
|
||||
get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2)
|
||||
|
||||
try:
|
||||
from cuda.bindings import driver as cuda
|
||||
except ImportError:
|
||||
from cuda import cuda
|
||||
|
||||
if IS_CUTLASS_DSL_AVAILABLE:
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import (
|
||||
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import (
|
||||
Sm100BlockScaledPersistentDenseGemmKernel,
|
||||
Sm100BlockScaledPersistentDenseGemmKernelWrapper)
|
||||
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.utils import make_ptr
|
||||
|
||||
try:
|
||||
from cuda.bindings import driver as cuda
|
||||
except ImportError:
|
||||
from cuda import cuda
|
||||
from ..cute_dsl_kernels.blackwell.grouped_blockscaled_gemm_persistent import \
|
||||
Sm100BlockScaledPersistentGroupedGemmKernel
|
||||
from ..cute_dsl_kernels.blackwell.utils import make_ptr
|
||||
|
||||
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
|
||||
kernel_dict = dict()
|
||||
@ -46,7 +47,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
|
||||
if get_sm_version() != 100:
|
||||
raise ValueError(
|
||||
f"SM version {get_sm_version()} is not supported for CuteDSLNVFP4BlackwellLinear, it only supports SM 100"
|
||||
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
|
||||
)
|
||||
|
||||
# rewrite the hash function because the value of self.alpha doesn't affect the tactic.
|
||||
@ -209,7 +210,6 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
gemm_wrapper_func = Sm100BlockScaledPersistentDenseGemmKernelWrapper
|
||||
CACHE_KEY = (
|
||||
sf_vec_size,
|
||||
mma_tiler_mn,
|
||||
@ -236,7 +236,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
kernel_sf_n = sf_n
|
||||
|
||||
if CACHE_KEY not in CuteDSLNVFP4BlackwellLinear.kernel_dict:
|
||||
gemm = gemm_wrapper_func(
|
||||
gemm = Sm100BlockScaledPersistentDenseGemmKernelWrapper(
|
||||
sf_vec_size,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
@ -337,3 +337,563 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
# output is fixed as bf16
|
||||
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
|
||||
return ret
|
||||
|
||||
class GroupedGemmInputsHelper:
|
||||
|
||||
def __init__(self, num_experts: int, top_k: int, num_local_experts: int,
|
||||
local_expert_offset: int, tile_size: int):
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.num_local_experts = num_local_experts
|
||||
self.local_expert_offset = local_expert_offset
|
||||
self.tile_size = tile_size
|
||||
|
||||
def get_max_num_tiles(self, num_tokens: int) -> int:
|
||||
num_expanded_tokens = num_tokens * self.top_k
|
||||
if num_expanded_tokens <= self.num_local_experts:
|
||||
return num_expanded_tokens
|
||||
return (
|
||||
num_expanded_tokens +
|
||||
(self.tile_size - 1) * self.num_local_experts) // self.tile_size
|
||||
|
||||
def get_max_num_permuted_tokens(self, num_tokens: int) -> int:
|
||||
return self.get_max_num_tiles(num_tokens) * self.tile_size
|
||||
|
||||
def infer_num_tokens(self, max_num_permuted_tokens: int) -> int:
|
||||
max_num_tiles = max_num_permuted_tokens // self.tile_size
|
||||
if max_num_tiles >= self.num_local_experts:
|
||||
return (max_num_permuted_tokens - (self.tile_size - 1) *
|
||||
(self.num_local_experts - 1)) // self.top_k
|
||||
return max_num_tiles // self.top_k
|
||||
|
||||
def gen_tuning_buckets(self, max_num_tokens: int) -> List[int]:
|
||||
buckets = get_last_power_of_2_num_tokens_buckets(
|
||||
self.infer_num_tokens(max_num_tokens))
|
||||
return sorted(
|
||||
list(set(self.get_max_num_permuted_tokens(x) for x in buckets)))
|
||||
|
||||
def map_to_tuning_buckets(self, x: int) -> int:
|
||||
return self.get_max_num_permuted_tokens(
|
||||
last_positive_power_of_2(self.infer_num_tokens(x)))
|
||||
|
||||
def infer_tile_idx_to_group_idx_shape(
|
||||
self, input_shapes: List[torch.Size]) -> int:
|
||||
return input_shapes[0][0] // self.tile_size
|
||||
|
||||
def inputs_pre_hook(self,
|
||||
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles = inputs
|
||||
num_tokens = self.infer_num_tokens(a.size(0))
|
||||
average_num_tokens_per_expert = num_tokens * self.top_k / self.num_experts
|
||||
balance = 0
|
||||
tile_idx_to_group_idx_list = []
|
||||
for i in range(self.num_local_experts):
|
||||
balance += average_num_tokens_per_expert
|
||||
if balance <= 1e-3:
|
||||
continue
|
||||
curr_num_tokens = int(balance) + 1
|
||||
curr_num_tiles = (curr_num_tokens + self.tile_size -
|
||||
1) // self.tile_size
|
||||
tile_idx_to_group_idx_list.extend([i] * curr_num_tiles)
|
||||
balance -= curr_num_tokens
|
||||
|
||||
num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list)
|
||||
assert 0 < num_non_exiting_tiles_val <= tile_idx_to_group_idx.size(
|
||||
0)
|
||||
|
||||
tile_idx_to_group_idx_list.extend(
|
||||
[int(-1e9)] *
|
||||
(tile_idx_to_group_idx.size(0) - num_non_exiting_tiles_val))
|
||||
tile_idx_to_group_idx = torch.tensor(
|
||||
tile_idx_to_group_idx_list,
|
||||
dtype=tile_idx_to_group_idx.dtype,
|
||||
device=tile_idx_to_group_idx.device)
|
||||
num_non_exiting_tiles = torch.tensor(
|
||||
[num_non_exiting_tiles_val],
|
||||
dtype=num_non_exiting_tiles.dtype,
|
||||
device=num_non_exiting_tiles.device)
|
||||
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles
|
||||
|
||||
class Sm100BlockScaledPersistentGroupedGemmRunner(TunableRunner):
|
||||
kernel_cache = dict()
|
||||
tuning_config_cache = dict()
|
||||
|
||||
def __init__(self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_local_experts: int,
|
||||
local_expert_offset: int,
|
||||
tile_size: int,
|
||||
output_dtype: torch.dtype,
|
||||
scaling_vector_size: int = 16):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.num_local_experts = num_local_experts
|
||||
self.local_expert_offset = local_expert_offset
|
||||
self.tile_size = tile_size
|
||||
|
||||
assert output_dtype == torch.bfloat16
|
||||
self.output_dtype = output_dtype
|
||||
self.scaling_vector_size = scaling_vector_size
|
||||
|
||||
if get_sm_version() != 100:
|
||||
raise ValueError(
|
||||
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
|
||||
)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[Tuple[int, int]]:
|
||||
a, b, *_ = inputs
|
||||
m, k = a.size(0), a.size(1) * 2
|
||||
l, n = b.size(0), b.size(1)
|
||||
|
||||
# TODO: Add full shmoo
|
||||
mma_tiler_mn_candidates = [(128, 128), (128, 256)]
|
||||
cluster_shape_mn_candidates = [(1, 1), (1, 2)]
|
||||
|
||||
valid_tactics = []
|
||||
for mma_tiler_mn, cluster_shape_mn in itertools.product(
|
||||
mma_tiler_mn_candidates, cluster_shape_mn_candidates):
|
||||
if Sm100BlockScaledPersistentGroupedGemmKernel.can_implement(
|
||||
ab_dtype=cutlass.Float4E2M1FN,
|
||||
sf_dtype=cutlass.Float8E4M3FN,
|
||||
sf_vec_size=self.scaling_vector_size,
|
||||
acc_dtype=cutlass.Float32,
|
||||
c_dtype=cutlass.BFloat16,
|
||||
use_2cta_instrs=False,
|
||||
mma_tiler_mn=mma_tiler_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
l=l,
|
||||
a_major="k",
|
||||
b_major="k",
|
||||
c_major="n",
|
||||
m_aligned=self.tile_size,
|
||||
):
|
||||
valid_tactics.append((mma_tiler_mn, cluster_shape_mn))
|
||||
|
||||
assert len(valid_tactics) > 0
|
||||
return valid_tactics
|
||||
|
||||
def get_tuning_config(self) -> TuningConfig:
|
||||
key = hash(self)
|
||||
if key not in self.__class__.tuning_config_cache:
|
||||
helper = GroupedGemmInputsHelper(self.num_experts, self.top_k,
|
||||
self.num_local_experts,
|
||||
self.local_expert_offset,
|
||||
self.tile_size)
|
||||
self.__class__.tuning_config_cache[key] = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, helper.gen_tuning_buckets,
|
||||
helper.map_to_tuning_buckets), ),
|
||||
constraint_specs=(
|
||||
ConstraintSpec(2, 0, fp4_scale_infer_shape),
|
||||
ConstraintSpec(
|
||||
5, 0, helper.infer_tile_idx_to_group_idx_shape)),
|
||||
inputs_pre_hook=helper.inputs_pre_hook,
|
||||
)
|
||||
return self.__class__.tuning_config_cache[key]
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor],
|
||||
tactic: Optional[tuple]) -> torch.Tensor:
|
||||
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles = inputs
|
||||
assert a.dtype == torch.float4_e2m1fn_x2
|
||||
assert a.dim() == 2
|
||||
assert b.dtype == torch.float4_e2m1fn_x2
|
||||
assert b.dim() == 3
|
||||
assert a_sf.dtype == torch.uint8
|
||||
assert a_sf.dim() == 1
|
||||
assert b_sf.dtype == torch.uint8
|
||||
assert b_sf.dim() == 3
|
||||
assert alpha.dtype == torch.float32
|
||||
assert alpha.dim() == 1
|
||||
|
||||
m, k = a.size(0), a.size(1) * 2
|
||||
l, n = b.size(0), b.size(1)
|
||||
scale_k = k // self.scaling_vector_size
|
||||
assert m % self.tile_size == 0
|
||||
assert k % (self.scaling_vector_size * 4) == 0
|
||||
assert b.size(2) * 2 == k
|
||||
assert a_sf.size(0) == m * scale_k
|
||||
assert b_sf.size(0) == l
|
||||
assert b_sf.size(1) == n
|
||||
assert b_sf.size(2) == scale_k
|
||||
assert alpha.size(0) == l
|
||||
|
||||
num_tiles = m // self.tile_size
|
||||
assert tile_idx_to_group_idx.dtype == torch.int32
|
||||
assert tile_idx_to_group_idx.size() == (num_tiles, )
|
||||
assert num_non_exiting_tiles.dtype == torch.int32
|
||||
assert num_non_exiting_tiles.size() == (1, )
|
||||
|
||||
c = torch.empty(m, n, dtype=self.output_dtype, device=a.device)
|
||||
|
||||
a_ptr = make_ptr(cutlass.Float4E2M1FN,
|
||||
a.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=32)
|
||||
b_ptr = make_ptr(cutlass.Float4E2M1FN,
|
||||
b.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=32)
|
||||
a_sf_ptr = make_ptr(cutlass.Float8E4M3FN,
|
||||
a_sf.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16)
|
||||
b_sf_ptr = make_ptr(cutlass.Float8E4M3FN,
|
||||
b_sf.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16)
|
||||
alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(),
|
||||
cute.AddressSpace.gmem)
|
||||
tile_idx_to_group_idx_ptr = make_ptr(
|
||||
cutlass.Int32, tile_idx_to_group_idx.data_ptr(),
|
||||
cute.AddressSpace.gmem)
|
||||
num_non_exiting_tiles_ptr = make_ptr(
|
||||
cutlass.Int32, num_non_exiting_tiles.data_ptr(),
|
||||
cute.AddressSpace.gmem)
|
||||
c_ptr = make_ptr(cutlass.BFloat16,
|
||||
c.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16)
|
||||
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
if isinstance(tactic, tuple):
|
||||
mma_tiler_mn, cluster_shape_mn = tactic
|
||||
else:
|
||||
mma_tiler_mn, cluster_shape_mn = (128, 128), (1, 1)
|
||||
|
||||
cache_key = (self.scaling_vector_size, self.tile_size, mma_tiler_mn,
|
||||
cluster_shape_mn)
|
||||
if cache_key not in self.__class__.kernel_cache:
|
||||
gemm = Sm100BlockScaledPersistentGroupedGemmKernel(
|
||||
sf_vec_size=self.scaling_vector_size,
|
||||
acc_dtype=cutlass.Float32,
|
||||
use_2cta_instrs=False,
|
||||
mma_tiler_mn=mma_tiler_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
)
|
||||
|
||||
compiled_gemm = cute.compile(
|
||||
gemm.wrapper,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
alpha_ptr,
|
||||
tile_idx_to_group_idx_ptr,
|
||||
num_non_exiting_tiles_ptr,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
l,
|
||||
tile_size=self.tile_size,
|
||||
scaling_vector_size=self.scaling_vector_size,
|
||||
max_active_clusters=16,
|
||||
stream=stream,
|
||||
)
|
||||
self.__class__.kernel_cache[cache_key] = compiled_gemm
|
||||
else:
|
||||
compiled_gemm = self.__class__.kernel_cache[cache_key]
|
||||
|
||||
compiled_gemm(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
alpha_ptr,
|
||||
tile_idx_to_group_idx_ptr,
|
||||
num_non_exiting_tiles_ptr,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
stream=stream,
|
||||
)
|
||||
return c
|
||||
|
||||
@torch.library.custom_op("trtllm::cute_dsl_nvfp4_grouped_gemm_blackwell",
|
||||
mutates_args=(),
|
||||
device_types="cuda")
|
||||
def cute_dsl_nvfp4_grouped_gemm_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
tile_idx_to_group_idx: torch.Tensor,
|
||||
num_non_exiting_tiles: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_local_experts: int,
|
||||
local_expert_offset: int,
|
||||
tile_size: int,
|
||||
output_dtype: torch.dtype,
|
||||
scaling_vector_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = Sm100BlockScaledPersistentGroupedGemmRunner(
|
||||
num_experts, top_k, num_local_experts, local_expert_offset,
|
||||
tile_size, output_dtype, scaling_vector_size)
|
||||
inputs = [
|
||||
input, weight, input_scale, weight_scale, alpha,
|
||||
tile_idx_to_group_idx, num_non_exiting_tiles
|
||||
]
|
||||
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_nvfp4_grouped_gemm_blackwell",
|
||||
[runner],
|
||||
runner.get_tuning_config(),
|
||||
inputs,
|
||||
)
|
||||
output = runner(inputs, tactic=best_tactic)
|
||||
return output
|
||||
|
||||
@torch.library.register_fake(
|
||||
"trtllm::cute_dsl_nvfp4_grouped_gemm_blackwell")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
tile_idx_to_group_idx: torch.Tensor,
|
||||
num_non_exiting_tiles: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_local_experts: int,
|
||||
local_expert_offset: int,
|
||||
tile_size: int,
|
||||
output_dtype: torch.dtype,
|
||||
scaling_vector_size: int = 16,
|
||||
):
|
||||
m = input.size(0)
|
||||
n = weight.size(1)
|
||||
return torch.empty(m, n, dtype=output_dtype, device=input.device)
|
||||
|
||||
class FusedMoEInputsHelper:
|
||||
|
||||
def __init__(self, num_experts: int, top_k: int, num_local_experts: int,
|
||||
local_expert_offset: int):
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.num_local_experts = num_local_experts
|
||||
self.local_expert_offset = local_expert_offset
|
||||
|
||||
def infer_token_selected_experts_shape(
|
||||
self, input_shapes: List[torch.Size]) -> int:
|
||||
return input_shapes[0][0]
|
||||
|
||||
def infer_token_final_scales_shape(
|
||||
self, input_shapes: List[torch.Size]) -> int:
|
||||
return input_shapes[0][0]
|
||||
|
||||
def inputs_pre_hook(self,
|
||||
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
x, x_sf, token_selected_experts, token_final_scales, *others = inputs
|
||||
num_tokens = token_selected_experts.size(0)
|
||||
new_token_final_scales, new_token_selected_experts = torch.randn(
|
||||
num_tokens,
|
||||
self.num_experts,
|
||||
device=token_selected_experts.device).topk(self.top_k, dim=-1)
|
||||
new_token_selected_experts = new_token_selected_experts.to(
|
||||
token_selected_experts.dtype)
|
||||
new_token_final_scales = new_token_final_scales.softmax(dim=-1).to(
|
||||
token_final_scales.dtype)
|
||||
return x, x_sf, new_token_selected_experts, new_token_final_scales, *others
|
||||
|
||||
class Sm100BlockScaledFusedMoERunner(TunableRunner):
|
||||
tuning_config_cache = dict()
|
||||
|
||||
def __init__(self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_local_experts: int,
|
||||
local_expert_offset: int,
|
||||
output_dtype: torch.dtype,
|
||||
scaling_vector_size: int = 16):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.num_local_experts = num_local_experts
|
||||
self.local_expert_offset = local_expert_offset
|
||||
|
||||
assert output_dtype == torch.bfloat16
|
||||
self.output_dtype = output_dtype
|
||||
self.scaling_vector_size = scaling_vector_size
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
return [128]
|
||||
|
||||
def get_tuning_config(self) -> TuningConfig:
|
||||
key = hash(self)
|
||||
if key not in self.__class__.tuning_config_cache:
|
||||
helper = FusedMoEInputsHelper(self.num_experts, self.top_k,
|
||||
self.num_local_experts,
|
||||
self.local_expert_offset)
|
||||
self.__class__.tuning_config_cache[key] = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(
|
||||
ConstraintSpec(1, 0, fp4_scale_infer_shape),
|
||||
ConstraintSpec(
|
||||
2, 0, helper.infer_token_selected_experts_shape),
|
||||
ConstraintSpec(3, 0,
|
||||
helper.infer_token_final_scales_shape)),
|
||||
inputs_pre_hook=helper.inputs_pre_hook,
|
||||
)
|
||||
return self.__class__.tuning_config_cache[key]
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor],
|
||||
tactic: Optional[int]) -> torch.Tensor:
|
||||
if isinstance(tactic, int):
|
||||
tile_size = tactic
|
||||
else:
|
||||
tile_size = 128
|
||||
|
||||
x, x_sf, token_selected_experts, token_final_scales, gemm1_weight, gemm1_weight_scale, gemm1_alpha, gemm2_input_global_scale, gemm2_weight, gemm2_weight_scale, gemm2_alpha = inputs
|
||||
tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
local_expert_offset=self.local_expert_offset,
|
||||
local_num_experts=self.num_local_experts,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
x, x_sf = torch.ops.trtllm.moe_permute(
|
||||
input=x,
|
||||
input_sf=x_sf,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
tile_tokens_dim=tile_size,
|
||||
top_k=self.top_k,
|
||||
)
|
||||
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=gemm1_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=gemm1_weight_scale.view(torch.uint8),
|
||||
alpha=gemm1_alpha,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
num_local_experts=self.num_local_experts,
|
||||
local_expert_offset=self.local_expert_offset,
|
||||
tile_size=tile_size,
|
||||
output_dtype=self.output_dtype,
|
||||
)
|
||||
x, x_sf = torch.ops.trtllm.moe_swiglu_nvfp4_quantize(
|
||||
input=x,
|
||||
global_sf=gemm2_input_global_scale,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=gemm2_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=gemm2_weight_scale.view(torch.uint8),
|
||||
alpha=gemm2_alpha,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
num_experts=self.num_experts,
|
||||
top_k=self.top_k,
|
||||
num_local_experts=self.num_local_experts,
|
||||
local_expert_offset=self.local_expert_offset,
|
||||
tile_size=tile_size,
|
||||
output_dtype=self.output_dtype,
|
||||
)
|
||||
x = torch.ops.trtllm.moe_unpermute(
|
||||
permuted_input=x,
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
topk_scales=token_final_scales,
|
||||
)
|
||||
return x
|
||||
|
||||
@torch.library.custom_op("trtllm::cute_dsl_nvfp4_fused_moe_blackwell",
|
||||
mutates_args=(),
|
||||
device_types="cuda")
|
||||
def cute_dsl_nvfp4_fused_moe_blackwell(
|
||||
input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: torch.Tensor,
|
||||
gemm1_weight: torch.Tensor,
|
||||
gemm1_weight_scale: torch.Tensor,
|
||||
gemm1_alpha: torch.Tensor,
|
||||
gemm2_input_global_scale: torch.Tensor,
|
||||
gemm2_weight: torch.Tensor,
|
||||
gemm2_weight_scale: torch.Tensor,
|
||||
gemm2_alpha: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_local_experts: int,
|
||||
local_expert_offset: int,
|
||||
output_dtype: torch.dtype,
|
||||
scaling_vector_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
tuner = AutoTuner.get()
|
||||
runner = Sm100BlockScaledFusedMoERunner(num_experts, top_k,
|
||||
num_local_experts,
|
||||
local_expert_offset,
|
||||
output_dtype,
|
||||
scaling_vector_size)
|
||||
inputs = [
|
||||
input, input_scale, token_selected_experts, token_final_scales,
|
||||
gemm1_weight, gemm1_weight_scale, gemm1_alpha,
|
||||
gemm2_input_global_scale, gemm2_weight, gemm2_weight_scale,
|
||||
gemm2_alpha
|
||||
]
|
||||
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_nvfp4_fused_moe_blackwell",
|
||||
[runner],
|
||||
runner.get_tuning_config(),
|
||||
inputs,
|
||||
)
|
||||
output = runner(inputs, tactic=best_tactic)
|
||||
return output
|
||||
|
||||
@torch.library.register_fake("trtllm::cute_dsl_nvfp4_fused_moe_blackwell")
|
||||
def _(
|
||||
input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: torch.Tensor,
|
||||
gemm1_weight: torch.Tensor,
|
||||
gemm1_weight_scale: torch.Tensor,
|
||||
gemm1_alpha: torch.Tensor,
|
||||
gemm2_input_global_scale: torch.Tensor,
|
||||
gemm2_weight: torch.Tensor,
|
||||
gemm2_weight_scale: torch.Tensor,
|
||||
gemm2_alpha: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_local_experts: int,
|
||||
local_expert_offset: int,
|
||||
output_dtype: torch.dtype,
|
||||
scaling_vector_size: int = 16,
|
||||
):
|
||||
m, k = input.size(0), input.size(1) * 2
|
||||
return torch.empty(m, k, dtype=output_dtype, device=input.device)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -701,8 +701,7 @@ class Deepseekv3RoutingImpl():
|
||||
new_mask.scatter_(-1, topk_idx, 1)
|
||||
scores = scores * new_mask
|
||||
score_sum = torch.sum(scores, dim=-1, keepdim=True) + 1e-20
|
||||
scores = scores / score_sum * \
|
||||
self.routed_scaling_factor
|
||||
scores = scores / score_sum * self.routed_scaling_factor
|
||||
topk_values, topk_indices = torch.topk(scores,
|
||||
k=self.top_k,
|
||||
dim=-1,
|
||||
|
||||
@ -6,13 +6,15 @@ import torch.nn.functional as F
|
||||
|
||||
from tensorrt_llm._utils import is_sm_100f
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .quantization import MoEWeightLoadingMode
|
||||
from .routing import BaseMoeRoutingMethod
|
||||
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def swiglu_fused_moe(x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.silu(gate) * x
|
||||
@ -88,6 +90,65 @@ def cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
return ref
|
||||
|
||||
|
||||
def cute_dsl_nvfp4_grouped_gemm_ref(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_sf: torch.Tensor,
|
||||
b_sf: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
tile_idx_to_group_idx: torch.Tensor,
|
||||
num_non_exiting_tiles: torch.Tensor,
|
||||
tile_size: int,
|
||||
output_dtype: torch.dtype,
|
||||
scaling_vector_size: int = 16,
|
||||
):
|
||||
assert a.dtype == torch.float4_e2m1fn_x2
|
||||
assert a.dim() == 2
|
||||
assert b.dtype == torch.float4_e2m1fn_x2
|
||||
assert b.dim() == 3
|
||||
assert a_sf.dtype == torch.uint8
|
||||
assert a_sf.dim() == 1
|
||||
assert b_sf.dtype == torch.uint8
|
||||
assert b_sf.dim() == 3
|
||||
assert alpha.dtype == torch.float32
|
||||
assert alpha.dim() == 1
|
||||
|
||||
m, k = a.size(0), a.size(1) * 2
|
||||
l, n = b.size(0), b.size(1)
|
||||
scale_k = k // scaling_vector_size
|
||||
assert m % tile_size == 0
|
||||
assert k % (scaling_vector_size * 4) == 0
|
||||
assert b.size(2) * 2 == k
|
||||
assert a_sf.size(0) == m * scale_k
|
||||
assert b_sf.size(0) == l
|
||||
assert b_sf.size(1) == n
|
||||
assert b_sf.size(2) == scale_k
|
||||
assert alpha.size(0) == l
|
||||
|
||||
num_tiles = m // tile_size
|
||||
assert tile_idx_to_group_idx.dtype == torch.int32
|
||||
assert tile_idx_to_group_idx.size() == (num_tiles, )
|
||||
assert num_non_exiting_tiles.dtype == torch.int32
|
||||
assert num_non_exiting_tiles.size() == (1, )
|
||||
|
||||
num_tiles_per_expert = torch.bincount(
|
||||
tile_idx_to_group_idx[:num_non_exiting_tiles[0].item()], minlength=l)
|
||||
offsets = [0] + num_tiles_per_expert.cumsum(dim=0).tolist()
|
||||
|
||||
ref = torch.empty(m, n, dtype=output_dtype, device="cuda")
|
||||
for i, (start, end) in enumerate(zip(offsets[:-1], offsets[1:])):
|
||||
if end <= start:
|
||||
continue
|
||||
a_sliced = a[start * tile_size:end * tile_size]
|
||||
a_sf_sliced = a_sf[start * tile_size * k // scaling_vector_size:end *
|
||||
tile_size * k // scaling_vector_size]
|
||||
ref[start * tile_size:end * tile_size] = torch.ops.trtllm.nvfp4_gemm(
|
||||
a_sliced.view(torch.uint8), b[i].view(torch.uint8), a_sf_sliced,
|
||||
b_sf[i], alpha[i], output_dtype)
|
||||
|
||||
return ref
|
||||
|
||||
|
||||
class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
"""
|
||||
Python Flow of Fused Mixture of Experts (MoE) Layer.
|
||||
@ -140,7 +201,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
def forward_chunk(
|
||||
def forward_chunk_unquantized(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
@ -149,11 +210,24 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: tuple = (True, True),
|
||||
) -> torch.Tensor:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert output_dtype is not None
|
||||
output_dtype = output_dtype
|
||||
else:
|
||||
output_dtype = x.dtype
|
||||
assert not self.has_any_quant
|
||||
return super().forward_chunk(x,
|
||||
router_logits,
|
||||
output_dtype=output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=repeating_info)
|
||||
|
||||
def forward_chunk_fp8_block_scales(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: tuple = (True, True),
|
||||
) -> torch.Tensor:
|
||||
assert self.has_deepseek_fp8_block_scales
|
||||
|
||||
# apply routing
|
||||
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||
@ -172,17 +246,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||
token_final_scales = None
|
||||
|
||||
# quantize inputs
|
||||
use_deepseek_fp8_block_scale = False
|
||||
weight_dtype = self.w3_w1_weight.dtype
|
||||
x_sf = None
|
||||
if self.has_any_quant:
|
||||
if self.has_deepseek_fp8_block_scales:
|
||||
use_deepseek_fp8_block_scale = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}"
|
||||
)
|
||||
|
||||
(
|
||||
permuted_row_to_unpermuted_row_tensor,
|
||||
@ -198,7 +262,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
None, # w3_w1_weight.view(weight_dtype),
|
||||
None, # w2_weight.view(weight_dtype),
|
||||
None, # quant_scales,
|
||||
input_sf=x_sf,
|
||||
input_sf=None,
|
||||
num_experts_on_rank=self.expert_size_per_partition,
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
@ -207,7 +271,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
cluster_size=self.cluster_size,
|
||||
cluster_rank=self.cluster_rank,
|
||||
min_latency_mode=False,
|
||||
use_fp8_block_scaling=use_deepseek_fp8_block_scale,
|
||||
use_fp8_block_scaling=True,
|
||||
)
|
||||
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
|
||||
permuted_data_tensor)
|
||||
@ -227,7 +291,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
b_sf=self.quant_scales[1],
|
||||
offset_array=expert_first_token_offset_tensor,
|
||||
)
|
||||
final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op(
|
||||
h4 = torch.ops.trtllm.moe_finalize_scale_op(
|
||||
h3,
|
||||
None, # biases
|
||||
token_final_scales,
|
||||
@ -246,5 +310,164 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
self.ep_size,
|
||||
self.ep_rank,
|
||||
)
|
||||
return h4
|
||||
|
||||
return final_hidden_states
|
||||
def forward_chunk_nvfp4(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: tuple = (True, True),
|
||||
) -> torch.Tensor:
|
||||
assert self.has_nvfp4
|
||||
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert output_dtype is not None
|
||||
else:
|
||||
output_dtype = x.dtype
|
||||
|
||||
# apply routing
|
||||
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||
router_logits)
|
||||
assert token_selected_experts.shape[
|
||||
1] == self.routing_method.experts_per_token
|
||||
assert token_selected_experts.shape == token_final_scales.shape
|
||||
assert token_selected_experts.shape[0] == router_logits.shape[0]
|
||||
assert token_final_scales.dtype == torch.float32
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
|
||||
run_post_quant_allgather = self.use_dp and self.parallel_size > 1
|
||||
if run_post_quant_allgather:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
|
||||
x, x_sf = x.fp4_tensor, x.scaling_factor
|
||||
else:
|
||||
x, x_sf = torch.ops.trtllm.fp4_quantize(
|
||||
x, self.fc31_input_scale, self.scaling_vector_size, False,
|
||||
False)
|
||||
# note: we use uint8 to store 2 fp4 values
|
||||
x_row, x_col = x.size(0), x.size(1) * 2
|
||||
else:
|
||||
if not isinstance(x, Fp4QuantizedTensor):
|
||||
x, x_sf = torch.ops.trtllm.fp4_quantize(
|
||||
x, self.fc31_input_scale, self.scaling_vector_size, False,
|
||||
False)
|
||||
|
||||
if run_post_quant_allgather:
|
||||
# Original allgather logic
|
||||
if x_sf is not None:
|
||||
x_sf = x_sf.view(x_row, ceil_div(x_col,
|
||||
self.scaling_vector_size))
|
||||
assert x_sf.dim(
|
||||
) == 2, "The hidden states scaling factor should be 2D tensor before allgather"
|
||||
|
||||
x, x_sf, token_selected_experts, token_final_scales = allgather(
|
||||
[x, x_sf, token_selected_experts, token_final_scales],
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
|
||||
tile_size = 128
|
||||
tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
local_expert_offset=self.slot_start,
|
||||
local_num_experts=self.expert_size_per_partition,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
|
||||
x, x_sf = torch.ops.trtllm.moe_permute(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
input_sf=x_sf,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
tile_tokens_dim=tile_size,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
)
|
||||
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8),
|
||||
alpha=self.quant_scales.fc1_global,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_local_experts=self.expert_size_per_partition,
|
||||
local_expert_offset=self.slot_start,
|
||||
tile_size=tile_size,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
x, x_sf = torch.ops.trtllm.moe_swiglu_nvfp4_quantize(
|
||||
input=x,
|
||||
global_sf=self.fc2_input_scale,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=self.quant_scales.fc2_weight_block.view(torch.uint8),
|
||||
alpha=self.quant_scales.fc2_global,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_local_experts=self.expert_size_per_partition,
|
||||
local_expert_offset=self.slot_start,
|
||||
tile_size=tile_size,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
x = torch.ops.trtllm.moe_unpermute(
|
||||
permuted_input=x,
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
topk_scales=token_final_scales,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: tuple = (True, True),
|
||||
) -> torch.Tensor:
|
||||
if self.has_any_quant:
|
||||
if self.has_nvfp4:
|
||||
return self.forward_chunk_nvfp4(
|
||||
x,
|
||||
router_logits,
|
||||
output_dtype=output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=repeating_info)
|
||||
elif self.has_deepseek_fp8_block_scales:
|
||||
return self.forward_chunk_fp8_block_scales(
|
||||
x,
|
||||
router_logits,
|
||||
output_dtype=output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=repeating_info)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}"
|
||||
)
|
||||
else:
|
||||
return self.forward_chunk_unquantized(
|
||||
x,
|
||||
router_logits,
|
||||
output_dtype=output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=repeating_info)
|
||||
|
||||
@ -303,7 +303,6 @@ class CutlassFusedMoE(MoE):
|
||||
) -> torch.Tensor:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert output_dtype is not None
|
||||
output_dtype = output_dtype
|
||||
else:
|
||||
output_dtype = x.dtype
|
||||
|
||||
|
||||
@ -474,7 +474,6 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
) -> torch.Tensor:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert output_dtype is not None
|
||||
output_dtype = output_dtype
|
||||
else:
|
||||
output_dtype = x.dtype
|
||||
|
||||
|
||||
@ -400,7 +400,6 @@ class WideEPMoE(MoE):
|
||||
all_rank_max_num_tokens = max(all_rank_num_tokens)
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert output_dtype is not None
|
||||
output_dtype = output_dtype
|
||||
else:
|
||||
output_dtype = x.dtype
|
||||
|
||||
|
||||
@ -1685,13 +1685,15 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
(False, False, False, True),
|
||||
(True, False, True, True), (True, True, True, True)])
|
||||
@parametrize_with_ids("mtp_nextn", [0, 2])
|
||||
@parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM"])
|
||||
@parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM", "CUTEDSL"])
|
||||
def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
|
||||
torch_compile, mtp_nextn, moe_backend):
|
||||
if moe_backend == "TRTLLM" and (get_sm_version() == 120
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
if moe_backend == "CUTEDSL" and get_sm_version() != 100:
|
||||
pytest.skip(f"{moe_backend} backend supports SM 100 only")
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
@ -1778,7 +1780,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
(2, 2, 1), (1, 4, 1)],
|
||||
ids=["tp4", "ep4", "tp2pp2", "pp4"])
|
||||
@parametrize_with_ids("mtp_nextn", [0, 2])
|
||||
@parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM"])
|
||||
@parametrize_with_ids("moe_backend", ["CUTLASS", "TRTLLM", "CUTEDSL"])
|
||||
def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
|
||||
overlap_scheduler, tp_size, pp_size, ep_size,
|
||||
torch_compile, mtp_nextn, moe_backend):
|
||||
@ -1788,6 +1790,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
if moe_backend == "CUTEDSL" and get_sm_version() != 100:
|
||||
pytest.skip(f"{moe_backend} backend supports SM 100 only")
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
|
||||
# Picewise Cuda Graph cannot be enabled for nvfp4 attention dp.
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
|
||||
@ -27,6 +27,7 @@ l0_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False] ISOLATION
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTEDSL-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=nvfp4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=none-kv_cache_reuse=True-fp8kv=False-overlap_scheduler=True]
|
||||
@ -148,4 +149,5 @@ l0_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=TRTLLM-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTEDSL-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestSeedOss_36B::test_auto_dtype
|
||||
|
||||
@ -36,6 +36,7 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False]
|
||||
@ -168,6 +169,7 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding_4gpus[xgrammar-mtp_nextn=0]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
|
||||
|
||||
431
tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py
Normal file
431
tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py
Normal file
@ -0,0 +1,431 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import cute_dsl_nvfp4_grouped_gemm_ref
|
||||
from tensorrt_llm._torch.utils import unswizzle_sf
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
|
||||
|
||||
def swiglu_ref(x: torch.Tensor) -> torch.Tensor:
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return x * torch.nn.functional.silu(gate)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("ep_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 6, 8])
|
||||
def test_grouped_gemm_inputs_helper(top_k: int, ep_size: int, tile_size: int):
|
||||
num_experts = 256
|
||||
num_local_experts = num_experts // ep_size
|
||||
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
max_num_tokens = 8192
|
||||
num_tokens_list = list(range(1, max_num_tokens + 1))
|
||||
max_num_permuted_tokens_list = [helper.get_max_num_permuted_tokens(x) for x in num_tokens_list]
|
||||
num_inferred_tokens_list = [helper.infer_num_tokens(x) for x in max_num_permuted_tokens_list]
|
||||
|
||||
for i in range(max_num_tokens):
|
||||
assert num_inferred_tokens_list[i] >= num_tokens_list[i]
|
||||
assert num_inferred_tokens_list[i] < num_tokens_list[i] + tile_size
|
||||
if i > 0:
|
||||
assert num_inferred_tokens_list[i] >= num_inferred_tokens_list[i - 1]
|
||||
|
||||
buckets = helper.gen_tuning_buckets(max_num_permuted_tokens_list[-1])
|
||||
assert set([helper.map_to_tuning_buckets(x) for x in max_num_permuted_tokens_list]) == set(
|
||||
buckets
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("ep_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024, 8192])
|
||||
def test_moe_sort(num_tokens: int, top_k: int, ep_size: int, tile_size: int):
|
||||
num_experts = 256
|
||||
num_local_experts = num_experts // ep_size
|
||||
|
||||
routing_logits = torch.randn(num_tokens, num_experts, device="cuda")
|
||||
token_final_scales, token_selected_experts = routing_logits.topk(top_k, dim=-1)
|
||||
token_selected_experts = token_selected_experts.to(torch.int32)
|
||||
token_final_scales = token_final_scales.softmax(dim=-1).to(torch.bfloat16)
|
||||
|
||||
(
|
||||
tile_idx_to_group_idx,
|
||||
tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx,
|
||||
total_num_padded_tokens,
|
||||
num_non_exiting_tiles,
|
||||
) = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
local_expert_offset=0,
|
||||
local_num_experts=num_local_experts,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
|
||||
num_tokens_per_expert = torch.bincount(token_selected_experts.flatten(), minlength=num_experts)
|
||||
num_tokens_per_expert = num_tokens_per_expert[:num_local_experts]
|
||||
num_tiles_per_expert = (num_tokens_per_expert + tile_size - 1) // tile_size
|
||||
num_tokens_per_expert = num_tokens_per_expert.cpu()
|
||||
num_tiles_per_expert = num_tiles_per_expert.cpu()
|
||||
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
max_num_tiles = helper.get_max_num_tiles(num_tokens)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
num_valid_tiles = num_tiles_per_expert.sum().item()
|
||||
num_valid_permuted_tokens = num_valid_tiles * tile_size
|
||||
assert 0 <= num_valid_tiles <= max_num_tiles
|
||||
assert 0 <= num_valid_permuted_tokens <= max_num_permuted_tokens
|
||||
|
||||
tile_idx_to_group_idx = tile_idx_to_group_idx.cpu()
|
||||
tile_idx_to_mn_limit = tile_idx_to_mn_limit.cpu()
|
||||
assert tile_idx_to_group_idx.size() == (max_num_tiles,)
|
||||
assert tile_idx_to_mn_limit.size() == (max_num_tiles,)
|
||||
tile_idx = 0
|
||||
for expert_idx in range(num_local_experts):
|
||||
num_remaining_tokens = num_tokens_per_expert[expert_idx].item()
|
||||
for i in range(num_tiles_per_expert[expert_idx].item()):
|
||||
mn_limit = tile_idx * tile_size
|
||||
if i < num_tiles_per_expert[expert_idx].item() - 1:
|
||||
assert num_remaining_tokens > tile_size
|
||||
num_remaining_tokens -= tile_size
|
||||
mn_limit += tile_size
|
||||
else:
|
||||
assert 0 < num_remaining_tokens <= tile_size
|
||||
mn_limit += num_remaining_tokens
|
||||
assert tile_idx_to_group_idx[tile_idx].item() == expert_idx
|
||||
assert tile_idx_to_mn_limit[tile_idx].item() == mn_limit
|
||||
tile_idx += 1
|
||||
|
||||
token_selected_experts = token_selected_experts.cpu()
|
||||
expanded_idx_to_permuted_idx = expanded_idx_to_permuted_idx.cpu()
|
||||
permuted_idx_to_expanded_idx = permuted_idx_to_expanded_idx.cpu()
|
||||
assert expanded_idx_to_permuted_idx.size() == (num_tokens, top_k)
|
||||
assert permuted_idx_to_expanded_idx.size() == (max_num_permuted_tokens,)
|
||||
for i in range(num_tokens):
|
||||
for k in range(top_k):
|
||||
expert_idx = token_selected_experts[i, k].item()
|
||||
expanded_idx = i * top_k + k
|
||||
permuted_idx = expanded_idx_to_permuted_idx[i, k].item()
|
||||
if expert_idx >= num_local_experts:
|
||||
assert permuted_idx == -1
|
||||
else:
|
||||
assert permuted_idx >= 0
|
||||
assert permuted_idx_to_expanded_idx[permuted_idx].item() == expanded_idx
|
||||
tile_idx = permuted_idx // tile_size
|
||||
assert tile_idx_to_group_idx[tile_idx].item() == expert_idx
|
||||
|
||||
for i in range(num_valid_permuted_tokens):
|
||||
tile_idx = i // tile_size
|
||||
if i < tile_idx_to_mn_limit[tile_idx].item():
|
||||
expanded_idx = permuted_idx_to_expanded_idx[i].item()
|
||||
token_idx = expanded_idx // top_k
|
||||
topk_idx = expanded_idx % top_k
|
||||
assert expanded_idx_to_permuted_idx[token_idx, topk_idx].item() == i
|
||||
|
||||
assert total_num_padded_tokens.size() == (1,)
|
||||
assert total_num_padded_tokens[0].item() == num_valid_permuted_tokens
|
||||
assert num_non_exiting_tiles.size() == (1,)
|
||||
assert num_non_exiting_tiles[0].item() == num_valid_tiles
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16", "float16", "float8", "float4"])
|
||||
def test_moe_permute(dtype: str, num_tokens: int, top_k: int, tile_size: int):
|
||||
sf_vec_size = 16
|
||||
hidden_size = 4096
|
||||
num_experts = 256
|
||||
num_local_experts = num_experts // 32
|
||||
x = torch.randint(-100, 100, (num_tokens, hidden_size), dtype=torch.int32, device="cuda")
|
||||
x_sf = None
|
||||
if dtype == "float4":
|
||||
x = x[:, : hidden_size // 2].to(torch.int8).view(torch.float4_e2m1fn_x2)
|
||||
x_sf = torch.randint(
|
||||
-100, 100, (num_tokens, hidden_size // sf_vec_size), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
x_sf = x_sf.to(torch.float8_e4m3fn).view(torch.uint8)
|
||||
elif dtype == "float8":
|
||||
x = x.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
x = x.to(getattr(torch, dtype))
|
||||
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
max_num_tiles = helper.get_max_num_tiles(num_tokens)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
tile_idx_to_mn_limit = (
|
||||
torch.arange(1, max_num_tiles + 1, dtype=torch.int32, device="cuda") * tile_size
|
||||
)
|
||||
permuted_idx_to_expanded_idx = torch.randint(
|
||||
0, num_tokens * top_k, (max_num_permuted_tokens,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_non_exiting_tiles_val = (num_tokens * top_k + tile_size - 1) // tile_size
|
||||
num_non_exiting_tiles = torch.tensor(
|
||||
[num_non_exiting_tiles_val], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
permuted_x, permuted_sf = torch.ops.trtllm.moe_permute(
|
||||
x,
|
||||
x_sf,
|
||||
tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles,
|
||||
tile_size,
|
||||
top_k,
|
||||
)
|
||||
if dtype == "float4":
|
||||
assert permuted_sf is not None
|
||||
permuted_sf = unswizzle_sf(permuted_sf, max_num_permuted_tokens, hidden_size, sf_vec_size)
|
||||
else:
|
||||
assert permuted_sf is None
|
||||
|
||||
for i in range(max_num_permuted_tokens):
|
||||
if i >= num_non_exiting_tiles_val * tile_size:
|
||||
break
|
||||
expanded_idx = permuted_idx_to_expanded_idx[i].item()
|
||||
if expanded_idx < 0:
|
||||
continue
|
||||
token_idx = expanded_idx // top_k
|
||||
if dtype == "float4":
|
||||
torch.testing.assert_close(
|
||||
permuted_x[i].view(torch.uint8), x[token_idx].view(torch.uint8)
|
||||
)
|
||||
torch.testing.assert_close(permuted_sf[i], x_sf[token_idx])
|
||||
else:
|
||||
torch.testing.assert_close(permuted_x[i], x[token_idx])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16", "float16"])
|
||||
def test_moe_unpermute(dtype: str, num_tokens: int, top_k: int, tile_size: int):
|
||||
dtype = getattr(torch, dtype)
|
||||
hidden_size = 4096
|
||||
num_experts = 256
|
||||
num_local_experts = num_experts // 32
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
permuted_x = torch.randint(
|
||||
-100, 100, (max_num_permuted_tokens, hidden_size), dtype=torch.int32, device="cuda"
|
||||
).to(dtype)
|
||||
|
||||
expanded_idx_to_permuted_idx = torch.randint(
|
||||
0, max_num_permuted_tokens, (num_tokens, top_k), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
topk_scales = torch.randn(num_tokens, top_k, dtype=torch.float32, device="cuda").softmax(dim=-1)
|
||||
x = torch.ops.trtllm.moe_unpermute(permuted_x, expanded_idx_to_permuted_idx, topk_scales)
|
||||
|
||||
x_ref = (
|
||||
(permuted_x[expanded_idx_to_permuted_idx] * topk_scales.unsqueeze(-1)).sum(dim=1).to(dtype)
|
||||
)
|
||||
torch.testing.assert_close(x, x_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16", "float16"])
|
||||
def test_moe_swiglu(dtype: str, num_tokens: int, top_k: int, tile_size: int):
|
||||
dtype = getattr(torch, dtype)
|
||||
interm_size = 4096
|
||||
num_experts = 256
|
||||
num_local_experts = num_experts // 32
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
max_num_tiles = helper.get_max_num_tiles(num_tokens)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
|
||||
x = torch.randint(
|
||||
-100, 100, (max_num_permuted_tokens, interm_size * 2), dtype=torch.int32, device="cuda"
|
||||
).to(dtype)
|
||||
tile_idx_to_mn_limit = (
|
||||
torch.arange(1, max_num_tiles + 1, dtype=torch.int32, device="cuda") * tile_size
|
||||
)
|
||||
num_non_exiting_tiles_val = (num_tokens * top_k + tile_size - 1) // tile_size
|
||||
num_non_exiting_tiles = torch.tensor(
|
||||
[num_non_exiting_tiles_val], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_permuted_tokens = num_non_exiting_tiles_val * tile_size
|
||||
|
||||
y = torch.ops.trtllm.moe_swiglu(x, tile_idx_to_mn_limit, num_non_exiting_tiles, tile_size)
|
||||
y_ref = swiglu_ref(x)
|
||||
torch.testing.assert_close(y[:num_permuted_tokens], y_ref[:num_permuted_tokens])
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16", "float16"])
|
||||
def test_moe_swiglu_nvfp4_quantize(dtype: str, num_tokens: int, top_k: int, tile_size: int):
|
||||
dtype = getattr(torch, dtype)
|
||||
sf_vec_size = 16
|
||||
interm_size = 4096
|
||||
num_experts = 256
|
||||
num_local_experts = num_experts // 32
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
max_num_tiles = helper.get_max_num_tiles(num_tokens)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
|
||||
x = torch.randint(
|
||||
-100, 100, (max_num_permuted_tokens, interm_size * 2), dtype=torch.int32, device="cuda"
|
||||
).to(dtype)
|
||||
tile_idx_to_mn_limit = (
|
||||
torch.arange(1, max_num_tiles + 1, dtype=torch.int32, device="cuda") * tile_size
|
||||
)
|
||||
num_non_exiting_tiles_val = (num_tokens * top_k + tile_size - 1) // tile_size
|
||||
num_non_exiting_tiles = torch.tensor(
|
||||
[num_non_exiting_tiles_val], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_permuted_tokens = num_non_exiting_tiles_val * tile_size
|
||||
|
||||
global_sf = swiglu_ref(x).abs().max().float() / (448 * 6)
|
||||
global_sf = 1 / global_sf
|
||||
y, y_sf = torch.ops.trtllm.moe_swiglu_nvfp4_quantize(
|
||||
x, global_sf, tile_idx_to_mn_limit, num_non_exiting_tiles, tile_size
|
||||
)
|
||||
y_ref, y_sf_ref = torch.ops.trtllm.fp4_quantize(swiglu_ref(x), global_sf, 16, False)
|
||||
match_ratio = (
|
||||
y[:num_permuted_tokens].view(torch.uint8) == y_ref[:num_permuted_tokens]
|
||||
).sum().item() / y[:num_permuted_tokens].numel()
|
||||
assert match_ratio > 0.999
|
||||
|
||||
num_sf_elements = num_permuted_tokens * interm_size // sf_vec_size
|
||||
match_ratio = (
|
||||
y_sf[:num_sf_elements] == y_sf_ref[:num_sf_elements]
|
||||
).sum().item() / num_sf_elements
|
||||
assert match_ratio > 0.999
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16", "float16"])
|
||||
def test_moe_gelu(dtype: str, num_tokens: int, top_k: int, tile_size: int):
|
||||
dtype = getattr(torch, dtype)
|
||||
interm_size = 4096
|
||||
num_experts = 256
|
||||
num_local_experts = num_experts // 32
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
max_num_tiles = helper.get_max_num_tiles(num_tokens)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
|
||||
x = torch.randint(
|
||||
-100, 100, (max_num_permuted_tokens, interm_size), dtype=torch.int32, device="cuda"
|
||||
).to(dtype)
|
||||
tile_idx_to_mn_limit = (
|
||||
torch.arange(1, max_num_tiles + 1, dtype=torch.int32, device="cuda") * tile_size
|
||||
)
|
||||
num_non_exiting_tiles_val = (num_tokens * top_k + tile_size - 1) // tile_size
|
||||
num_non_exiting_tiles = torch.tensor(
|
||||
[num_non_exiting_tiles_val], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_permuted_tokens = num_non_exiting_tiles_val * tile_size
|
||||
|
||||
y = torch.ops.trtllm.moe_gelu(x, tile_idx_to_mn_limit, num_non_exiting_tiles, tile_size)
|
||||
y_ref = torch.nn.functional.gelu(x)
|
||||
torch.testing.assert_close(y[:num_permuted_tokens], y_ref[:num_permuted_tokens])
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
|
||||
@pytest.mark.parametrize("tile_size", [128])
|
||||
@pytest.mark.parametrize("ep_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024, 8192])
|
||||
def test_nvfp4_grouped_gemm_blackwell(num_tokens: int, top_k: int, ep_size: int, tile_size: int):
|
||||
sf_vec_size = 16
|
||||
hidden_size = 4096
|
||||
inter_size = 8192
|
||||
num_experts = 256
|
||||
num_local_experts = num_experts // ep_size
|
||||
|
||||
helper = GroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size)
|
||||
max_num_tiles = helper.get_max_num_tiles(num_tokens)
|
||||
max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens)
|
||||
routing_logits = torch.randn(num_tokens, num_experts, device="cuda")
|
||||
_, token_selected_experts = routing_logits.topk(top_k, dim=-1)
|
||||
token_selected_experts = token_selected_experts.to(torch.int32)
|
||||
num_tokens_per_expert = torch.bincount(token_selected_experts.flatten(), minlength=num_experts)
|
||||
num_tokens_per_expert = num_tokens_per_expert[:num_local_experts]
|
||||
num_tiles_per_expert = (num_tokens_per_expert + tile_size - 1) // tile_size
|
||||
num_tokens_per_expert = num_tokens_per_expert.cpu()
|
||||
num_tiles_per_expert = num_tiles_per_expert.cpu()
|
||||
num_valid_tiles = num_tiles_per_expert.sum().item()
|
||||
assert 0 <= num_valid_tiles <= max_num_tiles
|
||||
|
||||
num_non_exiting_tiles = torch.tensor([num_valid_tiles], dtype=torch.int32, device="cuda")
|
||||
tile_idx_to_group_idx = torch.empty(max_num_tiles, dtype=torch.int32)
|
||||
# Note: Fill -2e9 for invalid tiles.
|
||||
tile_idx_to_group_idx.fill_(-2e9)
|
||||
tile_idx = 0
|
||||
for expert_idx in range(num_local_experts):
|
||||
for i in range(num_tiles_per_expert[expert_idx].item()):
|
||||
tile_idx_to_group_idx[tile_idx] = expert_idx
|
||||
tile_idx += 1
|
||||
tile_idx_to_group_idx = tile_idx_to_group_idx.cuda()
|
||||
|
||||
a = torch.randint(
|
||||
-100, 100, (max_num_permuted_tokens, hidden_size // 2), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
a = a.to(torch.int8).view(torch.float4_e2m1fn_x2)
|
||||
a_sf = torch.randint(
|
||||
-100,
|
||||
100,
|
||||
(max_num_permuted_tokens, hidden_size // sf_vec_size),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
a_sf = a_sf.to(torch.float8_e4m3fn).view(torch.uint8).flatten()
|
||||
b = torch.randint(
|
||||
-100,
|
||||
100,
|
||||
(num_local_experts, inter_size, hidden_size // 2),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
b = b.to(torch.int8).view(torch.float4_e2m1fn_x2)
|
||||
b_sf = torch.randint(
|
||||
-100,
|
||||
100,
|
||||
(num_local_experts, inter_size, hidden_size // sf_vec_size),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
b_sf = b_sf.to(torch.float8_e4m3fn).view(torch.uint8)
|
||||
alpha = torch.ones(num_local_experts, dtype=torch.float32, device="cuda")
|
||||
|
||||
c = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
|
||||
a,
|
||||
b,
|
||||
a_sf,
|
||||
b_sf,
|
||||
alpha,
|
||||
tile_idx_to_group_idx,
|
||||
num_non_exiting_tiles,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
num_local_experts=num_local_experts,
|
||||
local_expert_offset=0,
|
||||
tile_size=tile_size,
|
||||
output_dtype=torch.bfloat16,
|
||||
scaling_vector_size=sf_vec_size,
|
||||
)
|
||||
c_ref = cute_dsl_nvfp4_grouped_gemm_ref(
|
||||
a,
|
||||
b,
|
||||
a_sf,
|
||||
b_sf,
|
||||
alpha,
|
||||
tile_idx_to_group_idx,
|
||||
num_non_exiting_tiles,
|
||||
tile_size=tile_size,
|
||||
output_dtype=torch.bfloat16,
|
||||
scaling_vector_size=sf_vec_size,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
c[: num_valid_tiles * tile_size], c_ref[: num_valid_tiles * tile_size]
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user