mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Perf] Use 2D-grid to eliminate divmod in W8W8 group quant (#42153)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
(cherry picked from commit dd6b3a5ef5)
This commit is contained in:
committed by
khluu
parent
2a2ac21d3d
commit
65df49eba3
@@ -156,6 +156,17 @@ inline int GetGroupsPerBlock(int64_t num_groups) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Largest divisor of padded_groups_per_row that is <= 16. ry = 16 / kx.
|
||||
inline int GetGroupsPerBlockX(int64_t padded_groups_per_row) {
|
||||
if (padded_groups_per_row % 16 == 0) {
|
||||
return 16;
|
||||
}
|
||||
if (padded_groups_per_row % 8 == 0) {
|
||||
return 8;
|
||||
}
|
||||
return 4;
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit(const torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& output_q,
|
||||
torch::stable::Tensor& output_s,
|
||||
@@ -247,11 +258,11 @@ void per_token_group_quant_8bit(const torch::stable::Tensor& input,
|
||||
//
|
||||
// Constraints: GROUP_SIZE % (THREADS_PER_GROUP * VEC_SIZE) == 0; for
|
||||
// THREADS_PER_GROUP=8 and bf16/fp16 (VEC_SIZE=16), this means GROUP_SIZE=128.
|
||||
template <typename T, typename DST_DTYPE, int GROUP_SIZE>
|
||||
template <typename T, typename DST_DTYPE, int GROUP_SIZE, int kGroupsPerBlockX,
|
||||
int kRowsPerBlock>
|
||||
__global__ void per_token_group_quant_8bit_packed_register_kernel(
|
||||
const T* __restrict__ input, void* __restrict__ output_q,
|
||||
unsigned int* __restrict__ output_s_packed, const int64_t num_groups_padded,
|
||||
const int groups_per_block, const int padded_groups_per_row,
|
||||
unsigned int* __restrict__ output_s_packed, const int padded_groups_per_row,
|
||||
const int groups_per_row, const int mn, const int output_q_mn_extent,
|
||||
const int tma_aligned_mn, const int64_t num_scale_elems, const float eps,
|
||||
const float min_8bit, const float max_8bit) {
|
||||
@@ -260,27 +271,25 @@ __global__ void per_token_group_quant_8bit_packed_register_kernel(
|
||||
constexpr int VEC_SIZE = 32 / sizeof(T); // 16 for bf16/fp16
|
||||
static_assert(GROUP_SIZE == THREADS_PER_GROUP * VEC_SIZE,
|
||||
"GROUP_SIZE must equal THREADS_PER_GROUP * VEC_SIZE");
|
||||
// Each group's 8 threads must live in a single warp octet so the
|
||||
// 0xffu << (threadIdx.x & 24u) shuffle mask selects exactly the lanes
|
||||
// that share a group. Requires 32 % THREADS_PER_GROUP == 0 and the host
|
||||
// to launch num_threads as a multiple of THREADS_PER_GROUP (which it does
|
||||
// via num_threads = groups_per_block * THREADS_PER_GROUP).
|
||||
static_assert(32 % THREADS_PER_GROUP == 0,
|
||||
"THREADS_PER_GROUP must divide warp size for the shuffle "
|
||||
"mask to be valid");
|
||||
static_assert(
|
||||
kGroupsPerBlockX > 0 && (kGroupsPerBlockX & (kGroupsPerBlockX - 1)) == 0,
|
||||
"kGroupsPerBlockX must be a positive power of 2");
|
||||
static_assert(kRowsPerBlock > 0, "kRowsPerBlock must be positive");
|
||||
|
||||
const int local_group_id = threadIdx.x / THREADS_PER_GROUP;
|
||||
const int lane_id = threadIdx.x % THREADS_PER_GROUP;
|
||||
|
||||
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
if (global_group_id >= num_groups_padded) {
|
||||
const int sf_k_local = local_group_id % kGroupsPerBlockX;
|
||||
const int row_local = local_group_id / kGroupsPerBlockX;
|
||||
const int sf_k_idx = blockIdx.x * kGroupsPerBlockX + sf_k_local;
|
||||
const int mn_idx = blockIdx.y * kRowsPerBlock + row_local;
|
||||
|
||||
if (mn_idx >= tma_aligned_mn) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int sf_k_idx =
|
||||
static_cast<int>(global_group_id % padded_groups_per_row);
|
||||
const int mn_idx = static_cast<int>(global_group_id / padded_groups_per_row);
|
||||
const bool is_valid_group = (mn_idx < mn) && (sf_k_idx < groups_per_row);
|
||||
|
||||
// Load 16 input elements (32 B) into registers as two adjacent uint4
|
||||
@@ -443,34 +452,53 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 8;
|
||||
const int64_t padded_groups_per_row = k_num_packed_sfk * 4;
|
||||
const int64_t num_groups_padded = tma_aligned_mn * padded_groups_per_row;
|
||||
const int64_t num_scale_elems = mn + (k_num_packed_sfk - 1) * tma_aligned_mn;
|
||||
const int groups_per_block = GetGroupsPerBlock(num_groups_padded);
|
||||
|
||||
STD_TORCH_CHECK(padded_groups_per_row % 4 == 0,
|
||||
"padded_groups_per_row=", padded_groups_per_row,
|
||||
" is not a multiple of 4.");
|
||||
const int kx = GetGroupsPerBlockX(padded_groups_per_row);
|
||||
const int ry = 16 / kx;
|
||||
const int64_t blocks_x = padded_groups_per_row / kx;
|
||||
const int64_t blocks_y = (tma_aligned_mn + ry - 1) / ry;
|
||||
const int num_threads = (kx * ry) * THREADS_PER_GROUP;
|
||||
// CUDA caps grid.x and grid.y at 2^31 - 1; guard against pathological inputs.
|
||||
STD_TORCH_CHECK(blocks_x <= static_cast<int64_t>(INT32_MAX) &&
|
||||
blocks_y <= static_cast<int64_t>(INT32_MAX),
|
||||
"per_token_group_quant_8bit_packed grid too large: (",
|
||||
blocks_x, ", ", blocks_y, ").");
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int64_t num_blocks = num_groups_padded / groups_per_block;
|
||||
const int num_threads = groups_per_block * THREADS_PER_GROUP;
|
||||
// CUDA caps grid.x at 2^31 - 1; this fits any realistic shape but guard
|
||||
// against pathological inputs.
|
||||
STD_TORCH_CHECK(num_blocks <= static_cast<int64_t>(INT32_MAX),
|
||||
"per_token_group_quant_8bit_packed grid too large: ",
|
||||
num_blocks, " blocks (max ", INT32_MAX, ").");
|
||||
|
||||
#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(static_cast<unsigned int>(num_blocks)); \
|
||||
dim3 block(num_threads); \
|
||||
per_token_group_quant_8bit_packed_register_kernel<T, DST_DTYPE, 128> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
|
||||
num_groups_padded, groups_per_block, \
|
||||
static_cast<int>(padded_groups_per_row), \
|
||||
static_cast<int>(groups_per_row), static_cast<int>(mn), \
|
||||
static_cast<int>(output_q_mn_extent), \
|
||||
static_cast<int>(tma_aligned_mn), num_scale_elems, \
|
||||
static_cast<float>(eps), static_cast<float>(min_8bit), \
|
||||
static_cast<float>(max_8bit)); \
|
||||
#define LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, KX, RY) \
|
||||
do { \
|
||||
dim3 grid(static_cast<unsigned int>(blocks_x), \
|
||||
static_cast<unsigned int>(blocks_y)); \
|
||||
dim3 block(num_threads); \
|
||||
per_token_group_quant_8bit_packed_register_kernel<T, DST_DTYPE, 128, KX, \
|
||||
RY> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
|
||||
static_cast<int>(padded_groups_per_row), \
|
||||
static_cast<int>(groups_per_row), static_cast<int>(mn), \
|
||||
static_cast<int>(output_q_mn_extent), \
|
||||
static_cast<int>(tma_aligned_mn), num_scale_elems, \
|
||||
static_cast<float>(eps), static_cast<float>(min_8bit), \
|
||||
static_cast<float>(max_8bit)); \
|
||||
} while (0)
|
||||
|
||||
#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
if (kx == 16) { \
|
||||
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 16, 1); \
|
||||
} else if (kx == 8) { \
|
||||
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 8, 2); \
|
||||
} else if (kx == 4) { \
|
||||
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 4, 4); \
|
||||
} else { \
|
||||
STD_TORCH_CHECK(false, "Unsupported kx value ", kx); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
@@ -488,6 +516,7 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
|
||||
}));
|
||||
|
||||
#undef LAUNCH_REG_KERNEL
|
||||
#undef LAUNCH_REG_KERNEL_INST
|
||||
}
|
||||
|
||||
void per_token_group_quant_fp8(const torch::stable::Tensor& input,
|
||||
|
||||
Reference in New Issue
Block a user