mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Feature] Lazy import for the "mistral" tokenizer module. (#34651)
Signed-off-by: Neil Schemenauer <nas@arctrix.com>
This commit is contained in:
@@ -23,7 +23,7 @@ from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
from ....multimodal.utils import random_audio, random_image, random_video
|
||||
from ...registry import (
|
||||
@@ -183,7 +183,7 @@ def get_text_token_prompts(
|
||||
|
||||
text_prompt: str | None
|
||||
token_prompt: list[int]
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
# ChatCompletionRequest only supports ImageChunk natively;
|
||||
# for other modalities (e.g. audio), fall back to the model's
|
||||
# own dummy inputs builder which knows the right placeholders.
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
|
||||
class StreamingReasoningReconstructor:
|
||||
@@ -59,7 +59,7 @@ def run_reasoning_extraction_mistral(
|
||||
request: ChatCompletionRequest | None = None,
|
||||
streaming: bool = False,
|
||||
) -> tuple[str | None, str | None]:
|
||||
assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type(
|
||||
assert is_mistral_tokenizer(reasoning_parser.model_tokenizer), type(
|
||||
reasoning_parser.model_tokenizer
|
||||
)
|
||||
if streaming:
|
||||
@@ -130,7 +130,7 @@ def run_reasoning_extraction_streaming_mistral(
|
||||
model_deltas: list[int],
|
||||
request: ChatCompletionRequest | None = None,
|
||||
) -> StreamingReasoningReconstructor:
|
||||
assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type(
|
||||
assert is_mistral_tokenizer(reasoning_parser.model_tokenizer), type(
|
||||
reasoning_parser.model_tokenizer
|
||||
)
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
|
||||
@@ -83,9 +83,9 @@ from vllm.renderers.inputs.preprocess import (
|
||||
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.counter import Counter
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
from vllm.utils.tqdm_utils import maybe_tqdm
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
@@ -891,7 +891,7 @@ class LLM:
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -1458,7 +1458,7 @@ class LLM:
|
||||
model_config = self.model_config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
raise ValueError("Score API is not supported for Mistral tokenizer")
|
||||
|
||||
if len(data_1) == 1:
|
||||
|
||||
@@ -75,16 +75,12 @@ from vllm.parser import ParserManager
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import (
|
||||
MistralTokenizer,
|
||||
maybe_serialize_tool_calls,
|
||||
truncate_tool_call_ids,
|
||||
validate_request_params,
|
||||
)
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
|
||||
from vllm.tool_parsers.utils import partial_json_loads
|
||||
from vllm.utils.collection_utils import as_list
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
from vllm.utils.mistral import mt as _mt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -244,18 +240,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
|
||||
truncate_tool_call_ids(request) # type: ignore[arg-type]
|
||||
validate_request_params(request)
|
||||
_mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
|
||||
_mt.truncate_tool_call_ids(request) # type: ignore[arg-type]
|
||||
_mt.validate_request_params(request)
|
||||
|
||||
# Check if tool parsing is unavailable (common condition)
|
||||
tool_parsing_unavailable = (
|
||||
tool_parser is None
|
||||
and not isinstance(tokenizer, MistralTokenizer)
|
||||
and not is_mistral_tokenizer(tokenizer)
|
||||
and not self.use_harmony
|
||||
)
|
||||
|
||||
@@ -639,8 +635,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_metadata: RequestResponseMetadata,
|
||||
reasoning_parser: ReasoningParser | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
@@ -955,7 +949,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
else:
|
||||
# Generate ID based on tokenizer type
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
tool_call_id = MistralToolCall.generate_random_id()
|
||||
else:
|
||||
tool_call_id = make_tool_call_id(
|
||||
@@ -1516,7 +1510,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_parser_cls=self.tool_parser,
|
||||
)
|
||||
tool_call_class = (
|
||||
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
|
||||
MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall
|
||||
)
|
||||
if self.use_harmony:
|
||||
# Harmony models already have parsed content and tool_calls
|
||||
@@ -1951,7 +1945,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
|
||||
_mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
|
||||
|
||||
# Add system message.
|
||||
# NOTE: In Chat Completion API, browsing is enabled by default
|
||||
|
||||
@@ -128,6 +128,7 @@ from vllm.utils.async_utils import (
|
||||
collect_from_async_generator,
|
||||
merge_async_iterators,
|
||||
)
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
|
||||
class GenerationError(Exception):
|
||||
@@ -976,15 +977,13 @@ class OpenAIServing:
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
renderer = self.renderer
|
||||
|
||||
default_template_kwargs = merge_kwargs(
|
||||
default_template_kwargs,
|
||||
dict(
|
||||
tools=tool_dicts,
|
||||
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -41,8 +41,8 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.async_utils import make_async, merge_async_iterators
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -348,7 +348,7 @@ class ServingScores(OpenAIServing):
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
raise ValueError("MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
@@ -26,6 +26,7 @@ from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
|
||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@@ -260,10 +261,8 @@ class InputProcessingContext:
|
||||
|
||||
typ = ProcessorMixin
|
||||
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
tokenizer = tokenizer.transformers_tokenizer
|
||||
|
||||
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
from vllm.v1.serial_utils import PydanticMsgspecMixin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -731,7 +732,6 @@ class SamplingParams(
|
||||
):
|
||||
raise ValueError("structured_outputs.grammar cannot be an empty string")
|
||||
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
has_guidance_unsupported_json_features,
|
||||
validate_guidance_grammar,
|
||||
@@ -752,7 +752,7 @@ class SamplingParams(
|
||||
# allows <|special_token|> and similar, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||
# Without tokenizer these are disallowed in grammars.
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
raise ValueError(
|
||||
"Mistral tokenizer is not supported for the 'guidance' "
|
||||
"structured output backend. Please use ['xgrammar', 'outlines'] "
|
||||
@@ -764,7 +764,7 @@ class SamplingParams(
|
||||
validate_structured_output_request_outlines(self)
|
||||
elif backend == "lm-format-enforcer":
|
||||
# lm format enforcer backend
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
raise ValueError(
|
||||
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
|
||||
"structured output backend. Please use ['xgrammar', 'outlines'] "
|
||||
@@ -796,7 +796,7 @@ class SamplingParams(
|
||||
schema = so_params.json
|
||||
skip_guidance = has_guidance_unsupported_json_features(schema)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer) or skip_guidance:
|
||||
if is_mistral_tokenizer(tokenizer) or skip_guidance:
|
||||
# Fall back to outlines if the tokenizer is Mistral
|
||||
# or if schema contains features unsupported by guidance
|
||||
validate_structured_output_request_outlines(self)
|
||||
|
||||
@@ -210,6 +210,8 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
|
||||
|
||||
|
||||
class MistralTokenizer(TokenizerLike):
|
||||
IS_MISTRAL_TOKENIZER = True # used by vllm.utils.mistral
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
|
||||
@@ -22,10 +22,10 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -34,7 +34,7 @@ class Hermes2ProToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
logger.error("Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = tokenizer.tokenizer
|
||||
|
||||
|
||||
@@ -22,9 +22,9 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.tool_parsers.utils import extract_intermediate_diff
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -33,7 +33,7 @@ class JambaToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(self.model_tokenizer):
|
||||
raise ValueError(
|
||||
"Detected a MistralTokenizer tokenizer when using a Jamba model"
|
||||
)
|
||||
|
||||
@@ -25,10 +25,10 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -66,9 +66,7 @@ class MistralToolCall(ToolCall):
|
||||
|
||||
|
||||
def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
|
||||
return not (
|
||||
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
|
||||
)
|
||||
return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11)
|
||||
|
||||
|
||||
class MistralToolParser(ToolParser):
|
||||
@@ -83,7 +81,7 @@ class MistralToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
if not is_mistral_tokenizer(self.model_tokenizer):
|
||||
logger.info("Non-Mistral tokenizer detected when using a Mistral model...")
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
@@ -115,7 +113,7 @@ class MistralToolParser(ToolParser):
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if (
|
||||
not isinstance(self.model_tokenizer, MistralTokenizer)
|
||||
not is_mistral_tokenizer(self.model_tokenizer)
|
||||
and request.tools
|
||||
and request.tool_choice != "none"
|
||||
):
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Provides lazy import of the vllm.tokenizers.mistral module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeGuard
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# if type checking, eagerly import the module
|
||||
import vllm.tokenizers.mistral as mt
|
||||
else:
|
||||
mt = LazyLoader("mt", globals(), "vllm.tokenizers.mistral")
|
||||
|
||||
|
||||
def is_mistral_tokenizer(obj: TokenizerLike | None) -> TypeGuard[mt.MistralTokenizer]:
|
||||
"""Return true if the tokenizer is a MistralTokenizer instance."""
|
||||
cls = type(obj)
|
||||
# Check for special class attribute, this avoids importing the class to
|
||||
# do an isinstance() check. If the attribute is True, do an isinstance
|
||||
# check to be sure we have the correct type.
|
||||
return bool(
|
||||
getattr(cls, "IS_MISTRAL_TOKENIZER", False)
|
||||
and isinstance(obj, mt.MistralTokenizer)
|
||||
)
|
||||
@@ -10,8 +10,8 @@ import torch
|
||||
import vllm.envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
@@ -38,7 +38,7 @@ class XgrammarBackend(StructuredOutputBackend):
|
||||
self.vllm_config.structured_outputs_config.disable_any_whitespace
|
||||
)
|
||||
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
if is_mistral_tokenizer(self.tokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||
|
||||
Reference in New Issue
Block a user