[ROCm][CI] Add missing quantization methods and fix online quant test failures (#39801)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-04-27 15:08:57 -05:00
committed by GitHub
parent c8bbe05189
commit 5e2c37facd
9 changed files with 50 additions and 23 deletions
+2 -2
View File
@@ -54,13 +54,13 @@ MODEL_ARG_EXPTYPES = [
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ",
None,
"awq_marlin" if current_platform.is_cuda() else "awq",
"awq_marlin" if current_platform.is_cuda_alike() else "awq",
),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ",
"marlin",
"awq_marlin" if current_platform.is_cuda() else "ERROR",
"awq_marlin" if current_platform.is_cuda_alike() else "ERROR",
),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
]
+1 -1
View File
@@ -43,7 +43,7 @@ TEST_CONFIGS = {
"amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8": {"arc_challenge": 0.52, "mmlu": 0.72},
# Non-mixed-precision (PTQ) model
# - Reference for pipeline compatibility verification -> No conflicts or breakings
"amd/Llama-2-70b-chat-hf-FP8-MLPerf-fp8_attn_quark_format": {
"amd/Llama-2-70b-chat-hf_FP8_MLPerf_V2": {
"arc_challenge": 0.53,
"mmlu": 0.61,
},
+4 -1
View File
@@ -352,7 +352,10 @@ def generate_rotation_matrix(d: int, seed: int, device: str = "cpu") -> torch.Te
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
G = torch.randn(d, d, generator=gen, device="cpu", dtype=torch.float32)
Q, R = torch.linalg.qr(G)
# torch.linalg.qr on CPU requires LAPACK, which some torch wheels
# (ROCm) ship without. Run QR on accelerator instead
qr_device = "cuda" if torch.cuda.is_available() else "cpu"
Q, R = torch.linalg.qr(G.to(qr_device))
diag_sign = torch.sign(torch.diag(R))
diag_sign[diag_sign == 0] = 1.0
Q = Q * diag_sign.unsqueeze(0)
+4 -5
View File
@@ -80,14 +80,13 @@ def _test_online_quant_peak_mem_impl(
print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")
# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
expected_model_memory_gib = 6.7
# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
# expected as when we load and quantize weights in a streaming fashion we
# need to have individual weights in bf16 + fp8 alive at the same time.
# GiB on CUDA, which is 1.36x above model_memory_gib. A slightly higher
# number is expected as when we load and quantize weights in a streaming
# fashion we need to have individual weights in bf16 + fp8 alive at the
# same time.
expected_peak_memory_gib = expected_model_memory_gib * 1.4
assert model_memory_gib < expected_model_memory_gib, (
@@ -144,13 +144,21 @@ class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# torch._scaled_mm rowwise requires scale_a = (m, 1), scale_b = (1, n).
# CompressedTensors stores weight_scale as (n, 1), so `.t()` yields (1, n).
# ModelOpt FP8_PER_CHANNEL_PER_TOKEN stores it as 1-D (n,); reshape to
# (1, n) so both paths satisfy the rowwise contract.
scale_b = Bs.view(1, -1) if Bs.dim() == 1 else Bs.t()
if As.dim() == 1:
As = As.view(-1, 1)
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs.t(),
scale_b=scale_b,
bias=bias,
)
@@ -1025,9 +1025,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
get_current_vllm_config().model_config.hf_config, "model_type", None
)
# TODO(aiter): extend once rocm_aiter_fused_experts gains dispatch
# for the other OCP MX schemes. Today its CK MoE kernel only has an
# entry for `w_mxfp4` (w4a16); mixed schemes like `w_mxfp4_a_mxfp6_*`
# fall through to QuantMethod.NO and raise "Unsupported kernel config
# for moe heuristic dispatch".
_AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4",)
self.emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
or self.ocp_mx_scheme not in _AITER_NATIVE_OCP_MX_SCHEMES
) and (
self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe
)
@@ -376,11 +376,15 @@ class QuarkOCP_MX(QuarkScheme):
dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype)
qdq_x = self.quant_dequant_func(x)
return F.linear(qdq_x, dq_w, bias)
else:
return torch.ops.vllm.gemm_with_dynamic_quant(
x,
layer.weight,
layer.weight_scale,
self.rocm_use_aiter_fp4_asm_gemm,
self.out_dtype,
)
y = torch.ops.vllm.gemm_with_dynamic_quant(
x,
layer.weight,
layer.weight_scale,
self.rocm_use_aiter_fp4_asm_gemm,
self.out_dtype,
)
# gemm_with_dynamic_quant has no bias argument; add it here so the
# native path matches F.linear (e.g. qkv_proj with qkv_bias=True).
if bias is not None:
y = y + bias
return y
@@ -65,7 +65,7 @@ class BaseModelLoader(ABC):
# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
if current_platform.is_cuda():
if current_platform.is_cuda_alike():
peak_memory = torch.accelerator.max_memory_allocated()
logger.debug_once(
"Peak GPU memory after loading weights: %s GiB",
+10 -3
View File
@@ -414,10 +414,17 @@ class RocmPlatform(Platform):
"gguf",
"quark",
"mxfp4",
"gpt_oss_mxfp4",
"mxfp8",
"torchao",
"bitsandbytes",
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
"modelopt_mixed",
"fp8_per_tensor",
"fp8_per_block",
"online",
"gpt_oss_mxfp4",
]
@classmethod
@@ -785,9 +792,9 @@ class RocmPlatform(Platform):
def get_current_memory_usage(
cls, device: torch.types.Device | None = None
) -> float:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
free_mem, total_mem = torch.cuda.mem_get_info(device)
return total_mem - free_mem
return torch.cuda.max_memory_allocated(device)
@classmethod
def get_device_communicator_cls(cls) -> str: