Fix sarvam forward compatibility with transformers v5 (#38804)

Signed-off-by: vikrantpalle <vikrantpalle@gmail.com>
This commit is contained in:
Flame
2026-06-05 21:21:44 +05:30
committed by GitHub
parent aa6fb8a329
commit 91e17d4315
+34 -1
View File
@@ -5,7 +5,7 @@ import os
from collections.abc import Callable, Iterator from collections.abc import Callable, Iterator
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import asdict from dataclasses import asdict
from functools import cache, partial from functools import cache, partial, wraps
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Any, Literal, TypeAlias from typing import Any, Literal, TypeAlias
@@ -141,6 +141,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
_SPECULATIVE_DECODING_CONFIGS: set[str] = {"eagle", "speculators"} _SPECULATIVE_DECODING_CONFIGS: set[str] = {"eagle", "speculators"}
_PATCH_HF_VALIDATE_ROPE: set[str] = {"sarvam_mla"}
_CONFIG_ATTRS_MAPPING: dict[str, str] = { _CONFIG_ATTRS_MAPPING: dict[str, str] = {
"llm_config": "text_config", "llm_config": "text_config",
} }
@@ -173,6 +175,34 @@ def _mistral_patch_hf_hub_constants() -> Iterator[None]:
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
def _patch_hf_transformers_validate_rope():
"""Transformers v5 moved the ignore_keys option from the method signature of
validate_rope and replaced it with the ignore_keys_at_rope_validation parameter
in the PreTrainedConfig class. This is a patch to make older versions of
validate_rope() with the ignore_keys parameter work with newer versions of
hf transformers (from v5 onwards)
"""
if Version(version("transformers")) >= Version("5.0.0"):
if hasattr(PretrainedConfig.validate_rope, "__vllm_patched__"):
return
_original_validate_rope = PretrainedConfig.validate_rope
@wraps(_original_validate_rope)
def patched_validate_rope(self, *args, **kwargs):
ignore_keys_param = kwargs.pop("ignore_keys", None)
original_ignore_keys = self.ignore_keys_at_rope_validation
self.ignore_keys_at_rope_validation = (
original_ignore_keys or ignore_keys_param
)
result = _original_validate_rope(self, *args, **kwargs)
return result
patched_validate_rope.__vllm_patched__ = True # type: ignore[attr-defined]
PretrainedConfig.validate_rope = patched_validate_rope
class HFConfigParser(ConfigParserBase): class HFConfigParser(ConfigParserBase):
def parse( def parse(
self, self,
@@ -212,6 +242,9 @@ class HFConfigParser(ConfigParserBase):
dummy_model_type = hf_overrides(dummy_config).model_type dummy_model_type = hf_overrides(dummy_config).model_type
model_type = dummy_model_type.removeprefix("dummy_") model_type = dummy_model_type.removeprefix("dummy_")
if model_type in _PATCH_HF_VALIDATE_ROPE:
_patch_hf_transformers_validate_rope()
if model_type in _SPECULATIVE_DECODING_CONFIGS: if model_type in _SPECULATIVE_DECODING_CONFIGS:
config_class = _CONFIG_REGISTRY[model_type] config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained( config = config_class.from_pretrained(