From b4e2621de80e410b9822f498b165ac6446890b7d Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 15 May 2026 11:58:13 +0200 Subject: [PATCH] add mxfp4 repacking --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 161 +++++++++++++++++- .../vulkan-shaders/dequant_funcs.glsl | 16 ++ .../vulkan-shaders/dequant_funcs_cm2.glsl | 17 +- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 8 + .../vulkan-shaders/mul_mm_cm2.comp | 5 +- .../vulkan-shaders/mul_mm_funcs.glsl | 6 + .../vulkan-shaders/mul_mmq_funcs.glsl | 8 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +- 8 files changed, 215 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8680afded8..06cfefdaab 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7315,8 +7315,30 @@ static size_t ggml_vk_repack_q4_0_size_tensor(const ggml_tensor * tensor) { return ggml_vk_repack_q4_0_size(ggml_vk_get_num_blocks(tensor)); } +static size_t ggml_vk_repack_mxfp4_scale_offset(size_t n_blocks) { + return GGML_PAD(n_blocks * 16, VULKAN_REPACK_ALIGNMENT); +} + +static size_t ggml_vk_repack_mxfp4_size(size_t n_blocks) { + return ggml_vk_repack_mxfp4_scale_offset(n_blocks) + n_blocks * 1; +} + +static size_t ggml_vk_repack_mxfp4_scale_offset_tensor(const ggml_tensor * tensor) { + return ggml_vk_repack_mxfp4_scale_offset(ggml_vk_get_num_blocks(tensor)); +} + +static size_t ggml_vk_repack_mxfp4_size_tensor(const ggml_tensor * tensor) { + return ggml_vk_repack_mxfp4_size(ggml_vk_get_num_blocks(tensor)); +} + static size_t ggml_vk_repack_size_tensor(const ggml_tensor * tensor) { - return tensor->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_size_tensor(tensor) : ggml_nbytes(tensor); + if (tensor->type == GGML_TYPE_Q4_0) { + return ggml_vk_repack_q4_0_size_tensor(tensor); + } + if (tensor->type == GGML_TYPE_MXFP4) { + return ggml_vk_repack_mxfp4_size_tensor(tensor); + } + return ggml_nbytes(tensor); } static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) { @@ -7960,7 +7982,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } - const uint32_t deltas_offset = src0->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2 : 0; + uint32_t deltas_offset = 0; + if (src0->type == GGML_TYPE_Q4_0) { + deltas_offset = ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2; + } else if (src0->type == GGML_TYPE_MXFP4) { + deltas_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(src0); + } // compute ggml_vk_matmul( @@ -8262,7 +8289,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1])); - const uint32_t deltas_offset = src0->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2 : 0; + uint32_t deltas_offset = 0; + if (src0->type == GGML_TYPE_Q4_0) { + deltas_offset = ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2; + } else if (src0->type == GGML_TYPE_MXFP4) { + deltas_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(src0); + } uint32_t base_work_group_y = 0; while (base_work_group_y < ne12 * ne13) { @@ -8818,7 +8850,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } - const uint32_t deltas_offset = src0->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2 : 0; + uint32_t deltas_offset = 0; + if (src0->type == GGML_TYPE_Q4_0) { + deltas_offset = ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2; + } else if (src0->type == GGML_TYPE_MXFP4) { + deltas_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(src0); + } // compute ggml_vk_matmul_id( @@ -9043,7 +9080,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1; } - const uint32_t deltas_offset = src0->type == GGML_TYPE_Q4_0 ? ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2 : 0; + uint32_t deltas_offset = 0; + if (src0->type == GGML_TYPE_Q4_0) { + deltas_offset = ggml_vk_repack_q4_0_delta_offset_tensor(src0) / 2; + } else if (src0->type == GGML_TYPE_MXFP4) { + deltas_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(src0); + } // Loop over the batch dimension for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) { @@ -13862,6 +13904,27 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml return; } + if (tensor->type == GGML_TYPE_MXFP4) { + const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(tensor); + const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(tensor); + + void * data_repacked = malloc(repacked_size); + uint8_t * quants = (uint8_t *)data_repacked; + uint8_t * scales = (uint8_t *)data_repacked + scales_offset; + + const block_mxfp4 * src = (const block_mxfp4 *)data; + + for (size_t i = 0; i < ggml_vk_get_num_blocks(tensor); i++) { + memcpy(quants + 16 * i, src[i].qs, 16); + scales[i] = src[i].e; + } + + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data_repacked, repacked_size); + + free(data_repacked); + return; + } + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } @@ -13910,6 +13973,27 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons return; } + if (tensor->type == GGML_TYPE_MXFP4) { + const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(tensor); + const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(tensor); + + void * data_repacked = malloc(repacked_size); + uint8_t * quants = (uint8_t *)data_repacked; + uint8_t * scales = (uint8_t *)data_repacked + scales_offset; + + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data_repacked, repacked_size); + + block_mxfp4 * dst = (block_mxfp4 *)data; + + for (size_t i = 0; i < ggml_vk_get_num_blocks(tensor); i++) { + memcpy(dst[i].qs, quants + 16 * i, 16); + dst[i].e = scales[i]; + } + + free(data_repacked); + return; + } + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } @@ -14009,6 +14093,12 @@ static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_typ return ggml_vk_repack_q4_0_size(num_blocks_per_row * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]); } + if (tensor->type == GGML_TYPE_MXFP4) { + const size_t num_blocks_per_row = tensor->ne[0] / ggml_blck_size(tensor->type); + + return ggml_vk_repack_mxfp4_size(num_blocks_per_row * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]); + } + return ggml_nbytes(tensor); UNUSED(buft); @@ -14159,6 +14249,27 @@ static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_ten return; } + if (tensor->type == GGML_TYPE_MXFP4) { + const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(tensor); + const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(tensor); + + void * data_repacked = malloc(repacked_size); + uint8_t * quants = (uint8_t *)data_repacked; + uint8_t * scales = (uint8_t *)data_repacked + scales_offset; + + const block_mxfp4 * src = (const block_mxfp4 *)data; + + for (size_t i = 0; i < ggml_vk_get_num_blocks(tensor); i++) { + memcpy(quants + 16 * i, src[i].qs, 16); + scales[i] = src[i].e; + } + + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data_repacked, repacked_size); + + free(data_repacked); + return; + } + vk_context cpy_ctx; if (ctx->device->async_use_transfer_queue) { @@ -14248,6 +14359,27 @@ static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const gg return; } + if (tensor->type == GGML_TYPE_MXFP4) { + const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(tensor); + const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(tensor); + + void * data_repacked = malloc(repacked_size); + uint8_t * quants = (uint8_t *)data_repacked; + uint8_t * scales = (uint8_t *)data_repacked + scales_offset; + + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data_repacked, repacked_size); + + block_mxfp4 * dst = (block_mxfp4 *)data; + + for (size_t i = 0; i < ggml_vk_get_num_blocks(tensor); i++) { + memcpy(dst[i].qs, quants + 16 * i, 16); + dst[i].e = scales[i]; + } + + free(data_repacked); + return; + } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; @@ -16789,6 +16921,25 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * dst[i].d = deltas[i]; } + free(data_repacked); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (srci->type == GGML_TYPE_MXFP4) { + const size_t repacked_size = ggml_vk_repack_mxfp4_size_tensor(srci); + const size_t scales_offset = ggml_vk_repack_mxfp4_scale_offset_tensor(srci); + + void * data_repacked = malloc(repacked_size); + uint8_t * quants = (uint8_t *)data_repacked; + uint8_t * scales = (uint8_t *)data_repacked + scales_offset; + + ggml_vk_buffer_read(buffer_gpu, offset, data_repacked, repacked_size); + + block_mxfp4 * dst = (block_mxfp4 *)srci_clone->data; + + for (size_t i = 0; i < ggml_vk_get_num_blocks(srci); i++) { + memcpy(dst[i].qs, quants + 16 * i, 16); + dst[i].e = scales[i]; + } + free(data_repacked); memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); } else { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 01c78c4e33..954e56cc19 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -450,6 +450,17 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { #endif #if defined(DATA_A_MXFP4) +#if defined(A_TYPE_REPACKED) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]); + return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5; +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]); + return vec4(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[(vui >> 4) & 0xF], + kvalues_mxfp4[(vui >> 8) & 0xF], kvalues_mxfp4[vui >> 12]) * 0.5; +} +#else vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5; @@ -460,6 +471,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { return vec4(v0.x, v0.y, v1.x, v1.y); } #endif +#endif #if defined(DATA_A_NVFP4) vec2 dequantize(uint ib, uint iqs, uint a_offset) { @@ -514,7 +526,11 @@ vec2 get_dm(uint ib, uint a_offset) { #if defined(DATA_A_MXFP4) vec2 get_dm(uint ib, uint a_offset) { +#if defined(A_TYPE_REPACKED) + return vec2(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + a_offset + ib])), 0); +#else return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); +#endif } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 2c485bc842..2f0c5ad1d7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -691,18 +691,33 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #endif #if defined(DATA_A_MXFP4) +#ifdef A_TYPE_REPACKED +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufMXFP4 { + uint16_t qs[8]; +}; +#else layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 { block_mxfp4 block; }; +#endif float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { - const float d = e8m0_to_fp32(bl.block.e); const uint idx = coordInBlock[1]; +#ifdef A_TYPE_REPACKED + const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1]; + const float d = e8m0_to_fp32(data_a_scales[p.deltas_offset + ib]); + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = uint32_t(bl.qs[(iqs & 0xE) >> 1]); + qs >>= ((iqs & 1) * 8 + shift); +#else + const float d = e8m0_to_fp32(bl.block.e); const uint iqs = idx & 0xF; const uint shift = (idx & 0x10) >> 2; uint32_t qs = bl.block.qs[iqs]; qs >>= shift; +#endif qs &= 0xF; float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); return ret; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index ed7b911576..390967c3bd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -22,7 +22,11 @@ FLOAT_TYPEV2 get_dm(uint ib) { #if defined(DATA_A_MXFP4) FLOAT_TYPE get_dm(uint ib) { +#if defined(A_TYPE_REPACKED) + return FLOAT_TYPE(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib]))); +#else return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +#endif } #endif @@ -123,10 +127,14 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const i #if defined(DATA_A_MXFP4) // 1-byte loads for mxfp4 blocks (17 bytes) i32vec2 repack(uint ib, uint iqs) { +#if defined(A_TYPE_REPACKED) + const uint32_t qs = data_a_quants32[ib * 4 + iqs]; +#else const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], data_a[ib].qs[iqs * 4 + 1], data_a[ib].qs[iqs * 4 + 2], data_a[ib].qs[iqs * 4 + 3])); +#endif const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 3bb3c4af72..096538a6af 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -68,9 +68,10 @@ layout (push_constant) uniform parameter #ifdef A_TYPE_REPACKED -struct block_q4_0_quants { uint16_t qs[8]; }; -layout (binding = 0) readonly buffer A {block_q4_0_quants data_a[];}; +struct block_repacked_quants { uint16_t qs[8]; }; +layout (binding = 0) readonly buffer A {block_repacked_quants data_a[];}; layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];}; +layout (binding = 0) readonly buffer A_SCALES {uint8_t data_a_scales[];}; #else layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index be1397952f..cc7d526654 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -500,9 +500,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 8; const uint iqs = (idx & 0x07) * 2; +#if defined(A_TYPE_REPACKED) + const float d = e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])) * 0.5; + const uint vui = uint(data_a_quants[ib * 16 + iqs]); + const uint vui2 = uint(data_a_quants[ib * 16 + iqs + 1]); +#else const float d = e8m0_to_fp32(data_a[ib].e) * 0.5; const uint vui = uint(data_a[ib].qs[iqs]); const uint vui2 = uint(data_a[ib].qs[iqs+1]); +#endif buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, kvalues_mxfp4[vui2 & 0xF] * d); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index d131fbb80a..0e9e3d6912 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -155,10 +155,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { #if defined(DATA_A_MXFP4) // 1-byte loads for mxfp4 blocks (17 bytes) void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#if defined(A_TYPE_REPACKED) + const uint32_t qs = data_a_quants32[ib * 4 + iqs]; +#else const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], data_a[ib].qs[iqs * 4 + 1], data_a[ib].qs[iqs * 4 + 2], data_a[ib].qs[iqs * 4 + 3])); +#endif const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); @@ -167,7 +171,11 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])); if (iqs == 0) { +#if defined(A_TYPE_REPACKED) + buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])) * 0.5); +#else buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index fd7a2d3eb3..48240bdda5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -565,7 +565,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } std::map mm_base_dict = base_dict; - if (tname == "q4_0") { + if (tname == "q4_0" || tname == "mxfp4") { mm_base_dict["A_TYPE_REPACKED"] = "1"; } @@ -671,7 +671,7 @@ void process_shaders() { for (const auto& tname : type_names) { std::map mmv_base_dict = base_dict; - if (tname == "q4_0") { + if (tname == "q4_0" || tname == "mxfp4") { mmv_base_dict["A_TYPE_REPACKED"] = "1"; }