[https://nvbugs/5378031] [feat] W4A8 AWQ MoE supports Per Expert Pre-quant Scale Factor for PyT backend (#7286)

Signed-off-by: Min Yu <171526537+yumin066@users.noreply.github.com>
This commit is contained in:
Min Yu 2025-10-16 11:07:48 +08:00 committed by GitHub
parent e75b4f9f65
commit 0a0159fdd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 217 additions and 72 deletions

View File

@ -859,7 +859,7 @@ private:
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
cudaStream_t stream);
cudaStream_t stream, int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0);
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;

View File

@ -52,6 +52,7 @@
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/moe_utils.cuh"
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
#include "tensorrt_llm/kernels/quantization.cuh"
@ -897,27 +898,6 @@ void threeStepBuildExpertMapsSortFirstToken(int const* token_selected_experts, i
}
// ============================== Infer GEMM sizes =================================
// TODO Could linear search be better for small # experts
template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
{
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high)
{
int64_t mid = (low + high) / 2;
if (sorted_indices[mid] >= target)
{
high = mid - 1;
}
else
{
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}
template <class T>
using sizeof_bits = cutlass::sizeof_bits<typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
@ -1508,6 +1488,9 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
static_assert(!is_nvfp4 && !is_mxfp8, "NVFP4 and MXFP8 are not supported for AWQ");
static_assert(!std::is_same_v<InputActivationsType, ExpandedActivationsType>,
"Input and output types must be different for AWQ");
int64_t expert = findTotalEltsLessThanTarget(
expert_first_token_offset, num_experts_per_node, (int64_t) permuted_row + 1)
- 1;
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
{
auto frag_elems = source_row_ptr[elem_index];
@ -1515,7 +1498,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
CUTLASS_PRAGMA_UNROLL
for (int e = 0; e < ELEM_PER_THREAD; e++)
{
frag_elems[e] = frag_elems[e] * prequant_scales[elem_index * ELEM_PER_THREAD + e];
frag_elems[e]
= frag_elems[e] * prequant_scales[expert * hidden_size + elem_index * ELEM_PER_THREAD + e];
}
dest_row_ptr[elem_index] = arrayConvert<DataElem, OutputElem>(frag_elems);
@ -2918,7 +2902,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
template <class T, class WeightType, class OutputType, class InputType, class ScaleBiasType, class Enable>
T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Enable>::applyPrequantScale(
void* smoothed_act, void const* permuted_data, void const* prequant_scales, int64_t const* num_valid_tokens_ptr,
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream)
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream,
int64_t* expert_first_token_offset, int const num_experts_per_node)
{
T const* gemm_input;
bool use_prequant_scale_kernel = use_awq && !std::is_same_v<T, WeightType>;
@ -2928,10 +2913,20 @@ T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType,
(!std::is_same_v<T, WeightType>), "Prequant scales are only used for different weight/activation type!");
if constexpr (!std::is_same_v<T, WeightType>)
{
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
num_valid_tokens_ptr, stream);
if (expert_first_token_offset != nullptr)
{
tensorrt_llm::kernels::apply_per_channel_scale_per_expert_kernel_launcher<UnfusedGemmOutputType, T>(
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
expert_first_token_offset, num_experts_per_node, num_valid_tokens_ptr, stream);
}
else
{
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
num_valid_tokens_ptr, stream);
}
}
gemm_input = reinterpret_cast<T const*>(smoothed_act);
}
@ -3740,7 +3735,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
}
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream);
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, expert_first_token_offset_,
num_experts_per_node);
sync_check_cuda_error(stream);
Self::gemm2(moe_gemm_runner_, blockscale_gemm_runner, gemm2_input, fc2_result_, final_output,
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,

View File

@ -0,0 +1,48 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace tensorrt_llm
{
namespace kernels
{
// TODO Could linear search be better for small # experts
template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
{
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high)
{
int64_t mid = (low + high) / 2;
if (sorted_indices[mid] >= target)
{
high = mid - 1;
}
else
{
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/moe_utils.cuh"
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
namespace tensorrt_llm
@ -41,7 +42,7 @@ struct Vec2Type<__nv_bfloat16>
template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
__global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows,
int cols, int64_t const* num_valid_tokens_ptr)
int cols, int64_t const* num_valid_tokens_ptr, int64_t* expert_first_token_offset, int const num_experts_per_node)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
T_in scale[kElems], act_vec[kElems];
@ -53,11 +54,19 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
return;
act += row_offset * kProcessRows * cols;
smoothed_act += row_offset * kProcessRows * cols;
*reinterpret_cast<AccessType*>(scale) = reinterpret_cast<AccessType const*>(per_channel_scale)[col_offset];
#pragma unroll
for (int i = 0; i < kProcessRows; ++i)
{
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<AccessType const*>(act + i * cols)[col_offset];
int expert = 0;
if (expert_first_token_offset != nullptr)
{
expert = findTotalEltsLessThanTarget(
expert_first_token_offset, num_experts_per_node, (int64_t) row_offset * kProcessRows + i + 1)
- 1;
}
*reinterpret_cast<AccessType*>(scale)
= reinterpret_cast<AccessType const*>(per_channel_scale)[expert * cols / kElems + col_offset];
if constexpr ((std::is_same_v<T_in, half>
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|| std::is_same_v<T_in, __nv_bfloat16>
@ -98,13 +107,14 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
void apply_per_channel_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0)
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0,
int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
dim3 block(128);
dim3 grid((rows + kProcessRows - 1) / kProcessRows, (cols / kElems + block.x - 1) / block.x);
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType>
<<<grid, block, 0, stream>>>(smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr);
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType><<<grid, block, 0, stream>>>(smoothed_act, act,
per_channel_scale, rows, cols, num_valid_tokens_ptr, expert_first_token_offset, num_experts_per_node);
}
template <typename T_in, typename T_out>
@ -134,6 +144,34 @@ void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* ac
}
}
template <typename T_in, typename T_out>
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
{
uint64_t elems = static_cast<uint64_t>(rows) * static_cast<uint64_t>(cols);
if (elems < 2048 * 2048)
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 1, float4>(smoothed_act, act, per_channel_scale, rows,
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
}
else if (elems < 4096 * 4096)
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 4, float4>(smoothed_act, act, per_channel_scale, rows,
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
}
else if (elems < 8192 * 8192)
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 8, float4>(smoothed_act, act, per_channel_scale, rows,
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
}
else
{
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 16, float4>(smoothed_act, act, per_channel_scale, rows,
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
}
}
#define INSTANTIATE_PREQUANT_SCALE(T_in, T_out) \
template void apply_per_channel_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, const T_in* act, \
const T_in* per_channel_scale, int rows, int cols, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
@ -150,5 +188,22 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif
#define INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(T_in, T_out) \
template void apply_per_channel_scale_per_expert_kernel_launcher<T_in, T_out>(T_out * smoothed_act, \
const T_in* act, const T_in* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset, \
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, half);
#if defined(ENABLE_FP8)
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, __nv_fp8_e4m3);
#endif
#if defined(ENABLE_BF16)
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_bfloat16);
#if defined(ENABLE_FP8)
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -39,5 +39,10 @@ template <typename T_in, typename T_out = T_in>
void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0);
template <typename T_in, typename T_out = T_in>
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -914,14 +914,18 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
module.intermediate_size_per_partition // 2)
# Multiply act with reciprocal of per-channel pre_quant_scale * per-tensor input_scale
fc31_act_scale = nn.Parameter(torch.empty(1,
module.hidden_size,
dtype=module.dtype),
fc31_act_scale = nn.Parameter(torch.empty(
module.expert_size_per_partition,
module.hidden_size,
dtype=module.dtype),
requires_grad=False)
module.register_parameter("fc31_act_scale", fc31_act_scale)
fc2_act_scale = nn.Parameter(torch.empty(
1, module.intermediate_size_per_partition, 1, dtype=module.dtype),
module.expert_size_per_partition,
module.intermediate_size_per_partition,
1,
dtype=module.dtype),
requires_grad=False)
module.register_parameter("fc2_act_scale", fc2_act_scale)
@ -1129,15 +1133,29 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w3_w1_pre_quant_scales_max = torch.max(
torch.stack(all_w3_pre_quant_scales +
all_w1_pre_quant_scales).to(module.dtype),
all_w3_w1_pre_quant_scales_greater = torch.max(
torch.stack([
torch.stack(all_w3_pre_quant_scales),
torch.stack(all_w1_pre_quant_scales)
]).to(module.dtype),
dim=0,
).values.permute(1, 0)
all_w3_w1_input_scales_greater = torch.max(
torch.stack([
torch.stack(all_w3_input_scales),
torch.stack(all_w1_input_scales)
]).to(module.dtype),
dim=0,
).values
all_w3_w1_pre_quant_scales_div_input_scales = (
all_w3_w1_pre_quant_scales_greater *
(1 / all_w3_w1_input_scales_greater.reshape(
1, module.expert_size_per_partition).float()))
module.fc31_act_scale.data.copy_(
torch.ones_like(module.fc31_act_scale, device=self.device) *
(all_w3_w1_pre_quant_scales_max) *
(1 / all_w3_w1_input_scales_max))
all_w3_w1_pre_quant_scales_div_input_scales.permute(1, 0))
# In vanilla ckpt (at least from ModelOpt), per-tensor weight_scale_2 is separately stored
all_w3_weight_scale_2 = [
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_2"],
@ -1149,13 +1167,21 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w3_w1_weight_scale_2_max = torch.max(
torch.stack(all_w3_weight_scale_2 + all_w1_weight_scale_2).to(
module.dtype),
dim=0,
).values
module.fc31_alpha.data.copy_(all_w3_w1_weight_scale_2_max.float() *
all_w3_w1_input_scales_max.float())
all_w3_w1_weight_scale_2 = torch.stack([
torch.stack(all_w3_weight_scale_2),
torch.stack(all_w1_weight_scale_2)
]).to(module.dtype)
all_w3_w1_weight_scale_2_greater = torch.max(
all_w3_w1_weight_scale_2, dim=0).values
all_w3_w1_weight_scale_2_mul_input_scales = (
all_w3_w1_weight_scale_2_greater.reshape(
module.expert_size_per_partition, 1).float() *
all_w3_w1_input_scales_greater.reshape(
module.expert_size_per_partition, 1).float())
module.fc31_alpha.data.copy_(
all_w3_w1_weight_scale_2_mul_input_scales.reshape(
module.expert_size_per_partition, 1).float())
# Per-group weight_scale
all_w3_scales = [
@ -1183,7 +1209,11 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
w3_w1_scales = all_w3_w1_scales.to(torch.bfloat16).view(
module.dtype)
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w3_w1_scales /= all_w3_w1_weight_scale_2_max.float()
w3_w1_scales = w3_w1_scales.permute(1, 2, 0)
w3_w1_scales /= all_w3_w1_weight_scale_2_greater.reshape(
module.expert_size_per_partition).float()
w3_w1_scales = w3_w1_scales.permute(2, 0, 1)
w3_w1_s_shape = w3_w1_scales.shape
w3_w1_scales_interleaved = w3_w1_scales.reshape(
w3_w1_s_shape[0], w3_w1_s_shape[1],
@ -1223,23 +1253,31 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w2_pre_quant_scales_max = torch.max(
torch.stack(all_w2_pre_quant_scales).to(module.dtype),
dim=0).values
all_w2_pre_quant_scales = torch.stack(all_w2_pre_quant_scales).to(
module.dtype)
all_w2_input_scales = torch.stack(all_w2_input_scales).to(
module.dtype)
all_w2_pre_quant_scales_div_input_scales = (
all_w2_pre_quant_scales.permute(1, 0) *
(1 / (all_w2_input_scales.reshape(
module.expert_size_per_partition).float()))).permute(1, 0)
module.fc2_act_scale.data.copy_(
torch.ones_like(module.fc2_act_scale, device=self.device) *
(all_w2_pre_quant_scales_max.unsqueeze(-1)) *
(1 / all_w2_input_scales_max))
all_w2_pre_quant_scales_div_input_scales.reshape(
module.fc2_act_scale.shape))
# In vanilla ckpt (at least from ModelOpt), per-tensor weight_scale_2 is separately stored
all_w2_weight_scale_2 = [
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_2"],
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w2_weight_scale_2_max = torch.stack(all_w2_weight_scale_2).to(
module.dtype).max()
module.fc2_alpha.data.copy_(all_w2_weight_scale_2_max.float() *
all_w2_input_scales_max.float())
all_w2_weight_scale_2 = torch.stack(all_w2_weight_scale_2).to(
module.dtype)
all_w2_weight_scale_2_mul_input_scales = (
all_w2_weight_scale_2.reshape(module.expert_size_per_partition,
1) *
all_w2_input_scales.reshape(module.expert_size_per_partition,
1))
module.fc2_alpha.data.copy_(all_w2_weight_scale_2_mul_input_scales)
# Per-group weight_scale
all_w2_scales = [
@ -1258,7 +1296,11 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
module.dtype)
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w2_scales /= all_w2_weight_scale_2_max.float()
w2_scales = w2_scales.permute(1, 2, 0)
all_w2_weight_scale_2 = all_w2_weight_scale_2.reshape(
module.expert_size_per_partition)
w2_scales /= (all_w2_weight_scale_2.float())
w2_scales = w2_scales.permute(2, 0, 1)
w2_s_shape = w2_scales.shape
w2_scales_interleaved = w2_scales.reshape(
w2_s_shape[0], w2_s_shape[1],

View File

@ -1612,15 +1612,14 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
dtype=torch.int8).cuda()
# The pre-quant scale to be multiplied with the input activation.
w1_pre_quant_scale = torch.ones(HIDDEN_SIZE,
dtype=dtype,
device="cuda")
w2_pre_quant_scale = torch.ones(INTERMEDIATE_SIZE,
dtype=dtype,
device="cuda")
w3_pre_quant_scale = torch.ones(HIDDEN_SIZE,
dtype=dtype,
device="cuda")
# Use random pre-quant scales [0.95, 1.05] instead of fixed 1.0 to ensure the kernel handles
# non-uniform pre-quant scaling factors correctly
w1_pre_quant_scale = torch.rand(
HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
w2_pre_quant_scale = torch.rand(
INTERMEDIATE_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
w3_pre_quant_scale = torch.rand(
HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
# The weight scale to dequantize int4 weights (by multiplication).
w1_scale = torch.randn(