[Refactor] Rename gptq_marlin to marlin to match MoE (#32952)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-01-23 16:48:12 -05:00
committed by GitHub
parent 6cc6d92be5
commit 4561f13985
24 changed files with 40 additions and 40 deletions
+9 -9
View File
@@ -377,7 +377,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# preselected input type pairs and schedules.
# Generate sources:
set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
@@ -412,7 +412,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if (MARLIN_ARCHS)
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu")
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")
@@ -422,7 +422,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu")
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_bfloat16.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
@@ -434,7 +434,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if (MARLIN_SM75_ARCHS)
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu")
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/marlin/sm75_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
@@ -446,7 +446,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
if (MARLIN_FP8_ARCHS)
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu")
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/marlin/sm89_kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
@@ -459,10 +459,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MARLIN_SRCS
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
"csrc/quantization/marlin/marlin.cu"
"csrc/quantization/marlin/marlin_int4_fp8_preprocess.cu"
"csrc/quantization/marlin/gptq_marlin_repack.cu"
"csrc/quantization/marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_SRCS}"
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
+1 -1
View File
@@ -231,7 +231,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
assert bt.w_tok_s is None
assert bt.group_size is not None
fn = lambda: ops.gptq_marlin_gemm(
fn = lambda: ops.marlin_gemm(
a=bt.a,
c=None,
b_q_weight=w_q,
+5 -5
View File
@@ -239,7 +239,7 @@ def bench_run(
"sm_version": sm_version,
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"marlin_gemm": ops.marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
@@ -263,21 +263,21 @@ def bench_run(
results.append(
benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm",
description="marlin_gemm",
).blocked_autorange(min_run_time=min_run_time)
)
results.append(
benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm_fp32",
description="marlin_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time)
)
+2 -2
View File
@@ -3,8 +3,8 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/marlin/marlin.cuh"
#include "quantization/marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
+4 -4
View File
@@ -23,10 +23,10 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "quantization/gptq_marlin/marlin_mma.h"
#include "quantization/marlin/marlin.cuh"
#include "quantization/marlin/marlin_dtypes.cuh"
#include "quantization/marlin/dequant.h"
#include "quantization/marlin/marlin_mma.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@@ -7,7 +7,7 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh"
#include "../marlin/marlin_dtypes.cuh"
using marlin::MarlinScalarType2;
namespace allspark {
@@ -46,7 +46,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
} // namespace marlin
torch::Tensor gptq_marlin_gemm(
torch::Tensor marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
@@ -528,7 +528,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
} // namespace marlin
torch::Tensor gptq_marlin_gemm(
torch::Tensor marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
@@ -856,5 +856,5 @@ torch::Tensor gptq_marlin_gemm(
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
m.impl("marlin_gemm", &marlin_gemm);
}
+2 -2
View File
@@ -303,9 +303,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
// Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4).
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none,Tensor b_scales, "
"Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
"Tensor? "
@@ -59,7 +59,7 @@ if current_platform.is_rocm():
pytest.skip(
"These tests require gptq_marlin_repack,"
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
"or gptq_marlin_gemm which are not supported on ROCm.",
"or marlin_gemm which are not supported on ROCm.",
allow_module_level=True,
)
@@ -417,7 +417,7 @@ def marlin_generate_valid_test_cases():
),
marlin_generate_valid_test_cases(),
)
def test_gptq_marlin_gemm(
def test_marlin_gemm(
a_type,
b_type,
c_type,
@@ -511,7 +511,7 @@ def test_gptq_marlin_gemm(
output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
output = ops.gptq_marlin_gemm(
output = ops.marlin_gemm(
a_input,
output,
marlin_q_w,
@@ -646,7 +646,7 @@ def test_marlin_gemm_subset_input():
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
output = ops.marlin_gemm(
a_input,
None,
marlin_q_w,
@@ -695,7 +695,7 @@ def test_marlin_gemm_with_bias(size_m):
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm(
output = ops.marlin_gemm(
a_input,
None,
marlin_q_w,
+4 -4
View File
@@ -591,8 +591,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(
@register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(
a: torch.Tensor,
c: torch.Tensor | None,
b_q_weight: torch.Tensor,
@@ -1312,7 +1312,7 @@ def marlin_int4_fp8_preprocess(
return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace)
def gptq_marlin_gemm(
def marlin_gemm(
a: torch.Tensor,
c: torch.Tensor | None,
b_q_weight: torch.Tensor,
@@ -1333,7 +1333,7 @@ def gptq_marlin_gemm(
use_fp32_reduce: bool = False,
is_zp_float: bool = False,
) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(
return torch.ops._C.marlin_gemm(
a,
c,
b_q_weight,
@@ -563,7 +563,7 @@ def apply_gptq_marlin_linear(
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
output = ops.marlin_gemm(
reshaped_x,
None,
weight,
@@ -628,7 +628,7 @@ def apply_awq_marlin_linear(
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
output = ops.marlin_gemm(
reshaped_x,
None,
weight,
@@ -121,7 +121,7 @@ def apply_fp4_marlin_linear(
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm(
output = ops.marlin_gemm(
a=inputs,
c=None,
b_q_weight=weight,
@@ -66,7 +66,7 @@ def apply_fp8_marlin_linear(
# inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
raise RuntimeError("Marlin W8A8 is not supported.")
output = ops.gptq_marlin_gemm(
output = ops.marlin_gemm(
a=inputs,
c=None,
b_q_weight=weight,