mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Refactor] Rename gptq_marlin to marlin to match MoE (#32952)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
+9
-9
@@ -377,7 +377,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
set(MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/marlin/generate_kernels.py)
|
||||
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
|
||||
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
|
||||
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
|
||||
@@ -412,7 +412,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
if (MARLIN_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
|
||||
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_float16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
@@ -422,7 +422,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
|
||||
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_bfloat16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
|
||||
@@ -434,7 +434,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
if (MARLIN_SM75_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
|
||||
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/marlin/sm75_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
|
||||
@@ -446,7 +446,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
if (MARLIN_FP8_ARCHS)
|
||||
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
|
||||
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/marlin/sm89_kernel_*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
|
||||
@@ -459,10 +459,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
"csrc/quantization/marlin/marlin.cu"
|
||||
"csrc/quantization/marlin/marlin_int4_fp8_preprocess.cu"
|
||||
"csrc/quantization/marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
|
||||
|
||||
@@ -231,7 +231,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
assert bt.w_tok_s is None
|
||||
assert bt.group_size is not None
|
||||
|
||||
fn = lambda: ops.gptq_marlin_gemm(
|
||||
fn = lambda: ops.marlin_gemm(
|
||||
a=bt.a,
|
||||
c=None,
|
||||
b_q_weight=w_q,
|
||||
|
||||
@@ -239,7 +239,7 @@ def bench_run(
|
||||
"sm_version": sm_version,
|
||||
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
|
||||
# Kernels
|
||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||
"marlin_gemm": ops.marlin_gemm,
|
||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
|
||||
@@ -263,21 +263,21 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm",
|
||||
description="marlin_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp32",
|
||||
description="marlin_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "quantization/marlin/marlin.cuh"
|
||||
#include "quantization/marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
|
||||
@@ -23,10 +23,10 @@
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "quantization/gptq_marlin/dequant.h"
|
||||
#include "quantization/gptq_marlin/marlin_mma.h"
|
||||
#include "quantization/marlin/marlin.cuh"
|
||||
#include "quantization/marlin/marlin_dtypes.cuh"
|
||||
#include "quantization/marlin/dequant.h"
|
||||
#include "quantization/marlin/marlin_mma.h"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <iostream>
|
||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "../marlin/marlin_dtypes.cuh"
|
||||
using marlin::MarlinScalarType2;
|
||||
|
||||
namespace allspark {
|
||||
|
||||
@@ -46,7 +46,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
@@ -528,7 +528,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
@@ -856,5 +856,5 @@ torch::Tensor gptq_marlin_gemm(
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
|
||||
m.impl("marlin_gemm", &marlin_gemm);
|
||||
}
|
||||
@@ -303,9 +303,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
|
||||
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
// Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4).
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
"marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
"Tensor? b_bias_or_none,Tensor b_scales, "
|
||||
"Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
|
||||
"Tensor? "
|
||||
|
||||
@@ -59,7 +59,7 @@ if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"These tests require gptq_marlin_repack,"
|
||||
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
|
||||
"or gptq_marlin_gemm which are not supported on ROCm.",
|
||||
"or marlin_gemm which are not supported on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
@@ -417,7 +417,7 @@ def marlin_generate_valid_test_cases():
|
||||
),
|
||||
marlin_generate_valid_test_cases(),
|
||||
)
|
||||
def test_gptq_marlin_gemm(
|
||||
def test_marlin_gemm(
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
@@ -511,7 +511,7 @@ def test_gptq_marlin_gemm(
|
||||
|
||||
output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
a_input,
|
||||
output,
|
||||
marlin_q_w,
|
||||
@@ -646,7 +646,7 @@ def test_marlin_gemm_subset_input():
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
@@ -695,7 +695,7 @@ def test_marlin_gemm_with_bias(size_m):
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
|
||||
+4
-4
@@ -591,8 +591,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
||||
|
||||
@register_fake("_C::gptq_marlin_gemm")
|
||||
def _gptq_marlin_gemm_fake(
|
||||
@register_fake("_C::marlin_gemm")
|
||||
def _marlin_gemm_fake(
|
||||
a: torch.Tensor,
|
||||
c: torch.Tensor | None,
|
||||
b_q_weight: torch.Tensor,
|
||||
@@ -1312,7 +1312,7 @@ def marlin_int4_fp8_preprocess(
|
||||
return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace)
|
||||
|
||||
|
||||
def gptq_marlin_gemm(
|
||||
def marlin_gemm(
|
||||
a: torch.Tensor,
|
||||
c: torch.Tensor | None,
|
||||
b_q_weight: torch.Tensor,
|
||||
@@ -1333,7 +1333,7 @@ def gptq_marlin_gemm(
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(
|
||||
return torch.ops._C.marlin_gemm(
|
||||
a,
|
||||
c,
|
||||
b_q_weight,
|
||||
|
||||
@@ -563,7 +563,7 @@ def apply_gptq_marlin_linear(
|
||||
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
@@ -628,7 +628,7 @@ def apply_awq_marlin_linear(
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
|
||||
@@ -121,7 +121,7 @@ def apply_fp4_marlin_linear(
|
||||
|
||||
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
a=inputs,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
|
||||
@@ -66,7 +66,7 @@ def apply_fp8_marlin_linear(
|
||||
# inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
|
||||
raise RuntimeError("Marlin W8A8 is not supported.")
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
output = ops.marlin_gemm(
|
||||
a=inputs,
|
||||
c=None,
|
||||
b_q_weight=weight,
|
||||
|
||||
Reference in New Issue
Block a user