[Feature] Lazy import for the "mistral" tokenizer module. (#34651)

Signed-off-by: Neil Schemenauer <nas@arctrix.com>
This commit is contained in:
Neil Schemenauer
2026-02-23 00:43:01 -08:00
committed by GitHub
parent e631f8e78e
commit 54e2f83d0a
14 changed files with 68 additions and 48 deletions
@@ -23,7 +23,7 @@ from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.mistral import is_mistral_tokenizer
from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import (
@@ -183,7 +183,7 @@ def get_text_token_prompts(
text_prompt: str | None
token_prompt: list[int]
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
# ChatCompletionRequest only supports ImageChunk natively;
# for other modalities (e.g. audio), fall back to the model's
# own dummy inputs builder which knows the right placeholders.
+3 -3
View File
@@ -4,7 +4,7 @@
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning import ReasoningParser
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.mistral import is_mistral_tokenizer
class StreamingReasoningReconstructor:
@@ -59,7 +59,7 @@ def run_reasoning_extraction_mistral(
request: ChatCompletionRequest | None = None,
streaming: bool = False,
) -> tuple[str | None, str | None]:
assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type(
assert is_mistral_tokenizer(reasoning_parser.model_tokenizer), type(
reasoning_parser.model_tokenizer
)
if streaming:
@@ -130,7 +130,7 @@ def run_reasoning_extraction_streaming_mistral(
model_deltas: list[int],
request: ChatCompletionRequest | None = None,
) -> StreamingReasoningReconstructor:
assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type(
assert is_mistral_tokenizer(reasoning_parser.model_tokenizer), type(
reasoning_parser.model_tokenizer
)
request = request or ChatCompletionRequest(messages=[], model="test-model")
+3 -3
View File
@@ -83,9 +83,9 @@ from vllm.renderers.inputs.preprocess import (
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.tqdm_utils import maybe_tqdm
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
@@ -891,7 +891,7 @@ class LLM:
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
),
)
@@ -1458,7 +1458,7 @@ class LLM:
model_config = self.model_config
tokenizer = self.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
raise ValueError("Score API is not supported for Mistral tokenizer")
if len(data_1) == 1:
@@ -75,16 +75,12 @@ from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
MistralTokenizer,
maybe_serialize_tool_calls,
truncate_tool_call_ids,
validate_request_params,
)
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt
logger = init_logger(__name__)
@@ -244,18 +240,18 @@ class OpenAIServingChat(OpenAIServing):
tool_parser = self.tool_parser
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
truncate_tool_call_ids(request) # type: ignore[arg-type]
validate_request_params(request)
_mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
_mt.truncate_tool_call_ids(request) # type: ignore[arg-type]
_mt.validate_request_params(request)
# Check if tool parsing is unavailable (common condition)
tool_parsing_unavailable = (
tool_parser is None
and not isinstance(tokenizer, MistralTokenizer)
and not is_mistral_tokenizer(tokenizer)
and not self.use_harmony
)
@@ -639,8 +635,6 @@ class OpenAIServingChat(OpenAIServing):
request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
@@ -955,7 +949,7 @@ class OpenAIServingChat(OpenAIServing):
)
else:
# Generate ID based on tokenizer type
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
tool_call_id = MistralToolCall.generate_random_id()
else:
tool_call_id = make_tool_call_id(
@@ -1516,7 +1510,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser_cls=self.tool_parser,
)
tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall
)
if self.use_harmony:
# Harmony models already have parsed content and tool_calls
@@ -1951,7 +1945,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
_mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
+2 -3
View File
@@ -128,6 +128,7 @@ from vllm.utils.async_utils import (
collect_from_async_generator,
merge_async_iterators,
)
from vllm.utils.mistral import is_mistral_tokenizer
class GenerationError(Exception):
@@ -976,15 +977,13 @@ class OpenAIServing:
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer
default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
)
+2 -2
View File
@@ -41,8 +41,8 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async, merge_async_iterators
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__)
@@ -348,7 +348,7 @@ class ServingScores(OpenAIServing):
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
tokenizer = self.renderer.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
raise ValueError("MistralTokenizer not supported for cross-encoding")
model_config = self.model_config
+2 -3
View File
@@ -26,6 +26,7 @@ from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
from vllm.utils.mistral import is_mistral_tokenizer
if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
@@ -260,10 +261,8 @@ class InputProcessingContext:
typ = ProcessorMixin
from vllm.tokenizers.mistral import MistralTokenizer
tokenizer = self.tokenizer
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
tokenizer = tokenizer.transformers_tokenizer
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
+4 -4
View File
@@ -16,6 +16,7 @@ from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.v1.serial_utils import PydanticMsgspecMixin
logger = init_logger(__name__)
@@ -731,7 +732,6 @@ class SamplingParams(
):
raise ValueError("structured_outputs.grammar cannot be an empty string")
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.v1.structured_output.backend_guidance import (
has_guidance_unsupported_json_features,
validate_guidance_grammar,
@@ -752,7 +752,7 @@ class SamplingParams(
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
@@ -764,7 +764,7 @@ class SamplingParams(
validate_structured_output_request_outlines(self)
elif backend == "lm-format-enforcer":
# lm format enforcer backend
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
@@ -796,7 +796,7 @@ class SamplingParams(
schema = so_params.json
skip_guidance = has_guidance_unsupported_json_features(schema)
if isinstance(tokenizer, MistralTokenizer) or skip_guidance:
if is_mistral_tokenizer(tokenizer) or skip_guidance:
# Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
validate_structured_output_request_outlines(self)
+2
View File
@@ -210,6 +210,8 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
class MistralTokenizer(TokenizerLike):
IS_MISTRAL_TOKENIZER = True # used by vllm.utils.mistral
@classmethod
def from_pretrained(
cls,
+2 -2
View File
@@ -22,10 +22,10 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__)
@@ -34,7 +34,7 @@ class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = tokenizer.tokenizer
+2 -2
View File
@@ -22,9 +22,9 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.utils import extract_intermediate_diff
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__)
@@ -33,7 +33,7 @@ class JambaToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
if is_mistral_tokenizer(self.model_tokenizer):
raise ValueError(
"Detected a MistralTokenizer tokenizer when using a Jamba model"
)
+4 -6
View File
@@ -25,10 +25,10 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__)
@@ -66,9 +66,7 @@ class MistralToolCall(ToolCall):
def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
return not (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11)
class MistralToolParser(ToolParser):
@@ -83,7 +81,7 @@ class MistralToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):
if not is_mistral_tokenizer(self.model_tokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral model...")
# initialize properties used for state when parsing tool calls in
@@ -115,7 +113,7 @@ class MistralToolParser(ToolParser):
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if (
not isinstance(self.model_tokenizer, MistralTokenizer)
not is_mistral_tokenizer(self.model_tokenizer)
and request.tools
and request.tool_choice != "none"
):
+28
View File
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Provides lazy import of the vllm.tokenizers.mistral module."""
from __future__ import annotations
from typing import TYPE_CHECKING, TypeGuard
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
# if type checking, eagerly import the module
import vllm.tokenizers.mistral as mt
else:
mt = LazyLoader("mt", globals(), "vllm.tokenizers.mistral")
def is_mistral_tokenizer(obj: TokenizerLike | None) -> TypeGuard[mt.MistralTokenizer]:
"""Return true if the tokenizer is a MistralTokenizer instance."""
cls = type(obj)
# Check for special class attribute, this avoids importing the class to
# do an isinstance() check. If the attribute is True, do an isinstance
# check to be sure we have the correct type.
return bool(
getattr(cls, "IS_MISTRAL_TOKENIZER", False)
and isinstance(obj, mt.MistralTokenizer)
)
@@ -10,8 +10,8 @@ import torch
import vllm.envs
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.import_utils import LazyLoader
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.v1.structured_output.backend_types import (
StructuredOutputBackend,
StructuredOutputGrammar,
@@ -38,7 +38,7 @@ class XgrammarBackend(StructuredOutputBackend):
self.vllm_config.structured_outputs_config.disable_any_whitespace
)
if isinstance(self.tokenizer, MistralTokenizer):
if is_mistral_tokenizer(self.tokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
stop_token_ids = [self.tokenizer.eos_token_id]