mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Merge branch 'main' into lwilkinson/kv-layout/bucket-layers-refactor
This commit is contained in:
@@ -32,7 +32,7 @@ partial-json-parser # used for parsing partial JSON outputs
|
||||
pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.17.0
|
||||
mistral_common[image] >= 1.11.2
|
||||
mistral_common[image] >= 1.11.3
|
||||
opencv-python-headless >= 4.13.0 # required for video IO
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
|
||||
@@ -31,7 +31,7 @@ torchaudio==2.11.0
|
||||
torchvision==0.26.0
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.11.2 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.11.3 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
|
||||
opencv-python-headless >= 4.13.0 # required for video test
|
||||
|
||||
@@ -409,7 +409,7 @@ mbstrdecoder==1.1.3
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.11.2
|
||||
mistral-common==1.11.3
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test/cuda.in
|
||||
|
||||
@@ -23,7 +23,7 @@ jiwer # required for audio tests
|
||||
timm # required for internvl test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.11.2 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.11.3 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.13.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
|
||||
@@ -30,7 +30,7 @@ tblib # for pickling test exceptions
|
||||
timm>=1.0.17 # required for internvl and gemma3n-mm test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio]>=1.11.2 # required for voxtral test
|
||||
mistral_common[image,audio]>=1.11.3 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
|
||||
opencv-python-headless>=4.13.0 # required for video test
|
||||
|
||||
@@ -512,7 +512,7 @@ mcp==1.27.0
|
||||
# via -r requirements/test/../common.txt
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.11.2
|
||||
mistral-common==1.11.3
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test/../common.txt
|
||||
|
||||
@@ -266,7 +266,7 @@ mbstrdecoder==1.1.4
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.11.2
|
||||
mistral-common==1.11.3
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test/xpu.in
|
||||
|
||||
@@ -100,32 +100,6 @@ def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) ->
|
||||
assert output.shape == (2, output_size)
|
||||
|
||||
|
||||
def test_kv_cache_scale_name_handling():
|
||||
# Mock a quant config that supports cache scales
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
|
||||
|
||||
# Condition check in load_weights
|
||||
name = "layers.0.self_attn.k_proj.weight"
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
|
||||
# Check if get_cache_scale is called and returns expected value
|
||||
mock_quant_config.get_cache_scale.assert_called_once_with(name)
|
||||
assert scale_name == "layers.0.self_attn.kv_scale"
|
||||
|
||||
|
||||
def test_kv_cache_scale_name_no_scale():
|
||||
# Mock a quant config that returns None for get_cache_scale
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value=None)
|
||||
|
||||
name = "layers.0.mlp.gate_proj.weight"
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
|
||||
# Should return None for weights that don't have cache scales
|
||||
assert scale_name is None
|
||||
|
||||
|
||||
def test_maybe_remap_kv_scale_name():
|
||||
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
|
||||
|
||||
@@ -183,33 +157,3 @@ def test_eagle3_lm_head_receives_quant_config():
|
||||
assert call_kwargs["quant_config"] is mock_quant_config, (
|
||||
"ParallelLMHead must receive the draft model's quant_config"
|
||||
)
|
||||
|
||||
|
||||
def test_load_weights_kv_scale_handling():
|
||||
kv_scale_param = Mock()
|
||||
kv_scale_param.weight_loader = Mock()
|
||||
|
||||
params_dict = {
|
||||
"layers.0.self_attn.kv_scale": kv_scale_param,
|
||||
}
|
||||
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
|
||||
|
||||
# Load_weights logic for KV cache scales
|
||||
name = "layers.0.self_attn.k_proj.weight"
|
||||
loaded_weight_tensor = torch.tensor([1.0, 2.0])
|
||||
|
||||
if mock_quant_config is not None:
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
if scale_name:
|
||||
param = params_dict[scale_name]
|
||||
assert param is kv_scale_param
|
||||
weight_to_load = (
|
||||
loaded_weight_tensor
|
||||
if loaded_weight_tensor.dim() == 0
|
||||
else loaded_weight_tensor[0]
|
||||
)
|
||||
|
||||
assert scale_name == "layers.0.self_attn.kv_scale"
|
||||
assert weight_to_load == loaded_weight_tensor[0]
|
||||
|
||||
@@ -162,7 +162,13 @@ class QuantizationConfig(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
def get_cache_scale_mapper(self) -> "WeightsMapper | None":
|
||||
"""Mapping from checkpoint KV-cache scale names to vLLM scale names.
|
||||
|
||||
Returning a mapper here causes `AutoWeightsLoader` to apply it to the
|
||||
weight stream automatically; individual model `load_weights` methods
|
||||
do not need to know about KV-cache scales.
|
||||
"""
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper( # noqa: B027
|
||||
|
||||
@@ -207,25 +207,18 @@ class Fp8Config(QuantizationConfig):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
def get_cache_scale_mapper(self) -> "WeightsMapper":
|
||||
"""Map compressed-tensors KV-cache scale names to vLLM names."""
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||
if name.endswith("self_attn.prob_output_scale"):
|
||||
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
return WeightsMapper(
|
||||
orig_to_new_suffix={
|
||||
".k_proj.output_scale": ".attn.k_scale",
|
||||
".v_proj.output_scale": ".attn.v_scale",
|
||||
".q_proj.output_scale": ".attn.q_scale",
|
||||
".self_attn.prob_output_scale": ".self_attn.attn.prob_scale",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class CopyNumelCounter(TorchDispatchMode):
|
||||
|
||||
@@ -15,6 +15,30 @@ from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVCacheScaleParameter(torch.nn.Parameter):
|
||||
"""Scalar parameter for KV-cache scales.
|
||||
|
||||
Initialized to -1.0 (an invalid sentinel) so call sites just write
|
||||
`KVCacheScaleParameter()`. The `weight_loader` accepts shape `()` or
|
||||
`(1,)` and rejects anything else — per-head scales go through a separate
|
||||
path (compressed-tensors' `_tp_aware_loader`), not this one. Per-instance
|
||||
overrides still work because instance attribute assignment shadows this
|
||||
class-level loader.
|
||||
"""
|
||||
|
||||
def __new__(cls) -> "KVCacheScaleParameter":
|
||||
return super().__new__(cls, torch.tensor(-1.0), requires_grad=False)
|
||||
|
||||
@staticmethod
|
||||
def weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
|
||||
if loaded_weight.numel() != 1:
|
||||
raise ValueError(
|
||||
f"KV-cache scale expects a scalar weight, got shape "
|
||||
f"{tuple(loaded_weight.shape)}"
|
||||
)
|
||||
param.data.copy_(loaded_weight.reshape(()))
|
||||
|
||||
|
||||
class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"""
|
||||
Quant method that adds `_k_scale` and `_v_scale` attributes to the
|
||||
@@ -37,11 +61,11 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
# Initialize the Q and KV cache scales to -1.0, an invalid value.
|
||||
# If the q and k/v_scales appear in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.q_scale = KVCacheScaleParameter()
|
||||
layer.k_scale = KVCacheScaleParameter()
|
||||
layer.v_scale = KVCacheScaleParameter()
|
||||
# Initialize P = softmax(QK^T) scales
|
||||
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||
layer.prob_scale = KVCacheScaleParameter()
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
||||
|
||||
@@ -646,26 +646,16 @@ class QuarkConfig(QuantizationConfig):
|
||||
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in quark. If this is the case, return its equivalent param name
|
||||
expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||
if name.endswith("self_attn.prob_output_scale"):
|
||||
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||
|
||||
# If no matches, return None
|
||||
return None
|
||||
def get_cache_scale_mapper(self) -> "WeightsMapper":
|
||||
"""Map Quark KV-cache scale names to vLLM names."""
|
||||
return WeightsMapper(
|
||||
orig_to_new_suffix={
|
||||
".k_proj.output_scale": ".attn.k_scale",
|
||||
".v_proj.output_scale": ".attn.v_scale",
|
||||
".q_proj.output_scale": ".attn.q_scale",
|
||||
".self_attn.prob_output_scale": ".self_attn.attn.prob_scale",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class QuarkLinearMethod(LinearMethodBase):
|
||||
|
||||
@@ -1541,6 +1541,11 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
|
||||
if no remapping is needed.
|
||||
None: If the remapped name is not found in params_dict.
|
||||
"""
|
||||
# Already in vLLM's expected form (e.g. weights pre-renamed by a
|
||||
# `WeightsMapper` from the quant config). Skip the regex remap, which
|
||||
# would otherwise double-apply the `.attn` prefix and drop the weight.
|
||||
if name in params_dict:
|
||||
return name
|
||||
if name.endswith(".kv_scale"):
|
||||
logger.warning_once(
|
||||
"DEPRECATED. Found kv_scale in the checkpoint. "
|
||||
|
||||
@@ -430,18 +430,6 @@ class ApertusModel(nn.Module, EagleModelMixin):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -293,18 +293,6 @@ class ArceeModel(nn.Module, EagleModelMixin):
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
continue
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if "scale" in name or "zero_point" in name:
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if remapped_name is None:
|
||||
|
||||
@@ -363,18 +363,6 @@ class AriaTextModel(LlamaModel, SupportsQuant):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -464,18 +464,6 @@ class Cohere2MoeModel(nn.Module):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
|
||||
@@ -150,18 +150,6 @@ class CohereEagleModel(nn.Module):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -352,19 +352,6 @@ class CohereModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
|
||||
@@ -394,19 +394,6 @@ class DbrxModel(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if name.endswith(("w1", "w2", "v1")):
|
||||
name = name + "_weight"
|
||||
for param_name, weight_name in expert_params_mapping:
|
||||
|
||||
@@ -284,19 +284,6 @@ class DeepseekV2Eagle3Model(nn.Module):
|
||||
if "midlayer." in name:
|
||||
name = name.replace("midlayer.", "layers.0.")
|
||||
|
||||
# Handle kv cache quantization scales
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
# Remapping the name FP8 kv-scale
|
||||
if "scale" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -391,18 +391,6 @@ class ExaoneModel(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -389,18 +389,6 @@ class Exaone4Model(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -374,18 +374,6 @@ class ExaoneMoeModel(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -328,16 +328,6 @@ class Gemma2Model(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
|
||||
@@ -386,17 +386,6 @@ class Gemma3Model(nn.Module):
|
||||
):
|
||||
loaded_weight -= 1
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
# Check if this is a scale parameter that needs remapping first
|
||||
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
|
||||
# Try to remap the scale name first
|
||||
|
||||
@@ -1056,16 +1056,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
):
|
||||
name = f"self_decoder.{name}"
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
|
||||
@@ -1409,16 +1409,6 @@ class Gemma4Model(nn.Module, EagleModelMixin):
|
||||
params_dict.update(dict(self.named_buffers()))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if remapped_name is not None and remapped_name in params_dict:
|
||||
|
||||
@@ -258,18 +258,6 @@ class Glm4Model(LlamaModel):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale or zero point.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -166,6 +166,10 @@ class GlmOcrMTP(nn.Module, SupportsPP):
|
||||
return self.model.compute_logits(hidden_states, spec_step_idx)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
if self.quant_config is not None and (
|
||||
cache_scale_mapper := self.quant_config.get_cache_scale_mapper()
|
||||
):
|
||||
weights = cache_scale_mapper.apply(weights)
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -189,19 +193,6 @@ class GlmOcrMTP(nn.Module, SupportsPP):
|
||||
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale or zero point.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -254,19 +254,6 @@ class GPTJModel(nn.Module):
|
||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||
continue
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -635,52 +635,6 @@ class GptOssModel(nn.Module, EagleModelMixin):
|
||||
|
||||
moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id)
|
||||
|
||||
def kv_cache_scale_loader(
|
||||
quant_config: QuantizationConfig,
|
||||
name: str,
|
||||
params_dict: dict[str, typing.Any],
|
||||
weight: torch.Tensor,
|
||||
default_weight_loader: Callable[..., None],
|
||||
loaded_params: set[str],
|
||||
) -> tuple[bool, set[str]]:
|
||||
"""
|
||||
Load KV cache output scales.
|
||||
Returns:
|
||||
Tuple of (bool, set):
|
||||
- bool: True if KV-cache scale was loaded into loaded_params
|
||||
- set: Updated set of loaded_params if True else the original set
|
||||
"""
|
||||
# load explicit cached KV output scale from quant_config
|
||||
if quant_config is not None and (
|
||||
scale_name := quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
if weight.numel() != 1:
|
||||
raise ValueError(
|
||||
f"KV cache scale '{scale_name}' is expected to be a "
|
||||
f"scalar, but got a tensor of shape {weight.shape}."
|
||||
)
|
||||
# Ensure weight is a scalar before passing to loader.
|
||||
weight_loader(param, weight.flatten()[0])
|
||||
loaded_params.add(scale_name)
|
||||
return True, loaded_params
|
||||
|
||||
return False, loaded_params
|
||||
|
||||
load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader(
|
||||
self.quant_config,
|
||||
name,
|
||||
params_dict,
|
||||
loaded_weight,
|
||||
default_weight_loader,
|
||||
loaded_params,
|
||||
)
|
||||
if load_kv_cache_scale_completed:
|
||||
continue
|
||||
|
||||
if (
|
||||
all(key in name for key in ["input_scale", "mlp.experts"])
|
||||
and expert_id is not None
|
||||
|
||||
@@ -334,18 +334,6 @@ class GraniteModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -365,19 +365,6 @@ class GraniteMoeModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -495,18 +495,6 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
if "A_log" in n:
|
||||
n = n.replace("A_log", "A")
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(n)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
loaded_weight = p
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
_load(scale_name, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if _load_quant_expert(n, p):
|
||||
continue
|
||||
|
||||
|
||||
@@ -548,18 +548,6 @@ class Grok1Model(nn.Module):
|
||||
for old_pattern, new_pattern in self.weight_name_remapping.items():
|
||||
if old_pattern in name:
|
||||
name = name.replace(old_pattern, new_pattern)
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
|
||||
@@ -771,15 +771,6 @@ class HunYuanModel(nn.Module, EagleModelMixin):
|
||||
# processed with quantization, LoRA, fine-tuning, etc.
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
|
||||
is_found = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
|
||||
@@ -531,17 +531,6 @@ class HYV3Model(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -264,6 +264,10 @@ class HYV3MTP(nn.Module):
|
||||
return torch.concat((q, k, v))
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
if self.quant_config is not None and (
|
||||
cache_scale_mapper := self.quant_config.get_cache_scale_mapper()
|
||||
):
|
||||
weights = cache_scale_mapper.apply(weights)
|
||||
cla_factor = _get_cla_factor(self.config)
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@@ -336,14 +340,6 @@ class HYV3MTP(nn.Module):
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
|
||||
@@ -395,18 +395,6 @@ class HyperCLOVAXModel(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale or zero point.
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -476,18 +476,6 @@ class IQuestLoopCoderModel(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if "gate_projections" in name:
|
||||
continue
|
||||
|
||||
@@ -386,16 +386,6 @@ class Jais2Model(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
|
||||
@@ -790,21 +790,6 @@ class KeyeSiglipVisionModel(nn.Module):
|
||||
continue
|
||||
if "head.mlp" in name or "head.probe" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(
|
||||
param,
|
||||
"weight_loader",
|
||||
default_weight_loader,
|
||||
)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (
|
||||
param_name,
|
||||
weight_name,
|
||||
|
||||
@@ -724,20 +724,6 @@ class LagunaModel(nn.Module, EagleModelMixin):
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
# Handle KV cache quantization scales
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
assert loaded_weight.numel() == 1, (
|
||||
f"KV scale numel {loaded_weight.numel()} != 1"
|
||||
)
|
||||
loaded_weight = loaded_weight.squeeze()
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
# Handle stacked params (QKV, gate_up for
|
||||
# non-expert layers and shared_expert)
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
|
||||
@@ -451,18 +451,6 @@ class LlamaModel(nn.Module, EagleModelMixin):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale or zero point.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -588,21 +588,6 @@ class Llama4Model(LlamaModel):
|
||||
fused_experts_params = True
|
||||
expert_params_mapping = expert_params_mapping_fused
|
||||
|
||||
# If kv cache quantization scales exist and the weight name
|
||||
# corresponds to one of the kv cache quantization scales, load
|
||||
# them.
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
# Iterate over stacked_params_mapping to check if the current weight
|
||||
# is one of the stacked parameters. If so, load the weight with the
|
||||
# corresponding shard id. Note that MoE weights are handled
|
||||
@@ -625,9 +610,9 @@ class Llama4Model(LlamaModel):
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# Remap kv cache scale names for ModelOpt checkpoints.
|
||||
# TODO: ModelOpt should implement get_cache_scale() such that
|
||||
# kv cache scale name remapping can be done there.
|
||||
# Remap kv cache scale names for any checkpoint format the
|
||||
# quant config's `get_cache_scale_mapper` does not cover
|
||||
# (idempotent for names already renamed by the mapper).
|
||||
if name.endswith("scale"):
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
|
||||
@@ -127,19 +127,6 @@ class LlamaModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
# Handle kv cache quantization scales
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
# Remapping the name FP8 kv-scale or zero point.
|
||||
if "scale" in name or "zero_point" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -267,19 +267,6 @@ class LlamaModel(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if "midlayer." in name:
|
||||
name = name.replace("midlayer.", "layers.0.")
|
||||
# Handle kv cache quantization scales
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
# Remapping the name FP8 kv-scale or zero point.
|
||||
if "scale" in name or "zero_point" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -104,18 +104,6 @@ class MiMoModel(Qwen2Model):
|
||||
continue
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -555,22 +555,6 @@ class MiMoV2Model(nn.Module):
|
||||
if "mtp" in name:
|
||||
continue
|
||||
|
||||
if self.quant_config is not None:
|
||||
cache_scale_name = self.quant_config.get_cache_scale(name)
|
||||
if cache_scale_name is not None and cache_scale_name in params_dict:
|
||||
param = params_dict[cache_scale_name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
|
||||
kv_scale = loaded_weight
|
||||
if kv_scale.dim() > 0 and kv_scale.numel() > 1:
|
||||
kv_scale = kv_scale.view(-1)[0]
|
||||
|
||||
weight_loader(param, kv_scale)
|
||||
loaded_params.add(cache_scale_name)
|
||||
continue
|
||||
|
||||
expert_matched = False
|
||||
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
|
||||
@@ -388,19 +388,6 @@ class MixtralModel(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -375,18 +375,6 @@ class NemotronModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -334,18 +334,6 @@ class DeciModel(nn.Module):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name or "zero_point" in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
@@ -390,18 +390,6 @@ class OuroModel(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -944,21 +944,6 @@ class SiglipVisionModel(nn.Module):
|
||||
continue
|
||||
if "packing_position_embedding" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(
|
||||
param,
|
||||
"weight_loader",
|
||||
default_weight_loader,
|
||||
)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (
|
||||
param_name,
|
||||
weight_name,
|
||||
|
||||
@@ -537,19 +537,6 @@ class PhiMoEModel(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -439,18 +439,6 @@ class Qwen2Model(nn.Module, EagleModelMixin):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -469,17 +469,6 @@ class DFlashQwen3Model(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if "midlayer." in name:
|
||||
name = name.replace("midlayer.", "layers.0.")
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
|
||||
@@ -552,19 +552,6 @@ class Qwen3MoeModel(nn.Module, EagleModelMixin):
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
assert loaded_weight.numel() == 1, (
|
||||
f"KV scale numel {loaded_weight.numel()} != 1"
|
||||
)
|
||||
loaded_weight = loaded_weight.squeeze()
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name or "zero_point" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
|
||||
@@ -350,16 +350,6 @@ class Rnj1Model(nn.Module):
|
||||
):
|
||||
loaded_weight -= 1
|
||||
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
|
||||
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if remapped_name is not None and remapped_name in params_dict:
|
||||
|
||||
@@ -376,18 +376,6 @@ class SeedOssModel(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -360,18 +360,6 @@ class SolarModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
@@ -52,6 +52,7 @@ class WeightsMapper:
|
||||
def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
|
||||
"""Combine two `WeightsMapper`s by merging their mappings."""
|
||||
return WeightsMapper(
|
||||
orig_to_new_regex={**self.orig_to_new_regex, **other.orig_to_new_regex},
|
||||
orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
|
||||
orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
|
||||
orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
|
||||
@@ -343,6 +344,20 @@ class AutoWeightsLoader:
|
||||
*,
|
||||
mapper: WeightsMapper | None = None,
|
||||
) -> set[str]:
|
||||
# Many models store quant_config in the base model instead of the causal model.
|
||||
# We look at the causal model's direct children for this reason.
|
||||
modules = (self.module, *self.module.children())
|
||||
iterator = (m.quant_config for m in modules if hasattr(m, "quant_config"))
|
||||
quant_config = next(iterator, None)
|
||||
cache_scale_mapper = (
|
||||
quant_config.get_cache_scale_mapper() if quant_config is not None else None
|
||||
)
|
||||
if cache_scale_mapper is not None:
|
||||
mapper = (
|
||||
mapper | cache_scale_mapper
|
||||
if mapper is not None
|
||||
else cache_scale_mapper
|
||||
)
|
||||
if mapper is not None:
|
||||
weights = mapper.apply(weights)
|
||||
# filter out weights with first-prefix/substr to skip in name
|
||||
|
||||
@@ -330,10 +330,34 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# Metadata-independent input GEMMs + RMSNorm stay in the captured
|
||||
# graph; the metadata-dependent rest (q up-proj + kv-insert, indexer,
|
||||
# compressor, MLA attention) runs in the eager break.
|
||||
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
|
||||
self.attn_gemm_parallel_execute(hidden_states)
|
||||
)
|
||||
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
|
||||
qr, kv = fused_q_kv_rmsnorm(
|
||||
qr,
|
||||
kv,
|
||||
self.q_norm.weight.data,
|
||||
self.kv_norm.weight.data,
|
||||
self.eps,
|
||||
)
|
||||
|
||||
# attention_impl is wrapped with @eager_break_during_capture: this is
|
||||
# where the breakable cudagraph capture breaks (the attention op runs
|
||||
# eagerly between captured graph segments).
|
||||
self.attention_impl(hidden_states, positions, o_padded)
|
||||
self.attention_impl(
|
||||
hidden_states,
|
||||
qr,
|
||||
kv,
|
||||
kv_score,
|
||||
indexer_kv_score,
|
||||
indexer_weights,
|
||||
positions,
|
||||
o_padded,
|
||||
)
|
||||
o = o_padded[:, : self.n_local_heads, :]
|
||||
|
||||
# Inverse-RoPE + wo_a + wo_b output projection (platform-specific).
|
||||
@@ -403,25 +427,17 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
|
||||
def attention_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
kv_score: torch.Tensor,
|
||||
indexer_kv_score: torch.Tensor,
|
||||
indexer_weights: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
|
||||
) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
|
||||
self.attn_gemm_parallel_execute(hidden_states)
|
||||
)
|
||||
|
||||
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
|
||||
qr, kv = fused_q_kv_rmsnorm(
|
||||
qr,
|
||||
kv,
|
||||
self.q_norm.weight.data,
|
||||
self.kv_norm.weight.data,
|
||||
self.eps,
|
||||
)
|
||||
|
||||
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
|
||||
# on the default stream so q stays on its consumer stream (forward_mqa
|
||||
# downstream reads q on default). Indexer/compressor go on aux for
|
||||
|
||||
@@ -92,6 +92,11 @@ class CpuPlatform(Platform):
|
||||
|
||||
return meminfo.total_memory
|
||||
|
||||
@classmethod
|
||||
def mem_get_info(cls) -> tuple[int, int]:
|
||||
meminfo = get_memory_node_info()
|
||||
return meminfo.available_memory, meminfo.total_memory
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
|
||||
@@ -90,7 +90,7 @@ MODEL_TO_TAG_STYLE: dict[str, CohereTagStyle] = {
|
||||
tools=COMMAND_A_TOOLS_TAG,
|
||||
),
|
||||
"Cohere2MoeForCausalLM": CohereTagStyle(
|
||||
json_tags=(COMMAND_A_JSON_TAG,),
|
||||
json_tags=(COMMAND_A_JSON_TAG, COMMAND_A_PLUS_JSON_TAG),
|
||||
tools=COMMAND_A_TOOLS_TAG,
|
||||
),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user