add mxfp4 repacking

This commit is contained in:
Ruben Ortlam
2026-05-15 11:58:13 +02:00
parent b1243aa933
commit b4e2621de8
8 changed files with 215 additions and 10 deletions
+156 -5
View File
@@ -7315,8 +7315,30 @@ static size_t ggml_vk_repack_q4_0_size_tensor(const ggml_tensor * tensor) {
return ggml_vk_repack_q4_0_size(ggml_vk_get_num_blocks(tensor));
}
static size_t ggml_vk_repack_mxfp4_scale_offset(size_t n_blocks) {
return GGML_PAD(n_blocks * 16, VULKAN_REPACK_ALIGNMENT);
}
static size_t ggml_vk_repack_mxfp4_size(size_t n_blocks) {
return ggml_vk_repack_mxfp4_scale_offset(n_blocks) + n_blocks * 1;
}
static size_t ggml_vk_repack_mxfp4_scale_offset_tensor(const ggml_tensor * tensor) {
return ggml_vk_repack_mxfp4_scale_offset(ggml_vk_get_num_blocks(tensor));
}
static size_t ggml_vk_repack_mxfp4_size_tensor(const ggml_tensor * tensor) {
return ggml_vk_repack_mxfp4_size(ggml_vk_get_num_blocks(tensor));
}
static size_t ggml_vk_repack_size_tensor(const ggml_tensor * tensor) {
return tensor->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_size_tensor(tensor) : ggml_nbytes(tensor);
if (tensor->type == GGML_TYPE_Q4_0) {
return ggml_vk_repack_q4_0_size_tensor(tensor);
}
if (tensor->type == GGML_TYPE_MXFP4) {
return ggml_vk_repack_mxfp4_size_tensor(tensor);
}
return ggml_nbytes(tensor);
}
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
@@ -7960,7 +7982,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
const uint32_t deltas_offset = src0->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2 : 0;
uint32_t deltas_offset = 0;
if (src0->type == GGML_TYPE_Q4_0) {
deltas_offset = ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2;
} else if (src0->type == GGML_TYPE_MXFP4) {
deltas_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(src0);
}
// compute
ggml_vk_matmul(
@@ -8262,7 +8289,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
const uint32_t deltas_offset = src0->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2 : 0;
uint32_t deltas_offset = 0;
if (src0->type == GGML_TYPE_Q4_0) {
deltas_offset = ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2;
} else if (src0->type == GGML_TYPE_MXFP4) {
deltas_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(src0);
}
uint32_t base_work_group_y = 0;
while (base_work_group_y < ne12 * ne13) {
@@ -8818,7 +8850,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
const uint32_t deltas_offset = src0->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2 : 0;
uint32_t deltas_offset = 0;
if (src0->type == GGML_TYPE_Q4_0) {
deltas_offset = ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2;
} else if (src0->type == GGML_TYPE_MXFP4) {
deltas_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(src0);
}
// compute
ggml_vk_matmul_id(
@@ -9043,7 +9080,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
}
const uint32_t deltas_offset = src0->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2 : 0;
uint32_t deltas_offset = 0;
if (src0->type == GGML_TYPE_Q4_0) {
deltas_offset = ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2;
} else if (src0->type == GGML_TYPE_MXFP4) {
deltas_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(src0);
}
// Loop over the batch dimension
for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
@@ -13862,6 +13904,27 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
return;
}
if (tensor->type == GGML_TYPE_MXFP4) {
const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(tensor);
const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(tensor);
void * data_repacked = malloc(repacked_size);
uint8_t * quants = (uint8_t *)data_repacked;
uint8_t * scales = (uint8_t *)data_repacked + scales_offset;
const block_mxfp4 * src = (const block_mxfp4 *)data;
for (size_t i = 0; i < ggml_vk_get_num_blocks(tensor); i++) {
memcpy(quants + 16 * i, src[i].qs, 16);
scales[i] = src[i].e;
}
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data_repacked, repacked_size);
free(data_repacked);
return;
}
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
@@ -13910,6 +13973,27 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
return;
}
if (tensor->type == GGML_TYPE_MXFP4) {
const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(tensor);
const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(tensor);
void * data_repacked = malloc(repacked_size);
uint8_t * quants = (uint8_t *)data_repacked;
uint8_t * scales = (uint8_t *)data_repacked + scales_offset;
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data_repacked, repacked_size);
block_mxfp4 * dst = (block_mxfp4 *)data;
for (size_t i = 0; i < ggml_vk_get_num_blocks(tensor); i++) {
memcpy(dst[i].qs, quants + 16 * i, 16);
dst[i].e = scales[i];
}
free(data_repacked);
return;
}
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
}
@@ -14009,6 +14093,12 @@ static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_typ
return ggml_vk_repack_q4_0_size(num_blocks_per_row * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]);
}
if (tensor->type == GGML_TYPE_MXFP4) {
const size_t num_blocks_per_row = tensor->ne[0] / ggml_blck_size(tensor->type);
return ggml_vk_repack_mxfp4_size(num_blocks_per_row * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]);
}
return ggml_nbytes(tensor);
UNUSED(buft);
@@ -14159,6 +14249,27 @@ static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_ten
return;
}
if (tensor->type == GGML_TYPE_MXFP4) {
const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(tensor);
const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(tensor);
void * data_repacked = malloc(repacked_size);
uint8_t * quants = (uint8_t *)data_repacked;
uint8_t * scales = (uint8_t *)data_repacked + scales_offset;
const block_mxfp4 * src = (const block_mxfp4 *)data;
for (size_t i = 0; i < ggml_vk_get_num_blocks(tensor); i++) {
memcpy(quants + 16 * i, src[i].qs, 16);
scales[i] = src[i].e;
}
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data_repacked, repacked_size);
free(data_repacked);
return;
}
vk_context cpy_ctx;
if (ctx->device->async_use_transfer_queue) {
@@ -14248,6 +14359,27 @@ static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const gg
return;
}
if (tensor->type == GGML_TYPE_MXFP4) {
const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(tensor);
const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(tensor);
void * data_repacked = malloc(repacked_size);
uint8_t * quants = (uint8_t *)data_repacked;
uint8_t * scales = (uint8_t *)data_repacked + scales_offset;
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data_repacked, repacked_size);
block_mxfp4 * dst = (block_mxfp4 *)data;
for (size_t i = 0; i < ggml_vk_get_num_blocks(tensor); i++) {
memcpy(dst[i].qs, quants + 16 * i, 16);
dst[i].e = scales[i];
}
free(data_repacked);
return;
}
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
@@ -16789,6 +16921,25 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
dst[i].d = deltas[i];
}
free(data_repacked);
memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else if (srci->type == GGML_TYPE_MXFP4) {
const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(srci);
const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(srci);
void * data_repacked = malloc(repacked_size);
uint8_t * quants = (uint8_t *)data_repacked;
uint8_t * scales = (uint8_t *)data_repacked + scales_offset;
ggml_vk_buffer_read(buffer_gpu, offset, data_repacked, repacked_size);
block_mxfp4 * dst = (block_mxfp4 *)srci_clone->data;
for (size_t i = 0; i < ggml_vk_get_num_blocks(srci); i++) {
memcpy(dst[i].qs, quants + 16 * i, 16);
dst[i].e = scales[i];
}
free(data_repacked);
memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
} else {
@@ -450,6 +450,17 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
#endif
#if defined(DATA_A_MXFP4)
#if defined(A_TYPE_REPACKED)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]);
return vec4(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[(vui >> 4) & 0xF],
kvalues_mxfp4[(vui >> 8) & 0xF], kvalues_mxfp4[vui >> 12]) * 0.5;
}
#else
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
@@ -460,6 +471,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#endif
#if defined(DATA_A_NVFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
@@ -514,7 +526,11 @@ vec2 get_dm(uint ib, uint a_offset) {
#if defined(DATA_A_MXFP4)
vec2 get_dm(uint ib, uint a_offset) {
#if defined(A_TYPE_REPACKED)
return vec2(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + a_offset + ib])), 0);
#else
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
#endif
}
#endif
@@ -691,18 +691,33 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
#endif
#if defined(DATA_A_MXFP4)
#ifdef A_TYPE_REPACKED
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufMXFP4 {
uint16_t qs[8];
};
#else
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
};
#endif
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float d = e8m0_to_fp32(bl.block.e);
const uint idx = coordInBlock[1];
#ifdef A_TYPE_REPACKED
const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1];
const float d = e8m0_to_fp32(data_a_scales[p.deltas_offset + ib]);
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = uint32_t(bl.qs[(iqs & 0xE) >> 1]);
qs >>= ((iqs & 1) * 8 + shift);
#else
const float d = e8m0_to_fp32(bl.block.e);
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
#endif
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
return ret;
@@ -22,7 +22,11 @@ FLOAT_TYPEV2 get_dm(uint ib) {
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_dm(uint ib) {
#if defined(A_TYPE_REPACKED)
return FLOAT_TYPE(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])));
#else
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
#endif
}
#endif
@@ -123,10 +127,14 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const i
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
#if defined(A_TYPE_REPACKED)
const uint32_t qs = data_a_quants32[ib * 4 + iqs];
#else
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
#endif
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
@@ -68,9 +68,10 @@ layout (push_constant) uniform parameter
#ifdef A_TYPE_REPACKED
struct block_q4_0_quants { uint16_t qs[8]; };
layout (binding = 0) readonly buffer A {block_q4_0_quants data_a[];};
struct block_repacked_quants { uint16_t qs[8]; };
layout (binding = 0) readonly buffer A {block_repacked_quants data_a[];};
layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];};
layout (binding = 0) readonly buffer A_SCALES {uint8_t data_a_scales[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#endif
@@ -500,9 +500,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
#if defined(A_TYPE_REPACKED)
const float d = e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])) * 0.5;
const uint vui = uint(data_a_quants[ib * 16 + iqs]);
const uint vui2 = uint(data_a_quants[ib * 16 + iqs + 1]);
#else
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
#endif
buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d,
kvalues_mxfp4[vui2 & 0xF] * d);
@@ -155,10 +155,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#if defined(A_TYPE_REPACKED)
const uint32_t qs = data_a_quants32[ib * 4 + iqs];
#else
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
#endif
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
@@ -167,7 +171,11 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
if (iqs == 0) {
#if defined(A_TYPE_REPACKED)
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])) * 0.5);
#else
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
#endif
}
}
@@ -565,7 +565,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
std::map<std::string, std::string> mm_base_dict = base_dict;
if (tname == "q4_0") {
if (tname == "q4_0" || tname == "mxfp4") {
mm_base_dict["A_TYPE_REPACKED"] = "1";
}
@@ -671,7 +671,7 @@ void process_shaders() {
for (const auto& tname : type_names) {
std::map<std::string, std::string> mmv_base_dict = base_dict;
if (tname == "q4_0") {
if (tname == "q4_0" || tname == "mxfp4") {
mmv_base_dict["A_TYPE_REPACKED"] = "1";
}