[Perf] Use 2D-grid to eliminate divmod in W8W8 group quant (#42153)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
(cherry picked from commit dd6b3a5ef5)
This commit is contained in:
Jiahan Chang (Cyrus)
2026-05-12 22:01:30 +08:00
committed by khluu
parent 2a2ac21d3d
commit 65df49eba3
@@ -156,6 +156,17 @@ inline int GetGroupsPerBlock(int64_t num_groups) {
return 1;
}
// Largest divisor of padded_groups_per_row that is <= 16. ry = 16 / kx.
inline int GetGroupsPerBlockX(int64_t padded_groups_per_row) {
if (padded_groups_per_row % 16 == 0) {
return 16;
}
if (padded_groups_per_row % 8 == 0) {
return 8;
}
return 4;
}
void per_token_group_quant_8bit(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
@@ -247,11 +258,11 @@ void per_token_group_quant_8bit(const torch::stable::Tensor& input,
//
// Constraints: GROUP_SIZE % (THREADS_PER_GROUP * VEC_SIZE) == 0; for
// THREADS_PER_GROUP=8 and bf16/fp16 (VEC_SIZE=16), this means GROUP_SIZE=128.
template <typename T, typename DST_DTYPE, int GROUP_SIZE>
template <typename T, typename DST_DTYPE, int GROUP_SIZE, int kGroupsPerBlockX,
int kRowsPerBlock>
__global__ void per_token_group_quant_8bit_packed_register_kernel(
const T* __restrict__ input, void* __restrict__ output_q,
unsigned int* __restrict__ output_s_packed, const int64_t num_groups_padded,
const int groups_per_block, const int padded_groups_per_row,
unsigned int* __restrict__ output_s_packed, const int padded_groups_per_row,
const int groups_per_row, const int mn, const int output_q_mn_extent,
const int tma_aligned_mn, const int64_t num_scale_elems, const float eps,
const float min_8bit, const float max_8bit) {
@@ -260,27 +271,25 @@ __global__ void per_token_group_quant_8bit_packed_register_kernel(
constexpr int VEC_SIZE = 32 / sizeof(T); // 16 for bf16/fp16
static_assert(GROUP_SIZE == THREADS_PER_GROUP * VEC_SIZE,
"GROUP_SIZE must equal THREADS_PER_GROUP * VEC_SIZE");
// Each group's 8 threads must live in a single warp octet so the
// 0xffu << (threadIdx.x & 24u) shuffle mask selects exactly the lanes
// that share a group. Requires 32 % THREADS_PER_GROUP == 0 and the host
// to launch num_threads as a multiple of THREADS_PER_GROUP (which it does
// via num_threads = groups_per_block * THREADS_PER_GROUP).
static_assert(32 % THREADS_PER_GROUP == 0,
"THREADS_PER_GROUP must divide warp size for the shuffle "
"mask to be valid");
static_assert(
kGroupsPerBlockX > 0 && (kGroupsPerBlockX & (kGroupsPerBlockX - 1)) == 0,
"kGroupsPerBlockX must be a positive power of 2");
static_assert(kRowsPerBlock > 0, "kRowsPerBlock must be positive");
const int local_group_id = threadIdx.x / THREADS_PER_GROUP;
const int lane_id = threadIdx.x % THREADS_PER_GROUP;
const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id;
if (global_group_id >= num_groups_padded) {
const int sf_k_local = local_group_id % kGroupsPerBlockX;
const int row_local = local_group_id / kGroupsPerBlockX;
const int sf_k_idx = blockIdx.x * kGroupsPerBlockX + sf_k_local;
const int mn_idx = blockIdx.y * kRowsPerBlock + row_local;
if (mn_idx >= tma_aligned_mn) {
return;
}
const int sf_k_idx =
static_cast<int>(global_group_id % padded_groups_per_row);
const int mn_idx = static_cast<int>(global_group_id / padded_groups_per_row);
const bool is_valid_group = (mn_idx < mn) && (sf_k_idx < groups_per_row);
// Load 16 input elements (32 B) into registers as two adjacent uint4
@@ -443,34 +452,53 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
constexpr int THREADS_PER_GROUP = 8;
const int64_t padded_groups_per_row = k_num_packed_sfk * 4;
const int64_t num_groups_padded = tma_aligned_mn * padded_groups_per_row;
const int64_t num_scale_elems = mn + (k_num_packed_sfk - 1) * tma_aligned_mn;
const int groups_per_block = GetGroupsPerBlock(num_groups_padded);
STD_TORCH_CHECK(padded_groups_per_row % 4 == 0,
"padded_groups_per_row=", padded_groups_per_row,
" is not a multiple of 4.");
const int kx = GetGroupsPerBlockX(padded_groups_per_row);
const int ry = 16 / kx;
const int64_t blocks_x = padded_groups_per_row / kx;
const int64_t blocks_y = (tma_aligned_mn + ry - 1) / ry;
const int num_threads = (kx * ry) * THREADS_PER_GROUP;
// CUDA caps grid.x and grid.y at 2^31 - 1; guard against pathological inputs.
STD_TORCH_CHECK(blocks_x <= static_cast<int64_t>(INT32_MAX) &&
blocks_y <= static_cast<int64_t>(INT32_MAX),
"per_token_group_quant_8bit_packed grid too large: (",
blocks_x, ", ", blocks_y, ").");
auto dst_type = output_q.scalar_type();
const int64_t num_blocks = num_groups_padded / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP;
// CUDA caps grid.x at 2^31 - 1; this fits any realistic shape but guard
// against pathological inputs.
STD_TORCH_CHECK(num_blocks <= static_cast<int64_t>(INT32_MAX),
"per_token_group_quant_8bit_packed grid too large: ",
num_blocks, " blocks (max ", INT32_MAX, ").");
#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(static_cast<unsigned int>(num_blocks)); \
dim3 block(num_threads); \
per_token_group_quant_8bit_packed_register_kernel<T, DST_DTYPE, 128> \
<<<grid, block, 0, stream>>>( \
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
num_groups_padded, groups_per_block, \
static_cast<int>(padded_groups_per_row), \
static_cast<int>(groups_per_row), static_cast<int>(mn), \
static_cast<int>(output_q_mn_extent), \
static_cast<int>(tma_aligned_mn), num_scale_elems, \
static_cast<float>(eps), static_cast<float>(min_8bit), \
static_cast<float>(max_8bit)); \
#define LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, KX, RY) \
do { \
dim3 grid(static_cast<unsigned int>(blocks_x), \
static_cast<unsigned int>(blocks_y)); \
dim3 block(num_threads); \
per_token_group_quant_8bit_packed_register_kernel<T, DST_DTYPE, 128, KX, \
RY> \
<<<grid, block, 0, stream>>>( \
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
static_cast<int>(padded_groups_per_row), \
static_cast<int>(groups_per_row), static_cast<int>(mn), \
static_cast<int>(output_q_mn_extent), \
static_cast<int>(tma_aligned_mn), num_scale_elems, \
static_cast<float>(eps), static_cast<float>(min_8bit), \
static_cast<float>(max_8bit)); \
} while (0)
#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \
do { \
if (kx == 16) { \
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 16, 1); \
} else if (kx == 8) { \
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 8, 2); \
} else if (kx == 4) { \
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 4, 4); \
} else { \
STD_TORCH_CHECK(false, "Unsupported kx value ", kx); \
} \
} while (0)
VLLM_STABLE_DISPATCH_HALF_TYPES(
@@ -488,6 +516,7 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
}));
#undef LAUNCH_REG_KERNEL
#undef LAUNCH_REG_KERNEL_INST
}
void per_token_group_quant_fp8(const torch::stable::Tensor& input,