mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Quantization][Deprecation] Remove Marlin 24 (#32688)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -458,7 +458,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/marlin.cu"
|
||||
"csrc/quantization/marlin/marlin_int4_fp8_preprocess.cu"
|
||||
"csrc/quantization/marlin/gptq_marlin_repack.cu"
|
||||
|
||||
@@ -6,12 +6,6 @@ import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
ALLSPARK_SUPPORTED_QUANT_TYPES,
|
||||
@@ -34,9 +28,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
awq_marlin_quantize,
|
||||
marlin_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack,
|
||||
gptq_quantize_weights,
|
||||
@@ -78,14 +69,7 @@ def bench_run(
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
marlin_24_supported = (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
repack_supported = (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in MARLIN_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
repack_supported = group_size in MARLIN_SUPPORTED_GROUP_SIZES
|
||||
allspark_supported = (
|
||||
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||
and group_size == -1
|
||||
@@ -126,14 +110,6 @@ def bench_run(
|
||||
marlin_sort_indices,
|
||||
)
|
||||
|
||||
def gen_marlin_24_params():
|
||||
marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None
|
||||
if marlin_24_supported:
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
|
||||
marlin_24_quantize(b, quant_type, group_size)
|
||||
)
|
||||
return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s)
|
||||
|
||||
def gen_repack_params():
|
||||
q_w_gptq = None
|
||||
repack_sort_indices = None
|
||||
@@ -188,9 +164,6 @@ def bench_run(
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
) = gen_marlin_params()
|
||||
marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = (
|
||||
gen_marlin_24_params()
|
||||
)
|
||||
q_w_gptq, repack_sort_indices = gen_repack_params()
|
||||
qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = (
|
||||
gen_allspark_params()
|
||||
@@ -200,9 +173,6 @@ def bench_run(
|
||||
marlin_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
)
|
||||
marlin_24_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
@@ -222,12 +192,6 @@ def bench_run(
|
||||
"marlin_sort_indices": marlin_sort_indices,
|
||||
"marlin_workspace": marlin_workspace,
|
||||
"is_k_full": is_k_full,
|
||||
# Marlin_24 params
|
||||
"marlin_24_w_ref": marlin_24_w_ref,
|
||||
"marlin_24_q_w_comp": marlin_24_q_w_comp,
|
||||
"marlin_24_meta": marlin_24_meta,
|
||||
"marlin_24_s": marlin_24_s,
|
||||
"marlin_24_workspace": marlin_24_workspace,
|
||||
# GPTQ params
|
||||
"q_w_gptq": q_w_gptq,
|
||||
"repack_sort_indices": repack_sort_indices,
|
||||
@@ -240,7 +204,6 @@ def bench_run(
|
||||
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
|
||||
# Kernels
|
||||
"marlin_gemm": ops.marlin_gemm,
|
||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
|
||||
}
|
||||
@@ -281,17 +244,6 @@ def bench_run(
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if marlin_24_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_24_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if repack_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
Contains code from https://github.com/IST-DASLab/Sparse-Marlin/
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,51 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||
* Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace marlin_24 {
|
||||
|
||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||
// for instance as inputs to tensor core operations. Consequently, all
|
||||
// corresponding index accesses must be compile-time constants, which is why we
|
||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||
// this.
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
|
||||
template <int M_, int N_, int K_>
|
||||
struct ShapeBase {
|
||||
static constexpr int M = M_, N = N_, K = K_;
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
// Matrix fragments for tensor core instructions; their precise layout is
|
||||
// documented here:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||
using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragM = Vec<uint, 1>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>; // quantization scales
|
||||
|
||||
} // namespace marlin_24
|
||||
@@ -1,136 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||
* Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "base.h"
|
||||
|
||||
namespace marlin_24 {
|
||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||
// predication to handle batchsizes that are not multiples of 16.
|
||||
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
|
||||
const void* glob_ptr,
|
||||
bool pred = true,
|
||||
const bool zfill = false) {
|
||||
const int BYTES = 16;
|
||||
int src_in_bytes = (zfill ? 0 : BYTES);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Asynchronous global->shared copy
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Async copy fence.
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
// Wait until at most `n` async copy stages are still pending.
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||
: "r"(smem));
|
||||
}
|
||||
|
||||
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
||||
: "=r"(a[0]), "=r"(a[1])
|
||||
: "r"(smem));
|
||||
}
|
||||
|
||||
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||
// memory, directly in tensor core layout.
|
||||
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
|
||||
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||
: "r"(smem));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||
if (threadIdx.x == 0) {
|
||||
int state = -1;
|
||||
do
|
||||
// Guarantee that subsequent writes by this threadblock will be visible
|
||||
// globally.
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||
: "=r"(state)
|
||||
: "l"(lock));
|
||||
while (state != count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Release barrier and increment visitation count.
|
||||
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
if (reset) {
|
||||
lock[0] = 0;
|
||||
return;
|
||||
}
|
||||
int val = 1;
|
||||
// Make sure that all writes since acquiring this barrier are visible
|
||||
// globally, while releasing the barrier.
|
||||
asm volatile("fence.acq_rel.gpu;\n");
|
||||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||
:
|
||||
: "l"(lock), "r"(val));
|
||||
}
|
||||
}
|
||||
} // namespace marlin_24
|
||||
@@ -1,191 +0,0 @@
|
||||
/*
|
||||
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||
* Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "base.h"
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
namespace marlin_24 {
|
||||
|
||||
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
|
||||
// is not supported. On later versions of CUDA the version without ordered
|
||||
// metadata results in the following warning:
|
||||
// | Advisory: Modifier ‘.sp::ordered_metadata’ should be used on instruction
|
||||
// | ‘mma’ instead of modifier ‘.sp’ as it is expected to have substantially
|
||||
// | reduced performance on some future architectures
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||
#define MMA_SP_INST \
|
||||
"mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
#else
|
||||
#define MMA_SP_INST "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||
#endif
|
||||
|
||||
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
||||
// output/accumulation.
|
||||
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
|
||||
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
|
||||
const int psel) {
|
||||
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
|
||||
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
|
||||
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
|
||||
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
|
||||
|
||||
float* c = reinterpret_cast<float*>(&frag_c);
|
||||
if (psel == 0) {
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||
} else {
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||
asm volatile(MMA_SP_INST
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup-table based 3-input logical operation; explicitly used for
|
||||
// dequantization as the compiler does not seem to automatically recognize it in
|
||||
// all cases.
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
|
||||
float c3) {
|
||||
uint2 r;
|
||||
asm("{\n\t"
|
||||
".reg .f16 a, b, c, d; \n\t"
|
||||
"cvt.rn.f16.f32 a, %2; \n\t"
|
||||
"cvt.rn.f16.f32 b, %3; \n\t"
|
||||
"cvt.rn.f16.f32 c, %4; \n\t"
|
||||
"cvt.rn.f16.f32 d, %5; \n\t"
|
||||
"mov.b32 %0, {a, b}; \n\t"
|
||||
"mov.b32 %1, {c, d}; \n\t"
|
||||
"}"
|
||||
: "=r"(r.x), "=r"(r.y)
|
||||
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
||||
return r;
|
||||
}
|
||||
|
||||
// Constructs destination register by taking bytes from 2 sources (based on
|
||||
// mask)
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||
// values. We mostly follow the strategy in the link below, with some small
|
||||
// changes:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__device__ inline FragB dequant_4bit(int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd480d480;
|
||||
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||
// values. We mostly follow the strategy in the link below, with some small
|
||||
// changes:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
__device__ inline FragB dequant_8bit(int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
|
||||
frag_b[0] = __hmul2(frag_b[0], s);
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
|
||||
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
|
||||
FragS& s0, float* c4, float* c5, float* c6,
|
||||
float* c7, FragS& s1) {
|
||||
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
|
||||
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
|
||||
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
|
||||
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
|
||||
|
||||
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
|
||||
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
|
||||
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
|
||||
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
|
||||
}
|
||||
|
||||
} // namespace marlin_24
|
||||
File diff suppressed because it is too large
Load Diff
@@ -259,14 +259,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// custom types:
|
||||
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||
|
||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||
"Tensor b_scales, Tensor workspace, "
|
||||
"int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||
ops.def(
|
||||
"machete_supported_schedules("
|
||||
|
||||
@@ -58,17 +58,6 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
|
||||
)
|
||||
)
|
||||
|
||||
if is_quant_method_supported("gptq_marlin_24"):
|
||||
TEST_MODELS.append(
|
||||
(
|
||||
"alexm-nm/tinyllama-24-marlin24-4bit-g128",
|
||||
{
|
||||
"quantization": "gptq_marlin_24",
|
||||
"allow_deprecated_quantization": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||
TEST_MODELS.append(
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
|
||||
|
||||
@@ -10,15 +10,9 @@ import itertools
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from tests.kernels.utils import opcheck
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8,
|
||||
)
|
||||
@@ -36,15 +30,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace,
|
||||
awq_marlin_quantize,
|
||||
get_weight_perm,
|
||||
marlin_quantize,
|
||||
marlin_weights,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack,
|
||||
gptq_pack,
|
||||
@@ -57,9 +47,7 @@ from vllm.scalar_type import scalar_types
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"These tests require gptq_marlin_repack,"
|
||||
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
|
||||
"or marlin_gemm which are not supported on ROCm.",
|
||||
"These tests require marlin, which is not supported on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
@@ -71,9 +59,6 @@ USE_FP32_REDUCE_OPTS = [True]
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 256]
|
||||
|
||||
MARLIN_24_K_CHUNKS = [128]
|
||||
MARLIN_24_N_CHUNKS = [512]
|
||||
|
||||
MARLIN_REPACK_NK_FACTORS = [
|
||||
(4, 8),
|
||||
(7, 5),
|
||||
@@ -538,96 +523,6 @@ def test_marlin_gemm(
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
# TODO: find better way to test this?
|
||||
@torch.compile(fullgraph=True)
|
||||
def marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
):
|
||||
return ops.gptq_marlin_24_gemm(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
|
||||
workspace_24 = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(a_input, w_24_ref)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_24_gemm,
|
||||
(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
quant_type.id,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
output = marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_subset_input():
|
||||
quant_type = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compare the outputs of a GPTQ model to a Marlin_24 model.
|
||||
|
||||
Note: GPTQ and Marlin_24 do not have bitwise correctness.
|
||||
As a result, in this test, we just confirm that the top selected tokens of the
|
||||
Marlin/GPTQ models are in the top 3 selections of each other.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import check_logprobs_close
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelPair:
|
||||
model_marlin: str
|
||||
model_gptq: str
|
||||
|
||||
|
||||
model_pairs = [
|
||||
# 4-bit, group_size == 128
|
||||
ModelPair(
|
||||
model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128",
|
||||
model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128",
|
||||
),
|
||||
# # 4-bit, group_size == channelwise
|
||||
# ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-channelwise",
|
||||
# model_gptq="alexm-nm/tinyllama-24-gptq-4bit-channelwise"),
|
||||
# 8-bit, group_size == 128
|
||||
ModelPair(
|
||||
model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128",
|
||||
model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128",
|
||||
),
|
||||
# # 8-bit, group_size == channelwise
|
||||
# ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-channelwise",
|
||||
# model_gptq="alexm-nm/tinyllama-24-gptq-8bit-channelwise"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin_24")
|
||||
or current_platform.is_rocm()
|
||||
or not current_platform.is_cuda(),
|
||||
reason="Marlin24 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("model_pair", model_pairs)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [8])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model_pair: ModelPair,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with vllm_runner(
|
||||
model_pair.model_marlin,
|
||||
dtype=dtype,
|
||||
quantization="gptq_marlin_24",
|
||||
allow_deprecated_quantization=True,
|
||||
) as marlin_24_model:
|
||||
marlin_24_outputs = marlin_24_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
with vllm_runner(
|
||||
model_pair.model_gptq, dtype=dtype, quantization="gptq"
|
||||
) as gptq_model:
|
||||
gptq_outputs = gptq_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=gptq_outputs,
|
||||
outputs_1_lst=marlin_24_outputs,
|
||||
name_0="gptq",
|
||||
name_1="marlin_24",
|
||||
)
|
||||
@@ -17,7 +17,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsW4A4Fp4,
|
||||
CompressedTensorsW4A8Fp8,
|
||||
CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
@@ -307,28 +306,6 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
||||
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
|
||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
|
||||
assert qkv_proj.weight_packed.dtype is torch.int32
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=4)
|
||||
assert output
|
||||
|
||||
|
||||
def test_compressed_tensors_fp8(vllm_runner):
|
||||
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||
|
||||
+143
-150
@@ -499,6 +499,23 @@ def awq_dequantize(
|
||||
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "awq_dequantize"):
|
||||
|
||||
@register_fake("_C::awq_dequantize")
|
||||
def _awq_dequantize_fake(
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
split_k_iters: torch.SymInt,
|
||||
thx: int,
|
||||
thy: int,
|
||||
) -> torch.Tensor:
|
||||
in_c = qweight.size(0)
|
||||
qout_c = qweight.size(1)
|
||||
out_c = qout_c * 8
|
||||
return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device)
|
||||
|
||||
|
||||
def awq_gemm(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
@@ -513,6 +530,24 @@ def awq_gemm(
|
||||
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "awq_gemm"):
|
||||
|
||||
@register_fake("_C::awq_gemm")
|
||||
def _awq_gemm_fake(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
num_in_feats = input.size(0)
|
||||
return torch.empty(
|
||||
(split_k_iters, num_in_feats, qweight.size(1) * 8),
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
).sum(0)
|
||||
|
||||
|
||||
# gptq
|
||||
def gptq_gemm(
|
||||
a: torch.Tensor,
|
||||
@@ -558,152 +593,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None
|
||||
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
|
||||
|
||||
|
||||
# marlin_24
|
||||
def gptq_marlin_24_gemm(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_meta: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_24_gemm(
|
||||
a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
|
||||
@register_fake("_C::gptq_marlin_24_gemm")
|
||||
def _gptq_marlin_24_gemm_fake(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_meta: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
size_m: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
size_k: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@register_fake("_C::marlin_gemm")
|
||||
def _marlin_gemm_fake(
|
||||
a: torch.Tensor,
|
||||
c: torch.Tensor | None,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_bias: torch.Tensor | None,
|
||||
b_scales: torch.Tensor,
|
||||
a_scales: torch.Tensor | None,
|
||||
global_scale: torch.Tensor | None,
|
||||
b_zeros: torch.Tensor | None,
|
||||
g_idx: torch.Tensor | None,
|
||||
perm: torch.Tensor | None,
|
||||
workspace: torch.Tensor,
|
||||
b_q_type_id: int,
|
||||
size_m: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
size_k: torch.SymInt,
|
||||
is_k_full: bool = True,
|
||||
use_atomic_add: bool = False,
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False,
|
||||
) -> torch.Tensor:
|
||||
dtype = a.dtype
|
||||
if dtype not in [torch.half, torch.bfloat16]:
|
||||
dtype = b_scales.dtype
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=dtype)
|
||||
|
||||
@register_fake("_C::awq_dequantize")
|
||||
def _awq_dequantize_fake(
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
split_k_iters: torch.SymInt,
|
||||
thx: int,
|
||||
thy: int,
|
||||
) -> torch.Tensor:
|
||||
in_c = qweight.size(0)
|
||||
qout_c = qweight.size(1)
|
||||
out_c = qout_c * 8
|
||||
return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device)
|
||||
|
||||
@register_fake("_C::awq_gemm")
|
||||
def _awq_gemm_fake(
|
||||
input: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
split_k_iters: torch.SymInt,
|
||||
) -> torch.Tensor:
|
||||
num_in_feats = input.size(0)
|
||||
return torch.empty(
|
||||
(split_k_iters, num_in_feats, qweight.size(1) * 8),
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
).sum(0)
|
||||
|
||||
@register_fake("_C::machete_mm")
|
||||
def machete_mm_fake(
|
||||
a: torch.Tensor,
|
||||
# b_q Should be the tensor returned by machete_prepack_B
|
||||
b_q: torch.Tensor,
|
||||
b_type: ScalarType,
|
||||
out_type: torch.dtype | None = None,
|
||||
b_group_scales: torch.Tensor | None = None,
|
||||
b_group_zeros: torch.Tensor | None = None,
|
||||
b_group_size: int | None = None,
|
||||
b_channel_scales: torch.Tensor | None = None,
|
||||
a_token_scales: torch.Tensor | None = None,
|
||||
schedule: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
m = a.size(0)
|
||||
n = b_q.size(1)
|
||||
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@register_fake("_C::machete_prepack_B")
|
||||
def machete_prepack_B_fake(
|
||||
b_q_weight: torch.Tensor,
|
||||
a_type: torch.dtype,
|
||||
b_type: ScalarType,
|
||||
group_scales_type: torch.dtype | None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format)
|
||||
|
||||
@register_fake("_C::cutlass_w4a8_mm")
|
||||
def cutlass_w4a8_mm_fake(
|
||||
a: torch.Tensor,
|
||||
# b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
|
||||
b_q: torch.Tensor,
|
||||
b_group_scales: torch.Tensor,
|
||||
b_group_size: int,
|
||||
b_channel_scales: torch.Tensor,
|
||||
a_token_scales: torch.Tensor,
|
||||
out_type: torch.dtype | None = None,
|
||||
maybe_schedule: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
m = a.size(0)
|
||||
n = b_q.size(1)
|
||||
out_dtype = out_type if out_type is not None else torch.bfloat16
|
||||
return torch.empty((m, n), device=a.device, dtype=out_dtype)
|
||||
|
||||
@register_fake("_C::cutlass_pack_scale_fp8")
|
||||
def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(scales, memory_format=torch.contiguous_format)
|
||||
|
||||
@register_fake("_C::cutlass_encode_and_reorder_int4b")
|
||||
def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(b, memory_format=torch.contiguous_format)
|
||||
|
||||
@register_fake("_C::cutlass_encode_and_reorder_int4b_grouped")
|
||||
def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(b, memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
|
||||
|
||||
@register_fake("_C::allspark_w8a16_gemm")
|
||||
@@ -1356,6 +1245,36 @@ def marlin_gemm(
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "marlin_gemm"):
|
||||
|
||||
@register_fake("_C::marlin_gemm")
|
||||
def _marlin_gemm_fake(
|
||||
a: torch.Tensor,
|
||||
c: torch.Tensor | None,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_bias: torch.Tensor | None,
|
||||
b_scales: torch.Tensor,
|
||||
a_scales: torch.Tensor | None,
|
||||
global_scale: torch.Tensor | None,
|
||||
b_zeros: torch.Tensor | None,
|
||||
g_idx: torch.Tensor | None,
|
||||
perm: torch.Tensor | None,
|
||||
workspace: torch.Tensor,
|
||||
b_q_type_id: int,
|
||||
size_m: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
size_k: torch.SymInt,
|
||||
is_k_full: bool = True,
|
||||
use_atomic_add: bool = False,
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False,
|
||||
) -> torch.Tensor:
|
||||
dtype = a.dtype
|
||||
if dtype not in [torch.half, torch.bfloat16]:
|
||||
dtype = b_scales.dtype
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=dtype)
|
||||
|
||||
|
||||
# machete
|
||||
def machete_supported_schedules(
|
||||
a_type: torch.dtype,
|
||||
@@ -1404,6 +1323,27 @@ def machete_mm(
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "machete_mm"):
|
||||
|
||||
@register_fake("_C::machete_mm")
|
||||
def machete_mm_fake(
|
||||
a: torch.Tensor,
|
||||
# b_q Should be the tensor returned by machete_prepack_B
|
||||
b_q: torch.Tensor,
|
||||
b_type: ScalarType,
|
||||
out_type: torch.dtype | None = None,
|
||||
b_group_scales: torch.Tensor | None = None,
|
||||
b_group_zeros: torch.Tensor | None = None,
|
||||
b_group_size: int | None = None,
|
||||
b_channel_scales: torch.Tensor | None = None,
|
||||
a_token_scales: torch.Tensor | None = None,
|
||||
schedule: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
m = a.size(0)
|
||||
n = b_q.size(1)
|
||||
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
||||
|
||||
|
||||
def machete_prepack_B(
|
||||
b_q_weight: torch.Tensor,
|
||||
a_type: torch.dtype,
|
||||
@@ -1415,6 +1355,18 @@ def machete_prepack_B(
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "machete_prepack_B"):
|
||||
|
||||
@register_fake("_C::machete_prepack_B")
|
||||
def machete_prepack_B_fake(
|
||||
b_q_weight: torch.Tensor,
|
||||
a_type: torch.dtype,
|
||||
b_type: ScalarType,
|
||||
group_scales_type: torch.dtype | None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
# CUTLASS W4A8
|
||||
def cutlass_w4a8_mm(
|
||||
a: torch.Tensor,
|
||||
@@ -1439,14 +1391,48 @@ def cutlass_w4a8_mm(
|
||||
)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "cutlass_w4a8_mm"):
|
||||
|
||||
@register_fake("_C::cutlass_w4a8_mm")
|
||||
def cutlass_w4a8_mm_fake(
|
||||
a: torch.Tensor,
|
||||
# b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b
|
||||
b_q: torch.Tensor,
|
||||
b_group_scales: torch.Tensor,
|
||||
b_group_size: int,
|
||||
b_channel_scales: torch.Tensor,
|
||||
a_token_scales: torch.Tensor,
|
||||
out_type: torch.dtype | None = None,
|
||||
maybe_schedule: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
m = a.size(0)
|
||||
n = b_q.size(1)
|
||||
out_dtype = out_type if out_type is not None else torch.bfloat16
|
||||
return torch.empty((m, n), device=a.device, dtype=out_dtype)
|
||||
|
||||
|
||||
def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.cutlass_pack_scale_fp8(scales)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "cutlass_pack_scale_fp8"):
|
||||
|
||||
@register_fake("_C::cutlass_pack_scale_fp8")
|
||||
def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(scales, memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.cutlass_encode_and_reorder_int4b(b)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "cutlass_encode_and_reorder_int4b"):
|
||||
|
||||
@register_fake("_C::cutlass_encode_and_reorder_int4b")
|
||||
def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(b, memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
def cutlass_w4a8_moe_mm(
|
||||
out_tensors: torch.Tensor,
|
||||
a_tensors: torch.Tensor,
|
||||
@@ -1519,6 +1505,17 @@ def cutlass_encode_and_reorder_int4b_grouped(
|
||||
return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "cutlass_encode_and_reorder_int4b_grouped"):
|
||||
|
||||
@register_fake("_C::cutlass_encode_and_reorder_int4b_grouped")
|
||||
def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(b, memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.permute_cols(a, perm)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "permute_cols"):
|
||||
|
||||
@register_fake("_C::permute_cols")
|
||||
@@ -1526,10 +1523,6 @@ if hasattr(torch.ops._C, "permute_cols"):
|
||||
return torch.empty_like(a)
|
||||
|
||||
|
||||
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.permute_cols(a, perm)
|
||||
|
||||
|
||||
# fp4
|
||||
def scaled_fp4_quant(
|
||||
input: torch.Tensor,
|
||||
|
||||
@@ -890,7 +890,6 @@ class ModelConfig:
|
||||
# `override_quantization_method` method) must be checked in order
|
||||
# of preference (this is particularly important for GPTQ).
|
||||
overrides = [
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"ipex",
|
||||
|
||||
@@ -18,7 +18,6 @@ QuantizationMethods = Literal[
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"gguf",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"gptq",
|
||||
@@ -41,7 +40,6 @@ DEPRECATED_QUANTIZATION_METHODS = [
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"fp_quant",
|
||||
"gptq_marlin_24",
|
||||
"experts_int8",
|
||||
"ipex",
|
||||
"petit_nvfp4",
|
||||
@@ -122,7 +120,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .gguf import GGUFConfig
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .gptq_marlin_24 import GPTQMarlin24Config
|
||||
from .inc import INCConfig
|
||||
from .ipex_quant import IPEXConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
@@ -140,7 +137,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptNvFp4Config,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
|
||||
@@ -40,7 +40,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS,
|
||||
WNA16_SUPPORTED_BITS,
|
||||
CompressedTensors24,
|
||||
CompressedTensorsScheme,
|
||||
@@ -49,7 +48,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW4A8Int,
|
||||
CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Mxfp4,
|
||||
CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
@@ -610,29 +608,19 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
actorder=weight_quant.actorder,
|
||||
)
|
||||
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if (
|
||||
format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
|
||||
):
|
||||
assert weight_quant.symmetric
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
strategy=weight_quant.strategy,
|
||||
num_bits=weight_quant.num_bits,
|
||||
group_size=weight_quant.group_size,
|
||||
)
|
||||
if (
|
||||
format == CompressionFormat.pack_quantized.value
|
||||
and weight_quant.num_bits in WNA16_SUPPORTED_BITS
|
||||
):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder,
|
||||
layer_name=layer_name,
|
||||
)
|
||||
if (
|
||||
self._is_wNa16_group_channel(weight_quant, input_quant)
|
||||
and (format == CompressionFormat.pack_quantized.value)
|
||||
and (weight_quant.num_bits in WNA16_SUPPORTED_BITS)
|
||||
):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder,
|
||||
layer_name=layer_name,
|
||||
)
|
||||
|
||||
act_quant_format = is_activation_quantization_format(format)
|
||||
if act_quant_format:
|
||||
|
||||
@@ -5,10 +5,6 @@ from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
|
||||
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_24 import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24,
|
||||
)
|
||||
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
@@ -23,11 +19,9 @@ __all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS",
|
||||
"W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensors24",
|
||||
"CompressedTensorsW4A16Fp4",
|
||||
"CompressedTensorsW4A16Mxfp4",
|
||||
|
||||
-176
@@ -1,176 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Sparse24"]
|
||||
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
}
|
||||
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
def __init__(self, strategy: str, num_bits: int, group_size: int | None = None):
|
||||
self.strategy = strategy
|
||||
self.group_size = group_size
|
||||
self.tile_size = 16
|
||||
|
||||
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}"
|
||||
)
|
||||
|
||||
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
if self.strategy == "group" and self.group_size is None:
|
||||
raise ValueError("group_size must be given when using strategy group")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere + up
|
||||
return 80
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
layer.scale_packed = Parameter(layer.scale_packed.data, requires_grad=False)
|
||||
layer.meta = Parameter(layer.meta.data, requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
assert params_dtype == torch.float16, (
|
||||
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
)
|
||||
|
||||
pack_factor = 32 // self.quant_type.size_bits
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.tile_size // 2,
|
||||
output_size_per_partition * self.tile_size // pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=pack_factor,
|
||||
marlin_tile_size=self.tile_size,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
input_groups = (
|
||||
1
|
||||
if self.group_size is None
|
||||
else input_size_per_partition // self.group_size
|
||||
)
|
||||
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
|
||||
if self.group_size is not None:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
else:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
|
||||
weight_shape = BasevLLMParameter(
|
||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
meta = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_packed", qweight)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
layer.register_parameter("scale_packed", scales)
|
||||
layer.register_parameter("meta", meta)
|
||||
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // GPTQ_MARLIN_24_MIN_THREAD_N
|
||||
) * GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
workspace = Parameter(
|
||||
torch.zeros(max_workspace_size, dtype=torch.int), requires_grad=False
|
||||
)
|
||||
layer.workspace = workspace
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.weight_packed
|
||||
meta = layer.meta
|
||||
scales = layer.scale_packed
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(
|
||||
x_2d,
|
||||
qweight,
|
||||
meta,
|
||||
scales,
|
||||
workspace,
|
||||
self.quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
@@ -1,320 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_24_TILE = 16
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
|
||||
|
||||
class GPTQMarlin24Config(QuantizationConfig):
|
||||
"""Config class for Marlin24."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
quant_type = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
}.get(weight_bits)
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
# Verify
|
||||
if quant_type is None or quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support quant_type = {quant_type}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
|
||||
"are supported."
|
||||
)
|
||||
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported."
|
||||
)
|
||||
|
||||
self.quant_type = quant_type
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.quant_type.size_bits
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Marlin24Config(quant_type={}, group_size={})".format(
|
||||
self.quant_type, self.group_size
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24"
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "gptq" or user_quant == "gptq_marlin_24"
|
||||
)
|
||||
|
||||
if is_marlin_24_format and is_valid_user_quant:
|
||||
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
||||
cls.get_name(), cls.get_name()
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["GPTQMarlin24LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlin24LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin24.
|
||||
|
||||
Args:
|
||||
quant_config: The Marlin24 quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlin24Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}"
|
||||
)
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"min_n_threads = {self.quant_config.min_n_threads}."
|
||||
)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"pack_factor = {self.quant_config.pack_factor}."
|
||||
)
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"min_k_threads = {self.quant_config.min_k_threads}."
|
||||
)
|
||||
if (
|
||||
self.quant_config.group_size != -1
|
||||
and input_size_per_partition % self.quant_config.group_size != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"group_size = {self.quant_config.group_size}."
|
||||
)
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2
|
||||
)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError("Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size // 2,
|
||||
output_size_per_partition
|
||||
* self.quant_config.tile_size
|
||||
// self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
marlin_tile_size=self.quant_config.tile_size,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Meta
|
||||
meta = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
device="cuda",
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = (
|
||||
1
|
||||
if self.quant_config.group_size == -1
|
||||
else input_size_per_partition // self.quant_config.group_size
|
||||
)
|
||||
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // self.quant_config.min_n_threads
|
||||
) * self.quant_config.max_parallel
|
||||
|
||||
workspace = BasevLLMParameter(
|
||||
data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("B_24", qweight)
|
||||
layer.register_parameter("B_meta", meta)
|
||||
layer.register_parameter("s", scales)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
|
||||
layer.s = Parameter(layer.s.data, requires_grad=False)
|
||||
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
|
||||
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.B_24
|
||||
meta = layer.B_meta
|
||||
scales = layer.s
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(
|
||||
x_2d,
|
||||
qweight,
|
||||
meta,
|
||||
scales,
|
||||
workspace,
|
||||
self.quant_config.quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
@@ -1,467 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
from .marlin_utils_test import marlin_weights
|
||||
from .quant_utils import gptq_quantize_weights
|
||||
|
||||
|
||||
# This is PyTorch implementation of main part of reorder_meta()
|
||||
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
||||
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
||||
# GEMM decides upon layout of this matrix, and at the moment for the
|
||||
# sparse GEMM executed on tensor cores, this is layout described by
|
||||
# ColumnMajorInterleaved<2> data structure, in
|
||||
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
||||
# reordering of meta matrix into meta_reordered matrix calculated
|
||||
# according to these segments of CUTLASS code is re-implemented here.
|
||||
# Note that this calculation produces offsets for scattering metadata
|
||||
# matrix elements into reordered metadata matrix elements (or,
|
||||
# equivalently, for gathering reordered metadata matrix element back
|
||||
# into metadata matrix elements).
|
||||
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
||||
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
||||
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
||||
|
||||
# Reorder the rows, then swizzle the 2x2 blocks.
|
||||
group_x = 64
|
||||
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
||||
|
||||
dst_rows = (
|
||||
dst_rows // group_x * group_x
|
||||
+ (dst_rows % 2) * 2
|
||||
+ (dst_rows % 8) // 4
|
||||
+ ((dst_rows % group_y) % 4) // 2 * 32
|
||||
+ ((dst_rows % group_x) // 8) * 4
|
||||
)
|
||||
|
||||
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
||||
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
||||
dst_rows += topright - bottomleft
|
||||
dst_cols -= topright - bottomleft
|
||||
|
||||
# Assumed that meta tensor is to be stored in CUTLASS
|
||||
# InterleavedColumnMajor layout, and reverse engineered
|
||||
# corresponding code to store values into this tensor.
|
||||
interleave = 2
|
||||
cols_maj = dst_cols // interleave
|
||||
cols_min = dst_cols % interleave
|
||||
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
||||
|
||||
|
||||
# This function converts dense matrix into sparse semi-structured
|
||||
# representation, producing "compressed" matrix, in the layout used by
|
||||
# CUTLASS backend, and corresponding metadata matrix.
|
||||
def sparse_semi_structured_from_dense_cutlass(dense):
|
||||
if dense.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
||||
)
|
||||
|
||||
m, k = dense.shape
|
||||
device = dense.device
|
||||
|
||||
meta_dtype = torch.int8
|
||||
if dense.dtype == torch.int8:
|
||||
meta_dtype = torch.int32
|
||||
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
||||
meta_dtype = torch.int16
|
||||
else:
|
||||
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
||||
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
||||
if quadbits_per_meta_elem not in (4, 8):
|
||||
raise RuntimeError("Invalid number of elements per meta element calculated")
|
||||
|
||||
if meta_dtype == torch.int32:
|
||||
if m % 16 != 0:
|
||||
raise RuntimeError(
|
||||
f"Number of rows of dense matrix {m} must be divisible by 16"
|
||||
)
|
||||
else:
|
||||
if m % 32 != 0:
|
||||
raise RuntimeError(
|
||||
f"Number of rows of dense matrix {m} must be divisible by 32"
|
||||
)
|
||||
if k % (4 * quadbits_per_meta_elem) != 0:
|
||||
raise RuntimeError(
|
||||
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
||||
)
|
||||
|
||||
if dense.dtype != torch.float:
|
||||
ksparse = 4
|
||||
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
||||
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
||||
else:
|
||||
ksparse = 2
|
||||
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
||||
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
||||
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
||||
|
||||
# Encoding quadruples of True/False values as follows:
|
||||
# [True, True, False, False] -> 0b0100
|
||||
# [True, False, True, False] -> 0b1000
|
||||
# [False, True, True, False] -> 0b1001
|
||||
# [True, False, False, True ] -> 0b1100
|
||||
# [False, True, False, True ] -> 0b1101
|
||||
# [False, False, True, True ] -> 0b1110
|
||||
# Thus, lower two bits in the encoding are index of the True value
|
||||
# at the lowest index in the quadruple, and the higher two bits in
|
||||
# the encoding are index of the other True value in the quadruple.
|
||||
# In case there are less than two True values, than False value or
|
||||
# values at some index or indices are considered True for the
|
||||
# encoding. In case there are more than two True values, then the
|
||||
# excess True value(s) at some indices are considered False for
|
||||
# the encoding. The exact encodings used for these cases are as
|
||||
# follows:
|
||||
# [False, False, False, False] -> 0b1110
|
||||
# [False, False, False, True ] -> 0b1110
|
||||
# [False, False, True, False] -> 0b1110
|
||||
# [False, True, False, False] -> 0b1001
|
||||
# [False, True, True, True ] -> 0b1101
|
||||
# [True, False, False, False] -> 0b1000
|
||||
# [True, False, True, True ] -> 0b1100
|
||||
# [True, True, False, True ] -> 0b0100
|
||||
# [True, True, True, False] -> 0b0100
|
||||
# [True, True, True, True ] -> 0b0100
|
||||
# These particular encodings are chosen, with the help of Espresso
|
||||
# logic minimizer software, for the purpose of minimization of
|
||||
# corresponding Boolean functions, that translate non-zero flags
|
||||
# into encoding bits. Note also possible choices for the first
|
||||
# and last of these encodings were limited only to (0b0100,
|
||||
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
||||
# case.
|
||||
|
||||
expr0 = m0 & m1
|
||||
expr1 = ~m0 & m1
|
||||
expr2 = ~m0 & ~m1
|
||||
bit0 = expr1
|
||||
bit1 = expr2
|
||||
bit2 = expr0 | expr2 | m3
|
||||
bit3 = expr1 | ~m1
|
||||
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
||||
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
||||
|
||||
if dense.dtype != torch.float:
|
||||
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
|
||||
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
||||
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
||||
else:
|
||||
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
|
||||
|
||||
meta_4 = idxs0 | (idxs1 << 2)
|
||||
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
||||
|
||||
if quadbits_per_meta_elem == 4:
|
||||
meta = (
|
||||
meta_n[:, :, 0]
|
||||
| (meta_n[:, :, 1] << 4)
|
||||
| (meta_n[:, :, 2] << 8)
|
||||
| (meta_n[:, :, 3] << 12)
|
||||
)
|
||||
elif quadbits_per_meta_elem == 8:
|
||||
meta = (
|
||||
meta_n[:, :, 0]
|
||||
| (meta_n[:, :, 1] << 4)
|
||||
| (meta_n[:, :, 2] << 8)
|
||||
| (meta_n[:, :, 3] << 12)
|
||||
| (meta_n[:, :, 4] << 16)
|
||||
| (meta_n[:, :, 5] << 20)
|
||||
| (meta_n[:, :, 6] << 24)
|
||||
| (meta_n[:, :, 7] << 28)
|
||||
)
|
||||
|
||||
# Reorder meta tensor elements.
|
||||
meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined]
|
||||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||
m, meta_ncols, meta_dtype, device
|
||||
)
|
||||
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
||||
|
||||
return (sparse, meta_reordered.view(m, meta_ncols))
|
||||
|
||||
|
||||
# This function performs reverse of the function above - it
|
||||
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
||||
# in the layout used by CUTLASS backend, and accompanying metadata
|
||||
# matrix.
|
||||
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
||||
if sparse.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
||||
)
|
||||
|
||||
m, k = sparse.shape
|
||||
device = sparse.device
|
||||
|
||||
if meta_reordered.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
||||
)
|
||||
if meta_reordered.device != device:
|
||||
raise RuntimeError(
|
||||
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
||||
)
|
||||
|
||||
meta_dtype = meta_reordered.dtype
|
||||
if meta_dtype not in (torch.int16, torch.int32):
|
||||
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
||||
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
||||
|
||||
ksparse = 4 if sparse.dtype != torch.float else 2
|
||||
|
||||
meta_nrows, meta_ncols = meta_reordered.shape
|
||||
if meta_nrows != m:
|
||||
raise RuntimeError(
|
||||
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
||||
)
|
||||
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
||||
raise RuntimeError(
|
||||
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
||||
"expected according to the number of columns of meta matrix"
|
||||
)
|
||||
|
||||
# Undo meta tensor elements reordering.
|
||||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||
m, meta_ncols, meta_dtype, device
|
||||
)
|
||||
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
||||
|
||||
# Unpack sparse tensor back to original dense tensor, using
|
||||
# information provided by meta tensor. Note that torch.float
|
||||
# datatype is handled pretty much the same as
|
||||
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
||||
# value is encoded as if underlying 8 bytes contain four
|
||||
# torch.half/torch.bfloat16 values, where either first two or last
|
||||
# two are zeros.
|
||||
meta_2 = torch.empty(
|
||||
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
||||
dtype=meta_dtype,
|
||||
device=device,
|
||||
)
|
||||
if quadbits_per_meta_elem == 4:
|
||||
meta_2[:, :, 0] = meta & 0b11
|
||||
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
||||
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
||||
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
||||
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
||||
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
||||
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
||||
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
||||
elif quadbits_per_meta_elem == 8:
|
||||
meta_2[:, :, 0] = meta & 0b11
|
||||
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
||||
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
||||
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
||||
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
||||
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
||||
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
||||
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
||||
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
||||
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
||||
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
||||
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
||||
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
||||
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
||||
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
||||
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
||||
|
||||
dense_offsets = meta_2.view(-1) + (
|
||||
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
||||
).view(-1, 1).repeat(1, 2).view(-1)
|
||||
|
||||
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
||||
if sparse.dtype != torch.float:
|
||||
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
||||
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
||||
else:
|
||||
dense.view(torch.half).scatter_(
|
||||
0, dense_offsets, sparse.view(torch.half).view(-1)
|
||||
)
|
||||
|
||||
return dense.view(m, 2 * k)
|
||||
|
||||
|
||||
def mask_creator(tensor):
|
||||
"""
|
||||
Class for creating N:M sparsity masks.
|
||||
Masks will be created using the N:M ratio, where for every block of
|
||||
M weights, N will be pruned based on ranked weight value. Each mask
|
||||
will correspond to the given tensor.
|
||||
|
||||
:param N: The number of weights in a group to keep
|
||||
:param M: The size of a weight group
|
||||
"""
|
||||
N = 2
|
||||
M = 4
|
||||
|
||||
mask = None
|
||||
# for i, tensor in enumerate(tensors):
|
||||
if tensor.numel() % M != 0:
|
||||
raise ValueError(
|
||||
f"Tensor of size {tensor.shape} can't be evenly divided into {M} groups"
|
||||
)
|
||||
|
||||
num_groups = tensor.numel() // M
|
||||
|
||||
# N:M sparsity for linear layers
|
||||
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
||||
index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
|
||||
|
||||
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
||||
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def inject_24(w, size_k, size_n):
|
||||
assert w.shape == (size_k, size_n)
|
||||
|
||||
mask = mask_creator(w.t()).t().cuda().bool()
|
||||
|
||||
return (mask * w).contiguous(), mask.contiguous()
|
||||
|
||||
|
||||
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
||||
BLOCK_SIZE = 4
|
||||
MAX_NON_ZEROS = 2
|
||||
|
||||
w = w.t().contiguous()
|
||||
|
||||
print("check_24: w.shape = {}".format(w.shape))
|
||||
|
||||
num_rows, num_cols = w.shape
|
||||
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
||||
if _verbose:
|
||||
print(f"Sampled row idxs = {sampled_row_idxs}")
|
||||
|
||||
total_segments = 0
|
||||
non_24_segments = 0
|
||||
for i in sampled_row_idxs:
|
||||
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
||||
total_segments += 1
|
||||
block = w[i, j : j + BLOCK_SIZE]
|
||||
num_nonzero = torch.count_nonzero(block)
|
||||
if num_nonzero > MAX_NON_ZEROS:
|
||||
print("i = {} j = {} block = {}".format(i, j, block))
|
||||
non_24_segments += 1
|
||||
|
||||
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
||||
|
||||
|
||||
def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
|
||||
assert q_24.shape == (size_k, size_n)
|
||||
|
||||
# Remove bias to normalize over 0
|
||||
q_24_no_zp = q_24 - wtype.bias
|
||||
|
||||
# Compress
|
||||
q_24_no_zp = q_24_no_zp.t().contiguous()
|
||||
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
|
||||
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
||||
|
||||
# Restore bias
|
||||
q_24_comp = q_24_no_zp_comp + wtype.bias
|
||||
|
||||
# Resize meta to its actual shape (without moving any data)
|
||||
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
||||
|
||||
return q_24_comp, meta
|
||||
|
||||
|
||||
def get_scale_perms_24():
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def get_weight_perm_24(num_bits: int):
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
col_o = col // 2
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 1 * j for p in perm1])
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
return perm
|
||||
|
||||
|
||||
def marlin_permute_scales_24(
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
||||
) -> torch.Tensor:
|
||||
scale_perm, scale_perm_single = get_scale_perms_24()
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def marlin_24_quantize(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Inject 2:4 sparsity
|
||||
w_24, mask_24 = inject_24(w, size_k, size_n)
|
||||
|
||||
# Quantize
|
||||
w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
w_24, quant_type, group_size, act_order=False
|
||||
)
|
||||
|
||||
# Compress quantized weight
|
||||
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
|
||||
size_k_comp = size_k // 2
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm_24(quant_type.size_bits)
|
||||
marlin_24_q_w_comp = marlin_weights(
|
||||
q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
|
||||
)
|
||||
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
|
||||
|
||||
# Create result
|
||||
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
Reference in New Issue
Block a user