mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Fix sarvam forward compatibility with transformers v5 (#38804)
Signed-off-by: vikrantpalle <vikrantpalle@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import os
|
|||||||
from collections.abc import Callable, Iterator
|
from collections.abc import Callable, Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from functools import cache, partial
|
from functools import cache, partial, wraps
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal, TypeAlias
|
||||||
@@ -141,6 +141,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
|||||||
|
|
||||||
_SPECULATIVE_DECODING_CONFIGS: set[str] = {"eagle", "speculators"}
|
_SPECULATIVE_DECODING_CONFIGS: set[str] = {"eagle", "speculators"}
|
||||||
|
|
||||||
|
_PATCH_HF_VALIDATE_ROPE: set[str] = {"sarvam_mla"}
|
||||||
|
|
||||||
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
||||||
"llm_config": "text_config",
|
"llm_config": "text_config",
|
||||||
}
|
}
|
||||||
@@ -173,6 +175,34 @@ def _mistral_patch_hf_hub_constants() -> Iterator[None]:
|
|||||||
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
|
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_hf_transformers_validate_rope():
|
||||||
|
"""Transformers v5 moved the ignore_keys option from the method signature of
|
||||||
|
validate_rope and replaced it with the ignore_keys_at_rope_validation parameter
|
||||||
|
in the PreTrainedConfig class. This is a patch to make older versions of
|
||||||
|
validate_rope() with the ignore_keys parameter work with newer versions of
|
||||||
|
hf transformers (from v5 onwards)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if Version(version("transformers")) >= Version("5.0.0"):
|
||||||
|
if hasattr(PretrainedConfig.validate_rope, "__vllm_patched__"):
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_validate_rope = PretrainedConfig.validate_rope
|
||||||
|
|
||||||
|
@wraps(_original_validate_rope)
|
||||||
|
def patched_validate_rope(self, *args, **kwargs):
|
||||||
|
ignore_keys_param = kwargs.pop("ignore_keys", None)
|
||||||
|
original_ignore_keys = self.ignore_keys_at_rope_validation
|
||||||
|
self.ignore_keys_at_rope_validation = (
|
||||||
|
original_ignore_keys or ignore_keys_param
|
||||||
|
)
|
||||||
|
result = _original_validate_rope(self, *args, **kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
patched_validate_rope.__vllm_patched__ = True # type: ignore[attr-defined]
|
||||||
|
PretrainedConfig.validate_rope = patched_validate_rope
|
||||||
|
|
||||||
|
|
||||||
class HFConfigParser(ConfigParserBase):
|
class HFConfigParser(ConfigParserBase):
|
||||||
def parse(
|
def parse(
|
||||||
self,
|
self,
|
||||||
@@ -212,6 +242,9 @@ class HFConfigParser(ConfigParserBase):
|
|||||||
dummy_model_type = hf_overrides(dummy_config).model_type
|
dummy_model_type = hf_overrides(dummy_config).model_type
|
||||||
model_type = dummy_model_type.removeprefix("dummy_")
|
model_type = dummy_model_type.removeprefix("dummy_")
|
||||||
|
|
||||||
|
if model_type in _PATCH_HF_VALIDATE_ROPE:
|
||||||
|
_patch_hf_transformers_validate_rope()
|
||||||
|
|
||||||
if model_type in _SPECULATIVE_DECODING_CONFIGS:
|
if model_type in _SPECULATIVE_DECODING_CONFIGS:
|
||||||
config_class = _CONFIG_REGISTRY[model_type]
|
config_class = _CONFIG_REGISTRY[model_type]
|
||||||
config = config_class.from_pretrained(
|
config = config_class.from_pretrained(
|
||||||
|
|||||||
Reference in New Issue
Block a user