[TRTLLM-3330][feat] Support DeepSeek-R1 W4A8 on Hopper (#4123)

Support DeepSeek-R1 W4A8 on Hopper

Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Co-authored-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
This commit is contained in:
Barry Kang 2025-05-14 15:48:07 +08:00 committed by GitHub
parent bb17649517
commit 20b42912ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1232 additions and 117 deletions

View File

@ -102,8 +102,9 @@ enum class CutlassTileConfigSM90
CtaShape128x128x128B,
CtaShape128x256x128B,
// CTA configs for M=128
// CTA configs for M=256
CtaShape256x128x128B,
CtaShape256x256x128B,
};
enum class CutlassTileConfigSM100
@ -204,7 +205,9 @@ enum class TileShape
TileShape_128x32x128,
TileShape_128x64x128,
TileShape_128x128x128,
TileShape_128x256x128
TileShape_128x256x128,
TileShape_256x128x128,
TileShape_256x256x128
};
template <TileShape Shape_MNK>
@ -255,6 +258,14 @@ constexpr auto get_tile_shape()
{
return cute::Shape<_128, _256, _128>{};
}
else if constexpr (Shape_MNK == TileShape::TileShape_256x128x128)
{
return cute::Shape<_256, _128, _128>{};
}
else if constexpr (Shape_MNK == TileShape::TileShape_256x256x128)
{
return cute::Shape<_256, _256, _128>{};
}
}
static auto get_tile_shape_name(TileShape Shape_MNK)
@ -303,6 +314,14 @@ static auto get_tile_shape_name(TileShape Shape_MNK)
{
return "128x256x128";
}
else if (Shape_MNK == TileShape::TileShape_256x128x128)
{
return "256x128x128";
}
else if (Shape_MNK == TileShape::TileShape_256x256x128)
{
return "256x256x128";
}
return "Unknown shape";
}

View File

@ -225,9 +225,19 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(CutlassGemmConfig::C
#else
if (config & CutlassGemmConfig::GROUPED_GEMM)
{
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
if (config & CutlassGemmConfig::WEIGHT_ONLY)
{
return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B,
CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B,
CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B};
}
else
{
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
}
}
else
{
@ -240,6 +250,19 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(CutlassGemmConfig::C
#endif
}
bool sm90_supports_coop(CutlassTileConfigSM90 const tile)
{
#ifdef FAST_BUILD
return false;
#else
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B,
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B,
CutlassTileConfigSM90::CtaShape256x128x128B, CutlassTileConfigSM90::CtaShape256x256x128B};
return valid_tiles.count(tile) == 1;
#endif
}
// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve
// compilation speed.
bool sm90_supports_mcast_along_m(CutlassTileConfigSM90 const tile)
@ -275,37 +298,65 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(CutlassGemmConfig::Can
std::vector<CutlassGemmConfig> candidate_configs;
for (auto const& tile_config : tiles)
{
CutlassGemmConfig config(
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(config);
bool const has_m_mcast = sm90_supports_mcast_along_m(tile_config);
bool const has_n_mcast = sm90_supports_mcast_along_n(tile_config);
if (has_m_mcast)
bool const has_w4afp8 = (config & CutlassGemmConfig::WEIGHT_ONLY) && (config & CutlassGemmConfig::GROUPED_GEMM);
if (has_w4afp8)
{
CutlassGemmConfig config(
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(config);
bool const has_coop_supported = sm90_supports_coop(tile_config);
std::set<MainloopScheduleType> mainloop_schedules{MainloopScheduleType::PINGPONG};
if (has_coop_supported)
{
mainloop_schedules.insert(MainloopScheduleType::COOPERATIVE);
}
auto const epilogue_schedule = EpilogueScheduleType::AUTO;
for (auto const& mainloop_schedule : mainloop_schedules)
{
CutlassGemmConfig candidate(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(candidate);
}
}
if (has_n_mcast)
else
{
CutlassGemmConfig config(
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(config);
}
CutlassGemmConfig candidate(
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(candidate);
if (has_m_mcast)
{
CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(candidate);
}
if (has_m_mcast && has_n_mcast)
{
CutlassGemmConfig config(
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(config);
if (has_n_mcast)
{
CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(candidate);
}
if (has_m_mcast && has_n_mcast)
{
CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(candidate);
}
}
}
// add cuda kernel profiler to tactics for weight-only plugins
if (config & CutlassGemmConfig::WEIGHT_ONLY)
{
if (tiles.size() > 0)
if (tiles.size() > 0 && !(config & CutlassGemmConfig::GROUPED_GEMM))
{
CutlassGemmConfig CudaKernelConfig(
tiles[0], MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);

View File

@ -206,38 +206,47 @@ const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zer
);
"""
elif operation.gemm_kind == GemmKind.Grouped:
# Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules
assert operation.mainloop_schedule in [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
]
assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
kernel_sched.replace("::Kernel", "::KernelGrouped")
epi_sched += "Grouped"
if operation.act_type != operation.weight_type:
# Mixed MoE GEMM
weight_tag = DataTypeTag[operation.weight_type]
instantiation = f"""
template void sm90_generic_mixed_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag},
{epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}, {quant_op}> (
GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size);
"""
else:
# Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules
assert operation.mainloop_schedule in [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
]
assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
kernel_sched.replace("::Kernel", "::KernelGrouped")
epi_sched += "Grouped"
# arch_tag = f"cutlass::arch::Sm{operation.arch}"
arch_tag = f"Sm{operation.arch}"
weight_tag = CudaTypeName[operation.weight_type]
assert operation.epi_fusion is not None
epi_fusion = EpiFusion[operation.epi_fusion]
# arch_tag = f"cutlass::arch::Sm{operation.arch}"
arch_tag = f"Sm{operation.arch}"
weight_tag = CudaTypeName[operation.weight_type]
assert operation.epi_fusion is not None
epi_fusion = EpiFusion[operation.epi_fusion]
epi_fusion = epi_fusion.split(':')[-1]
epi_tag = epi_tag.split(':')[-1]
epi_fusion = epi_fusion.split(':')[-1]
epi_tag = epi_tag.split(':')[-1]
guard_map = {
e2m1: "defined(ENABLE_FP4)",
DataType.e4m3: "defined(ENABLE_FP8)",
DataType.bf16: "defined(ENABLE_BF16)"
}
guard = guard_map[
operation.act_type] if operation.act_type in guard_map else "1"
# TODO Revert this once compiler bug is fixed so we can use template instead of macro again
# instantiation = f"""
# template void tma_warp_specialized_generic_moe_gemm_kernelLauncher<{arch_tag}, {act_tag}, {weight_tag}, {out_tag},
# {epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, false>
# (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
# """
instantiation = f"""
guard_map = {
e2m1: "defined(ENABLE_FP4)",
DataType.e4m3: "defined(ENABLE_FP8)",
DataType.bf16: "defined(ENABLE_BF16)"
}
guard = guard_map[
operation.act_type] if operation.act_type in guard_map else "1"
# TODO Revert this once compiler bug is fixed so we can use template instead of macro again
# instantiation = f"""
# template void tma_warp_specialized_generic_moe_gemm_kernelLauncher<{arch_tag}, {act_tag}, {weight_tag}, {out_tag},
# {epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, false>
# (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
# """
instantiation = f"""
#if {guard}\n
INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag},
{epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, false);\n
@ -510,9 +519,57 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled):
return operations
def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
if not is_arch_enabled:
return []
arch = 90
supported_dtypes = [
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
(DataType.e4m3, DataType.u4, DataType.bf16, DataType.bf16,
DataType.bf16),
]
quant_ops = [TrtLlm_QuantOp.finegrained_scale_only]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM
N_TILES = [16, 32, 64, 128]
K_TILES = [128, 256, 512]
cta_shapes_mnk = list(product(M_TILES, N_TILES, K_TILES))
warp_shape = [0, 0, 0] # ignored except for naming
stages = 0 # auto
cga_shapes = product([1, 2], [1, 2], [1])
partial_args = product(supported_dtypes, quant_ops, epi_tags,
cta_shapes_mnk, cga_shapes)
operations = list()
for dtype_combo, quant_op, epi_tag, cta_shape_mnk, cga_shape in partial_args:
use_coop = cta_shape_mnk[0] >= 128
mainloop_schedules = [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedPingpong
] if use_coop else [KernelScheduleType.TmaWarpSpecializedPingpong]
epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
for mainloop_schedule in mainloop_schedules:
if (cta_shape_mnk[0] == 128 and cta_shape_mnk[1] == 128
and mainloop_schedule
== KernelScheduleType.TmaWarpSpecializedCooperative):
continue
moe_gemm_operation = TrtLlm_GemmLauncher(GemmKind.Grouped, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
operations.append(moe_gemm_operation)
return operations
def generate_sm90_operations(is_arch_enabled):
operations = generate_sm90_mixed_gemm_operations()
operations.extend(generate_sm90_grouped_gemm_operations(is_arch_enabled))
operations.extend(
generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled))
return operations
@ -699,6 +756,7 @@ if __name__ == "__main__":
fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
# moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
moe_mixed_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl"
# sm80_moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl"
sm80_moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl"
@ -725,20 +783,29 @@ if __name__ == "__main__":
is_internal = op.gemm_kind == GemmKind.Grouped
return is_internal != args.internal
def is_mixed_dtype_grouped(op):
if isinstance(op, GemmSm80LauncherConfig):
return False
return (op.act_type != op.weight_type) and (op.gemm_kind
== GemmKind.Grouped)
op_groups = dict()
for op in operations:
if should_skip(op):
continue
dict_key = (op.gemm_kind, op.arch, op.cta_shape[0])
dict_key = (op.gemm_kind, op.arch, op.cta_shape[0],
is_mixed_dtype_grouped(op))
op_group = op_groups.get(dict_key, list())
op_group.append(op)
op_groups[dict_key] = op_group
file_counter = 1
for key, value in op_groups.items():
gemm_kind, _, _ = key
gemm_kind, _, _, is_mixed_dtype_grouped = key
out_file = os.path.join(
output_dir, GemmKindNames[gemm_kind],
f"cutlass_kernel_file_{file_counter}.generated.cu")
write_file(inl_map[key[:2]], value, out_file)
inl_file = [moe_mixed_gemm_inl
] if is_mixed_dtype_grouped else inl_map[key[:2]]
write_file(inl_file, value, out_file)
file_counter += 1

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bcea962f0c9ab2efb4cd6171be456b1c7f68e31d2a257c4eee6b3e9f5e560904
size 47910208
oid sha256:723334a99a2f23bd16f50e69c2a3f21a06a06a41d0eb2ebe100337cbc0907c1a
size 52931896

View File

@ -1,2 +1,2 @@
a6bcad94c12cb55cbabc5fb30e7a4adb9e6906cc52cb285a9dd42aa71f7760e3 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 9f0fabbb7f7f678fe34bb0eeed756869676d9304
26da3daf623e613ccb02765cfb19b4ad9e19888a2fbfb0be7d2cfb96735c8a13 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 1c6e86206675670e0f37ebf40dcc2562d56039ca

View File

@ -211,10 +211,32 @@ struct TmaWarpSpecializedGroupedGemmInput
LayoutScalingFactorsA* fp4_block_scaling_factors_stride_A = nullptr;
LayoutScalingFactorsB* fp4_block_scaling_factors_stride_B = nullptr;
struct INT4GroupwiseParams
{
constexpr static int group_size = 128; // Unused, hard-coded to 128
bool enabled = false;
using SFA = __nv_bfloat16;
using SFB = __nv_bfloat16;
using ProblemShapeInt = cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
using LayoutSFA = typename cutlass::layout::ColumnMajor;
using LayoutSFB = typename cutlass::layout::ColumnMajor; // Unused
using StrideSFA = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using StrideSFB = cute::Stride<cute::Int<1>, int64_t, int64_t>; // Unused
StrideSFA* stride_s_a = nullptr;
StrideSFB* stride_s_b = nullptr; // Unused
const SFA** ptr_s_a = nullptr;
const SFA** ptr_z_a = nullptr; // Unused
const SFB** ptr_s_b = nullptr; // Unused
const SFB** ptr_z_b = nullptr; // Unused
ProblemShapeInt shape{};
};
INT4GroupwiseParams int4_groupwise_params;
uint8_t* gemm_workspace = nullptr;
size_t gemm_workspace_size = 0;
static std::array<size_t, 14> workspaceBuffers(int num_experts);
static std::array<size_t, 17> workspaceBuffers(int num_experts);
static size_t workspaceSize(int num_experts);
@ -248,9 +270,13 @@ public:
MoeGemmRunner();
#if defined(ENABLE_FP8)
static constexpr bool use_fp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
static constexpr bool use_fp8 = (std::is_same_v<T, __nv_fp8_e4m3>
|| std::is_same_v<T, __nv_fp8_e5m2>) &&!std::is_same_v<WeightType, cutlass::uint4b_t>;
static constexpr bool use_w4afp8
= std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, cutlass::uint4b_t>;
#else
static constexpr bool use_fp8 = false;
static constexpr bool use_w4afp8 = false;
#endif
#if defined(ENABLE_FP4)

View File

@ -558,6 +558,7 @@ private:
static TmaWarpSpecializedGroupedGemmInput computeStridesTmaWarpSpecialized(int64_t const* expert_first_token_offset,
TmaWarpSpecializedGroupedGemmInput layout_info, int64_t num_tokens, int64_t expanded_num_tokens, int64_t gemm_n,
int64_t gemm_k, int const num_experts_per_node, T const* in, WeightType const* weights,
TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const* w4a8_weight_scale_flat,
float const* fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_scale_flat,
QuantParams::FP4Inputs::GemmInputs fp4_inputs, T const* bias, UnfusedGemmOutputType* output,
cudaStream_t stream);
@ -622,6 +623,7 @@ private:
QuantParams& quant_params, cudaStream_t stream);
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr,
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream);
CubKeyValueSorter sorter_;

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e47924159db0476a3d8026e27514eea07e4b4db690d6f334ef05c41a235014cf
size 47540524
oid sha256:a70cc6672a1a083c7bca7f8efe4331b84535e79a609cbae7a8b289a4cbb3725b
size 52534712

View File

@ -1,2 +1,2 @@
a4cd97f177fc4c582d8ab2dfd10b7428b33154ddbd4d9f734cb561ba15e552e7 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 9f0fabbb7f7f678fe34bb0eeed756869676d9304
bc4a32343119b018d87c5716986055ce5f149e0c8d695fcd591581310c1f6066 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 1c6e86206675670e0f37ebf40dcc2562d56039ca

View File

@ -148,5 +148,128 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif
template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
__global__ void apply_per_expert_scale(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
T_in act_vec[kElems];
int col_offset = blockIdx.x * blockDim.x + threadIdx.x;
int row_offset = blockIdx.y;
int expert_idx = permuted_token_selected_experts[row_offset];
T_in scale = per_expert_scale[expert_idx];
if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows)
return;
if (num_valid_tokens_ptr && (row_offset * kProcessRows >= *num_valid_tokens_ptr))
return;
act += row_offset * kProcessRows * cols;
smoothed_act += row_offset * kProcessRows * cols;
#pragma unroll
for (int i = 0; i < kProcessRows; ++i)
{
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<AccessType const*>(act + i * cols)[col_offset];
if constexpr ((std::is_same_v<T_in, half>
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|| std::is_same_v<T_in, __nv_bfloat16>
#endif
) &&(kElems % 2 == 0))
{
using Vec2 = typename Vec2Type<T_in>::type;
#pragma unroll
for (int j = 0; j < kElems; j += 2)
{
if constexpr (std::is_same_v<T_in, half>)
{
*reinterpret_cast<Vec2*>(act_vec + j)
= __hmul2(*reinterpret_cast<Vec2*>(act_vec + j), __half2half2(scale));
}
else
{
*reinterpret_cast<Vec2*>(act_vec + j)
= __hmul2(*reinterpret_cast<Vec2*>(act_vec + j), __bfloat162bfloat162(scale));
}
}
}
else
{
#pragma unroll
for (int j = 0; j < kElems; ++j)
{
act_vec[j] = static_cast<T_in>(static_cast<float>(act_vec[j]) * static_cast<float>(scale));
}
}
if constexpr (std::is_same_v<T_in, T_out>)
{
reinterpret_cast<AccessType*>(smoothed_act + i * cols)[col_offset]
= *reinterpret_cast<AccessType*>(act_vec);
}
else
{
#pragma unroll
for (int j = 0; j < kElems; ++j)
{
(smoothed_act + i * cols)[col_offset * kElems + j] = static_cast<T_out>(act_vec[j]);
}
}
}
}
template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
void apply_per_expert_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols,
cudaStream_t stream = 0)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
dim3 block(128);
dim3 grid((cols / kElems + block.x - 1) / block.x, (rows + kProcessRows - 1) / kProcessRows);
apply_per_expert_scale<T_in, T_out, kProcessRows, AccessType><<<grid, block, 0, stream>>>(
smoothed_act, act, per_expert_scale, permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols);
}
template <typename T_in, typename T_out>
void apply_per_expert_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols,
cudaStream_t stream)
{
int elems = rows * cols;
if (elems < 2048 * 2048)
{
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 1, float4>(smoothed_act, act, per_expert_scale,
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
}
else if (elems < 4096 * 4096)
{
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 4, float4>(smoothed_act, act, per_expert_scale,
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
}
else if (elems < 8192 * 8192)
{
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 8, float4>(smoothed_act, act, per_expert_scale,
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
}
else
{
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 16, float4>(smoothed_act, act, per_expert_scale,
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
}
}
#define INSTANTIATE_PEREXPERT_SCALE(T_in, T_out) \
template void apply_per_expert_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, T_in const* act, \
T_in const* per_expert_scale, int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, \
int rows, int cols, cudaStream_t stream)
INSTANTIATE_PEREXPERT_SCALE(half, half);
#if defined(ENABLE_FP8)
INSTANTIATE_PEREXPERT_SCALE(half, __nv_fp8_e4m3);
#endif
#if defined(ENABLE_BF16)
INSTANTIATE_PEREXPERT_SCALE(__nv_bfloat16, __nv_bfloat16);
#if defined(ENABLE_FP8)
INSTANTIATE_PEREXPERT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -39,5 +39,10 @@ template <typename T_in, typename T_out = T_in>
void apply_per_channel_scale_kernel_launcher(
T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream = 0);
template <typename T_in, typename T_out = T_in>
void apply_per_expert_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols,
cudaStream_t stream = 0);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -109,6 +109,7 @@ public:
case at::ScalarType::Bool: return IBuffer::DataType::kBOOL;
case at::ScalarType::Float8_e4m3fn: return IBuffer::DataType::kFP8;
case at::ScalarType::BFloat16: return IBuffer::DataType::kBF16;
case at::ScalarType::QUInt4x2: return IBuffer::DataType::kINT4;
default: TLLM_THROW("unsupported data type");
}
}

View File

@ -84,12 +84,13 @@ public:
};
FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
bool use_fp8_block_scaling)
bool use_fp8_block_scaling, bool use_w4a8_group_scaling)
{
mActivationDtype = activation_dtype;
mWeightDtype = weight_dtype;
mOutputDtype = output_dtype;
mUseFp8BlockScaling = use_fp8_block_scaling;
mUseW4A8GroupScaling = use_w4a8_group_scaling;
mInnerDimMultiplier = 1;
// keep consistent with cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp
@ -132,6 +133,44 @@ public:
}
}
#endif
if (isInt4Quant())
{
mInnerDimMultiplier = 2;
if (mActivationDtype == c10::ScalarType::Half)
{
#ifdef ENABLE_FP8
if (mUseW4A8GroupScaling)
{
mKernelRunner
= std::make_unique<kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>>();
}
else
{
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>();
}
#else
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>();
#endif
}
#ifdef ENABLE_BF16
else if (mActivationDtype == c10::ScalarType::BFloat16)
{
#ifdef ENABLE_FP8
if (mUseW4A8GroupScaling)
{
mKernelRunner = std::make_unique<
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>>();
}
else
{
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>();
}
#else
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>();
#endif
}
#endif
}
if (!mKernelRunner)
{
C10_THROW_ERROR_FORMATTED(Error,
@ -350,6 +389,7 @@ public:
int64_t const num_rows = input.sizes()[0];
int64_t const hidden_size = fc2_expert_weights.sizes()[1];
int64_t const inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
int64_t const group_size = isInt4Quant() ? 128 : -1;
int const num_experts = static_cast<int>(fc2_expert_weights.sizes()[0] * ep_size);
// Get specific profile configs according to the profile_id.
@ -371,14 +411,14 @@ public:
static_cast<int>(tp_rank), static_cast<int>(ep_size), static_cast<int>(ep_rank),
static_cast<int>(cluster_size), static_cast<int>(cluster_rank));
int const GROUP_SIZE = -1;
bool const USE_BIAS = false;
bool const USE_LORA = false;
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
tensorrt_llm::runtime::TorchUtils::dataType(mActivationDtype),
tensorrt_llm::runtime::TorchUtils::dataType(
mUseW4A8GroupScaling ? at::ScalarType::Float8_e4m3fn : mActivationDtype),
tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype),
tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
hidden_size, inter_size, GROUP_SIZE, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA,
hidden_size, inter_size, group_size, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA,
min_latency_mode, parallelism_config);
freeProfileWorkspace();
@ -412,6 +452,7 @@ private:
char* mProfileWorkspace = nullptr;
bool mUseFp8BlockScaling = false;
bool mUseW4A8GroupScaling = false;
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
std::vector<Profile> mAllProfiles;
@ -457,10 +498,9 @@ private:
int num_experts, int experts_per_token, tensorrt_llm::ActivationType activation_type,
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode)
{
size_t moe_workspace_size
= mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, experts_per_token,
activation_type, parallelismConfig, /* use_lora */ false, mUseFp8BlockScaling, min_latency_mode,
/* hasExpertPrequantScales */ false);
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseFp8BlockScaling,
min_latency_mode, mUseW4A8GroupScaling);
size_t src_to_dest_map_size = experts_per_token * num_rows * sizeof(int);
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
@ -561,6 +601,28 @@ private:
return kernels::QuantParams::FP8BlockScaling(
static_cast<float const*>(fc1_scales.data_ptr()), static_cast<float const*>(fc2_scales.data_ptr()));
}
else if (isInt4Quant())
{
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for INT4 quantization");
TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for INT4 quantization");
auto& fc1_weight_scales = quant_scales.value()[0];
auto& fc2_weight_scales = quant_scales.value()[1];
auto& fc1_act_scales = quant_scales.value()[2];
auto& fc2_act_scales = quant_scales.value()[3];
auto& fc1_weight_zeros = quant_scales.value()[4];
auto& fc2_weight_zeros = quant_scales.value()[5];
auto& fc1_alpha = quant_scales.value()[6];
auto& fc2_alpha = quant_scales.value()[7];
int group_size = 128;
return kernels::QuantParams::GroupWise(group_size, static_cast<void const*>(fc1_weight_scales.data_ptr()),
static_cast<void const*>(fc2_weight_scales.data_ptr()),
static_cast<void const*>(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr),
static_cast<void const*>(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr),
static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr),
static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr),
static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr),
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr));
}
else
{
return kernels::QuantParams{};
@ -577,6 +639,16 @@ private:
{
return mWeightDtype == c10::ScalarType::Long;
}
bool isInt4Quant() const
{
return mWeightDtype == c10::ScalarType::QUInt4x2;
}
bool isW4AFp8Quant() const
{
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant();
}
};
} // namespace torch_ext
@ -584,7 +656,7 @@ private:
TORCH_LIBRARY(trtllm, m)
{
m.class_<torch_ext::FusedMoeRunner>("FusedMoeRunner")
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool>())
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool>())
.def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile)
.def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum)
.def("run_moe", &torch_ext::FusedMoeRunner::runMoe)

View File

@ -32,21 +32,22 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/
- [DeepGEMM](#deepgemm)
- [FlashMLA](#flashmla)
- [FP8 KV Cache and MLA](#fp8-kv-cache-and-mla)
- [W4AFP8](#w4afp8)
- [Notes and Troubleshooting](#notes-and-troubleshooting)
## Hardware Requirements
DeepSeek-v3 has 671B parameters which needs about 671GB GPU memory for FP8 weights, and needs more memories for activation tensors and KV cache.
The minimum hardware requirements for running DeepSeek V3/R1 FP8&FP4 are listed as follows.
The minimum hardware requirements for running DeepSeek V3/R1 at FP8/FP4/W4A8 are listed as follows.
| GPU | DeepSeek-V3/R1 FP8 | DeepSeek-V3/R1 FP4 |
| -------- | ------- | -- |
| H100 80GB | 16 | N/A |
| H20 141GB | 8 | N/A |
| H20 96GB | 8 | N/A |
| H200 | 8 | N/A |
| B200/GB200| Not supported yet, WIP | 4 (8 GPUs is recommended for best perf) |
| GPU | DeepSeek-V3/R1 FP8 | DeepSeek-V3/R1 FP4 | DeepSeek-V3/R1 W4A8 |
| -------- | ------- | -- | -- |
| H100 80GB | 16 | N/A | 8 |
| H20 141GB | 8 | N/A | 4 |
| H20 96GB | 8 | N/A | 4 |
| H200 | 8 | N/A | 4 |
| B200/GB200| Not supported yet, WIP | 4 (8 GPUs is recommended for best perf) | Not supported yet, WIP |
Ampere architecture (SM80 & SM86) is not supported.
@ -566,6 +567,74 @@ pytorch_backend_config:
# ...
```
### W4AFP8
TensorRT-LLM supports W(INT)4-A(FP)8 for DeepSeek on __Hopper__. Activations and weights are quantized at per-tensor and per-group (1x128) granularity respectively for MoE, and FP8 block scaling is preserved for dense layers.
We provide a pre-quantized checkpoint for DeepSeek-R1 W4AFP8 at [HF model hub](https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8).
```bash
python quickstart_advanced.py --model_dir <W4AFP8 Checkpoint> --tp_size 8
```
Or you can follow the steps to generate one by yourselves.
#### Activation calibration
[ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is used for calibrating activations of MoE layers. We provide a calibrated file at [HF model hub](https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main/act_scales.safetensors) or you can run the following commands to generate by yourselves.
```bash
# Make sure for enough GPU resources (8xH200s) to run the following commands
PATH_OF_DEEPSEEK_R1=/llm-models/DeepSeek-R1/DeepSeek-R1
# Install ModelOpt from source
git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer/ && cd modelopt
pip install "nvidia-modelopt[all]" -U --extra-index-url https://pypi.nvidia.com
# Clone DeepSeek-V3 (base model of R1) Github repository for FP8 inference,
git clone https://github.com/deepseek-ai/DeepSeek-V3.git && cd DeepSeek-V3 && git checkout 1398800
# Convert the HF checkpoint to a specific format for DeepSeek
python inference/convert.py --hf-ckpt-path $PATH_OF_DEEPSEEK_R1 --save-path ds_r1 --n-experts 256 --model-parallel 8 && cd ..
# Do per-tensor fp8 calibration
torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path DeepSeek-V3/ds_r1 --config DeepSeek-V3/inference/configs/config_671B.json --quant_cfg FP8_DEFAULT_CFG --output_path ds_r1_fp8_per_tensor_calibration && cd ../..
```
#### Weight quantization and assembling
You can run the following bash to quantize weights and generate the full checkpoint.
```bash
#!/bin/bash
HF_MODEL_DIR=/models/DeepSeek-R1/DeepSeek-R1/
OUTPUT_DIR=/workspace/ckpt/
# Safetensors or ModelOpt exported FP8 checkpoint path is accepted
# e.g. ACT_SCALES=ds_r1_fp8_per_tensor_calibration
ACT_SCALES=/workspace/act_scales.safetensors
if [ ! -d "convert_logs" ]; then
mkdir convert_logs
fi
pids=()
for i in 0 1 2 3 4 5 6 7
do
python examples/quantization/quantize_mixed_precision_moe.py --model_dir $HF_MODEL_DIR --output_dir $OUTPUT_DIR --act_scales $ACT_SCALES --parts 9 --rank $i > convert_logs/log_$i 2>&1 &
pids+=($!)
done
python examples/quantization/quantize_mixed_precision_moe.py --model_dir $HF_MODEL_DIR --output_dir $OUTPUT_DIR --act_scales $ACT_SCALES --parts 9 --rank 8 > convert_logs/log_8 2>&1
pids+=($!)
for pid in ${pids[@]}; do
wait $pid
done
echo "All processes completed!"
```
The converted checkpoint could be used as `<YOUR_MODEL_DIR>` and consumed by other commands.
## Notes and Troubleshooting
- **Model Directory:** Update `<YOUR_MODEL_DIR>` with the actual path where the model weights reside.

View File

@ -0,0 +1,305 @@
import argparse
import json
import os
import re
import shutil
import torch
from safetensors.torch import safe_open, save_file
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir',
type=str,
required=True,
help='HF checkpoint path')
parser.add_argument('--output_dir',
type=str,
required=True,
help='Save path')
parser.add_argument(
'--act_scales',
type=str,
required=True,
help=
'ModelOpt calibrated checkpoint dir or extracted safetensors for activation scales'
)
parser.add_argument('--parts',
type=int,
default=1,
help='devide all safetensors into parts')
parser.add_argument('--rank',
type=int,
default=0,
help='which part to be quantize')
args = parser.parse_args()
return args
def load_and_preprocess_state_dict(modelopt_state_root, world_size=8):
state_dict_list = []
# load amax from state dict
for rank in range(world_size):
state_dict_list.append(
torch.load(
f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt",
map_location="cuda:0"))
# calculate the max across all TP ranks
merged_state_dict = state_dict_list[0]
for rank in range(world_size):
for key, amax in state_dict_list[rank].items():
if key in merged_state_dict.items():
amax = torch.max(amax.to(0), merged_state_dict[key].to(0))
merged_state_dict[key] = amax.to(0)
mapping = {
"ffn.shared_experts.w1": "mlp.shared_experts.gate_proj",
"ffn.shared_experts.w2": "mlp.shared_experts.down_proj",
"ffn.shared_experts.w3": "mlp.shared_experts.up_proj",
"ffn.shared_experts": "mlp.shared_experts",
"ffn.shared_experts": "mlp.shared_experts",
"ffn.shared_experts": "mlp.shared_experts",
"ffn.w1": "mlp.gate_proj",
"ffn.w2": "mlp.down_proj",
"ffn.w3": "mlp.up_proj",
"head": "lm_head",
"attn": "self_attn",
}
new_dict = {}
for k, v in merged_state_dict.items():
new_key = k.replace("layers", "model.layers")
for original_pattern, replace_pattern in mapping.items():
new_key = new_key.replace(original_pattern, replace_pattern)
# ffn.experts.xx.w1/w2/w3- > mlp.experts.xx.gate_proj/down_proj/up_proj
new_key = re.sub(r"ffn\.experts\.(\d+)\.w1",
r"mlp.experts.\1.gate_proj", new_key)
new_key = re.sub(r"ffn\.experts\.(\d+)\.w2",
r"mlp.experts.\1.down_proj", new_key)
new_key = re.sub(r"ffn\.experts\.(\d+)\.w3", r"mlp.experts.\1.up_proj",
new_key)
new_dict[new_key] = v
merged_state_dict.clear()
merged_state_dict.update(new_dict)
# set amax for modules to be fused and make sure they share the same input
for key, amax in merged_state_dict.items():
if "up_proj" in key:
gate_proj_key = key.replace("up_proj", "gate_proj")
if "weight_quantizer" in key:
fused_amax = torch.max(amax, merged_state_dict[gate_proj_key])
merged_state_dict[key] = fused_amax
merged_state_dict[gate_proj_key] = fused_amax
elif "input_quantizer" in key:
assert amax == merged_state_dict[gate_proj_key]
else:
raise NotImplementedError
return merged_state_dict
def get_scales_from_amax(start_layer, end_layer, renamed_state_dict):
weight_name_dict = {"gate_proj": 1, "down_proj": 2, "up_proj": 3}
scales = {}
for layer_idx in range(start_layer, end_layer):
amax_keys_per_layer = [
x for x in renamed_state_dict.keys()
if (x.startswith(f'model.layers.{layer_idx}.mlp.experts.')
and x.endswith(".input_quantizer._amax"))
]
for k in amax_keys_per_layer:
expert_idx = int(k.split('.')[5])
weight_idx = weight_name_dict[k.split('.')[6]]
val = renamed_state_dict[k]
scales[
f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.w{weight_idx}.input_scale'] = val.unsqueeze(
0) / 448
return scales
def quantize_fp8_block_scale_to_int4(fp8_tensor, fp8_scale):
group_size = 128
blocked_tensor = fp8_tensor.view(fp8_tensor.shape[0] // 128, 128,
fp8_tensor.shape[1] // 128,
128).to(torch.float32)
dequant_tensor = (blocked_tensor *
fp8_scale.unsqueeze(1).unsqueeze(3)).view(
fp8_tensor.shape[0],
fp8_tensor.shape[1] // group_size,
group_size).to(torch.bfloat16).to(torch.float32)
scale_tensor = torch.abs(dequant_tensor).max(dim=2).values / 7
quant_tensor = torch.clamp(torch.round(
(dequant_tensor / scale_tensor.unsqueeze(-1))),
min=-8,
max=7)
quant_tensor = quant_tensor.to(torch.int8)
return quant_tensor.view(fp8_tensor.shape), scale_tensor
def main(args):
model_dir = args.model_dir
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(args.rank % num_gpus)
model_index_file = os.path.join(model_dir, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
processed_files = {}
for tensor_name in list(weight_map.keys()):
if tensor_name not in weight_map:
continue
file_name = weight_map[tensor_name]
if file_name in processed_files:
continue
processed_files[file_name] = safe_open(os.path.join(
model_dir, file_name),
"pt",
device="cuda")
with open(os.path.join(model_dir, "config.json"), 'r') as file:
config = json.load(file)
num_layer = config['num_hidden_layers']
part_layer = (num_layer + args.parts - 1) // args.parts
start_layer = args.rank * part_layer
end_layer = min(num_layer, args.rank * part_layer + part_layer)
def get_tensor(name):
if name not in weight_map:
return None
ff = weight_map[name]
safetensors_loader = processed_files[ff]
return safetensors_loader.get_tensor(name).cuda()
def get_file_name(layer):
rank = layer // part_layer
return "model-%05d-of-%05d.safetensors" % (rank, args.parts)
new_safetensors = {}
new_json = {}
new_json['weight_map'] = {}
new_json['metadata'] = {}
for key in tqdm(list(weight_map.keys())):
if "mlp.experts" in key and (key.endswith("weight")
or key.endswith("weight_scale_inv")):
if key.endswith("weight_scale_inv"):
continue
if args.rank == 0:
layer = int(key.split(".")[2])
new_json['weight_map'][key] = get_file_name(layer)
new_json['weight_map'][key.replace(
"weight", "weight_scale_inv")] = get_file_name(layer)
if int(key.split(".")[2]) < start_layer or int(
key.split(".")[2]) >= end_layer:
continue
fp8_tensor = get_tensor(key)
fp8_scale = get_tensor(key.replace("weight", "weight_scale_inv"))
quant_tensor, scale_tensor = quantize_fp8_block_scale_to_int4(
fp8_tensor, fp8_scale)
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
packed_tensor = packer(quant_tensor.cpu().contiguous())
new_safetensors.update({key: packed_tensor})
new_safetensors.update({
key.replace("weight", "weight_scale_inv"):
scale_tensor.contiguous()
})
else:
name = key.split(".")
if args.rank == 0:
if len(name) < 3 or not name[2].isdigit():
new_safetensors.update({key: get_tensor(key)})
new_json['weight_map'][key] = get_file_name(0)
continue
file_name = get_file_name(int(name[2]))
new_json['weight_map'][key] = file_name
if len(name) < 3 or not name[2].isdigit() or (int(
name[2]) < start_layer or int(name[2]) >= end_layer):
continue
new_safetensors.update({key: get_tensor(key)})
if args.rank == 0:
if os.path.isdir(args.act_scales):
# Extract activation scales
renamed_state_dict = load_and_preprocess_state_dict(
modelopt_state_root=args.act_scales, world_size=8)
get_scales_from_amax(start_layer=start_layer,
end_layer=end_layer,
renamed_state_dict=renamed_state_dict)
else:
input_scales = safe_open(args.act_scales, "pt")
for k in input_scales.keys():
new_safetensors.update({k: input_scales.get_tensor(k)})
new_json['weight_map'][k] = "input_scales.safetensors"
file_name = get_file_name(start_layer)
print(f'saving to {file_name}...')
save_file(new_safetensors, os.path.join(output_dir, file_name))
with open(os.path.join(output_dir, "model.safetensors.index.json"),
"w") as f:
json.dump(new_json, f)
names = [
"configuration_deepseek.py", "generation_config.json",
"modeling_deepseek.py", "tokenizer.json", "tokenizer_config.json"
]
for name in names:
shutil.copy(os.path.join(model_dir, name), output_dir)
shutil.copy(args.act_scales, output_dir)
# config.json
del config['quantization_config']
with open(os.path.join(output_dir, "config.json"), 'w') as file:
json.dump(config, file, indent=4)
# quant_cfg.json
attn_names = ["fused_a", "q_b_proj", "kv_b_proj", "o_proj"]
mlp_names = ["gate_up_proj", "down_proj"]
fp8_block_scale = {"quant_algo": "FP8_BLOCK_SCALES"}
w4a8_awq = {"quant_algo": "W4A8_AWQ"}
quant_cfg = {}
quant_cfg["quant_algo"] = "MIXED_PRECISION"
quant_cfg["kv_cache_quant_algo"] = None
quant_cfg["quantized_layers"] = {}
for l in range(61):
prefix = f"model.layers.{l}"
for n1 in attn_names:
quant_cfg["quantized_layers"][
f"{prefix}.self_attn.{n1}"] = fp8_block_scale
for n2 in mlp_names:
quant_cfg["quantized_layers"][
f"{prefix}.mlp.shared_experts.{n2}"] = fp8_block_scale
if l < 3:
for n3 in mlp_names:
quant_cfg["quantized_layers"][
f"{prefix}.mlp.{n3}"] = fp8_block_scale
else:
quant_cfg["quantized_layers"][
f"{prefix}.mlp.experts"] = w4a8_awq
with open(os.path.join(output_dir, "quant_cfg.json"), 'w') as file:
json.dump(quant_cfg, file, indent=4)
# hf_quant_config.json
hf_quant_config = {}
hf_quant_config['quantization'] = {}
hf_quant_config['quantization']["quant_algo"] = "MIXED_PRECISION"
hf_quant_config['quantization']["kv_cache_quant_algo"] = None
with open(os.path.join(output_dir, "hf_quant_config.json"),
'w') as file:
json.dump(hf_quant_config, file, indent=4)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -33,6 +33,7 @@ class MoERunner(TunableRunner):
cluster_size: int,
cluster_rank: int,
use_fp8_block_scaling: bool,
use_w4a8_group_scaling: bool,
):
self.x_dtype = x_dtype
self.weight_dtype = weight_dtype
@ -45,14 +46,16 @@ class MoERunner(TunableRunner):
self.cluster_size = cluster_size
self.cluster_rank = cluster_rank
self.use_fp8_block_scaling = use_fp8_block_scaling
self.use_w4a8_group_scaling = use_w4a8_group_scaling
instance_key = (x_dtype, weight_dtype, output_dtype,
use_fp8_block_scaling)
use_fp8_block_scaling, use_w4a8_group_scaling)
if instance_key not in MoERunner._runner_dict:
MoERunner._runner_dict[
instance_key] = torch.classes.trtllm.FusedMoeRunner(
x_dtype, weight_dtype, output_dtype, use_fp8_block_scaling)
x_dtype, weight_dtype, output_dtype, use_fp8_block_scaling,
use_w4a8_group_scaling)
self._fused_moe_runner = MoERunner._runner_dict[instance_key]
self._is_nvfp4 = weight_dtype == torch.int64
@ -121,6 +124,7 @@ def fused_moe(
cluster_size: int = 1,
cluster_rank: int = 0,
use_fp8_block_scaling: bool = False,
use_w4a8_group_scaling: bool = False,
min_latency_mode: bool = False,
) -> List[torch.Tensor]:
@ -151,6 +155,7 @@ def fused_moe(
cluster_size=cluster_size,
cluster_rank=cluster_rank,
use_fp8_block_scaling=use_fp8_block_scaling,
use_w4a8_group_scaling=use_w4a8_group_scaling,
)
_, gemm_tactic_1 = tuner.choose_one(
@ -208,6 +213,7 @@ def _(
cluster_size: int = 1,
cluster_rank: int = 0,
use_fp8_block_scaling: bool = False,
use_w4a8_group_scaling: bool = False,
min_latency_mode: bool = False,
):
seq_len = input.shape[0]

View File

@ -418,7 +418,7 @@ class DecoderModelForCausalLM(nn.Module,
if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
for n, q in quant_config_dict.items():
# gate_proj and up_proj share the same quant config
if prefix_name + '.gate_proj' in n:
if prefix_name + '.gate_proj' in n or prefix_name + '.gate_up_proj' in n:
module.quant_config = q
break
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
@ -438,7 +438,13 @@ class DecoderModelForCausalLM(nn.Module,
if name + '.q_proj' in n:
module.quant_config = q
break
# TODO: support MLA
elif hasattr(module, 'fused_a'):
# DeepseekV3Attention
for n, q in quant_config_dict.items():
# reuse q_proj quant config as the attention quant config
if name + '.fused_a' in n:
module.quant_config = q
break
# 2. skip quant for modules in QuantConfig.exclude_modules.
# kv_cache_quant_algo takes precedence over exclude_modules.

View File

@ -8,6 +8,7 @@ import torch
from torch import nn
from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.quantization.utils.fp4_utils import (
reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a)
@ -398,7 +399,9 @@ class FusedMoE(nn.Module):
exclude_kv_cache=True):
if not (self.quant_config.quant_mode.has_nvfp4()
| self.quant_config.quant_mode.has_fp8_block_scales()
| self.quant_config.quant_mode.has_fp8_qdq()):
| self.quant_config.quant_mode.has_fp8_qdq()
| self.quant_config.quant_mode.
is_int4_weight_only_per_group()):
raise ValueError(
f"unsupported quantization mode: {self.quant_config.quant_mode}"
)
@ -428,6 +431,17 @@ class FusedMoE(nn.Module):
fc2_weight_block=self.w2_weight_scale,
fc2_global=self.fc2_alpha,
)
elif self.has_w4afp8:
self.quant_scales = FusedMoEQuantScalesW4A8(
scale_1_interleaved=self.fc31_weight_scale,
scale_2_interleaved=self.fc2_weight_scale,
pre_quant_scale_1=self.fc31_act_scale,
pre_quant_scale_2=self.fc2_act_scale,
zero_1=torch.Tensor(),
zero_2=torch.Tensor(),
alpha_1=self.fc31_alpha,
alpha_2=self.fc2_alpha,
)
def is_trtllm(self):
return self.moe_backend == "TRTLLM" and self.has_any_quant
@ -458,6 +472,23 @@ class FusedMoE(nn.Module):
fc2_global=self.fc2_alpha.narrow(0, expert_start,
expert_end - expert_start),
)
elif self.has_w4afp8:
return FusedMoEQuantScalesW4A8(
scale_1_interleaved=self.fc31_weight_scale.narrow(
0, expert_start, expert_end - expert_start),
scale_2_interleaved=self.fc2_weight_scale.narrow(
0, expert_start, expert_end - expert_start),
pre_quant_scale_1=self.fc31_act_scale.narrow(
0, expert_start, expert_end - expert_start),
pre_quant_scale_2=self.fc2_act_scale.narrow(
0, expert_start, expert_end - expert_start),
zero_1=torch.Tensor(),
zero_2=torch.Tensor(),
alpha_1=self.fc31_alpha.narrow(0, expert_start,
expert_end - expert_start),
alpha_2=self.fc2_alpha.narrow(0, expert_start,
expert_end - expert_start),
)
else:
return self.quant_scales
@ -479,6 +510,7 @@ class FusedMoE(nn.Module):
self.has_fp8_qdq = False
self.has_fp8_block_scales = False
self.has_nvfp4 = False
self.has_w4afp8 = False
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True):
qc = self.quant_config
@ -534,6 +566,91 @@ class FusedMoE(nn.Module):
requires_grad=False)
self.register_parameter("w2_weight_scaling_factor",
w2_weight_scaling_factor)
elif qc.quant_mode.is_int4_weight_only_per_group():
self.has_w4afp8 = True
self.sm_version = get_sm_version()
if self.sm_version == 89:
self.interleave = [1, 1]
elif self.sm_version == 90:
self.interleave = []
for k_shape in [
self.hidden_size,
self.intermediate_size_per_partition
]:
if k_shape % 512 == 0:
self.interleave.append(4)
elif k_shape % 256 == 0:
self.interleave.append(2)
elif k_shape % 128 == 0:
self.interleave.append(1)
else:
raise NotImplementedError(
f"K shape is required to be multiple of 128, received {k_shape}."
)
else:
raise NotImplementedError(
f"W4AFP8 MoE is unsupported on SM{self.sm_version}.")
weight_dtype = torch.int8
w3_w1_weight_shape = (self.expert_size_per_partition,
self.intermediate_size_per_partition * 2,
self.hidden_size // 2)
w2_weight_shape = (self.expert_size_per_partition,
self.hidden_size,
self.intermediate_size_per_partition // 2)
fc31_act_scale = nn.Parameter(torch.empty(
self.expert_size_per_partition,
1,
dtype=self.dtype,
device=device),
requires_grad=False)
self.register_parameter("fc31_act_scale", fc31_act_scale)
fc2_act_scale = nn.Parameter(torch.empty(
self.expert_size_per_partition,
1,
dtype=self.dtype,
device=device),
requires_grad=False)
self.register_parameter("fc2_act_scale", fc2_act_scale)
# col parallel
fc31_weight_scale = nn.Parameter(
torch.empty(self.expert_size_per_partition,
self.hidden_size // (128 * self.interleave[0]),
self.intermediate_size_per_partition * 2 *
self.interleave[0],
dtype=self.dtype,
device=device),
requires_grad=False)
self.register_parameter("fc31_weight_scale", fc31_weight_scale)
# row parallel
fc2_weight_scale = nn.Parameter(
torch.empty(self.expert_size_per_partition,
self.intermediate_size_per_partition //
(128 * self.interleave[1]),
self.hidden_size * self.interleave[1],
dtype=self.dtype,
device=device),
requires_grad=False)
self.register_parameter("fc2_weight_scale", fc2_weight_scale)
fc31_alpha = nn.Parameter(torch.empty(
self.expert_size_per_partition,
1,
dtype=torch.float32,
device=device),
requires_grad=False)
self.register_parameter("fc31_alpha", fc31_alpha)
fc2_alpha = nn.Parameter(torch.empty(
self.expert_size_per_partition,
1,
dtype=torch.float32,
device=device),
requires_grad=False)
self.register_parameter("fc2_alpha", fc2_alpha)
elif qc.quant_mode.has_nvfp4():
self.has_nvfp4 = True
if self.is_trtllm():
@ -701,6 +818,8 @@ class FusedMoE(nn.Module):
output_dtype = x.dtype
use_fp8_block_scaling = False
use_w4a8_group_scaling = False
weight_dtype = self.w3_w1_weight.dtype
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
@ -749,6 +868,9 @@ class FusedMoE(nn.Module):
elif self.has_fp8_block_scales:
use_fp8_block_scaling = True
elif self.has_w4afp8:
use_w4a8_group_scaling = True
weight_dtype = torch.quint4x2
else:
raise ValueError(
f"unsupported quantization mode: {self.quant_config.quant_mode}"
@ -801,8 +923,8 @@ class FusedMoE(nn.Module):
x,
token_selected_experts,
token_final_scales,
w3_w1_weight,
w2_weight,
w3_w1_weight.view(weight_dtype),
w2_weight.view(weight_dtype),
output_dtype,
quant_scales=quant_scales,
input_sf=x_sf,
@ -813,6 +935,7 @@ class FusedMoE(nn.Module):
cluster_size=cluster_size,
cluster_rank=cluster_rank,
use_fp8_block_scaling=use_fp8_block_scaling,
use_w4a8_group_scaling=use_w4a8_group_scaling,
min_latency_mode=cutlass_min_latency_mode,
)
@ -1151,6 +1274,18 @@ class FusedMoE(nn.Module):
w31_weight_shard)
w31_weight_shard = shuffle_matrix_a(w31_weight_shard,
epilogue_tile_m)
if self.has_w4afp8 and self.sm_version == 89:
import tensorrt_llm.quantization.functional
preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
w31_weight_shard = packer(
unpacker(w31_weight_shard.cpu()).T.contiguous()).to(
w31_weight_shard.device)
w31_weight_shard = preprocessor(w31_weight_shard,
torch.quint4x2,
torch.float8_e4m3fn,
89).view(dst_w3_w1_weight.shape)
dst_w3_w1_weight.copy_(w31_weight_shard.view(
dst_w3_w1_weight.dtype))
@ -1166,6 +1301,19 @@ class FusedMoE(nn.Module):
epilogue_tile_m = 128
w2_weight_shard = shuffle_matrix_a(w2_weight_shard,
epilogue_tile_m)
if self.has_w4afp8 and self.sm_version == 89:
import tensorrt_llm.quantization.functional
preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
w2_weight_shard = packer(
unpacker(w2_weight_shard.cpu()).T.contiguous()).to(
w2_weight_shard.device)
w2_weight_shard = preprocessor(w2_weight_shard, torch.quint4x2,
torch.float8_e4m3fn,
89).view(dst_w2_weight.shape)
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype))
# Use multi-threading to load expert weights in parallel.
@ -1220,6 +1368,8 @@ class FusedMoE(nn.Module):
self._load_nvfp4_scales(weights)
elif self.quant_config.quant_mode.has_fp8_block_scales():
self._load_fp8_block_scales_scales(weights)
elif self.quant_config.quant_mode.is_int4_weight_only_per_group():
self._load_int4_groupwise_scales(weights)
else:
raise ValueError(
f"unsupported quantization mode: {self.quant_config.quant_mode}"
@ -1549,6 +1699,86 @@ class FusedMoE(nn.Module):
self.fc31_scale_c.data.copy_(self.fc2_input_scale.data *
self.fc31_alpha.data)
def _load_int4_groupwise_scales(self, weights: Dict):
# fc31 scales
assert (len(self.interleave) == 2)
all_w3_input_scales = [
load_weight_shard(weights[f"{expert_id}.w3.input_scale"])
for expert_id in range(self.expert_start, self.expert_end)
]
all_w1_input_scales = [
load_weight_shard(weights[f"{expert_id}.w1.input_scale"])
for expert_id in range(self.expert_start, self.expert_end)
]
all_w3_w1_input_scales = torch.max(torch.stack(all_w3_input_scales),
torch.stack(all_w1_input_scales))
all_w3_w1_input_scales = torch.ones_like(
all_w3_w1_input_scales) * all_w3_w1_input_scales.max()
self.fc31_act_scale.data.copy_(1 / all_w3_w1_input_scales)
self.fc31_alpha.data.copy_(all_w3_w1_input_scales.float())
all_w3_scales = [
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"],
self.tp_size, self.tp_rank,
TensorParallelMode.COLUMN)
for expert_id in range(self.expert_start, self.expert_end)
]
all_w1_scales = [
load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"],
self.tp_size, self.tp_rank,
TensorParallelMode.COLUMN)
for expert_id in range(self.expert_start, self.expert_end)
]
all_w3_w1_scales = torch.cat(
[torch.stack(all_w3_scales),
torch.stack(all_w1_scales)], dim=-2)
if self.sm_version == 89:
w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(self.dtype)
else:
w3_w1_scales = all_w3_w1_scales.to(torch.bfloat16).view(self.dtype)
w3_w1_s_shape = w3_w1_scales.shape
w3_w1_scales_interleaved = w3_w1_scales.reshape(
w3_w1_s_shape[0], w3_w1_s_shape[1],
(w3_w1_s_shape[2] // self.interleave[0]), self.interleave[0])
w3_w1_scales_interleaved = w3_w1_scales_interleaved.permute(0, 2, 1, 3)
w3_w1_scales_interleaved = w3_w1_scales_interleaved.reshape(
w3_w1_s_shape[0], w3_w1_s_shape[2] // self.interleave[0],
w3_w1_s_shape[1] * self.interleave[0])
self.fc31_weight_scale.data.copy_(w3_w1_scales_interleaved.contiguous())
# fc2 scales
all_w2_input_scales = [
load_weight_shard(weights[f"{expert_id}.w2.input_scale"])
for expert_id in range(self.expert_start, self.expert_end)
]
all_w2_input_scales = torch.stack(all_w2_input_scales).to(self.dtype)
all_w2_input_scales = torch.ones_like(
all_w2_input_scales) * all_w2_input_scales.max()
self.fc2_act_scale.data.copy_(1 / all_w2_input_scales)
self.fc2_alpha.data.copy_(all_w2_input_scales.float())
all_w2_scales = [
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"],
self.tp_size, self.tp_rank,
TensorParallelMode.ROW)
for expert_id in range(self.expert_start, self.expert_end)
]
if self.sm_version == 89:
w2_scales = torch.stack(all_w2_scales).to(torch.float16).view(
self.dtype)
else:
w2_scales = torch.stack(all_w2_scales).to(torch.bfloat16).view(
self.dtype)
w2_s_shape = w2_scales.shape
w2_scales_interleaved = w2_scales.reshape(
w2_s_shape[0], w2_s_shape[1], (w2_s_shape[2] // self.interleave[1]),
self.interleave[1])
w2_scales_interleaved = w2_scales_interleaved.permute(0, 2, 1, 3)
w2_scales_interleaved = w2_scales_interleaved.reshape(
w2_s_shape[0], w2_s_shape[2] // self.interleave[1],
w2_s_shape[1] * self.interleave[1])
self.fc2_weight_scale.data.copy_(w2_scales_interleaved.contiguous())
class FusedMoEQuantScalesFP8(NamedTuple):
fc1_dequant: torch.Tensor
@ -1572,3 +1802,14 @@ class FusedMoEQuantScalesNVFP4(NamedTuple):
class FusedMoEQuantScalesFP8BlockScales(NamedTuple):
fc_weight_scales: torch.Tensor
proj_weight_scales: torch.Tensor
class FusedMoEQuantScalesW4A8(NamedTuple):
scale_1_interleaved: torch.Tensor
scale_2_interleaved: torch.Tensor
pre_quant_scale_1: torch.Tensor
pre_quant_scale_2: torch.Tensor
zero_1: torch.Tensor
zero_2: torch.Tensor
alpha_1: torch.Tensor
alpha_2: torch.Tensor

View File

@ -500,8 +500,8 @@ class MOEWeightWrapper(Module):
else:
self.register_parameter('zero', None)
if groupwise_quant_algo & GroupwiseQuantAlgo.PRE_QUANT_SCALE:
self.prequant_scaling_factor = Parameter(shape=(1, in_features),
dtype=dtype)
self.prequant_scaling_factor = Parameter(
shape=(experts_per_node, 1), dtype=dtype)
else:
self.register_parameter('prequant_scaling_factor', None)
if groupwise_quant_algo & GroupwiseQuantAlgo.W4A8_ALPHA:

View File

@ -4,7 +4,8 @@ from typing import Dict, List, Optional
import pytest
import torch
import torch.nn as nn
from utils.util import skip_pre_blackwell, skip_pre_hopper
from utils.util import (skip_neither_ada_nor_hopper_unittest,
skip_pre_blackwell, skip_pre_hopper)
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.model_config import ModelConfig
@ -271,6 +272,136 @@ def test_fused_moe_nvfp4(dtype):
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
@skip_neither_ada_nor_hopper_unittest
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe_w4afp8(dtype):
SEQ_LEN = 4
HIDDEN_SIZE = 768
INTERMEDIATE_SIZE = 640
SCALING_GROUP_SIZE = 128
NUM_EXPERTS = 3
TOP_K = 2
routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda()
affine_coeff = 0.005
weights = {}
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.randint(-128,
127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2),
dtype=torch.int8).cuda()
w2_weight = torch.randint(-128,
127, (HIDDEN_SIZE, INTERMEDIATE_SIZE // 2),
dtype=torch.int8).cuda()
w3_weight = torch.randint(-128,
127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2),
dtype=torch.int8).cuda()
w1_scale = torch.randn(
(INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE),
dtype=dtype).cuda() * affine_coeff
w2_scale = torch.randn(
(HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE),
dtype=dtype).cuda() * affine_coeff
w3_scale = torch.randn(
(INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE),
dtype=dtype).cuda() * affine_coeff
w1_input = torch.randn(1, dtype=torch.float32).cuda() * 0.02
w2_input = w1_input
w3_input = w1_input
weights[f"{expert_id}.w1.weight"] = w1_weight
weights[f"{expert_id}.w2.weight"] = w2_weight
weights[f"{expert_id}.w3.weight"] = w3_weight
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_scale
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_scale
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_scale
weights[f"{expert_id}.w1.input_scale"] = w1_input
weights[f"{expert_id}.w2.input_scale"] = w2_input
weights[f"{expert_id}.w3.input_scale"] = w3_input
quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_AWQ)
fused_moe = FusedMoE(num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=False,
model_config=ModelConfig(quant_config=quant_config))
fused_moe.load_weights([weights])
fused_moe.cuda()
def ref():
results = torch.zeros_like(x)
selected_experts, final_scales = routing_method.apply(router_logits)
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
for e_idx in range(NUM_EXPERTS):
mask = selected_experts == e_idx
activated_tokens = mask.sum(1).bool()
act = x[activated_tokens, :]
if act.shape[0] == 0:
continue
final_scale = (final_scales *
mask).sum(1)[activated_tokens].unsqueeze(1)
# weights
w1 = weights[f"{e_idx}.w1.weight"]
w1 = unpacker(w1.cpu()).T.contiguous().cuda()
w2 = weights[f"{e_idx}.w2.weight"]
w2 = unpacker(w2.cpu()).T.contiguous().cuda()
w3 = weights[f"{e_idx}.w3.weight"]
w3 = unpacker(w3.cpu()).T.contiguous().cuda()
w3_w1 = torch.cat([w3, w1], dim=-1)
# scales
s1 = weights[f"{e_idx}.w1.weight_scale_inv"].T.contiguous().cuda()
s2 = weights[f"{e_idx}.w2.weight_scale_inv"].T.contiguous().cuda()
s3 = weights[f"{e_idx}.w3.weight_scale_inv"].T.contiguous().cuda()
s3_s1 = torch.cat([s3, s1], dim=-1)
# prequant / alpha
p1 = weights[f"{e_idx}.w1.input_scale"].cuda()
p2 = weights[f"{e_idx}.w2.input_scale"].cuda()
p3 = weights[f"{e_idx}.w3.input_scale"].cuda()
p3_p1 = max(p1, p3)
act = torch.clamp((act / p3_p1), -448.0,
448.0).to(torch.float8_e4m3fn).to(dtype)
w3_w1 = (w3_w1.float() *
s3_s1.repeat_interleave(128, dim=0).float()).to(dtype)
fc1 = torch.matmul(act, w3_w1) * p3_p1
fc1, gate = fc1.chunk(2, dim=-1)
fc1 = fc1 * torch.nn.functional.silu(gate)
act = torch.clamp((fc1 / p2), -448.0,
448.0).to(torch.float8_e4m3fn).to(dtype)
w2 = (w2.float() *
s2.repeat_interleave(128, dim=0).float()).to(dtype)
fc2 = torch.matmul(act, w2) * p2
results[activated_tokens, :] += (fc2 * final_scale).to(
results.dtype)
return results
AutoTuner.get().clear_cache()
with torch.inference_mode(), autotune():
fused_moe.forward(x, router_logits)
torch.cuda.synchronize()
with torch.inference_mode():
output = fused_moe.forward(x, router_logits)
ref_output = ref()
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
class RefGatedMLPFusedMoE(nn.Module):
def __init__(self,

View File

@ -154,14 +154,14 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
2**31, (num_experts, n, k // num_weights_in_32_bits),
dtype=torch.int32,
device="cuda")
pre_quant_scale_1 = torch.randn(1,
k,
dtype=activation_dtype,
device="cuda")
pre_quant_scale_2 = torch.randn(1,
n,
dtype=activation_dtype,
device="cuda")
pre_quant_scale_1 = torch.ones(num_experts,
1,
dtype=activation_dtype,
device="cuda")
pre_quant_scale_2 = torch.ones(num_experts,
1,
dtype=activation_dtype,
device="cuda")
scale_1 = torch.randn(num_experts,
k // group_size,
n * 2,
@ -182,14 +182,8 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
k,
dtype=activation_dtype,
device="cuda") * 0.01
alpha_1 = torch.randn(num_experts,
1,
dtype=torch.float32,
device="cuda")
alpha_2 = torch.randn(num_experts,
1,
dtype=torch.float32,
device="cuda")
alpha_1 = 1 / pre_quant_scale_1.float()
alpha_2 = 1 / pre_quant_scale_2.float()
preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
@ -242,8 +236,9 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
input = inputs_merged[i, :]
fc1_qd = ref_weight_1[expert].cuda().float()
if has_pre_quant:
input = input * pre_quant_scale_1.squeeze()
input = input * pre_quant_scale_1[expert]
if has_alpha:
input[input > 448.0] = 448.0
input = input.to(torch.float8_e4m3fn).float()
fc1_qd = fc1_qd.to(torch.float8_e4m3fn).float()
fc1 = torch.matmul(input, fc1_qd) * alpha_1[expert]
@ -253,8 +248,9 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
fc1 = fc1 * torch.nn.functional.silu(gate)
fc2_qd = ref_weight_2[expert].cuda().float()
if has_pre_quant:
fc1 = fc1 * pre_quant_scale_2.squeeze()
fc1 = fc1 * pre_quant_scale_2[expert]
if has_alpha:
fc1[fc1 > 448.0] = 448.0
fc1 = fc1.to(torch.float8_e4m3fn).float()
fc2_qd = fc2_qd.to(torch.float8_e4m3fn).float()
final = torch.matmul(fc1, fc2_qd) * alpha_2[expert]
@ -282,11 +278,6 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
name_func=unittest_name_func)
@skip_non_ada_unittest
def test_moe_w4a8(self, m, n, k, experts, dtype, has_pre_quant, has_zero):
# Skip specific problematic case
if m == 1 and n == 14336 and k == 4096 and experts == 8 and dtype == "bfloat16" and has_pre_quant and not has_zero:
self.skipTest(
"Skipping problematic case test_moe_w4a8_1_14336_4096_8_bfloat16_True_False"
)
self._woq_moe_groupwise_matmul(m, n, k, experts, dtype, torch.quint4x2,
has_pre_quant, has_zero, True)