Add LM head quantization support for ModelOpt (#42124)

Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
This commit is contained in:
Wei-Ming Chen
2026-05-26 09:21:05 -07:00
committed by GitHub
parent c8414a8271
commit 6f5b533241
8 changed files with 220 additions and 5 deletions
@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import Mock, patch
def test_nemotron_h_lm_head_receives_quant_config():
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
mock_quant_config = Mock()
mock_hf_config = Mock()
mock_hf_config.vocab_size = 128
mock_hf_config.hidden_size = 64
mock_vllm_config = Mock()
mock_vllm_config.model_config.hf_config = mock_hf_config
mock_vllm_config.model_config.dtype = None
mock_vllm_config.scheduler_config = Mock()
mock_vllm_config.quant_config = mock_quant_config
with (
patch("vllm.model_executor.models.nemotron_h.NemotronHModel") as MockModel,
patch("vllm.model_executor.models.nemotron_h.ParallelLMHead") as MockLMHead,
patch("vllm.model_executor.models.nemotron_h.LogitsProcessor"),
):
MockModel.return_value.make_empty_intermediate_tensors = Mock()
MockModel.return_value.has_moe = False
NemotronHForCausalLM(vllm_config=mock_vllm_config)
MockLMHead.assert_called_once()
call_kwargs = MockLMHead.call_args.kwargs
assert call_kwargs["quant_config"] is mock_quant_config
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import Mock, patch
def test_qwen3_5_lm_head_receives_quant_config():
from vllm.model_executor.models.qwen3_5 import Qwen3_5ForCausalLMBase
mock_quant_config = Mock()
mock_hf_config = Mock()
mock_hf_config.tie_word_embeddings = False
mock_hf_config.vocab_size = 128
mock_hf_config.hidden_size = 64
mock_vllm_config = Mock()
mock_vllm_config.model_config.hf_text_config = mock_hf_config
mock_vllm_config.cache_config.mamba_cache_mode = "align"
mock_vllm_config.scheduler_config = Mock()
mock_vllm_config.quant_config = mock_quant_config
mock_vllm_config.lora_config = None
mock_pp_group = Mock()
mock_pp_group.is_last_rank = True
with (
patch("vllm.model_executor.models.qwen3_5.Qwen3_5Model") as MockModel,
patch("vllm.model_executor.models.qwen3_5.ParallelLMHead") as MockLMHead,
patch("vllm.model_executor.models.qwen3_5.LogitsProcessor"),
patch(
"vllm.model_executor.models.qwen3_5.get_pp_group",
return_value=mock_pp_group,
),
):
MockModel.return_value.make_empty_intermediate_tensors = Mock()
Qwen3_5ForCausalLMBase(vllm_config=mock_vllm_config)
MockLMHead.assert_called_once()
call_kwargs = MockLMHead.call_args.kwargs
assert call_kwargs["quant_config"] is mock_quant_config
def test_qwen3_5_mtp_lm_head_receives_quant_config():
from vllm.config import CompilationMode
from vllm.model_executor.models.qwen3_5_mtp import Qwen3_5MTP
mock_quant_config = Mock()
mock_hf_config = Mock()
mock_hf_config.tie_word_embeddings = False
mock_hf_config.vocab_size = 128
mock_hf_config.hidden_size = 64
mock_vllm_config = Mock()
mock_vllm_config.model_config.hf_text_config = mock_hf_config
mock_vllm_config.cache_config.mamba_cache_mode = "align"
mock_vllm_config.compilation_config.mode = CompilationMode.NONE
mock_vllm_config.quant_config = mock_quant_config
mock_pp_group = Mock()
mock_pp_group.is_last_rank = True
with (
patch("vllm.model_executor.models.qwen3_5_mtp.Qwen3_5MultiTokenPredictor"),
patch("vllm.model_executor.models.qwen3_5_mtp.ParallelLMHead") as MockLMHead,
patch("vllm.model_executor.models.qwen3_5_mtp.LogitsProcessor"),
patch(
"vllm.model_executor.models.qwen3_5_mtp.get_pp_group",
return_value=mock_pp_group,
),
):
Qwen3_5MTP(vllm_config=mock_vllm_config)
MockLMHead.assert_called_once()
call_kwargs = MockLMHead.call_args.kwargs
assert call_kwargs["quant_config"] is mock_quant_config
+93 -1
View File
@@ -7,13 +7,24 @@ Run `pytest tests/quantization/test_modelopt.py`.
import os
from typing import Any, NoReturn
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm.config.model import ModelConfig
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptFp8Config,
ModelOptMixedPrecisionConfig,
ModelOptNvFp4Config,
ModelOptNvFp4LinearMethod,
)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
@pytest.fixture(scope="function", autouse=True)
@@ -44,6 +55,87 @@ def _snapshot_download_or_skip(model_id: str) -> str:
_skip(f"Failed to download {model_id} from the HF Hub: {e}")
def _mock_lm_head() -> Mock:
lm_head = Mock(spec=ParallelLMHead)
lm_head.__class__ = ParallelLMHead
return lm_head
def _mixed_precision_config(quantized_layers: dict) -> ModelOptMixedPrecisionConfig:
return ModelOptMixedPrecisionConfig(
kv_cache_quant_method=None,
exclude_modules=[],
quantized_layers=quantized_layers,
fp8_config=ModelOptFp8Config(
quant_method="FP8",
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=None,
exclude_modules=[],
),
nvfp4_config=ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
),
w4a16_nvfp4_config=ModelOptNvFp4Config(
quant_method="W4A16_NVFP4",
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
),
)
def test_modelopt_nvfp4_quantizes_parallel_lm_head():
config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
)
with patch(
"vllm.model_executor.layers.quantization.modelopt.init_nvfp4_linear_kernel"
):
method = config.get_quant_method(_mock_lm_head(), prefix="lm_head")
assert isinstance(method, ModelOptNvFp4LinearMethod)
def test_modelopt_nvfp4_leaves_excluded_parallel_lm_head_unquantized():
config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=["lm_head"],
)
method = config.get_quant_method(_mock_lm_head(), prefix="lm_head")
assert isinstance(method, UnquantizedLinearMethod)
def test_modelopt_mixed_precision_quantizes_parallel_lm_head():
config = _mixed_precision_config(
{"lm_head": {"quant_algo": "NVFP4", "group_size": 16}}
)
with patch(
"vllm.model_executor.layers.quantization.modelopt.init_nvfp4_linear_kernel"
):
method = config.get_quant_method(_mock_lm_head(), prefix="lm_head")
assert isinstance(method, ModelOptNvFp4LinearMethod)
def test_vocab_parallel_embedding_weight_loader_accepts_scalar_scale():
holder = Mock()
scale = torch.nn.Parameter(torch.empty(1))
loaded_scale = torch.tensor(2.0)
VocabParallelEmbedding.weight_loader(holder, scale, loaded_scale)
assert torch.equal(scale, loaded_scale.reshape(1))
@pytest.mark.skipif(
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
@@ -81,6 +81,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
@@ -182,7 +183,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
# handle exclusion
if self.is_layer_excluded(prefix):
if isinstance(layer, LinearBase):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
return None
@@ -195,7 +196,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
return UnquantizedLinearMethod()
# now, the layer is quantized, handle it here
if isinstance(layer, LinearBase):
if isinstance(layer, (LinearBase, ParallelLMHead)):
quant_method = self.LinearMethodCls(self)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
@@ -2371,13 +2372,13 @@ class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
# Excluded layers
if self.is_layer_excluded(prefix):
if isinstance(layer, LinearBase):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
return None
quant_algo = self._resolve_quant_algo(prefix)
if isinstance(layer, LinearBase):
if isinstance(layer, (LinearBase, ParallelLMHead)):
if quant_algo == "FP8":
return ModelOptFp8LinearMethod(self.fp8_config)
if quant_algo == "NVFP4":
@@ -290,6 +290,7 @@ class VocabParallelEmbedding(PluggableLayer):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the vocabulary dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(
@@ -438,6 +439,12 @@ class VocabParallelEmbedding(PluggableLayer):
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
if (
loaded_weight.ndim == 0
and param.data.ndim == 1
and param.data.numel() == 1
):
loaded_weight = loaded_weight.reshape(1)
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
return
+1
View File
@@ -875,6 +875,7 @@ class NemotronHForCausalLM(
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
+1
View File
@@ -477,6 +477,7 @@ class Qwen3_5ForCausalLMBase(
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
@@ -381,6 +381,7 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else: