mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-3330][feat] Support DeepSeek-R1 W4A8 on Hopper (#4123)
Support DeepSeek-R1 W4A8 on Hopper Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Co-authored-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
This commit is contained in:
parent
bb17649517
commit
20b42912ce
@ -102,8 +102,9 @@ enum class CutlassTileConfigSM90
|
||||
CtaShape128x128x128B,
|
||||
CtaShape128x256x128B,
|
||||
|
||||
// CTA configs for M=128
|
||||
// CTA configs for M=256
|
||||
CtaShape256x128x128B,
|
||||
CtaShape256x256x128B,
|
||||
};
|
||||
|
||||
enum class CutlassTileConfigSM100
|
||||
@ -204,7 +205,9 @@ enum class TileShape
|
||||
TileShape_128x32x128,
|
||||
TileShape_128x64x128,
|
||||
TileShape_128x128x128,
|
||||
TileShape_128x256x128
|
||||
TileShape_128x256x128,
|
||||
TileShape_256x128x128,
|
||||
TileShape_256x256x128
|
||||
};
|
||||
|
||||
template <TileShape Shape_MNK>
|
||||
@ -255,6 +258,14 @@ constexpr auto get_tile_shape()
|
||||
{
|
||||
return cute::Shape<_128, _256, _128>{};
|
||||
}
|
||||
else if constexpr (Shape_MNK == TileShape::TileShape_256x128x128)
|
||||
{
|
||||
return cute::Shape<_256, _128, _128>{};
|
||||
}
|
||||
else if constexpr (Shape_MNK == TileShape::TileShape_256x256x128)
|
||||
{
|
||||
return cute::Shape<_256, _256, _128>{};
|
||||
}
|
||||
}
|
||||
|
||||
static auto get_tile_shape_name(TileShape Shape_MNK)
|
||||
@ -303,6 +314,14 @@ static auto get_tile_shape_name(TileShape Shape_MNK)
|
||||
{
|
||||
return "128x256x128";
|
||||
}
|
||||
else if (Shape_MNK == TileShape::TileShape_256x128x128)
|
||||
{
|
||||
return "256x128x128";
|
||||
}
|
||||
else if (Shape_MNK == TileShape::TileShape_256x256x128)
|
||||
{
|
||||
return "256x256x128";
|
||||
}
|
||||
return "Unknown shape";
|
||||
}
|
||||
|
||||
|
||||
@ -225,9 +225,19 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(CutlassGemmConfig::C
|
||||
#else
|
||||
if (config & CutlassGemmConfig::GROUPED_GEMM)
|
||||
{
|
||||
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
|
||||
if (config & CutlassGemmConfig::WEIGHT_ONLY)
|
||||
{
|
||||
return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B,
|
||||
CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -240,6 +250,19 @@ std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(CutlassGemmConfig::C
|
||||
#endif
|
||||
}
|
||||
|
||||
bool sm90_supports_coop(CutlassTileConfigSM90 const tile)
|
||||
{
|
||||
#ifdef FAST_BUILD
|
||||
return false;
|
||||
#else
|
||||
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B,
|
||||
CutlassTileConfigSM90::CtaShape256x128x128B, CutlassTileConfigSM90::CtaShape256x256x128B};
|
||||
return valid_tiles.count(tile) == 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve
|
||||
// compilation speed.
|
||||
bool sm90_supports_mcast_along_m(CutlassTileConfigSM90 const tile)
|
||||
@ -275,37 +298,65 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(CutlassGemmConfig::Can
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
for (auto const& tile_config : tiles)
|
||||
{
|
||||
CutlassGemmConfig config(
|
||||
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
||||
candidate_configs.push_back(config);
|
||||
|
||||
bool const has_m_mcast = sm90_supports_mcast_along_m(tile_config);
|
||||
bool const has_n_mcast = sm90_supports_mcast_along_n(tile_config);
|
||||
if (has_m_mcast)
|
||||
bool const has_w4afp8 = (config & CutlassGemmConfig::WEIGHT_ONLY) && (config & CutlassGemmConfig::GROUPED_GEMM);
|
||||
if (has_w4afp8)
|
||||
{
|
||||
CutlassGemmConfig config(
|
||||
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1);
|
||||
candidate_configs.push_back(config);
|
||||
bool const has_coop_supported = sm90_supports_coop(tile_config);
|
||||
std::set<MainloopScheduleType> mainloop_schedules{MainloopScheduleType::PINGPONG};
|
||||
if (has_coop_supported)
|
||||
{
|
||||
mainloop_schedules.insert(MainloopScheduleType::COOPERATIVE);
|
||||
}
|
||||
auto const epilogue_schedule = EpilogueScheduleType::AUTO;
|
||||
for (auto const& mainloop_schedule : mainloop_schedules)
|
||||
{
|
||||
CutlassGemmConfig candidate(
|
||||
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1);
|
||||
candidate_configs.push_back(candidate);
|
||||
candidate = CutlassGemmConfig(
|
||||
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1);
|
||||
candidate_configs.push_back(candidate);
|
||||
candidate = CutlassGemmConfig(
|
||||
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1);
|
||||
candidate_configs.push_back(candidate);
|
||||
candidate = CutlassGemmConfig(
|
||||
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1);
|
||||
candidate_configs.push_back(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
if (has_n_mcast)
|
||||
else
|
||||
{
|
||||
CutlassGemmConfig config(
|
||||
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1);
|
||||
candidate_configs.push_back(config);
|
||||
}
|
||||
CutlassGemmConfig candidate(
|
||||
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
||||
candidate_configs.push_back(candidate);
|
||||
if (has_m_mcast)
|
||||
{
|
||||
CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x1x1);
|
||||
candidate_configs.push_back(candidate);
|
||||
}
|
||||
|
||||
if (has_m_mcast && has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig config(
|
||||
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x2x1);
|
||||
candidate_configs.push_back(config);
|
||||
if (has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_1x2x1);
|
||||
candidate_configs.push_back(candidate);
|
||||
}
|
||||
|
||||
if (has_m_mcast && has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig candidate(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x2x1);
|
||||
candidate_configs.push_back(candidate);
|
||||
}
|
||||
}
|
||||
}
|
||||
// add cuda kernel profiler to tactics for weight-only plugins
|
||||
if (config & CutlassGemmConfig::WEIGHT_ONLY)
|
||||
{
|
||||
if (tiles.size() > 0)
|
||||
if (tiles.size() > 0 && !(config & CutlassGemmConfig::GROUPED_GEMM))
|
||||
{
|
||||
CutlassGemmConfig CudaKernelConfig(
|
||||
tiles[0], MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
||||
|
||||
@ -206,38 +206,47 @@ const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zer
|
||||
);
|
||||
"""
|
||||
elif operation.gemm_kind == GemmKind.Grouped:
|
||||
# Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules
|
||||
assert operation.mainloop_schedule in [
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
|
||||
]
|
||||
assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
kernel_sched.replace("::Kernel", "::KernelGrouped")
|
||||
epi_sched += "Grouped"
|
||||
if operation.act_type != operation.weight_type:
|
||||
# Mixed MoE GEMM
|
||||
weight_tag = DataTypeTag[operation.weight_type]
|
||||
instantiation = f"""
|
||||
template void sm90_generic_mixed_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag},
|
||||
{epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}, {quant_op}> (
|
||||
GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size);
|
||||
"""
|
||||
else:
|
||||
# Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules
|
||||
assert operation.mainloop_schedule in [
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
|
||||
]
|
||||
assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
|
||||
kernel_sched.replace("::Kernel", "::KernelGrouped")
|
||||
epi_sched += "Grouped"
|
||||
|
||||
# arch_tag = f"cutlass::arch::Sm{operation.arch}"
|
||||
arch_tag = f"Sm{operation.arch}"
|
||||
weight_tag = CudaTypeName[operation.weight_type]
|
||||
assert operation.epi_fusion is not None
|
||||
epi_fusion = EpiFusion[operation.epi_fusion]
|
||||
# arch_tag = f"cutlass::arch::Sm{operation.arch}"
|
||||
arch_tag = f"Sm{operation.arch}"
|
||||
weight_tag = CudaTypeName[operation.weight_type]
|
||||
assert operation.epi_fusion is not None
|
||||
epi_fusion = EpiFusion[operation.epi_fusion]
|
||||
|
||||
epi_fusion = epi_fusion.split(':')[-1]
|
||||
epi_tag = epi_tag.split(':')[-1]
|
||||
epi_fusion = epi_fusion.split(':')[-1]
|
||||
epi_tag = epi_tag.split(':')[-1]
|
||||
|
||||
guard_map = {
|
||||
e2m1: "defined(ENABLE_FP4)",
|
||||
DataType.e4m3: "defined(ENABLE_FP8)",
|
||||
DataType.bf16: "defined(ENABLE_BF16)"
|
||||
}
|
||||
guard = guard_map[
|
||||
operation.act_type] if operation.act_type in guard_map else "1"
|
||||
# TODO Revert this once compiler bug is fixed so we can use template instead of macro again
|
||||
# instantiation = f"""
|
||||
# template void tma_warp_specialized_generic_moe_gemm_kernelLauncher<{arch_tag}, {act_tag}, {weight_tag}, {out_tag},
|
||||
# {epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, false>
|
||||
# (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
|
||||
# """
|
||||
instantiation = f"""
|
||||
guard_map = {
|
||||
e2m1: "defined(ENABLE_FP4)",
|
||||
DataType.e4m3: "defined(ENABLE_FP8)",
|
||||
DataType.bf16: "defined(ENABLE_BF16)"
|
||||
}
|
||||
guard = guard_map[
|
||||
operation.act_type] if operation.act_type in guard_map else "1"
|
||||
# TODO Revert this once compiler bug is fixed so we can use template instead of macro again
|
||||
# instantiation = f"""
|
||||
# template void tma_warp_specialized_generic_moe_gemm_kernelLauncher<{arch_tag}, {act_tag}, {weight_tag}, {out_tag},
|
||||
# {epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, false>
|
||||
# (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
|
||||
# """
|
||||
instantiation = f"""
|
||||
#if {guard}\n
|
||||
INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag},
|
||||
{epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, false);\n
|
||||
@ -510,9 +519,57 @@ def generate_sm90_grouped_gemm_operations(is_arch_enabled):
|
||||
return operations
|
||||
|
||||
|
||||
def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
|
||||
if not is_arch_enabled:
|
||||
return []
|
||||
arch = 90
|
||||
supported_dtypes = [
|
||||
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.e4m3, DataType.u4, DataType.bf16, DataType.bf16,
|
||||
DataType.bf16),
|
||||
]
|
||||
|
||||
quant_ops = [TrtLlm_QuantOp.finegrained_scale_only]
|
||||
|
||||
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
|
||||
|
||||
M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM
|
||||
N_TILES = [16, 32, 64, 128]
|
||||
K_TILES = [128, 256, 512]
|
||||
cta_shapes_mnk = list(product(M_TILES, N_TILES, K_TILES))
|
||||
|
||||
warp_shape = [0, 0, 0] # ignored except for naming
|
||||
stages = 0 # auto
|
||||
|
||||
cga_shapes = product([1, 2], [1, 2], [1])
|
||||
|
||||
partial_args = product(supported_dtypes, quant_ops, epi_tags,
|
||||
cta_shapes_mnk, cga_shapes)
|
||||
|
||||
operations = list()
|
||||
for dtype_combo, quant_op, epi_tag, cta_shape_mnk, cga_shape in partial_args:
|
||||
use_coop = cta_shape_mnk[0] >= 128
|
||||
mainloop_schedules = [
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong
|
||||
] if use_coop else [KernelScheduleType.TmaWarpSpecializedPingpong]
|
||||
epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
for mainloop_schedule in mainloop_schedules:
|
||||
if (cta_shape_mnk[0] == 128 and cta_shape_mnk[1] == 128
|
||||
and mainloop_schedule
|
||||
== KernelScheduleType.TmaWarpSpecializedCooperative):
|
||||
continue
|
||||
moe_gemm_operation = TrtLlm_GemmLauncher(GemmKind.Grouped, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
|
||||
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
|
||||
operations.append(moe_gemm_operation)
|
||||
return operations
|
||||
|
||||
|
||||
def generate_sm90_operations(is_arch_enabled):
|
||||
operations = generate_sm90_mixed_gemm_operations()
|
||||
operations.extend(generate_sm90_grouped_gemm_operations(is_arch_enabled))
|
||||
operations.extend(
|
||||
generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled))
|
||||
return operations
|
||||
|
||||
|
||||
@ -699,6 +756,7 @@ if __name__ == "__main__":
|
||||
fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
|
||||
# moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
|
||||
moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
|
||||
moe_mixed_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl"
|
||||
# sm80_moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl"
|
||||
sm80_moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl"
|
||||
|
||||
@ -725,20 +783,29 @@ if __name__ == "__main__":
|
||||
is_internal = op.gemm_kind == GemmKind.Grouped
|
||||
return is_internal != args.internal
|
||||
|
||||
def is_mixed_dtype_grouped(op):
|
||||
if isinstance(op, GemmSm80LauncherConfig):
|
||||
return False
|
||||
return (op.act_type != op.weight_type) and (op.gemm_kind
|
||||
== GemmKind.Grouped)
|
||||
|
||||
op_groups = dict()
|
||||
for op in operations:
|
||||
if should_skip(op):
|
||||
continue
|
||||
dict_key = (op.gemm_kind, op.arch, op.cta_shape[0])
|
||||
dict_key = (op.gemm_kind, op.arch, op.cta_shape[0],
|
||||
is_mixed_dtype_grouped(op))
|
||||
op_group = op_groups.get(dict_key, list())
|
||||
op_group.append(op)
|
||||
op_groups[dict_key] = op_group
|
||||
|
||||
file_counter = 1
|
||||
for key, value in op_groups.items():
|
||||
gemm_kind, _, _ = key
|
||||
gemm_kind, _, _, is_mixed_dtype_grouped = key
|
||||
out_file = os.path.join(
|
||||
output_dir, GemmKindNames[gemm_kind],
|
||||
f"cutlass_kernel_file_{file_counter}.generated.cu")
|
||||
write_file(inl_map[key[:2]], value, out_file)
|
||||
inl_file = [moe_mixed_gemm_inl
|
||||
] if is_mixed_dtype_grouped else inl_map[key[:2]]
|
||||
write_file(inl_file, value, out_file)
|
||||
file_counter += 1
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bcea962f0c9ab2efb4cd6171be456b1c7f68e31d2a257c4eee6b3e9f5e560904
|
||||
size 47910208
|
||||
oid sha256:723334a99a2f23bd16f50e69c2a3f21a06a06a41d0eb2ebe100337cbc0907c1a
|
||||
size 52931896
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
a6bcad94c12cb55cbabc5fb30e7a4adb9e6906cc52cb285a9dd42aa71f7760e3 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 9f0fabbb7f7f678fe34bb0eeed756869676d9304
|
||||
26da3daf623e613ccb02765cfb19b4ad9e19888a2fbfb0be7d2cfb96735c8a13 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 1c6e86206675670e0f37ebf40dcc2562d56039ca
|
||||
|
||||
@ -211,10 +211,32 @@ struct TmaWarpSpecializedGroupedGemmInput
|
||||
LayoutScalingFactorsA* fp4_block_scaling_factors_stride_A = nullptr;
|
||||
LayoutScalingFactorsB* fp4_block_scaling_factors_stride_B = nullptr;
|
||||
|
||||
struct INT4GroupwiseParams
|
||||
{
|
||||
constexpr static int group_size = 128; // Unused, hard-coded to 128
|
||||
bool enabled = false;
|
||||
using SFA = __nv_bfloat16;
|
||||
using SFB = __nv_bfloat16;
|
||||
using ProblemShapeInt = cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
|
||||
using LayoutSFA = typename cutlass::layout::ColumnMajor;
|
||||
using LayoutSFB = typename cutlass::layout::ColumnMajor; // Unused
|
||||
using StrideSFA = cute::Stride<cute::Int<1>, int64_t, int64_t>;
|
||||
using StrideSFB = cute::Stride<cute::Int<1>, int64_t, int64_t>; // Unused
|
||||
StrideSFA* stride_s_a = nullptr;
|
||||
StrideSFB* stride_s_b = nullptr; // Unused
|
||||
const SFA** ptr_s_a = nullptr;
|
||||
const SFA** ptr_z_a = nullptr; // Unused
|
||||
const SFB** ptr_s_b = nullptr; // Unused
|
||||
const SFB** ptr_z_b = nullptr; // Unused
|
||||
ProblemShapeInt shape{};
|
||||
};
|
||||
|
||||
INT4GroupwiseParams int4_groupwise_params;
|
||||
|
||||
uint8_t* gemm_workspace = nullptr;
|
||||
size_t gemm_workspace_size = 0;
|
||||
|
||||
static std::array<size_t, 14> workspaceBuffers(int num_experts);
|
||||
static std::array<size_t, 17> workspaceBuffers(int num_experts);
|
||||
|
||||
static size_t workspaceSize(int num_experts);
|
||||
|
||||
@ -248,9 +270,13 @@ public:
|
||||
MoeGemmRunner();
|
||||
|
||||
#if defined(ENABLE_FP8)
|
||||
static constexpr bool use_fp8 = std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
static constexpr bool use_fp8 = (std::is_same_v<T, __nv_fp8_e4m3>
|
||||
|| std::is_same_v<T, __nv_fp8_e5m2>) &&!std::is_same_v<WeightType, cutlass::uint4b_t>;
|
||||
static constexpr bool use_w4afp8
|
||||
= std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, cutlass::uint4b_t>;
|
||||
#else
|
||||
static constexpr bool use_fp8 = false;
|
||||
static constexpr bool use_w4afp8 = false;
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_FP4)
|
||||
|
||||
@ -558,6 +558,7 @@ private:
|
||||
static TmaWarpSpecializedGroupedGemmInput computeStridesTmaWarpSpecialized(int64_t const* expert_first_token_offset,
|
||||
TmaWarpSpecializedGroupedGemmInput layout_info, int64_t num_tokens, int64_t expanded_num_tokens, int64_t gemm_n,
|
||||
int64_t gemm_k, int const num_experts_per_node, T const* in, WeightType const* weights,
|
||||
TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::SFA const* w4a8_weight_scale_flat,
|
||||
float const* fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_scale_flat,
|
||||
QuantParams::FP4Inputs::GemmInputs fp4_inputs, T const* bias, UnfusedGemmOutputType* output,
|
||||
cudaStream_t stream);
|
||||
@ -622,6 +623,7 @@ private:
|
||||
QuantParams& quant_params, cudaStream_t stream);
|
||||
|
||||
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
|
||||
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr,
|
||||
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream);
|
||||
|
||||
CubKeyValueSorter sorter_;
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e47924159db0476a3d8026e27514eea07e4b4db690d6f334ef05c41a235014cf
|
||||
size 47540524
|
||||
oid sha256:a70cc6672a1a083c7bca7f8efe4331b84535e79a609cbae7a8b289a4cbb3725b
|
||||
size 52534712
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
a4cd97f177fc4c582d8ab2dfd10b7428b33154ddbd4d9f734cb561ba15e552e7 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 9f0fabbb7f7f678fe34bb0eeed756869676d9304
|
||||
bc4a32343119b018d87c5716986055ce5f149e0c8d695fcd591581310c1f6066 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 1c6e86206675670e0f37ebf40dcc2562d56039ca
|
||||
|
||||
@ -148,5 +148,128 @@ INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
|
||||
__global__ void apply_per_expert_scale(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
|
||||
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols)
|
||||
{
|
||||
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
|
||||
T_in act_vec[kElems];
|
||||
int col_offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row_offset = blockIdx.y;
|
||||
int expert_idx = permuted_token_selected_experts[row_offset];
|
||||
T_in scale = per_expert_scale[expert_idx];
|
||||
if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows)
|
||||
return;
|
||||
if (num_valid_tokens_ptr && (row_offset * kProcessRows >= *num_valid_tokens_ptr))
|
||||
return;
|
||||
act += row_offset * kProcessRows * cols;
|
||||
smoothed_act += row_offset * kProcessRows * cols;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kProcessRows; ++i)
|
||||
{
|
||||
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<AccessType const*>(act + i * cols)[col_offset];
|
||||
if constexpr ((std::is_same_v<T_in, half>
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
|| std::is_same_v<T_in, __nv_bfloat16>
|
||||
#endif
|
||||
) &&(kElems % 2 == 0))
|
||||
{
|
||||
using Vec2 = typename Vec2Type<T_in>::type;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kElems; j += 2)
|
||||
{
|
||||
if constexpr (std::is_same_v<T_in, half>)
|
||||
{
|
||||
*reinterpret_cast<Vec2*>(act_vec + j)
|
||||
= __hmul2(*reinterpret_cast<Vec2*>(act_vec + j), __half2half2(scale));
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<Vec2*>(act_vec + j)
|
||||
= __hmul2(*reinterpret_cast<Vec2*>(act_vec + j), __bfloat162bfloat162(scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kElems; ++j)
|
||||
{
|
||||
act_vec[j] = static_cast<T_in>(static_cast<float>(act_vec[j]) * static_cast<float>(scale));
|
||||
}
|
||||
}
|
||||
if constexpr (std::is_same_v<T_in, T_out>)
|
||||
{
|
||||
reinterpret_cast<AccessType*>(smoothed_act + i * cols)[col_offset]
|
||||
= *reinterpret_cast<AccessType*>(act_vec);
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kElems; ++j)
|
||||
{
|
||||
(smoothed_act + i * cols)[col_offset * kElems + j] = static_cast<T_out>(act_vec[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
|
||||
void apply_per_expert_scale_kernel_launcher_(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
|
||||
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
|
||||
dim3 block(128);
|
||||
dim3 grid((cols / kElems + block.x - 1) / block.x, (rows + kProcessRows - 1) / kProcessRows);
|
||||
apply_per_expert_scale<T_in, T_out, kProcessRows, AccessType><<<grid, block, 0, stream>>>(
|
||||
smoothed_act, act, per_expert_scale, permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols);
|
||||
}
|
||||
|
||||
template <typename T_in, typename T_out>
|
||||
void apply_per_expert_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
|
||||
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int elems = rows * cols;
|
||||
if (elems < 2048 * 2048)
|
||||
{
|
||||
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 1, float4>(smoothed_act, act, per_expert_scale,
|
||||
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
|
||||
}
|
||||
else if (elems < 4096 * 4096)
|
||||
{
|
||||
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 4, float4>(smoothed_act, act, per_expert_scale,
|
||||
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
|
||||
}
|
||||
else if (elems < 8192 * 8192)
|
||||
{
|
||||
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 8, float4>(smoothed_act, act, per_expert_scale,
|
||||
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_per_expert_scale_kernel_launcher_<T_in, T_out, 16, float4>(smoothed_act, act, per_expert_scale,
|
||||
permuted_token_selected_experts, num_valid_tokens_ptr, rows, cols, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_PEREXPERT_SCALE(T_in, T_out) \
|
||||
template void apply_per_expert_scale_kernel_launcher<T_in, T_out>(T_out * smoothed_act, T_in const* act, \
|
||||
T_in const* per_expert_scale, int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, \
|
||||
int rows, int cols, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_PEREXPERT_SCALE(half, half);
|
||||
#if defined(ENABLE_FP8)
|
||||
INSTANTIATE_PEREXPERT_SCALE(half, __nv_fp8_e4m3);
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_BF16)
|
||||
INSTANTIATE_PEREXPERT_SCALE(__nv_bfloat16, __nv_bfloat16);
|
||||
#if defined(ENABLE_FP8)
|
||||
INSTANTIATE_PEREXPERT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -39,5 +39,10 @@ template <typename T_in, typename T_out = T_in>
|
||||
void apply_per_channel_scale_kernel_launcher(
|
||||
T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream = 0);
|
||||
|
||||
template <typename T_in, typename T_out = T_in>
|
||||
void apply_per_expert_scale_kernel_launcher(T_out* smoothed_act, T_in const* act, T_in const* per_expert_scale,
|
||||
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr, int rows, int cols,
|
||||
cudaStream_t stream = 0);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -109,6 +109,7 @@ public:
|
||||
case at::ScalarType::Bool: return IBuffer::DataType::kBOOL;
|
||||
case at::ScalarType::Float8_e4m3fn: return IBuffer::DataType::kFP8;
|
||||
case at::ScalarType::BFloat16: return IBuffer::DataType::kBF16;
|
||||
case at::ScalarType::QUInt4x2: return IBuffer::DataType::kINT4;
|
||||
default: TLLM_THROW("unsupported data type");
|
||||
}
|
||||
}
|
||||
|
||||
@ -84,12 +84,13 @@ public:
|
||||
};
|
||||
|
||||
FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype,
|
||||
bool use_fp8_block_scaling)
|
||||
bool use_fp8_block_scaling, bool use_w4a8_group_scaling)
|
||||
{
|
||||
mActivationDtype = activation_dtype;
|
||||
mWeightDtype = weight_dtype;
|
||||
mOutputDtype = output_dtype;
|
||||
mUseFp8BlockScaling = use_fp8_block_scaling;
|
||||
mUseW4A8GroupScaling = use_w4a8_group_scaling;
|
||||
mInnerDimMultiplier = 1;
|
||||
|
||||
// keep consistent with cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp
|
||||
@ -132,6 +133,44 @@ public:
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (isInt4Quant())
|
||||
{
|
||||
mInnerDimMultiplier = 2;
|
||||
if (mActivationDtype == c10::ScalarType::Half)
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
if (mUseW4A8GroupScaling)
|
||||
{
|
||||
mKernelRunner
|
||||
= std::make_unique<kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>();
|
||||
}
|
||||
#else
|
||||
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>();
|
||||
#endif
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (mActivationDtype == c10::ScalarType::BFloat16)
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
if (mUseW4A8GroupScaling)
|
||||
{
|
||||
mKernelRunner = std::make_unique<
|
||||
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>();
|
||||
}
|
||||
#else
|
||||
mKernelRunner = std::make_shared<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>();
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if (!mKernelRunner)
|
||||
{
|
||||
C10_THROW_ERROR_FORMATTED(Error,
|
||||
@ -350,6 +389,7 @@ public:
|
||||
int64_t const num_rows = input.sizes()[0];
|
||||
int64_t const hidden_size = fc2_expert_weights.sizes()[1];
|
||||
int64_t const inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier;
|
||||
int64_t const group_size = isInt4Quant() ? 128 : -1;
|
||||
int const num_experts = static_cast<int>(fc2_expert_weights.sizes()[0] * ep_size);
|
||||
|
||||
// Get specific profile configs according to the profile_id.
|
||||
@ -371,14 +411,14 @@ public:
|
||||
static_cast<int>(tp_rank), static_cast<int>(ep_size), static_cast<int>(ep_rank),
|
||||
static_cast<int>(cluster_size), static_cast<int>(cluster_rank));
|
||||
|
||||
int const GROUP_SIZE = -1;
|
||||
bool const USE_BIAS = false;
|
||||
bool const USE_LORA = false;
|
||||
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
|
||||
tensorrt_llm::runtime::TorchUtils::dataType(mActivationDtype),
|
||||
tensorrt_llm::runtime::TorchUtils::dataType(
|
||||
mUseW4A8GroupScaling ? at::ScalarType::Float8_e4m3fn : mActivationDtype),
|
||||
tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype),
|
||||
tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
|
||||
hidden_size, inter_size, GROUP_SIZE, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA,
|
||||
hidden_size, inter_size, group_size, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA,
|
||||
min_latency_mode, parallelism_config);
|
||||
|
||||
freeProfileWorkspace();
|
||||
@ -412,6 +452,7 @@ private:
|
||||
char* mProfileWorkspace = nullptr;
|
||||
|
||||
bool mUseFp8BlockScaling = false;
|
||||
bool mUseW4A8GroupScaling = false;
|
||||
|
||||
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
std::vector<Profile> mAllProfiles;
|
||||
@ -457,10 +498,9 @@ private:
|
||||
int num_experts, int experts_per_token, tensorrt_llm::ActivationType activation_type,
|
||||
kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode)
|
||||
{
|
||||
size_t moe_workspace_size
|
||||
= mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, experts_per_token,
|
||||
activation_type, parallelismConfig, /* use_lora */ false, mUseFp8BlockScaling, min_latency_mode,
|
||||
/* hasExpertPrequantScales */ false);
|
||||
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts,
|
||||
experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseFp8BlockScaling,
|
||||
min_latency_mode, mUseW4A8GroupScaling);
|
||||
size_t src_to_dest_map_size = experts_per_token * num_rows * sizeof(int);
|
||||
|
||||
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
|
||||
@ -561,6 +601,28 @@ private:
|
||||
return kernels::QuantParams::FP8BlockScaling(
|
||||
static_cast<float const*>(fc1_scales.data_ptr()), static_cast<float const*>(fc2_scales.data_ptr()));
|
||||
}
|
||||
else if (isInt4Quant())
|
||||
{
|
||||
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for INT4 quantization");
|
||||
TORCH_CHECK(quant_scales.value().size() == 8, "Expecting 8 quant scales for INT4 quantization");
|
||||
auto& fc1_weight_scales = quant_scales.value()[0];
|
||||
auto& fc2_weight_scales = quant_scales.value()[1];
|
||||
auto& fc1_act_scales = quant_scales.value()[2];
|
||||
auto& fc2_act_scales = quant_scales.value()[3];
|
||||
auto& fc1_weight_zeros = quant_scales.value()[4];
|
||||
auto& fc2_weight_zeros = quant_scales.value()[5];
|
||||
auto& fc1_alpha = quant_scales.value()[6];
|
||||
auto& fc2_alpha = quant_scales.value()[7];
|
||||
int group_size = 128;
|
||||
return kernels::QuantParams::GroupWise(group_size, static_cast<void const*>(fc1_weight_scales.data_ptr()),
|
||||
static_cast<void const*>(fc2_weight_scales.data_ptr()),
|
||||
static_cast<void const*>(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() : nullptr),
|
||||
static_cast<void const*>(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() : nullptr),
|
||||
static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr),
|
||||
static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr),
|
||||
static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr),
|
||||
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
return kernels::QuantParams{};
|
||||
@ -577,6 +639,16 @@ private:
|
||||
{
|
||||
return mWeightDtype == c10::ScalarType::Long;
|
||||
}
|
||||
|
||||
bool isInt4Quant() const
|
||||
{
|
||||
return mWeightDtype == c10::ScalarType::QUInt4x2;
|
||||
}
|
||||
|
||||
bool isW4AFp8Quant() const
|
||||
{
|
||||
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch_ext
|
||||
@ -584,7 +656,7 @@ private:
|
||||
TORCH_LIBRARY(trtllm, m)
|
||||
{
|
||||
m.class_<torch_ext::FusedMoeRunner>("FusedMoeRunner")
|
||||
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool>())
|
||||
.def(torch::init<c10::ScalarType, c10::ScalarType, c10::ScalarType, bool, bool>())
|
||||
.def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile)
|
||||
.def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum)
|
||||
.def("run_moe", &torch_ext::FusedMoeRunner::runMoe)
|
||||
|
||||
@ -32,21 +32,22 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/
|
||||
- [DeepGEMM](#deepgemm)
|
||||
- [FlashMLA](#flashmla)
|
||||
- [FP8 KV Cache and MLA](#fp8-kv-cache-and-mla)
|
||||
- [W4AFP8](#w4afp8)
|
||||
- [Notes and Troubleshooting](#notes-and-troubleshooting)
|
||||
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
DeepSeek-v3 has 671B parameters which needs about 671GB GPU memory for FP8 weights, and needs more memories for activation tensors and KV cache.
|
||||
The minimum hardware requirements for running DeepSeek V3/R1 FP8&FP4 are listed as follows.
|
||||
The minimum hardware requirements for running DeepSeek V3/R1 at FP8/FP4/W4A8 are listed as follows.
|
||||
|
||||
| GPU | DeepSeek-V3/R1 FP8 | DeepSeek-V3/R1 FP4 |
|
||||
| -------- | ------- | -- |
|
||||
| H100 80GB | 16 | N/A |
|
||||
| H20 141GB | 8 | N/A |
|
||||
| H20 96GB | 8 | N/A |
|
||||
| H200 | 8 | N/A |
|
||||
| B200/GB200| Not supported yet, WIP | 4 (8 GPUs is recommended for best perf) |
|
||||
| GPU | DeepSeek-V3/R1 FP8 | DeepSeek-V3/R1 FP4 | DeepSeek-V3/R1 W4A8 |
|
||||
| -------- | ------- | -- | -- |
|
||||
| H100 80GB | 16 | N/A | 8 |
|
||||
| H20 141GB | 8 | N/A | 4 |
|
||||
| H20 96GB | 8 | N/A | 4 |
|
||||
| H200 | 8 | N/A | 4 |
|
||||
| B200/GB200| Not supported yet, WIP | 4 (8 GPUs is recommended for best perf) | Not supported yet, WIP |
|
||||
|
||||
Ampere architecture (SM80 & SM86) is not supported.
|
||||
|
||||
@ -566,6 +567,74 @@ pytorch_backend_config:
|
||||
# ...
|
||||
```
|
||||
|
||||
### W4AFP8
|
||||
|
||||
TensorRT-LLM supports W(INT)4-A(FP)8 for DeepSeek on __Hopper__. Activations and weights are quantized at per-tensor and per-group (1x128) granularity respectively for MoE, and FP8 block scaling is preserved for dense layers.
|
||||
|
||||
We provide a pre-quantized checkpoint for DeepSeek-R1 W4AFP8 at [HF model hub](https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8).
|
||||
|
||||
```bash
|
||||
python quickstart_advanced.py --model_dir <W4AFP8 Checkpoint> --tp_size 8
|
||||
```
|
||||
Or you can follow the steps to generate one by yourselves.
|
||||
|
||||
#### Activation calibration
|
||||
|
||||
[ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is used for calibrating activations of MoE layers. We provide a calibrated file at [HF model hub](https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main/act_scales.safetensors) or you can run the following commands to generate by yourselves.
|
||||
|
||||
```bash
|
||||
# Make sure for enough GPU resources (8xH200s) to run the following commands
|
||||
PATH_OF_DEEPSEEK_R1=/llm-models/DeepSeek-R1/DeepSeek-R1
|
||||
|
||||
# Install ModelOpt from source
|
||||
git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer/ && cd modelopt
|
||||
pip install "nvidia-modelopt[all]" -U --extra-index-url https://pypi.nvidia.com
|
||||
|
||||
# Clone DeepSeek-V3 (base model of R1) Github repository for FP8 inference,
|
||||
git clone https://github.com/deepseek-ai/DeepSeek-V3.git && cd DeepSeek-V3 && git checkout 1398800
|
||||
|
||||
# Convert the HF checkpoint to a specific format for DeepSeek
|
||||
python inference/convert.py --hf-ckpt-path $PATH_OF_DEEPSEEK_R1 --save-path ds_r1 --n-experts 256 --model-parallel 8 && cd ..
|
||||
|
||||
# Do per-tensor fp8 calibration
|
||||
torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path DeepSeek-V3/ds_r1 --config DeepSeek-V3/inference/configs/config_671B.json --quant_cfg FP8_DEFAULT_CFG --output_path ds_r1_fp8_per_tensor_calibration && cd ../..
|
||||
```
|
||||
|
||||
#### Weight quantization and assembling
|
||||
|
||||
You can run the following bash to quantize weights and generate the full checkpoint.
|
||||
```bash
|
||||
#!/bin/bash
|
||||
HF_MODEL_DIR=/models/DeepSeek-R1/DeepSeek-R1/
|
||||
OUTPUT_DIR=/workspace/ckpt/
|
||||
# Safetensors or ModelOpt exported FP8 checkpoint path is accepted
|
||||
# e.g. ACT_SCALES=ds_r1_fp8_per_tensor_calibration
|
||||
ACT_SCALES=/workspace/act_scales.safetensors
|
||||
|
||||
if [ ! -d "convert_logs" ]; then
|
||||
mkdir convert_logs
|
||||
fi
|
||||
|
||||
pids=()
|
||||
for i in 0 1 2 3 4 5 6 7
|
||||
do
|
||||
python examples/quantization/quantize_mixed_precision_moe.py --model_dir $HF_MODEL_DIR --output_dir $OUTPUT_DIR --act_scales $ACT_SCALES --parts 9 --rank $i > convert_logs/log_$i 2>&1 &
|
||||
pids+=($!)
|
||||
done
|
||||
|
||||
python examples/quantization/quantize_mixed_precision_moe.py --model_dir $HF_MODEL_DIR --output_dir $OUTPUT_DIR --act_scales $ACT_SCALES --parts 9 --rank 8 > convert_logs/log_8 2>&1
|
||||
pids+=($!)
|
||||
|
||||
for pid in ${pids[@]}; do
|
||||
wait $pid
|
||||
done
|
||||
|
||||
echo "All processes completed!"
|
||||
```
|
||||
|
||||
The converted checkpoint could be used as `<YOUR_MODEL_DIR>` and consumed by other commands.
|
||||
|
||||
|
||||
## Notes and Troubleshooting
|
||||
|
||||
- **Model Directory:** Update `<YOUR_MODEL_DIR>` with the actual path where the model weights reside.
|
||||
|
||||
305
examples/quantization/quantize_mixed_precision_moe.py
Normal file
305
examples/quantization/quantize_mixed_precision_moe.py
Normal file
@ -0,0 +1,305 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from safetensors.torch import safe_open, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='HF checkpoint path')
|
||||
parser.add_argument('--output_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Save path')
|
||||
parser.add_argument(
|
||||
'--act_scales',
|
||||
type=str,
|
||||
required=True,
|
||||
help=
|
||||
'ModelOpt calibrated checkpoint dir or extracted safetensors for activation scales'
|
||||
)
|
||||
parser.add_argument('--parts',
|
||||
type=int,
|
||||
default=1,
|
||||
help='devide all safetensors into parts')
|
||||
parser.add_argument('--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='which part to be quantize')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def load_and_preprocess_state_dict(modelopt_state_root, world_size=8):
|
||||
state_dict_list = []
|
||||
# load amax from state dict
|
||||
for rank in range(world_size):
|
||||
state_dict_list.append(
|
||||
torch.load(
|
||||
f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt",
|
||||
map_location="cuda:0"))
|
||||
# calculate the max across all TP ranks
|
||||
merged_state_dict = state_dict_list[0]
|
||||
for rank in range(world_size):
|
||||
for key, amax in state_dict_list[rank].items():
|
||||
if key in merged_state_dict.items():
|
||||
amax = torch.max(amax.to(0), merged_state_dict[key].to(0))
|
||||
merged_state_dict[key] = amax.to(0)
|
||||
|
||||
mapping = {
|
||||
"ffn.shared_experts.w1": "mlp.shared_experts.gate_proj",
|
||||
"ffn.shared_experts.w2": "mlp.shared_experts.down_proj",
|
||||
"ffn.shared_experts.w3": "mlp.shared_experts.up_proj",
|
||||
"ffn.shared_experts": "mlp.shared_experts",
|
||||
"ffn.shared_experts": "mlp.shared_experts",
|
||||
"ffn.shared_experts": "mlp.shared_experts",
|
||||
"ffn.w1": "mlp.gate_proj",
|
||||
"ffn.w2": "mlp.down_proj",
|
||||
"ffn.w3": "mlp.up_proj",
|
||||
"head": "lm_head",
|
||||
"attn": "self_attn",
|
||||
}
|
||||
new_dict = {}
|
||||
for k, v in merged_state_dict.items():
|
||||
new_key = k.replace("layers", "model.layers")
|
||||
for original_pattern, replace_pattern in mapping.items():
|
||||
new_key = new_key.replace(original_pattern, replace_pattern)
|
||||
# ffn.experts.xx.w1/w2/w3- > mlp.experts.xx.gate_proj/down_proj/up_proj
|
||||
new_key = re.sub(r"ffn\.experts\.(\d+)\.w1",
|
||||
r"mlp.experts.\1.gate_proj", new_key)
|
||||
new_key = re.sub(r"ffn\.experts\.(\d+)\.w2",
|
||||
r"mlp.experts.\1.down_proj", new_key)
|
||||
new_key = re.sub(r"ffn\.experts\.(\d+)\.w3", r"mlp.experts.\1.up_proj",
|
||||
new_key)
|
||||
new_dict[new_key] = v
|
||||
|
||||
merged_state_dict.clear()
|
||||
merged_state_dict.update(new_dict)
|
||||
|
||||
# set amax for modules to be fused and make sure they share the same input
|
||||
for key, amax in merged_state_dict.items():
|
||||
if "up_proj" in key:
|
||||
gate_proj_key = key.replace("up_proj", "gate_proj")
|
||||
if "weight_quantizer" in key:
|
||||
fused_amax = torch.max(amax, merged_state_dict[gate_proj_key])
|
||||
merged_state_dict[key] = fused_amax
|
||||
merged_state_dict[gate_proj_key] = fused_amax
|
||||
elif "input_quantizer" in key:
|
||||
assert amax == merged_state_dict[gate_proj_key]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return merged_state_dict
|
||||
|
||||
|
||||
def get_scales_from_amax(start_layer, end_layer, renamed_state_dict):
|
||||
weight_name_dict = {"gate_proj": 1, "down_proj": 2, "up_proj": 3}
|
||||
scales = {}
|
||||
for layer_idx in range(start_layer, end_layer):
|
||||
amax_keys_per_layer = [
|
||||
x for x in renamed_state_dict.keys()
|
||||
if (x.startswith(f'model.layers.{layer_idx}.mlp.experts.')
|
||||
and x.endswith(".input_quantizer._amax"))
|
||||
]
|
||||
for k in amax_keys_per_layer:
|
||||
expert_idx = int(k.split('.')[5])
|
||||
weight_idx = weight_name_dict[k.split('.')[6]]
|
||||
val = renamed_state_dict[k]
|
||||
scales[
|
||||
f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.w{weight_idx}.input_scale'] = val.unsqueeze(
|
||||
0) / 448
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
def quantize_fp8_block_scale_to_int4(fp8_tensor, fp8_scale):
|
||||
group_size = 128
|
||||
blocked_tensor = fp8_tensor.view(fp8_tensor.shape[0] // 128, 128,
|
||||
fp8_tensor.shape[1] // 128,
|
||||
128).to(torch.float32)
|
||||
dequant_tensor = (blocked_tensor *
|
||||
fp8_scale.unsqueeze(1).unsqueeze(3)).view(
|
||||
fp8_tensor.shape[0],
|
||||
fp8_tensor.shape[1] // group_size,
|
||||
group_size).to(torch.bfloat16).to(torch.float32)
|
||||
scale_tensor = torch.abs(dequant_tensor).max(dim=2).values / 7
|
||||
quant_tensor = torch.clamp(torch.round(
|
||||
(dequant_tensor / scale_tensor.unsqueeze(-1))),
|
||||
min=-8,
|
||||
max=7)
|
||||
quant_tensor = quant_tensor.to(torch.int8)
|
||||
return quant_tensor.view(fp8_tensor.shape), scale_tensor
|
||||
|
||||
|
||||
def main(args):
|
||||
model_dir = args.model_dir
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(args.rank % num_gpus)
|
||||
|
||||
model_index_file = os.path.join(model_dir, "model.safetensors.index.json")
|
||||
with open(model_index_file, "r") as f:
|
||||
model_index = json.load(f)
|
||||
weight_map = model_index["weight_map"]
|
||||
|
||||
processed_files = {}
|
||||
for tensor_name in list(weight_map.keys()):
|
||||
if tensor_name not in weight_map:
|
||||
continue
|
||||
file_name = weight_map[tensor_name]
|
||||
if file_name in processed_files:
|
||||
continue
|
||||
processed_files[file_name] = safe_open(os.path.join(
|
||||
model_dir, file_name),
|
||||
"pt",
|
||||
device="cuda")
|
||||
|
||||
with open(os.path.join(model_dir, "config.json"), 'r') as file:
|
||||
config = json.load(file)
|
||||
|
||||
num_layer = config['num_hidden_layers']
|
||||
part_layer = (num_layer + args.parts - 1) // args.parts
|
||||
start_layer = args.rank * part_layer
|
||||
end_layer = min(num_layer, args.rank * part_layer + part_layer)
|
||||
|
||||
def get_tensor(name):
|
||||
if name not in weight_map:
|
||||
return None
|
||||
ff = weight_map[name]
|
||||
safetensors_loader = processed_files[ff]
|
||||
return safetensors_loader.get_tensor(name).cuda()
|
||||
|
||||
def get_file_name(layer):
|
||||
rank = layer // part_layer
|
||||
return "model-%05d-of-%05d.safetensors" % (rank, args.parts)
|
||||
|
||||
new_safetensors = {}
|
||||
new_json = {}
|
||||
new_json['weight_map'] = {}
|
||||
new_json['metadata'] = {}
|
||||
for key in tqdm(list(weight_map.keys())):
|
||||
if "mlp.experts" in key and (key.endswith("weight")
|
||||
or key.endswith("weight_scale_inv")):
|
||||
if key.endswith("weight_scale_inv"):
|
||||
continue
|
||||
if args.rank == 0:
|
||||
layer = int(key.split(".")[2])
|
||||
new_json['weight_map'][key] = get_file_name(layer)
|
||||
new_json['weight_map'][key.replace(
|
||||
"weight", "weight_scale_inv")] = get_file_name(layer)
|
||||
if int(key.split(".")[2]) < start_layer or int(
|
||||
key.split(".")[2]) >= end_layer:
|
||||
continue
|
||||
fp8_tensor = get_tensor(key)
|
||||
fp8_scale = get_tensor(key.replace("weight", "weight_scale_inv"))
|
||||
quant_tensor, scale_tensor = quantize_fp8_block_scale_to_int4(
|
||||
fp8_tensor, fp8_scale)
|
||||
|
||||
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
|
||||
packed_tensor = packer(quant_tensor.cpu().contiguous())
|
||||
new_safetensors.update({key: packed_tensor})
|
||||
new_safetensors.update({
|
||||
key.replace("weight", "weight_scale_inv"):
|
||||
scale_tensor.contiguous()
|
||||
})
|
||||
else:
|
||||
name = key.split(".")
|
||||
if args.rank == 0:
|
||||
if len(name) < 3 or not name[2].isdigit():
|
||||
new_safetensors.update({key: get_tensor(key)})
|
||||
new_json['weight_map'][key] = get_file_name(0)
|
||||
continue
|
||||
|
||||
file_name = get_file_name(int(name[2]))
|
||||
new_json['weight_map'][key] = file_name
|
||||
|
||||
if len(name) < 3 or not name[2].isdigit() or (int(
|
||||
name[2]) < start_layer or int(name[2]) >= end_layer):
|
||||
continue
|
||||
new_safetensors.update({key: get_tensor(key)})
|
||||
|
||||
if args.rank == 0:
|
||||
if os.path.isdir(args.act_scales):
|
||||
# Extract activation scales
|
||||
renamed_state_dict = load_and_preprocess_state_dict(
|
||||
modelopt_state_root=args.act_scales, world_size=8)
|
||||
get_scales_from_amax(start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
renamed_state_dict=renamed_state_dict)
|
||||
else:
|
||||
input_scales = safe_open(args.act_scales, "pt")
|
||||
for k in input_scales.keys():
|
||||
new_safetensors.update({k: input_scales.get_tensor(k)})
|
||||
new_json['weight_map'][k] = "input_scales.safetensors"
|
||||
|
||||
file_name = get_file_name(start_layer)
|
||||
print(f'saving to {file_name}...')
|
||||
save_file(new_safetensors, os.path.join(output_dir, file_name))
|
||||
with open(os.path.join(output_dir, "model.safetensors.index.json"),
|
||||
"w") as f:
|
||||
json.dump(new_json, f)
|
||||
|
||||
names = [
|
||||
"configuration_deepseek.py", "generation_config.json",
|
||||
"modeling_deepseek.py", "tokenizer.json", "tokenizer_config.json"
|
||||
]
|
||||
for name in names:
|
||||
shutil.copy(os.path.join(model_dir, name), output_dir)
|
||||
shutil.copy(args.act_scales, output_dir)
|
||||
|
||||
# config.json
|
||||
del config['quantization_config']
|
||||
with open(os.path.join(output_dir, "config.json"), 'w') as file:
|
||||
json.dump(config, file, indent=4)
|
||||
|
||||
# quant_cfg.json
|
||||
attn_names = ["fused_a", "q_b_proj", "kv_b_proj", "o_proj"]
|
||||
mlp_names = ["gate_up_proj", "down_proj"]
|
||||
fp8_block_scale = {"quant_algo": "FP8_BLOCK_SCALES"}
|
||||
w4a8_awq = {"quant_algo": "W4A8_AWQ"}
|
||||
quant_cfg = {}
|
||||
quant_cfg["quant_algo"] = "MIXED_PRECISION"
|
||||
quant_cfg["kv_cache_quant_algo"] = None
|
||||
quant_cfg["quantized_layers"] = {}
|
||||
for l in range(61):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n1 in attn_names:
|
||||
quant_cfg["quantized_layers"][
|
||||
f"{prefix}.self_attn.{n1}"] = fp8_block_scale
|
||||
for n2 in mlp_names:
|
||||
quant_cfg["quantized_layers"][
|
||||
f"{prefix}.mlp.shared_experts.{n2}"] = fp8_block_scale
|
||||
if l < 3:
|
||||
for n3 in mlp_names:
|
||||
quant_cfg["quantized_layers"][
|
||||
f"{prefix}.mlp.{n3}"] = fp8_block_scale
|
||||
else:
|
||||
quant_cfg["quantized_layers"][
|
||||
f"{prefix}.mlp.experts"] = w4a8_awq
|
||||
with open(os.path.join(output_dir, "quant_cfg.json"), 'w') as file:
|
||||
json.dump(quant_cfg, file, indent=4)
|
||||
|
||||
# hf_quant_config.json
|
||||
hf_quant_config = {}
|
||||
hf_quant_config['quantization'] = {}
|
||||
hf_quant_config['quantization']["quant_algo"] = "MIXED_PRECISION"
|
||||
hf_quant_config['quantization']["kv_cache_quant_algo"] = None
|
||||
with open(os.path.join(output_dir, "hf_quant_config.json"),
|
||||
'w') as file:
|
||||
json.dump(hf_quant_config, file, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@ -33,6 +33,7 @@ class MoERunner(TunableRunner):
|
||||
cluster_size: int,
|
||||
cluster_rank: int,
|
||||
use_fp8_block_scaling: bool,
|
||||
use_w4a8_group_scaling: bool,
|
||||
):
|
||||
self.x_dtype = x_dtype
|
||||
self.weight_dtype = weight_dtype
|
||||
@ -45,14 +46,16 @@ class MoERunner(TunableRunner):
|
||||
self.cluster_size = cluster_size
|
||||
self.cluster_rank = cluster_rank
|
||||
self.use_fp8_block_scaling = use_fp8_block_scaling
|
||||
self.use_w4a8_group_scaling = use_w4a8_group_scaling
|
||||
|
||||
instance_key = (x_dtype, weight_dtype, output_dtype,
|
||||
use_fp8_block_scaling)
|
||||
use_fp8_block_scaling, use_w4a8_group_scaling)
|
||||
|
||||
if instance_key not in MoERunner._runner_dict:
|
||||
MoERunner._runner_dict[
|
||||
instance_key] = torch.classes.trtllm.FusedMoeRunner(
|
||||
x_dtype, weight_dtype, output_dtype, use_fp8_block_scaling)
|
||||
x_dtype, weight_dtype, output_dtype, use_fp8_block_scaling,
|
||||
use_w4a8_group_scaling)
|
||||
self._fused_moe_runner = MoERunner._runner_dict[instance_key]
|
||||
self._is_nvfp4 = weight_dtype == torch.int64
|
||||
|
||||
@ -121,6 +124,7 @@ def fused_moe(
|
||||
cluster_size: int = 1,
|
||||
cluster_rank: int = 0,
|
||||
use_fp8_block_scaling: bool = False,
|
||||
use_w4a8_group_scaling: bool = False,
|
||||
min_latency_mode: bool = False,
|
||||
) -> List[torch.Tensor]:
|
||||
|
||||
@ -151,6 +155,7 @@ def fused_moe(
|
||||
cluster_size=cluster_size,
|
||||
cluster_rank=cluster_rank,
|
||||
use_fp8_block_scaling=use_fp8_block_scaling,
|
||||
use_w4a8_group_scaling=use_w4a8_group_scaling,
|
||||
)
|
||||
|
||||
_, gemm_tactic_1 = tuner.choose_one(
|
||||
@ -208,6 +213,7 @@ def _(
|
||||
cluster_size: int = 1,
|
||||
cluster_rank: int = 0,
|
||||
use_fp8_block_scaling: bool = False,
|
||||
use_w4a8_group_scaling: bool = False,
|
||||
min_latency_mode: bool = False,
|
||||
):
|
||||
seq_len = input.shape[0]
|
||||
|
||||
@ -418,7 +418,7 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
if weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
|
||||
for n, q in quant_config_dict.items():
|
||||
# gate_proj and up_proj share the same quant config
|
||||
if prefix_name + '.gate_proj' in n:
|
||||
if prefix_name + '.gate_proj' in n or prefix_name + '.gate_up_proj' in n:
|
||||
module.quant_config = q
|
||||
break
|
||||
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
|
||||
@ -438,7 +438,13 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
if name + '.q_proj' in n:
|
||||
module.quant_config = q
|
||||
break
|
||||
# TODO: support MLA
|
||||
elif hasattr(module, 'fused_a'):
|
||||
# DeepseekV3Attention
|
||||
for n, q in quant_config_dict.items():
|
||||
# reuse q_proj quant config as the attention quant config
|
||||
if name + '.fused_a' in n:
|
||||
module.quant_config = q
|
||||
break
|
||||
|
||||
# 2. skip quant for modules in QuantConfig.exclude_modules.
|
||||
# kv_cache_quant_algo takes precedence over exclude_modules.
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.quantization.utils.fp4_utils import (
|
||||
reorder_rows_for_gated_act_gemm, shuffle_matrix_a, shuffle_matrix_sf_a)
|
||||
|
||||
@ -398,7 +399,9 @@ class FusedMoE(nn.Module):
|
||||
exclude_kv_cache=True):
|
||||
if not (self.quant_config.quant_mode.has_nvfp4()
|
||||
| self.quant_config.quant_mode.has_fp8_block_scales()
|
||||
| self.quant_config.quant_mode.has_fp8_qdq()):
|
||||
| self.quant_config.quant_mode.has_fp8_qdq()
|
||||
| self.quant_config.quant_mode.
|
||||
is_int4_weight_only_per_group()):
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode: {self.quant_config.quant_mode}"
|
||||
)
|
||||
@ -428,6 +431,17 @@ class FusedMoE(nn.Module):
|
||||
fc2_weight_block=self.w2_weight_scale,
|
||||
fc2_global=self.fc2_alpha,
|
||||
)
|
||||
elif self.has_w4afp8:
|
||||
self.quant_scales = FusedMoEQuantScalesW4A8(
|
||||
scale_1_interleaved=self.fc31_weight_scale,
|
||||
scale_2_interleaved=self.fc2_weight_scale,
|
||||
pre_quant_scale_1=self.fc31_act_scale,
|
||||
pre_quant_scale_2=self.fc2_act_scale,
|
||||
zero_1=torch.Tensor(),
|
||||
zero_2=torch.Tensor(),
|
||||
alpha_1=self.fc31_alpha,
|
||||
alpha_2=self.fc2_alpha,
|
||||
)
|
||||
|
||||
def is_trtllm(self):
|
||||
return self.moe_backend == "TRTLLM" and self.has_any_quant
|
||||
@ -458,6 +472,23 @@ class FusedMoE(nn.Module):
|
||||
fc2_global=self.fc2_alpha.narrow(0, expert_start,
|
||||
expert_end - expert_start),
|
||||
)
|
||||
elif self.has_w4afp8:
|
||||
return FusedMoEQuantScalesW4A8(
|
||||
scale_1_interleaved=self.fc31_weight_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
scale_2_interleaved=self.fc2_weight_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
pre_quant_scale_1=self.fc31_act_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
pre_quant_scale_2=self.fc2_act_scale.narrow(
|
||||
0, expert_start, expert_end - expert_start),
|
||||
zero_1=torch.Tensor(),
|
||||
zero_2=torch.Tensor(),
|
||||
alpha_1=self.fc31_alpha.narrow(0, expert_start,
|
||||
expert_end - expert_start),
|
||||
alpha_2=self.fc2_alpha.narrow(0, expert_start,
|
||||
expert_end - expert_start),
|
||||
)
|
||||
else:
|
||||
return self.quant_scales
|
||||
|
||||
@ -479,6 +510,7 @@ class FusedMoE(nn.Module):
|
||||
self.has_fp8_qdq = False
|
||||
self.has_fp8_block_scales = False
|
||||
self.has_nvfp4 = False
|
||||
self.has_w4afp8 = False
|
||||
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
|
||||
exclude_kv_cache=True):
|
||||
qc = self.quant_config
|
||||
@ -534,6 +566,91 @@ class FusedMoE(nn.Module):
|
||||
requires_grad=False)
|
||||
self.register_parameter("w2_weight_scaling_factor",
|
||||
w2_weight_scaling_factor)
|
||||
elif qc.quant_mode.is_int4_weight_only_per_group():
|
||||
self.has_w4afp8 = True
|
||||
self.sm_version = get_sm_version()
|
||||
if self.sm_version == 89:
|
||||
self.interleave = [1, 1]
|
||||
elif self.sm_version == 90:
|
||||
self.interleave = []
|
||||
for k_shape in [
|
||||
self.hidden_size,
|
||||
self.intermediate_size_per_partition
|
||||
]:
|
||||
if k_shape % 512 == 0:
|
||||
self.interleave.append(4)
|
||||
elif k_shape % 256 == 0:
|
||||
self.interleave.append(2)
|
||||
elif k_shape % 128 == 0:
|
||||
self.interleave.append(1)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"K shape is required to be multiple of 128, received {k_shape}."
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"W4AFP8 MoE is unsupported on SM{self.sm_version}.")
|
||||
weight_dtype = torch.int8
|
||||
w3_w1_weight_shape = (self.expert_size_per_partition,
|
||||
self.intermediate_size_per_partition * 2,
|
||||
self.hidden_size // 2)
|
||||
w2_weight_shape = (self.expert_size_per_partition,
|
||||
self.hidden_size,
|
||||
self.intermediate_size_per_partition // 2)
|
||||
|
||||
fc31_act_scale = nn.Parameter(torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
1,
|
||||
dtype=self.dtype,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.register_parameter("fc31_act_scale", fc31_act_scale)
|
||||
|
||||
fc2_act_scale = nn.Parameter(torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
1,
|
||||
dtype=self.dtype,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.register_parameter("fc2_act_scale", fc2_act_scale)
|
||||
|
||||
# col parallel
|
||||
fc31_weight_scale = nn.Parameter(
|
||||
torch.empty(self.expert_size_per_partition,
|
||||
self.hidden_size // (128 * self.interleave[0]),
|
||||
self.intermediate_size_per_partition * 2 *
|
||||
self.interleave[0],
|
||||
dtype=self.dtype,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.register_parameter("fc31_weight_scale", fc31_weight_scale)
|
||||
|
||||
# row parallel
|
||||
fc2_weight_scale = nn.Parameter(
|
||||
torch.empty(self.expert_size_per_partition,
|
||||
self.intermediate_size_per_partition //
|
||||
(128 * self.interleave[1]),
|
||||
self.hidden_size * self.interleave[1],
|
||||
dtype=self.dtype,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.register_parameter("fc2_weight_scale", fc2_weight_scale)
|
||||
|
||||
fc31_alpha = nn.Parameter(torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.register_parameter("fc31_alpha", fc31_alpha)
|
||||
|
||||
fc2_alpha = nn.Parameter(torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
device=device),
|
||||
requires_grad=False)
|
||||
self.register_parameter("fc2_alpha", fc2_alpha)
|
||||
elif qc.quant_mode.has_nvfp4():
|
||||
self.has_nvfp4 = True
|
||||
if self.is_trtllm():
|
||||
@ -701,6 +818,8 @@ class FusedMoE(nn.Module):
|
||||
output_dtype = x.dtype
|
||||
|
||||
use_fp8_block_scaling = False
|
||||
use_w4a8_group_scaling = False
|
||||
weight_dtype = self.w3_w1_weight.dtype
|
||||
|
||||
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||
router_logits)
|
||||
@ -749,6 +868,9 @@ class FusedMoE(nn.Module):
|
||||
|
||||
elif self.has_fp8_block_scales:
|
||||
use_fp8_block_scaling = True
|
||||
elif self.has_w4afp8:
|
||||
use_w4a8_group_scaling = True
|
||||
weight_dtype = torch.quint4x2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode: {self.quant_config.quant_mode}"
|
||||
@ -801,8 +923,8 @@ class FusedMoE(nn.Module):
|
||||
x,
|
||||
token_selected_experts,
|
||||
token_final_scales,
|
||||
w3_w1_weight,
|
||||
w2_weight,
|
||||
w3_w1_weight.view(weight_dtype),
|
||||
w2_weight.view(weight_dtype),
|
||||
output_dtype,
|
||||
quant_scales=quant_scales,
|
||||
input_sf=x_sf,
|
||||
@ -813,6 +935,7 @@ class FusedMoE(nn.Module):
|
||||
cluster_size=cluster_size,
|
||||
cluster_rank=cluster_rank,
|
||||
use_fp8_block_scaling=use_fp8_block_scaling,
|
||||
use_w4a8_group_scaling=use_w4a8_group_scaling,
|
||||
min_latency_mode=cutlass_min_latency_mode,
|
||||
)
|
||||
|
||||
@ -1151,6 +1274,18 @@ class FusedMoE(nn.Module):
|
||||
w31_weight_shard)
|
||||
w31_weight_shard = shuffle_matrix_a(w31_weight_shard,
|
||||
epilogue_tile_m)
|
||||
if self.has_w4afp8 and self.sm_version == 89:
|
||||
import tensorrt_llm.quantization.functional
|
||||
preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm
|
||||
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
|
||||
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
|
||||
w31_weight_shard = packer(
|
||||
unpacker(w31_weight_shard.cpu()).T.contiguous()).to(
|
||||
w31_weight_shard.device)
|
||||
w31_weight_shard = preprocessor(w31_weight_shard,
|
||||
torch.quint4x2,
|
||||
torch.float8_e4m3fn,
|
||||
89).view(dst_w3_w1_weight.shape)
|
||||
|
||||
dst_w3_w1_weight.copy_(w31_weight_shard.view(
|
||||
dst_w3_w1_weight.dtype))
|
||||
@ -1166,6 +1301,19 @@ class FusedMoE(nn.Module):
|
||||
epilogue_tile_m = 128
|
||||
w2_weight_shard = shuffle_matrix_a(w2_weight_shard,
|
||||
epilogue_tile_m)
|
||||
|
||||
if self.has_w4afp8 and self.sm_version == 89:
|
||||
import tensorrt_llm.quantization.functional
|
||||
preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm
|
||||
packer = torch.ops.trtllm.pack_int8_tensor_to_packed_int4
|
||||
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
|
||||
w2_weight_shard = packer(
|
||||
unpacker(w2_weight_shard.cpu()).T.contiguous()).to(
|
||||
w2_weight_shard.device)
|
||||
w2_weight_shard = preprocessor(w2_weight_shard, torch.quint4x2,
|
||||
torch.float8_e4m3fn,
|
||||
89).view(dst_w2_weight.shape)
|
||||
|
||||
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype))
|
||||
|
||||
# Use multi-threading to load expert weights in parallel.
|
||||
@ -1220,6 +1368,8 @@ class FusedMoE(nn.Module):
|
||||
self._load_nvfp4_scales(weights)
|
||||
elif self.quant_config.quant_mode.has_fp8_block_scales():
|
||||
self._load_fp8_block_scales_scales(weights)
|
||||
elif self.quant_config.quant_mode.is_int4_weight_only_per_group():
|
||||
self._load_int4_groupwise_scales(weights)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode: {self.quant_config.quant_mode}"
|
||||
@ -1549,6 +1699,86 @@ class FusedMoE(nn.Module):
|
||||
self.fc31_scale_c.data.copy_(self.fc2_input_scale.data *
|
||||
self.fc31_alpha.data)
|
||||
|
||||
def _load_int4_groupwise_scales(self, weights: Dict):
|
||||
# fc31 scales
|
||||
assert (len(self.interleave) == 2)
|
||||
all_w3_input_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w3.input_scale"])
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
]
|
||||
all_w1_input_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w1.input_scale"])
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
]
|
||||
all_w3_w1_input_scales = torch.max(torch.stack(all_w3_input_scales),
|
||||
torch.stack(all_w1_input_scales))
|
||||
all_w3_w1_input_scales = torch.ones_like(
|
||||
all_w3_w1_input_scales) * all_w3_w1_input_scales.max()
|
||||
self.fc31_act_scale.data.copy_(1 / all_w3_w1_input_scales)
|
||||
self.fc31_alpha.data.copy_(all_w3_w1_input_scales.float())
|
||||
|
||||
all_w3_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"],
|
||||
self.tp_size, self.tp_rank,
|
||||
TensorParallelMode.COLUMN)
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
]
|
||||
all_w1_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w1.weight_scale_inv"],
|
||||
self.tp_size, self.tp_rank,
|
||||
TensorParallelMode.COLUMN)
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
]
|
||||
all_w3_w1_scales = torch.cat(
|
||||
[torch.stack(all_w3_scales),
|
||||
torch.stack(all_w1_scales)], dim=-2)
|
||||
if self.sm_version == 89:
|
||||
w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(self.dtype)
|
||||
else:
|
||||
w3_w1_scales = all_w3_w1_scales.to(torch.bfloat16).view(self.dtype)
|
||||
w3_w1_s_shape = w3_w1_scales.shape
|
||||
w3_w1_scales_interleaved = w3_w1_scales.reshape(
|
||||
w3_w1_s_shape[0], w3_w1_s_shape[1],
|
||||
(w3_w1_s_shape[2] // self.interleave[0]), self.interleave[0])
|
||||
w3_w1_scales_interleaved = w3_w1_scales_interleaved.permute(0, 2, 1, 3)
|
||||
w3_w1_scales_interleaved = w3_w1_scales_interleaved.reshape(
|
||||
w3_w1_s_shape[0], w3_w1_s_shape[2] // self.interleave[0],
|
||||
w3_w1_s_shape[1] * self.interleave[0])
|
||||
self.fc31_weight_scale.data.copy_(w3_w1_scales_interleaved.contiguous())
|
||||
|
||||
# fc2 scales
|
||||
all_w2_input_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w2.input_scale"])
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
]
|
||||
all_w2_input_scales = torch.stack(all_w2_input_scales).to(self.dtype)
|
||||
all_w2_input_scales = torch.ones_like(
|
||||
all_w2_input_scales) * all_w2_input_scales.max()
|
||||
self.fc2_act_scale.data.copy_(1 / all_w2_input_scales)
|
||||
self.fc2_alpha.data.copy_(all_w2_input_scales.float())
|
||||
|
||||
all_w2_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"],
|
||||
self.tp_size, self.tp_rank,
|
||||
TensorParallelMode.ROW)
|
||||
for expert_id in range(self.expert_start, self.expert_end)
|
||||
]
|
||||
if self.sm_version == 89:
|
||||
w2_scales = torch.stack(all_w2_scales).to(torch.float16).view(
|
||||
self.dtype)
|
||||
else:
|
||||
w2_scales = torch.stack(all_w2_scales).to(torch.bfloat16).view(
|
||||
self.dtype)
|
||||
w2_s_shape = w2_scales.shape
|
||||
w2_scales_interleaved = w2_scales.reshape(
|
||||
w2_s_shape[0], w2_s_shape[1], (w2_s_shape[2] // self.interleave[1]),
|
||||
self.interleave[1])
|
||||
w2_scales_interleaved = w2_scales_interleaved.permute(0, 2, 1, 3)
|
||||
w2_scales_interleaved = w2_scales_interleaved.reshape(
|
||||
w2_s_shape[0], w2_s_shape[2] // self.interleave[1],
|
||||
w2_s_shape[1] * self.interleave[1])
|
||||
self.fc2_weight_scale.data.copy_(w2_scales_interleaved.contiguous())
|
||||
|
||||
|
||||
class FusedMoEQuantScalesFP8(NamedTuple):
|
||||
fc1_dequant: torch.Tensor
|
||||
@ -1572,3 +1802,14 @@ class FusedMoEQuantScalesNVFP4(NamedTuple):
|
||||
class FusedMoEQuantScalesFP8BlockScales(NamedTuple):
|
||||
fc_weight_scales: torch.Tensor
|
||||
proj_weight_scales: torch.Tensor
|
||||
|
||||
|
||||
class FusedMoEQuantScalesW4A8(NamedTuple):
|
||||
scale_1_interleaved: torch.Tensor
|
||||
scale_2_interleaved: torch.Tensor
|
||||
pre_quant_scale_1: torch.Tensor
|
||||
pre_quant_scale_2: torch.Tensor
|
||||
zero_1: torch.Tensor
|
||||
zero_2: torch.Tensor
|
||||
alpha_1: torch.Tensor
|
||||
alpha_2: torch.Tensor
|
||||
|
||||
@ -500,8 +500,8 @@ class MOEWeightWrapper(Module):
|
||||
else:
|
||||
self.register_parameter('zero', None)
|
||||
if groupwise_quant_algo & GroupwiseQuantAlgo.PRE_QUANT_SCALE:
|
||||
self.prequant_scaling_factor = Parameter(shape=(1, in_features),
|
||||
dtype=dtype)
|
||||
self.prequant_scaling_factor = Parameter(
|
||||
shape=(experts_per_node, 1), dtype=dtype)
|
||||
else:
|
||||
self.register_parameter('prequant_scaling_factor', None)
|
||||
if groupwise_quant_algo & GroupwiseQuantAlgo.W4A8_ALPHA:
|
||||
|
||||
@ -4,7 +4,8 @@ from typing import Dict, List, Optional
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils.util import skip_pre_blackwell, skip_pre_hopper
|
||||
from utils.util import (skip_neither_ada_nor_hopper_unittest,
|
||||
skip_pre_blackwell, skip_pre_hopper)
|
||||
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
@ -271,6 +272,136 @@ def test_fused_moe_nvfp4(dtype):
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||
|
||||
|
||||
@skip_neither_ada_nor_hopper_unittest
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_fused_moe_w4afp8(dtype):
|
||||
|
||||
SEQ_LEN = 4
|
||||
HIDDEN_SIZE = 768
|
||||
INTERMEDIATE_SIZE = 640
|
||||
SCALING_GROUP_SIZE = 128
|
||||
NUM_EXPERTS = 3
|
||||
TOP_K = 2
|
||||
routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
|
||||
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda()
|
||||
|
||||
affine_coeff = 0.005
|
||||
|
||||
weights = {}
|
||||
for expert_id in range(NUM_EXPERTS):
|
||||
w1_weight = torch.randint(-128,
|
||||
127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2),
|
||||
dtype=torch.int8).cuda()
|
||||
w2_weight = torch.randint(-128,
|
||||
127, (HIDDEN_SIZE, INTERMEDIATE_SIZE // 2),
|
||||
dtype=torch.int8).cuda()
|
||||
w3_weight = torch.randint(-128,
|
||||
127, (INTERMEDIATE_SIZE, HIDDEN_SIZE // 2),
|
||||
dtype=torch.int8).cuda()
|
||||
|
||||
w1_scale = torch.randn(
|
||||
(INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE),
|
||||
dtype=dtype).cuda() * affine_coeff
|
||||
w2_scale = torch.randn(
|
||||
(HIDDEN_SIZE, INTERMEDIATE_SIZE // SCALING_GROUP_SIZE),
|
||||
dtype=dtype).cuda() * affine_coeff
|
||||
w3_scale = torch.randn(
|
||||
(INTERMEDIATE_SIZE, HIDDEN_SIZE // SCALING_GROUP_SIZE),
|
||||
dtype=dtype).cuda() * affine_coeff
|
||||
|
||||
w1_input = torch.randn(1, dtype=torch.float32).cuda() * 0.02
|
||||
w2_input = w1_input
|
||||
w3_input = w1_input
|
||||
|
||||
weights[f"{expert_id}.w1.weight"] = w1_weight
|
||||
weights[f"{expert_id}.w2.weight"] = w2_weight
|
||||
weights[f"{expert_id}.w3.weight"] = w3_weight
|
||||
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_scale
|
||||
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_scale
|
||||
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_scale
|
||||
weights[f"{expert_id}.w1.input_scale"] = w1_input
|
||||
weights[f"{expert_id}.w2.input_scale"] = w2_input
|
||||
weights[f"{expert_id}.w3.input_scale"] = w3_input
|
||||
|
||||
quant_config = QuantConfig(quant_algo=QuantAlgo.W4A8_AWQ)
|
||||
fused_moe = FusedMoE(num_experts=NUM_EXPERTS,
|
||||
routing_method=routing_method,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
dtype=dtype,
|
||||
reduce_results=False,
|
||||
model_config=ModelConfig(quant_config=quant_config))
|
||||
fused_moe.load_weights([weights])
|
||||
fused_moe.cuda()
|
||||
|
||||
def ref():
|
||||
results = torch.zeros_like(x)
|
||||
selected_experts, final_scales = routing_method.apply(router_logits)
|
||||
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
|
||||
for e_idx in range(NUM_EXPERTS):
|
||||
mask = selected_experts == e_idx
|
||||
activated_tokens = mask.sum(1).bool()
|
||||
act = x[activated_tokens, :]
|
||||
if act.shape[0] == 0:
|
||||
continue
|
||||
final_scale = (final_scales *
|
||||
mask).sum(1)[activated_tokens].unsqueeze(1)
|
||||
|
||||
# weights
|
||||
w1 = weights[f"{e_idx}.w1.weight"]
|
||||
w1 = unpacker(w1.cpu()).T.contiguous().cuda()
|
||||
w2 = weights[f"{e_idx}.w2.weight"]
|
||||
w2 = unpacker(w2.cpu()).T.contiguous().cuda()
|
||||
w3 = weights[f"{e_idx}.w3.weight"]
|
||||
w3 = unpacker(w3.cpu()).T.contiguous().cuda()
|
||||
w3_w1 = torch.cat([w3, w1], dim=-1)
|
||||
|
||||
# scales
|
||||
s1 = weights[f"{e_idx}.w1.weight_scale_inv"].T.contiguous().cuda()
|
||||
s2 = weights[f"{e_idx}.w2.weight_scale_inv"].T.contiguous().cuda()
|
||||
s3 = weights[f"{e_idx}.w3.weight_scale_inv"].T.contiguous().cuda()
|
||||
s3_s1 = torch.cat([s3, s1], dim=-1)
|
||||
|
||||
# prequant / alpha
|
||||
p1 = weights[f"{e_idx}.w1.input_scale"].cuda()
|
||||
p2 = weights[f"{e_idx}.w2.input_scale"].cuda()
|
||||
p3 = weights[f"{e_idx}.w3.input_scale"].cuda()
|
||||
p3_p1 = max(p1, p3)
|
||||
|
||||
act = torch.clamp((act / p3_p1), -448.0,
|
||||
448.0).to(torch.float8_e4m3fn).to(dtype)
|
||||
w3_w1 = (w3_w1.float() *
|
||||
s3_s1.repeat_interleave(128, dim=0).float()).to(dtype)
|
||||
fc1 = torch.matmul(act, w3_w1) * p3_p1
|
||||
fc1, gate = fc1.chunk(2, dim=-1)
|
||||
fc1 = fc1 * torch.nn.functional.silu(gate)
|
||||
|
||||
act = torch.clamp((fc1 / p2), -448.0,
|
||||
448.0).to(torch.float8_e4m3fn).to(dtype)
|
||||
w2 = (w2.float() *
|
||||
s2.repeat_interleave(128, dim=0).float()).to(dtype)
|
||||
fc2 = torch.matmul(act, w2) * p2
|
||||
results[activated_tokens, :] += (fc2 * final_scale).to(
|
||||
results.dtype)
|
||||
return results
|
||||
|
||||
AutoTuner.get().clear_cache()
|
||||
with torch.inference_mode(), autotune():
|
||||
fused_moe.forward(x, router_logits)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
with torch.inference_mode():
|
||||
output = fused_moe.forward(x, router_logits)
|
||||
ref_output = ref()
|
||||
|
||||
# compare
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||
|
||||
|
||||
class RefGatedMLPFusedMoE(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@ -154,14 +154,14 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
|
||||
2**31, (num_experts, n, k // num_weights_in_32_bits),
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
pre_quant_scale_1 = torch.randn(1,
|
||||
k,
|
||||
dtype=activation_dtype,
|
||||
device="cuda")
|
||||
pre_quant_scale_2 = torch.randn(1,
|
||||
n,
|
||||
dtype=activation_dtype,
|
||||
device="cuda")
|
||||
pre_quant_scale_1 = torch.ones(num_experts,
|
||||
1,
|
||||
dtype=activation_dtype,
|
||||
device="cuda")
|
||||
pre_quant_scale_2 = torch.ones(num_experts,
|
||||
1,
|
||||
dtype=activation_dtype,
|
||||
device="cuda")
|
||||
scale_1 = torch.randn(num_experts,
|
||||
k // group_size,
|
||||
n * 2,
|
||||
@ -182,14 +182,8 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
|
||||
k,
|
||||
dtype=activation_dtype,
|
||||
device="cuda") * 0.01
|
||||
alpha_1 = torch.randn(num_experts,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
alpha_2 = torch.randn(num_experts,
|
||||
1,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
alpha_1 = 1 / pre_quant_scale_1.float()
|
||||
alpha_2 = 1 / pre_quant_scale_2.float()
|
||||
|
||||
preprocessor = tensorrt_llm.quantization.functional.preprocess_weights_for_mixed_gemm
|
||||
unpacker = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8
|
||||
@ -242,8 +236,9 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
|
||||
input = inputs_merged[i, :]
|
||||
fc1_qd = ref_weight_1[expert].cuda().float()
|
||||
if has_pre_quant:
|
||||
input = input * pre_quant_scale_1.squeeze()
|
||||
input = input * pre_quant_scale_1[expert]
|
||||
if has_alpha:
|
||||
input[input > 448.0] = 448.0
|
||||
input = input.to(torch.float8_e4m3fn).float()
|
||||
fc1_qd = fc1_qd.to(torch.float8_e4m3fn).float()
|
||||
fc1 = torch.matmul(input, fc1_qd) * alpha_1[expert]
|
||||
@ -253,8 +248,9 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
|
||||
fc1 = fc1 * torch.nn.functional.silu(gate)
|
||||
fc2_qd = ref_weight_2[expert].cuda().float()
|
||||
if has_pre_quant:
|
||||
fc1 = fc1 * pre_quant_scale_2.squeeze()
|
||||
fc1 = fc1 * pre_quant_scale_2[expert]
|
||||
if has_alpha:
|
||||
fc1[fc1 > 448.0] = 448.0
|
||||
fc1 = fc1.to(torch.float8_e4m3fn).float()
|
||||
fc2_qd = fc2_qd.to(torch.float8_e4m3fn).float()
|
||||
final = torch.matmul(fc1, fc2_qd) * alpha_2[expert]
|
||||
@ -282,11 +278,6 @@ class TestMoEWeightOnlyGroupWiseQuantMatmul(unittest.TestCase):
|
||||
name_func=unittest_name_func)
|
||||
@skip_non_ada_unittest
|
||||
def test_moe_w4a8(self, m, n, k, experts, dtype, has_pre_quant, has_zero):
|
||||
# Skip specific problematic case
|
||||
if m == 1 and n == 14336 and k == 4096 and experts == 8 and dtype == "bfloat16" and has_pre_quant and not has_zero:
|
||||
self.skipTest(
|
||||
"Skipping problematic case test_moe_w4a8_1_14336_4096_8_bfloat16_True_False"
|
||||
)
|
||||
|
||||
self._woq_moe_groupwise_matmul(m, n, k, experts, dtype, torch.quint4x2,
|
||||
has_pre_quant, has_zero, True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user