mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Fix sarvam forward compatibility with transformers v5 (#38804)
Signed-off-by: vikrantpalle <vikrantpalle@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import os
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict
|
||||
from functools import cache, partial
|
||||
from functools import cache, partial, wraps
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
@@ -141,6 +141,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
|
||||
_SPECULATIVE_DECODING_CONFIGS: set[str] = {"eagle", "speculators"}
|
||||
|
||||
_PATCH_HF_VALIDATE_ROPE: set[str] = {"sarvam_mla"}
|
||||
|
||||
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
||||
"llm_config": "text_config",
|
||||
}
|
||||
@@ -173,6 +175,34 @@ def _mistral_patch_hf_hub_constants() -> Iterator[None]:
|
||||
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
|
||||
|
||||
|
||||
def _patch_hf_transformers_validate_rope():
|
||||
"""Transformers v5 moved the ignore_keys option from the method signature of
|
||||
validate_rope and replaced it with the ignore_keys_at_rope_validation parameter
|
||||
in the PreTrainedConfig class. This is a patch to make older versions of
|
||||
validate_rope() with the ignore_keys parameter work with newer versions of
|
||||
hf transformers (from v5 onwards)
|
||||
"""
|
||||
|
||||
if Version(version("transformers")) >= Version("5.0.0"):
|
||||
if hasattr(PretrainedConfig.validate_rope, "__vllm_patched__"):
|
||||
return
|
||||
|
||||
_original_validate_rope = PretrainedConfig.validate_rope
|
||||
|
||||
@wraps(_original_validate_rope)
|
||||
def patched_validate_rope(self, *args, **kwargs):
|
||||
ignore_keys_param = kwargs.pop("ignore_keys", None)
|
||||
original_ignore_keys = self.ignore_keys_at_rope_validation
|
||||
self.ignore_keys_at_rope_validation = (
|
||||
original_ignore_keys or ignore_keys_param
|
||||
)
|
||||
result = _original_validate_rope(self, *args, **kwargs)
|
||||
return result
|
||||
|
||||
patched_validate_rope.__vllm_patched__ = True # type: ignore[attr-defined]
|
||||
PretrainedConfig.validate_rope = patched_validate_rope
|
||||
|
||||
|
||||
class HFConfigParser(ConfigParserBase):
|
||||
def parse(
|
||||
self,
|
||||
@@ -212,6 +242,9 @@ class HFConfigParser(ConfigParserBase):
|
||||
dummy_model_type = hf_overrides(dummy_config).model_type
|
||||
model_type = dummy_model_type.removeprefix("dummy_")
|
||||
|
||||
if model_type in _PATCH_HF_VALIDATE_ROPE:
|
||||
_patch_hf_transformers_validate_rope()
|
||||
|
||||
if model_type in _SPECULATIVE_DECODING_CONFIGS:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config = config_class.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user