Refactor CT NVFP4 linear to use a single class (#42443)

This commit is contained in:
Dipika Sikka
2026-06-04 08:25:08 -04:00
committed by GitHub
parent 4b87b3e845
commit e68988a248
6 changed files with 55 additions and 162 deletions
+5 -17
View File
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A4Fp4,
CompressedTensorsW4A4Mxfp4,
CompressedTensorsW4A8Fp8,
CompressedTensorsW4A16Fp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Mxfp8,
@@ -37,9 +36,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
@@ -376,13 +372,12 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
@pytest.mark.parametrize(
"args",
[
# TODO: Enable once model is available again
# ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", True),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", False),
],
)
def test_compressed_tensors_nvfp4(vllm_runner, args):
model, scheme = args
model, use_a16 = args
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
@@ -390,15 +385,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
if (
isinstance(qkv_proj.scheme, scheme)
or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
and not cutlass_fp4_supported()
):
assert True
else:
raise AssertionError("FP4 Scheme Mismatch")
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A4Fp4)
assert qkv_proj.scheme.use_a16 == use_a16
assert qkv_proj.scheme.group_size == 16
llm.apply_model(check_model)