Merge branch 'main' into lwilkinson/kv-layout/bucket-layers-refactor

This commit is contained in:
Matthew Bonanni
2026-06-05 10:19:20 -04:00
66 changed files with 130 additions and 752 deletions
+1 -1
View File
@@ -32,7 +32,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0
msgspec
gguf >= 0.17.0
mistral_common[image] >= 1.11.2
mistral_common[image] >= 1.11.3
opencv-python-headless >= 4.13.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
+1 -1
View File
@@ -31,7 +31,7 @@ torchaudio==2.11.0
torchvision==0.26.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.11.2 # required for voxtral test
mistral_common[image,audio] >= 1.11.3 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless >= 4.13.0 # required for video test
+1 -1
View File
@@ -409,7 +409,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.11.2
mistral-common==1.11.3
# via
# -c requirements/common.txt
# -r requirements/test/cuda.in
+1 -1
View File
@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.11.2 # required for voxtral test
mistral_common[image,audio] >= 1.11.3 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.13.0 # required for video test
datamodel_code_generator # required for minicpm3 test
+1 -1
View File
@@ -30,7 +30,7 @@ tblib # for pickling test exceptions
timm>=1.0.17 # required for internvl and gemma3n-mm test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio]>=1.11.2 # required for voxtral test
mistral_common[image,audio]>=1.11.3 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless>=4.13.0 # required for video test
+1 -1
View File
@@ -512,7 +512,7 @@ mcp==1.27.0
# via -r requirements/test/../common.txt
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.11.2
mistral-common==1.11.3
# via
# -c requirements/common.txt
# -r requirements/test/../common.txt
+1 -1
View File
@@ -266,7 +266,7 @@ mbstrdecoder==1.1.4
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.11.2
mistral-common==1.11.3
# via
# -c requirements/common.txt
# -r requirements/test/xpu.in
@@ -100,32 +100,6 @@ def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) ->
assert output.shape == (2, output_size)
def test_kv_cache_scale_name_handling():
# Mock a quant config that supports cache scales
mock_quant_config = Mock()
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
# Condition check in load_weights
name = "layers.0.self_attn.k_proj.weight"
scale_name = mock_quant_config.get_cache_scale(name)
# Check if get_cache_scale is called and returns expected value
mock_quant_config.get_cache_scale.assert_called_once_with(name)
assert scale_name == "layers.0.self_attn.kv_scale"
def test_kv_cache_scale_name_no_scale():
# Mock a quant config that returns None for get_cache_scale
mock_quant_config = Mock()
mock_quant_config.get_cache_scale = Mock(return_value=None)
name = "layers.0.mlp.gate_proj.weight"
scale_name = mock_quant_config.get_cache_scale(name)
# Should return None for weights that don't have cache scales
assert scale_name is None
def test_maybe_remap_kv_scale_name():
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
@@ -183,33 +157,3 @@ def test_eagle3_lm_head_receives_quant_config():
assert call_kwargs["quant_config"] is mock_quant_config, (
"ParallelLMHead must receive the draft model's quant_config"
)
def test_load_weights_kv_scale_handling():
kv_scale_param = Mock()
kv_scale_param.weight_loader = Mock()
params_dict = {
"layers.0.self_attn.kv_scale": kv_scale_param,
}
mock_quant_config = Mock()
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
# Load_weights logic for KV cache scales
name = "layers.0.self_attn.k_proj.weight"
loaded_weight_tensor = torch.tensor([1.0, 2.0])
if mock_quant_config is not None:
scale_name = mock_quant_config.get_cache_scale(name)
if scale_name:
param = params_dict[scale_name]
assert param is kv_scale_param
weight_to_load = (
loaded_weight_tensor
if loaded_weight_tensor.dim() == 0
else loaded_weight_tensor[0]
)
assert scale_name == "layers.0.self_attn.kv_scale"
assert weight_to_load == loaded_weight_tensor[0]
@@ -162,7 +162,13 @@ class QuantizationConfig(ABC):
"""
raise NotImplementedError
def get_cache_scale(self, name: str) -> str | None:
def get_cache_scale_mapper(self) -> "WeightsMapper | None":
"""Mapping from checkpoint KV-cache scale names to vLLM scale names.
Returning a mapper here causes `AutoWeightsLoader` to apply it to the
weight stream automatically; individual model `load_weights` methods
do not need to know about KV-cache scales.
"""
return None
def apply_vllm_mapper( # noqa: B027
+11 -18
View File
@@ -207,25 +207,18 @@ class Fp8Config(QuantizationConfig):
return Fp8KVCacheMethod(self)
return None
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
def get_cache_scale_mapper(self) -> "WeightsMapper":
"""Map compressed-tensors KV-cache scale names to vLLM names."""
from vllm.model_executor.models.utils import WeightsMapper
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None
return WeightsMapper(
orig_to_new_suffix={
".k_proj.output_scale": ".attn.k_scale",
".v_proj.output_scale": ".attn.v_scale",
".q_proj.output_scale": ".attn.q_scale",
".self_attn.prob_output_scale": ".self_attn.attn.prob_scale",
}
)
class CopyNumelCounter(TorchDispatchMode):
@@ -15,6 +15,30 @@ from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales
logger = init_logger(__name__)
class KVCacheScaleParameter(torch.nn.Parameter):
"""Scalar parameter for KV-cache scales.
Initialized to -1.0 (an invalid sentinel) so call sites just write
`KVCacheScaleParameter()`. The `weight_loader` accepts shape `()` or
`(1,)` and rejects anything else — per-head scales go through a separate
path (compressed-tensors' `_tp_aware_loader`), not this one. Per-instance
overrides still work because instance attribute assignment shadows this
class-level loader.
"""
def __new__(cls) -> "KVCacheScaleParameter":
return super().__new__(cls, torch.tensor(-1.0), requires_grad=False)
@staticmethod
def weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
if loaded_weight.numel() != 1:
raise ValueError(
f"KV-cache scale expects a scalar weight, got shape "
f"{tuple(loaded_weight.shape)}"
)
param.data.copy_(loaded_weight.reshape(()))
class BaseKVCacheMethod(QuantizeMethodBase):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
@@ -37,11 +61,11 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# Initialize the Q and KV cache scales to -1.0, an invalid value.
# If the q and k/v_scales appear in the checkpoint, it will be
# overwritten when loading weights.
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.q_scale = KVCacheScaleParameter()
layer.k_scale = KVCacheScaleParameter()
layer.v_scale = KVCacheScaleParameter()
# Initialize P = softmax(QK^T) scales
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.prob_scale = KVCacheScaleParameter()
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
@@ -646,26 +646,16 @@ class QuarkConfig(QuantizationConfig):
return scheme
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales
in quark. If this is the case, return its equivalent param name
expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None
def get_cache_scale_mapper(self) -> "WeightsMapper":
"""Map Quark KV-cache scale names to vLLM names."""
return WeightsMapper(
orig_to_new_suffix={
".k_proj.output_scale": ".attn.k_scale",
".v_proj.output_scale": ".attn.v_scale",
".q_proj.output_scale": ".attn.q_scale",
".self_attn.prob_output_scale": ".self_attn.attn.prob_scale",
}
)
class QuarkLinearMethod(LinearMethodBase):
@@ -1541,6 +1541,11 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
# Already in vLLM's expected form (e.g. weights pre-renamed by a
# `WeightsMapper` from the quant config). Skip the regex remap, which
# would otherwise double-apply the `.attn` prefix and drop the weight.
if name in params_dict:
return name
if name.endswith(".kv_scale"):
logger.warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
-12
View File
@@ -430,18 +430,6 @@ class ApertusModel(nn.Module, EagleModelMixin):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
-12
View File
@@ -293,18 +293,6 @@ class ArceeModel(nn.Module, EagleModelMixin):
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is None:
-12
View File
@@ -363,18 +363,6 @@ class AriaTextModel(LlamaModel, SupportsQuant):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-12
View File
@@ -464,18 +464,6 @@ class Cohere2MoeModel(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
@@ -150,18 +150,6 @@ class CohereEagleModel(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-13
View File
@@ -352,19 +352,6 @@ class CohereModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
-13
View File
@@ -394,19 +394,6 @@ class DbrxModel(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if name.endswith(("w1", "w2", "v1")):
name = name + "_weight"
for param_name, weight_name in expert_params_mapping:
@@ -284,19 +284,6 @@ class DeepseekV2Eagle3Model(nn.Module):
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
-12
View File
@@ -391,18 +391,6 @@ class ExaoneModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-12
View File
@@ -389,18 +389,6 @@ class Exaone4Model(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-12
View File
@@ -374,18 +374,6 @@ class ExaoneMoeModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-10
View File
@@ -328,16 +328,6 @@ class Gemma2Model(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
-11
View File
@@ -386,17 +386,6 @@ class Gemma3Model(nn.Module):
):
loaded_weight -= 1
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Check if this is a scale parameter that needs remapping first
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
# Try to remap the scale name first
-10
View File
@@ -1056,16 +1056,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
):
name = f"self_decoder.{name}"
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
-10
View File
@@ -1409,16 +1409,6 @@ class Gemma4Model(nn.Module, EagleModelMixin):
params_dict.update(dict(self.named_buffers()))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
-12
View File
@@ -258,18 +258,6 @@ class Glm4Model(LlamaModel):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
+4 -13
View File
@@ -166,6 +166,10 @@ class GlmOcrMTP(nn.Module, SupportsPP):
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if self.quant_config is not None and (
cache_scale_mapper := self.quant_config.get_cache_scale_mapper()
):
weights = cache_scale_mapper.apply(weights)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -189,19 +193,6 @@ class GlmOcrMTP(nn.Module, SupportsPP):
name = self._rewrite_spec_layer_name(spec_layer, name)
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
-13
View File
@@ -254,19 +254,6 @@ class GPTJModel(nn.Module):
if "attn.bias" in name or "attn.masked_bias" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-46
View File
@@ -635,52 +635,6 @@ class GptOssModel(nn.Module, EagleModelMixin):
moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id)
def kv_cache_scale_loader(
quant_config: QuantizationConfig,
name: str,
params_dict: dict[str, typing.Any],
weight: torch.Tensor,
default_weight_loader: Callable[..., None],
loaded_params: set[str],
) -> tuple[bool, set[str]]:
"""
Load KV cache output scales.
Returns:
Tuple of (bool, set):
- bool: True if KV-cache scale was loaded into loaded_params
- set: Updated set of loaded_params if True else the original set
"""
# load explicit cached KV output scale from quant_config
if quant_config is not None and (
scale_name := quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
if weight.numel() != 1:
raise ValueError(
f"KV cache scale '{scale_name}' is expected to be a "
f"scalar, but got a tensor of shape {weight.shape}."
)
# Ensure weight is a scalar before passing to loader.
weight_loader(param, weight.flatten()[0])
loaded_params.add(scale_name)
return True, loaded_params
return False, loaded_params
load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader(
self.quant_config,
name,
params_dict,
loaded_weight,
default_weight_loader,
loaded_params,
)
if load_kv_cache_scale_completed:
continue
if (
all(key in name for key in ["input_scale", "mlp.experts"])
and expert_id is not None
-12
View File
@@ -334,18 +334,6 @@ class GraniteModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-13
View File
@@ -365,19 +365,6 @@ class GraniteMoeModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -495,18 +495,6 @@ class GraniteMoeHybridModel(nn.Module):
if "A_log" in n:
n = n.replace("A_log", "A")
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(n)
):
# Loading kv cache quantization scales
loaded_weight = p
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
_load(scale_name, loaded_weight)
loaded_params.add(scale_name)
continue
if _load_quant_expert(n, p):
continue
-12
View File
@@ -548,18 +548,6 @@ class Grok1Model(nn.Module):
for old_pattern, new_pattern in self.weight_name_remapping.items():
if old_pattern in name:
name = name.replace(old_pattern, new_pattern)
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
-9
View File
@@ -771,15 +771,6 @@ class HunYuanModel(nn.Module, EagleModelMixin):
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
is_found = False
for param_name, weight_name, shard_id in stacked_params_mapping:
-11
View File
@@ -531,17 +531,6 @@ class HYV3Model(nn.Module):
for name, loaded_weight in weights:
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
+4 -8
View File
@@ -264,6 +264,10 @@ class HYV3MTP(nn.Module):
return torch.concat((q, k, v))
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
if self.quant_config is not None and (
cache_scale_mapper := self.quant_config.get_cache_scale_mapper()
):
weights = cache_scale_mapper.apply(weights)
cla_factor = _get_cla_factor(self.config)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -336,14 +340,6 @@ class HYV3MTP(nn.Module):
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
-12
View File
@@ -395,18 +395,6 @@ class HyperCLOVAXModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
@@ -476,18 +476,6 @@ class IQuestLoopCoderModel(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if "gate_projections" in name:
continue
-10
View File
@@ -386,16 +386,6 @@ class Jais2Model(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
-15
View File
@@ -790,21 +790,6 @@ class KeyeSiglipVisionModel(nn.Module):
continue
if "head.mlp" in name or "head.probe" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (
param_name,
weight_name,
-14
View File
@@ -724,20 +724,6 @@ class LagunaModel(nn.Module, EagleModelMixin):
loaded_params.add(name)
continue
# Handle KV cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
assert loaded_weight.numel() == 1, (
f"KV scale numel {loaded_weight.numel()} != 1"
)
loaded_weight = loaded_weight.squeeze()
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Handle stacked params (QKV, gate_up for
# non-expert layers and shared_expert)
for param_name, weight_name, shard_id in stacked_params_mapping:
-12
View File
@@ -451,18 +451,6 @@ class LlamaModel(nn.Module, EagleModelMixin):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
+3 -18
View File
@@ -588,21 +588,6 @@ class Llama4Model(LlamaModel):
fused_experts_params = True
expert_params_mapping = expert_params_mapping_fused
# If kv cache quantization scales exist and the weight name
# corresponds to one of the kv cache quantization scales, load
# them.
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Iterate over stacked_params_mapping to check if the current weight
# is one of the stacked parameters. If so, load the weight with the
# corresponding shard id. Note that MoE weights are handled
@@ -625,9 +610,9 @@ class Llama4Model(LlamaModel):
if is_pp_missing_parameter(name, self):
continue
# Remap kv cache scale names for ModelOpt checkpoints.
# TODO: ModelOpt should implement get_cache_scale() such that
# kv cache scale name remapping can be done there.
# Remap kv cache scale names for any checkpoint format the
# quant config's `get_cache_scale_mapper` does not cover
# (idempotent for names already renamed by the mapper).
if name.endswith("scale"):
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
-13
View File
@@ -127,19 +127,6 @@ class LlamaModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale or zero point.
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
@@ -267,19 +267,6 @@ class LlamaModel(nn.Module):
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale or zero point.
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
-12
View File
@@ -104,18 +104,6 @@ class MiMoModel(Qwen2Model):
continue
if "rotary_emb.inv_freq" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-16
View File
@@ -555,22 +555,6 @@ class MiMoV2Model(nn.Module):
if "mtp" in name:
continue
if self.quant_config is not None:
cache_scale_name = self.quant_config.get_cache_scale(name)
if cache_scale_name is not None and cache_scale_name in params_dict:
param = params_dict[cache_scale_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
kv_scale = loaded_weight
if kv_scale.dim() > 0 and kv_scale.numel() > 1:
kv_scale = kv_scale.view(-1)[0]
weight_loader(param, kv_scale)
loaded_params.add(cache_scale_name)
continue
expert_matched = False
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
if weight_name not in name:
-13
View File
@@ -388,19 +388,6 @@ class MixtralModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-12
View File
@@ -375,18 +375,6 @@ class NemotronModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -334,18 +334,6 @@ class DeciModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
-12
View File
@@ -390,18 +390,6 @@ class OuroModel(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -944,21 +944,6 @@ class SiglipVisionModel(nn.Module):
continue
if "packing_position_embedding" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(
param,
"weight_loader",
default_weight_loader,
)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (
param_name,
weight_name,
-13
View File
@@ -537,19 +537,6 @@ class PhiMoEModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-12
View File
@@ -439,18 +439,6 @@ class Qwen2Model(nn.Module, EagleModelMixin):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -469,17 +469,6 @@ class DFlashQwen3Model(nn.Module):
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
-13
View File
@@ -552,19 +552,6 @@ class Qwen3MoeModel(nn.Module, EagleModelMixin):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
assert loaded_weight.numel() == 1, (
f"KV scale numel {loaded_weight.numel()} != 1"
)
loaded_weight = loaded_weight.squeeze()
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
-10
View File
@@ -350,16 +350,6 @@ class Rnj1Model(nn.Module):
):
loaded_weight -= 1
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
-12
View File
@@ -376,18 +376,6 @@ class SeedOssModel(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
-12
View File
@@ -360,18 +360,6 @@ class SolarModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
+15
View File
@@ -52,6 +52,7 @@ class WeightsMapper:
def __or__(self, other: "WeightsMapper") -> "WeightsMapper":
"""Combine two `WeightsMapper`s by merging their mappings."""
return WeightsMapper(
orig_to_new_regex={**self.orig_to_new_regex, **other.orig_to_new_regex},
orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr},
orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix},
orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix},
@@ -343,6 +344,20 @@ class AutoWeightsLoader:
*,
mapper: WeightsMapper | None = None,
) -> set[str]:
# Many models store quant_config in the base model instead of the causal model.
# We look at the causal model's direct children for this reason.
modules = (self.module, *self.module.children())
iterator = (m.quant_config for m in modules if hasattr(m, "quant_config"))
quant_config = next(iterator, None)
cache_scale_mapper = (
quant_config.get_cache_scale_mapper() if quant_config is not None else None
)
if cache_scale_mapper is not None:
mapper = (
mapper | cache_scale_mapper
if mapper is not None
else cache_scale_mapper
)
if mapper is not None:
weights = mapper.apply(weights)
# filter out weights with first-prefix/substr to skip in name
+30 -14
View File
@@ -330,10 +330,34 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
device=hidden_states.device,
)
# Metadata-independent input GEMMs + RMSNorm stay in the captured
# graph; the metadata-dependent rest (q up-proj + kv-insert, indexer,
# compressor, MLA attention) runs in the eager break.
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
self.attn_gemm_parallel_execute(hidden_states)
)
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
qr, kv = fused_q_kv_rmsnorm(
qr,
kv,
self.q_norm.weight.data,
self.kv_norm.weight.data,
self.eps,
)
# attention_impl is wrapped with @eager_break_during_capture: this is
# where the breakable cudagraph capture breaks (the attention op runs
# eagerly between captured graph segments).
self.attention_impl(hidden_states, positions, o_padded)
self.attention_impl(
hidden_states,
qr,
kv,
kv_score,
indexer_kv_score,
indexer_weights,
positions,
o_padded,
)
o = o_padded[:, : self.n_local_heads, :]
# Inverse-RoPE + wo_a + wo_b output projection (platform-specific).
@@ -403,25 +427,17 @@ class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
def attention_impl(
self,
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
kv_score: torch.Tensor,
indexer_kv_score: torch.Tensor,
indexer_weights: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
self.attn_gemm_parallel_execute(hidden_states)
)
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
qr, kv = fused_q_kv_rmsnorm(
qr,
kv,
self.q_norm.weight.data,
self.kv_norm.weight.data,
self.eps,
)
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
# on the default stream so q stays on its consumer stream (forward_mqa
# downstream reads q on default). Indexer/compressor go on aux for
+5
View File
@@ -92,6 +92,11 @@ class CpuPlatform(Platform):
return meminfo.total_memory
@classmethod
def mem_get_info(cls) -> tuple[int, int]:
meminfo = get_memory_node_info()
return meminfo.available_memory, meminfo.total_memory
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
@@ -90,7 +90,7 @@ MODEL_TO_TAG_STYLE: dict[str, CohereTagStyle] = {
tools=COMMAND_A_TOOLS_TAG,
),
"Cohere2MoeForCausalLM": CohereTagStyle(
json_tags=(COMMAND_A_JSON_TAG,),
json_tags=(COMMAND_A_JSON_TAG, COMMAND_A_PLUS_JSON_TAG),
tools=COMMAND_A_TOOLS_TAG,
),
}