mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Add LM head quantization support for ModelOpt (#42124)
Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
def test_nemotron_h_lm_head_receives_quant_config():
|
||||
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
|
||||
mock_quant_config = Mock()
|
||||
|
||||
mock_hf_config = Mock()
|
||||
mock_hf_config.vocab_size = 128
|
||||
mock_hf_config.hidden_size = 64
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config.model_config.dtype = None
|
||||
mock_vllm_config.scheduler_config = Mock()
|
||||
mock_vllm_config.quant_config = mock_quant_config
|
||||
|
||||
with (
|
||||
patch("vllm.model_executor.models.nemotron_h.NemotronHModel") as MockModel,
|
||||
patch("vllm.model_executor.models.nemotron_h.ParallelLMHead") as MockLMHead,
|
||||
patch("vllm.model_executor.models.nemotron_h.LogitsProcessor"),
|
||||
):
|
||||
MockModel.return_value.make_empty_intermediate_tensors = Mock()
|
||||
MockModel.return_value.has_moe = False
|
||||
|
||||
NemotronHForCausalLM(vllm_config=mock_vllm_config)
|
||||
|
||||
MockLMHead.assert_called_once()
|
||||
call_kwargs = MockLMHead.call_args.kwargs
|
||||
assert call_kwargs["quant_config"] is mock_quant_config
|
||||
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
def test_qwen3_5_lm_head_receives_quant_config():
|
||||
from vllm.model_executor.models.qwen3_5 import Qwen3_5ForCausalLMBase
|
||||
|
||||
mock_quant_config = Mock()
|
||||
|
||||
mock_hf_config = Mock()
|
||||
mock_hf_config.tie_word_embeddings = False
|
||||
mock_hf_config.vocab_size = 128
|
||||
mock_hf_config.hidden_size = 64
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.model_config.hf_text_config = mock_hf_config
|
||||
mock_vllm_config.cache_config.mamba_cache_mode = "align"
|
||||
mock_vllm_config.scheduler_config = Mock()
|
||||
mock_vllm_config.quant_config = mock_quant_config
|
||||
mock_vllm_config.lora_config = None
|
||||
|
||||
mock_pp_group = Mock()
|
||||
mock_pp_group.is_last_rank = True
|
||||
|
||||
with (
|
||||
patch("vllm.model_executor.models.qwen3_5.Qwen3_5Model") as MockModel,
|
||||
patch("vllm.model_executor.models.qwen3_5.ParallelLMHead") as MockLMHead,
|
||||
patch("vllm.model_executor.models.qwen3_5.LogitsProcessor"),
|
||||
patch(
|
||||
"vllm.model_executor.models.qwen3_5.get_pp_group",
|
||||
return_value=mock_pp_group,
|
||||
),
|
||||
):
|
||||
MockModel.return_value.make_empty_intermediate_tensors = Mock()
|
||||
|
||||
Qwen3_5ForCausalLMBase(vllm_config=mock_vllm_config)
|
||||
|
||||
MockLMHead.assert_called_once()
|
||||
call_kwargs = MockLMHead.call_args.kwargs
|
||||
assert call_kwargs["quant_config"] is mock_quant_config
|
||||
|
||||
|
||||
def test_qwen3_5_mtp_lm_head_receives_quant_config():
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.model_executor.models.qwen3_5_mtp import Qwen3_5MTP
|
||||
|
||||
mock_quant_config = Mock()
|
||||
|
||||
mock_hf_config = Mock()
|
||||
mock_hf_config.tie_word_embeddings = False
|
||||
mock_hf_config.vocab_size = 128
|
||||
mock_hf_config.hidden_size = 64
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.model_config.hf_text_config = mock_hf_config
|
||||
mock_vllm_config.cache_config.mamba_cache_mode = "align"
|
||||
mock_vllm_config.compilation_config.mode = CompilationMode.NONE
|
||||
mock_vllm_config.quant_config = mock_quant_config
|
||||
|
||||
mock_pp_group = Mock()
|
||||
mock_pp_group.is_last_rank = True
|
||||
|
||||
with (
|
||||
patch("vllm.model_executor.models.qwen3_5_mtp.Qwen3_5MultiTokenPredictor"),
|
||||
patch("vllm.model_executor.models.qwen3_5_mtp.ParallelLMHead") as MockLMHead,
|
||||
patch("vllm.model_executor.models.qwen3_5_mtp.LogitsProcessor"),
|
||||
patch(
|
||||
"vllm.model_executor.models.qwen3_5_mtp.get_pp_group",
|
||||
return_value=mock_pp_group,
|
||||
),
|
||||
):
|
||||
Qwen3_5MTP(vllm_config=mock_vllm_config)
|
||||
|
||||
MockLMHead.assert_called_once()
|
||||
call_kwargs = MockLMHead.call_args.kwargs
|
||||
assert call_kwargs["quant_config"] is mock_quant_config
|
||||
@@ -7,13 +7,24 @@ Run `pytest tests/quantization/test_modelopt.py`.
|
||||
|
||||
import os
|
||||
from typing import Any, NoReturn
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptFp8Config,
|
||||
ModelOptMixedPrecisionConfig,
|
||||
ModelOptNvFp4Config,
|
||||
ModelOptNvFp4LinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
@@ -44,6 +55,87 @@ def _snapshot_download_or_skip(model_id: str) -> str:
|
||||
_skip(f"Failed to download {model_id} from the HF Hub: {e}")
|
||||
|
||||
|
||||
def _mock_lm_head() -> Mock:
|
||||
lm_head = Mock(spec=ParallelLMHead)
|
||||
lm_head.__class__ = ParallelLMHead
|
||||
return lm_head
|
||||
|
||||
|
||||
def _mixed_precision_config(quantized_layers: dict) -> ModelOptMixedPrecisionConfig:
|
||||
return ModelOptMixedPrecisionConfig(
|
||||
kv_cache_quant_method=None,
|
||||
exclude_modules=[],
|
||||
quantized_layers=quantized_layers,
|
||||
fp8_config=ModelOptFp8Config(
|
||||
quant_method="FP8",
|
||||
is_checkpoint_fp8_serialized=True,
|
||||
kv_cache_quant_method=None,
|
||||
exclude_modules=[],
|
||||
),
|
||||
nvfp4_config=ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
),
|
||||
w4a16_nvfp4_config=ModelOptNvFp4Config(
|
||||
quant_method="W4A16_NVFP4",
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_modelopt_nvfp4_quantizes_parallel_lm_head():
|
||||
config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=[],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"vllm.model_executor.layers.quantization.modelopt.init_nvfp4_linear_kernel"
|
||||
):
|
||||
method = config.get_quant_method(_mock_lm_head(), prefix="lm_head")
|
||||
|
||||
assert isinstance(method, ModelOptNvFp4LinearMethod)
|
||||
|
||||
|
||||
def test_modelopt_nvfp4_leaves_excluded_parallel_lm_head_unquantized():
|
||||
config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=None,
|
||||
exclude_modules=["lm_head"],
|
||||
)
|
||||
|
||||
method = config.get_quant_method(_mock_lm_head(), prefix="lm_head")
|
||||
|
||||
assert isinstance(method, UnquantizedLinearMethod)
|
||||
|
||||
|
||||
def test_modelopt_mixed_precision_quantizes_parallel_lm_head():
|
||||
config = _mixed_precision_config(
|
||||
{"lm_head": {"quant_algo": "NVFP4", "group_size": 16}}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"vllm.model_executor.layers.quantization.modelopt.init_nvfp4_linear_kernel"
|
||||
):
|
||||
method = config.get_quant_method(_mock_lm_head(), prefix="lm_head")
|
||||
|
||||
assert isinstance(method, ModelOptNvFp4LinearMethod)
|
||||
|
||||
|
||||
def test_vocab_parallel_embedding_weight_loader_accepts_scalar_scale():
|
||||
holder = Mock()
|
||||
scale = torch.nn.Parameter(torch.empty(1))
|
||||
loaded_scale = torch.tensor(2.0)
|
||||
|
||||
VocabParallelEmbedding.weight_loader(holder, scale, loaded_scale)
|
||||
|
||||
assert torch.equal(scale, loaded_scale.reshape(1))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("modelopt"),
|
||||
reason="ModelOpt FP8 is not supported on this GPU type.",
|
||||
|
||||
@@ -81,6 +81,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
@@ -182,7 +183,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
|
||||
# handle exclusion
|
||||
if self.is_layer_excluded(prefix):
|
||||
if isinstance(layer, LinearBase):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
return None
|
||||
|
||||
@@ -195,7 +196,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
# now, the layer is quantized, handle it here
|
||||
if isinstance(layer, LinearBase):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
quant_method = self.LinearMethodCls(self)
|
||||
if getattr(quant_method, "backend", "") == "marlin":
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
@@ -2371,13 +2372,13 @@ class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
|
||||
|
||||
# Excluded layers
|
||||
if self.is_layer_excluded(prefix):
|
||||
if isinstance(layer, LinearBase):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
return UnquantizedLinearMethod()
|
||||
return None
|
||||
|
||||
quant_algo = self._resolve_quant_algo(prefix)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||
if quant_algo == "FP8":
|
||||
return ModelOptFp8LinearMethod(self.fp8_config)
|
||||
if quant_algo == "NVFP4":
|
||||
|
||||
@@ -290,6 +290,7 @@ class VocabParallelEmbedding(PluggableLayer):
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
# Divide the weight matrix along the vocabulary dimension.
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(
|
||||
@@ -438,6 +439,12 @@ class VocabParallelEmbedding(PluggableLayer):
|
||||
# If parameter does not have output dim, then it should
|
||||
# be copied onto all gpus (e.g. g_idx for act_order gptq).
|
||||
if output_dim is None:
|
||||
if (
|
||||
loaded_weight.ndim == 0
|
||||
and param.data.ndim == 1
|
||||
and param.data.numel() == 1
|
||||
):
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
assert param.data.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
@@ -875,6 +875,7 @@ class NemotronHForCausalLM(
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
|
||||
@@ -477,6 +477,7 @@ class Qwen3_5ForCausalLMBase(
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -381,6 +381,7 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user