mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5378031] [feat] W4A8 AWQ MoE supports Per Expert Pre-quant Scale Factor for PyT backend (#7286)
Signed-off-by: Min Yu <171526537+yumin066@users.noreply.github.com>
This commit is contained in:
parent
e75b4f9f65
commit
0a0159fdd8
@ -859,7 +859,7 @@ private:
|
||||
|
||||
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
|
||||
cudaStream_t stream);
|
||||
cudaStream_t stream, int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0);
|
||||
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
|
||||
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;
|
||||
|
||||
@ -52,6 +52,7 @@
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/moe_utils.cuh"
|
||||
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
|
||||
#include "tensorrt_llm/kernels/quantization.cuh"
|
||||
|
||||
@ -897,27 +898,6 @@ void threeStepBuildExpertMapsSortFirstToken(int const* token_selected_experts, i
|
||||
}
|
||||
|
||||
// ============================== Infer GEMM sizes =================================
|
||||
// TODO Could linear search be better for small # experts
|
||||
template <class T>
|
||||
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
|
||||
{
|
||||
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||
while (low <= high)
|
||||
{
|
||||
int64_t mid = (low + high) / 2;
|
||||
|
||||
if (sorted_indices[mid] >= target)
|
||||
{
|
||||
high = mid - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
low = mid + 1;
|
||||
target_location = mid;
|
||||
}
|
||||
}
|
||||
return target_location + 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
using sizeof_bits = cutlass::sizeof_bits<typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
|
||||
@ -1508,6 +1488,9 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
|
||||
static_assert(!is_nvfp4 && !is_mxfp8, "NVFP4 and MXFP8 are not supported for AWQ");
|
||||
static_assert(!std::is_same_v<InputActivationsType, ExpandedActivationsType>,
|
||||
"Input and output types must be different for AWQ");
|
||||
int64_t expert = findTotalEltsLessThanTarget(
|
||||
expert_first_token_offset, num_experts_per_node, (int64_t) permuted_row + 1)
|
||||
- 1;
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
|
||||
{
|
||||
auto frag_elems = source_row_ptr[elem_index];
|
||||
@ -1515,7 +1498,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int e = 0; e < ELEM_PER_THREAD; e++)
|
||||
{
|
||||
frag_elems[e] = frag_elems[e] * prequant_scales[elem_index * ELEM_PER_THREAD + e];
|
||||
frag_elems[e]
|
||||
= frag_elems[e] * prequant_scales[expert * hidden_size + elem_index * ELEM_PER_THREAD + e];
|
||||
}
|
||||
|
||||
dest_row_ptr[elem_index] = arrayConvert<DataElem, OutputElem>(frag_elems);
|
||||
@ -2918,7 +2902,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
|
||||
template <class T, class WeightType, class OutputType, class InputType, class ScaleBiasType, class Enable>
|
||||
T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Enable>::applyPrequantScale(
|
||||
void* smoothed_act, void const* permuted_data, void const* prequant_scales, int64_t const* num_valid_tokens_ptr,
|
||||
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream)
|
||||
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream,
|
||||
int64_t* expert_first_token_offset, int const num_experts_per_node)
|
||||
{
|
||||
T const* gemm_input;
|
||||
bool use_prequant_scale_kernel = use_awq && !std::is_same_v<T, WeightType>;
|
||||
@ -2928,10 +2913,20 @@ T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType,
|
||||
(!std::is_same_v<T, WeightType>), "Prequant scales are only used for different weight/activation type!");
|
||||
if constexpr (!std::is_same_v<T, WeightType>)
|
||||
{
|
||||
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
|
||||
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
|
||||
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
|
||||
num_valid_tokens_ptr, stream);
|
||||
if (expert_first_token_offset != nullptr)
|
||||
{
|
||||
tensorrt_llm::kernels::apply_per_channel_scale_per_expert_kernel_launcher<UnfusedGemmOutputType, T>(
|
||||
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
|
||||
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
|
||||
expert_first_token_offset, num_experts_per_node, num_valid_tokens_ptr, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<UnfusedGemmOutputType, T>(
|
||||
reinterpret_cast<T*>(smoothed_act), reinterpret_cast<UnfusedGemmOutputType const*>(permuted_data),
|
||||
reinterpret_cast<UnfusedGemmOutputType const*>(prequant_scales), expanded_num_rows, seq_len,
|
||||
num_valid_tokens_ptr, stream);
|
||||
}
|
||||
}
|
||||
gemm_input = reinterpret_cast<T const*>(smoothed_act);
|
||||
}
|
||||
@ -3740,7 +3735,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
}
|
||||
|
||||
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
|
||||
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream);
|
||||
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, expert_first_token_offset_,
|
||||
num_experts_per_node);
|
||||
sync_check_cuda_error(stream);
|
||||
Self::gemm2(moe_gemm_runner_, blockscale_gemm_runner, gemm2_input, fc2_result_, final_output,
|
||||
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
|
||||
|
||||
48
cpp/tensorrt_llm/kernels/moe_utils.cuh
Normal file
48
cpp/tensorrt_llm/kernels/moe_utils.cuh
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
// TODO Could linear search be better for small # experts
|
||||
template <class T>
|
||||
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
|
||||
{
|
||||
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||
while (low <= high)
|
||||
{
|
||||
int64_t mid = (low + high) / 2;
|
||||
|
||||
if (sorted_indices[mid] >= target)
|
||||
{
|
||||
high = mid - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
low = mid + 1;
|
||||
target_location = mid;
|
||||
}
|
||||
}
|
||||
return target_location + 1;
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -14,6 +14,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tensorrt_llm/kernels/moe_utils.cuh"
|
||||
#include "tensorrt_llm/kernels/preQuantScaleKernel.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
@ -41,7 +42,7 @@ struct Vec2Type<__nv_bfloat16>
|
||||
|
||||
template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
|
||||
__global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows,
|
||||
int cols, int64_t const* num_valid_tokens_ptr)
|
||||
int cols, int64_t const* num_valid_tokens_ptr, int64_t* expert_first_token_offset, int const num_experts_per_node)
|
||||
{
|
||||
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
|
||||
T_in scale[kElems], act_vec[kElems];
|
||||
@ -53,11 +54,19 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
|
||||
return;
|
||||
act += row_offset * kProcessRows * cols;
|
||||
smoothed_act += row_offset * kProcessRows * cols;
|
||||
*reinterpret_cast<AccessType*>(scale) = reinterpret_cast<AccessType const*>(per_channel_scale)[col_offset];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kProcessRows; ++i)
|
||||
{
|
||||
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<AccessType const*>(act + i * cols)[col_offset];
|
||||
int expert = 0;
|
||||
if (expert_first_token_offset != nullptr)
|
||||
{
|
||||
expert = findTotalEltsLessThanTarget(
|
||||
expert_first_token_offset, num_experts_per_node, (int64_t) row_offset * kProcessRows + i + 1)
|
||||
- 1;
|
||||
}
|
||||
*reinterpret_cast<AccessType*>(scale)
|
||||
= reinterpret_cast<AccessType const*>(per_channel_scale)[expert * cols / kElems + col_offset];
|
||||
if constexpr ((std::is_same_v<T_in, half>
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
|| std::is_same_v<T_in, __nv_bfloat16>
|
||||
@ -98,13 +107,14 @@ __global__ void apply_per_channel_scale(T_out* smoothed_act, T_in const* act, T_
|
||||
|
||||
template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
|
||||
void apply_per_channel_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
|
||||
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0)
|
||||
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0,
|
||||
int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0)
|
||||
{
|
||||
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
|
||||
dim3 block(128);
|
||||
dim3 grid((rows + kProcessRows - 1) / kProcessRows, (cols / kElems + block.x - 1) / block.x);
|
||||
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType>
|
||||
<<<grid, block, 0, stream>>>(smoothed_act, act, per_channel_scale, rows, cols, num_valid_tokens_ptr);
|
||||
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType><<<grid, block, 0, stream>>>(smoothed_act, act,
|
||||
per_channel_scale, rows, cols, num_valid_tokens_ptr, expert_first_token_offset, num_experts_per_node);
|
||||
}
|
||||
|
||||
template <typename T_in, typename T_out>
|
||||
@ -134,6 +144,34 @@ void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* ac
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_in, typename T_out>
|
||||
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
|
||||
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
|
||||
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
|
||||
{
|
||||
uint64_t elems = static_cast<uint64_t>(rows) * static_cast<uint64_t>(cols);
|
||||
if (elems < 2048 * 2048)
|
||||
{
|
||||
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 1, float4>(smoothed_act, act, per_channel_scale, rows,
|
||||
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
|
||||
}
|
||||
else if (elems < 4096 * 4096)
|
||||
{
|
||||
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 4, float4>(smoothed_act, act, per_channel_scale, rows,
|
||||
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
|
||||
}
|
||||
else if (elems < 8192 * 8192)
|
||||
{
|
||||
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 8, float4>(smoothed_act, act, per_channel_scale, rows,
|
||||
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 16, float4>(smoothed_act, act, per_channel_scale, rows,
|
||||
cols, num_valid_tokens_ptr, stream, expert_first_token_offset, num_experts_per_node);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_PREQUANT_SCALE(T_in, T_out) \
|
||||
template void apply_per_channel_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, const T_in* act, \
|
||||
const T_in* per_channel_scale, int rows, int cols, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
|
||||
@ -150,5 +188,22 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(T_in, T_out) \
|
||||
template void apply_per_channel_scale_per_expert_kernel_launcher<T_in, T_out>(T_out * smoothed_act, \
|
||||
const T_in* act, const T_in* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset, \
|
||||
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, half);
|
||||
#if defined(ENABLE_FP8)
|
||||
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(half, __nv_fp8_e4m3);
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_BF16)
|
||||
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_bfloat16);
|
||||
#if defined(ENABLE_FP8)
|
||||
INSTANTIATE_PREQUANT_SCALE_PER_EXPERT(__nv_bfloat16, __nv_fp8_e4m3);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -39,5 +39,10 @@ template <typename T_in, typename T_out = T_in>
|
||||
void apply_per_channel_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale,
|
||||
int rows, int cols, int64_t const* num_valid_tokens_ptr = nullptr, cudaStream_t stream = 0);
|
||||
|
||||
template <typename T_in, typename T_out = T_in>
|
||||
void apply_per_channel_scale_per_expert_kernel_launcher(T_out* smoothed_act, T_in const* act,
|
||||
T_in const* per_channel_scale, int rows, int cols, int64_t* expert_first_token_offset,
|
||||
int const num_experts_per_node, int64_t const* num_valid_tokens_ptr, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -914,14 +914,18 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
module.intermediate_size_per_partition // 2)
|
||||
|
||||
# Multiply act with reciprocal of per-channel pre_quant_scale * per-tensor input_scale
|
||||
fc31_act_scale = nn.Parameter(torch.empty(1,
|
||||
module.hidden_size,
|
||||
dtype=module.dtype),
|
||||
fc31_act_scale = nn.Parameter(torch.empty(
|
||||
module.expert_size_per_partition,
|
||||
module.hidden_size,
|
||||
dtype=module.dtype),
|
||||
requires_grad=False)
|
||||
module.register_parameter("fc31_act_scale", fc31_act_scale)
|
||||
|
||||
fc2_act_scale = nn.Parameter(torch.empty(
|
||||
1, module.intermediate_size_per_partition, 1, dtype=module.dtype),
|
||||
module.expert_size_per_partition,
|
||||
module.intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=module.dtype),
|
||||
requires_grad=False)
|
||||
module.register_parameter("fc2_act_scale", fc2_act_scale)
|
||||
|
||||
@ -1129,15 +1133,29 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
device=self.device)
|
||||
for expert_id in module.initial_local_expert_ids
|
||||
]
|
||||
all_w3_w1_pre_quant_scales_max = torch.max(
|
||||
torch.stack(all_w3_pre_quant_scales +
|
||||
all_w1_pre_quant_scales).to(module.dtype),
|
||||
all_w3_w1_pre_quant_scales_greater = torch.max(
|
||||
torch.stack([
|
||||
torch.stack(all_w3_pre_quant_scales),
|
||||
torch.stack(all_w1_pre_quant_scales)
|
||||
]).to(module.dtype),
|
||||
dim=0,
|
||||
).values.permute(1, 0)
|
||||
|
||||
all_w3_w1_input_scales_greater = torch.max(
|
||||
torch.stack([
|
||||
torch.stack(all_w3_input_scales),
|
||||
torch.stack(all_w1_input_scales)
|
||||
]).to(module.dtype),
|
||||
dim=0,
|
||||
).values
|
||||
|
||||
all_w3_w1_pre_quant_scales_div_input_scales = (
|
||||
all_w3_w1_pre_quant_scales_greater *
|
||||
(1 / all_w3_w1_input_scales_greater.reshape(
|
||||
1, module.expert_size_per_partition).float()))
|
||||
|
||||
module.fc31_act_scale.data.copy_(
|
||||
torch.ones_like(module.fc31_act_scale, device=self.device) *
|
||||
(all_w3_w1_pre_quant_scales_max) *
|
||||
(1 / all_w3_w1_input_scales_max))
|
||||
all_w3_w1_pre_quant_scales_div_input_scales.permute(1, 0))
|
||||
# In vanilla ckpt (at least from ModelOpt), per-tensor weight_scale_2 is separately stored
|
||||
all_w3_weight_scale_2 = [
|
||||
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_2"],
|
||||
@ -1149,13 +1167,21 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
device=self.device)
|
||||
for expert_id in module.initial_local_expert_ids
|
||||
]
|
||||
all_w3_w1_weight_scale_2_max = torch.max(
|
||||
torch.stack(all_w3_weight_scale_2 + all_w1_weight_scale_2).to(
|
||||
module.dtype),
|
||||
dim=0,
|
||||
).values
|
||||
module.fc31_alpha.data.copy_(all_w3_w1_weight_scale_2_max.float() *
|
||||
all_w3_w1_input_scales_max.float())
|
||||
all_w3_w1_weight_scale_2 = torch.stack([
|
||||
torch.stack(all_w3_weight_scale_2),
|
||||
torch.stack(all_w1_weight_scale_2)
|
||||
]).to(module.dtype)
|
||||
all_w3_w1_weight_scale_2_greater = torch.max(
|
||||
all_w3_w1_weight_scale_2, dim=0).values
|
||||
|
||||
all_w3_w1_weight_scale_2_mul_input_scales = (
|
||||
all_w3_w1_weight_scale_2_greater.reshape(
|
||||
module.expert_size_per_partition, 1).float() *
|
||||
all_w3_w1_input_scales_greater.reshape(
|
||||
module.expert_size_per_partition, 1).float())
|
||||
module.fc31_alpha.data.copy_(
|
||||
all_w3_w1_weight_scale_2_mul_input_scales.reshape(
|
||||
module.expert_size_per_partition, 1).float())
|
||||
|
||||
# Per-group weight_scale
|
||||
all_w3_scales = [
|
||||
@ -1183,7 +1209,11 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
w3_w1_scales = all_w3_w1_scales.to(torch.bfloat16).view(
|
||||
module.dtype)
|
||||
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w3_w1_scales /= all_w3_w1_weight_scale_2_max.float()
|
||||
w3_w1_scales = w3_w1_scales.permute(1, 2, 0)
|
||||
w3_w1_scales /= all_w3_w1_weight_scale_2_greater.reshape(
|
||||
module.expert_size_per_partition).float()
|
||||
w3_w1_scales = w3_w1_scales.permute(2, 0, 1)
|
||||
|
||||
w3_w1_s_shape = w3_w1_scales.shape
|
||||
w3_w1_scales_interleaved = w3_w1_scales.reshape(
|
||||
w3_w1_s_shape[0], w3_w1_s_shape[1],
|
||||
@ -1223,23 +1253,31 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
device=self.device)
|
||||
for expert_id in module.initial_local_expert_ids
|
||||
]
|
||||
all_w2_pre_quant_scales_max = torch.max(
|
||||
torch.stack(all_w2_pre_quant_scales).to(module.dtype),
|
||||
dim=0).values
|
||||
all_w2_pre_quant_scales = torch.stack(all_w2_pre_quant_scales).to(
|
||||
module.dtype)
|
||||
all_w2_input_scales = torch.stack(all_w2_input_scales).to(
|
||||
module.dtype)
|
||||
all_w2_pre_quant_scales_div_input_scales = (
|
||||
all_w2_pre_quant_scales.permute(1, 0) *
|
||||
(1 / (all_w2_input_scales.reshape(
|
||||
module.expert_size_per_partition).float()))).permute(1, 0)
|
||||
module.fc2_act_scale.data.copy_(
|
||||
torch.ones_like(module.fc2_act_scale, device=self.device) *
|
||||
(all_w2_pre_quant_scales_max.unsqueeze(-1)) *
|
||||
(1 / all_w2_input_scales_max))
|
||||
all_w2_pre_quant_scales_div_input_scales.reshape(
|
||||
module.fc2_act_scale.shape))
|
||||
# In vanilla ckpt (at least from ModelOpt), per-tensor weight_scale_2 is separately stored
|
||||
all_w2_weight_scale_2 = [
|
||||
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_2"],
|
||||
device=self.device)
|
||||
for expert_id in module.initial_local_expert_ids
|
||||
]
|
||||
all_w2_weight_scale_2_max = torch.stack(all_w2_weight_scale_2).to(
|
||||
module.dtype).max()
|
||||
module.fc2_alpha.data.copy_(all_w2_weight_scale_2_max.float() *
|
||||
all_w2_input_scales_max.float())
|
||||
all_w2_weight_scale_2 = torch.stack(all_w2_weight_scale_2).to(
|
||||
module.dtype)
|
||||
all_w2_weight_scale_2_mul_input_scales = (
|
||||
all_w2_weight_scale_2.reshape(module.expert_size_per_partition,
|
||||
1) *
|
||||
all_w2_input_scales.reshape(module.expert_size_per_partition,
|
||||
1))
|
||||
module.fc2_alpha.data.copy_(all_w2_weight_scale_2_mul_input_scales)
|
||||
|
||||
# Per-group weight_scale
|
||||
all_w2_scales = [
|
||||
@ -1258,7 +1296,11 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
module.dtype)
|
||||
|
||||
if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w2_scales /= all_w2_weight_scale_2_max.float()
|
||||
w2_scales = w2_scales.permute(1, 2, 0)
|
||||
all_w2_weight_scale_2 = all_w2_weight_scale_2.reshape(
|
||||
module.expert_size_per_partition)
|
||||
w2_scales /= (all_w2_weight_scale_2.float())
|
||||
w2_scales = w2_scales.permute(2, 0, 1)
|
||||
w2_s_shape = w2_scales.shape
|
||||
w2_scales_interleaved = w2_scales.reshape(
|
||||
w2_s_shape[0], w2_s_shape[1],
|
||||
|
||||
@ -1612,15 +1612,14 @@ def test_fused_moe_w4afp8(dtype, weight_loading_mode):
|
||||
dtype=torch.int8).cuda()
|
||||
|
||||
# The pre-quant scale to be multiplied with the input activation.
|
||||
w1_pre_quant_scale = torch.ones(HIDDEN_SIZE,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
w2_pre_quant_scale = torch.ones(INTERMEDIATE_SIZE,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
w3_pre_quant_scale = torch.ones(HIDDEN_SIZE,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
# Use random pre-quant scales [0.95, 1.05] instead of fixed 1.0 to ensure the kernel handles
|
||||
# non-uniform pre-quant scaling factors correctly
|
||||
w1_pre_quant_scale = torch.rand(
|
||||
HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
|
||||
w2_pre_quant_scale = torch.rand(
|
||||
INTERMEDIATE_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
|
||||
w3_pre_quant_scale = torch.rand(
|
||||
HIDDEN_SIZE, dtype=dtype, device="cuda") * 0.1 + 0.95
|
||||
|
||||
# The weight scale to dequantize int4 weights (by multiplication).
|
||||
w1_scale = torch.randn(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user