bridge rust tool parser to python

Signed-off-by: Bugen Zhao <i@bugenzhao.com>
This commit is contained in:
Bugen Zhao
2026-06-05 05:57:20 +00:00
parent 9729e05917
commit 53275a22d6
14 changed files with 939 additions and 1643 deletions
+86
View File
@@ -3458,6 +3458,75 @@ version = "0.1.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f"
[[package]]
name = "pyo3"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12"
dependencies = [
"libc",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
]
[[package]]
name = "pyo3-build-config"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e"
dependencies = [
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.117",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.28.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.117",
]
[[package]]
name = "pythonize"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b79f670c9626c8b651c0581011b57b6ba6970bb69faf01a7c4c0cfc81c43f95"
dependencies = [
"pyo3",
"serde",
"serde_json",
]
[[package]]
name = "qoi"
version = "0.4.1"
@@ -4669,6 +4738,12 @@ dependencies = [
"libc",
]
[[package]]
name = "target-lexicon"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca"
[[package]]
name = "task-local"
version = "0.1.1"
@@ -5901,6 +5976,17 @@ dependencies = [
"winnow",
]
[[package]]
name = "vllm-tool-parser-py"
version = "0.1.0"
dependencies = [
"pyo3",
"pythonize",
"serde_json",
"thiserror-ext",
"vllm-tool-parser",
]
[[package]]
name = "walkdir"
version = "2.5.0"
+3
View File
@@ -12,6 +12,7 @@ members = [
"src/text",
"src/tokenizer",
"src/tool-parser",
"src/tool-parser/python",
]
resolver = "3"
@@ -60,6 +61,8 @@ prometheus-client = "0.24.0"
prometheus-client-derive-encode = "0.5.0"
prost = "0.14.3"
prost-types = "0.14.3"
pyo3 = "0.28.3"
pythonize = "0.28.0"
rand = "0.9.2"
reasoning-parser = "1.2.2"
reqwest = { version = "0.12.8", default-features = false, features = ["rustls-tls"] }
+23
View File
@@ -0,0 +1,23 @@
[package]
name = "vllm-tool-parser-py"
version.workspace = true
edition.workspace = true
license.workspace = true
[lib]
name = "_rust_tool_parser"
crate-type = ["cdylib", "rlib"]
[features]
default = []
extension-module = ["pyo3/extension-module"]
[dependencies]
pyo3.workspace = true
pythonize = { workspace = true, features = ["serde_json"] }
serde_json.workspace = true
thiserror-ext.workspace = true
vllm-tool-parser.workspace = true
[lints]
workspace = true
+400
View File
@@ -0,0 +1,400 @@
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyModule};
use pythonize::{depythonize, pythonize};
use serde_json::Value;
use thiserror_ext::AsReport as _;
use vllm_tool_parser::{
DeepSeekV3ToolParser, DeepSeekV4ToolParser, DeepSeekV31ToolParser, DeepSeekV32ToolParser,
Gemma4ToolParser, Glm45MoeToolParser, Glm47MoeToolParser, HermesToolParser, HyV3ToolParser,
KimiK2ToolParser, Llama3JsonToolParser, MinimaxM2ToolParser, MinimaxM3ToolParser,
MistralToolParser, Qwen3CoderToolParser, Qwen3XmlToolParser, Tool, ToolCallDelta, ToolParser,
ToolParserOutput,
};
#[pyclass(name = "Tool", module = "vllm._rust_tool_parser", skip_from_py_object)]
#[derive(Clone)]
struct PyTool {
inner: Tool,
}
#[pymethods]
impl PyTool {
#[new]
#[pyo3(signature = (name, description, parameters, strict=None))]
fn new(
name: String,
description: Option<String>,
parameters: &Bound<'_, PyAny>,
strict: Option<bool>,
) -> PyResult<Self> {
let parameters = depythonize::<Value>(parameters).map_err(|error| {
PyValueError::new_err(format!(
"failed to convert tool parameters from Python to JSON: {error}"
))
})?;
Ok(Self {
inner: Tool {
name,
description,
parameters,
strict,
},
})
}
#[getter]
fn name(&self) -> &str {
&self.inner.name
}
#[getter]
fn description(&self) -> Option<&str> {
self.inner.description.as_deref()
}
#[getter]
fn parameters(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
pythonize(py, &self.inner.parameters).map(Bound::unbind).map_err(|error| {
PyValueError::new_err(format!(
"failed to convert tool parameters from JSON to Python: {error}"
))
})
}
#[getter]
fn strict(&self) -> Option<bool> {
self.inner.strict
}
}
#[pyclass(
name = "ToolCallDelta",
module = "vllm._rust_tool_parser",
skip_from_py_object
)]
#[derive(Clone)]
struct PyToolCallDelta {
inner: ToolCallDelta,
}
#[pymethods]
impl PyToolCallDelta {
#[new]
#[pyo3(signature = (tool_index, name, arguments))]
fn new(tool_index: usize, name: Option<String>, arguments: String) -> Self {
Self {
inner: ToolCallDelta {
tool_index,
name,
arguments,
},
}
}
#[getter]
fn tool_index(&self) -> usize {
self.inner.tool_index
}
#[getter]
fn name(&self) -> Option<&str> {
self.inner.name.as_deref()
}
#[getter]
fn arguments(&self) -> &str {
&self.inner.arguments
}
}
#[pyclass(
name = "ToolParserOutput",
module = "vllm._rust_tool_parser",
skip_from_py_object
)]
#[derive(Clone)]
struct PyToolParserOutput {
inner: ToolParserOutput,
}
impl PyToolParserOutput {
fn from_inner(inner: ToolParserOutput) -> Self {
Self { inner }
}
}
#[pymethods]
impl PyToolParserOutput {
#[new]
#[pyo3(signature = (normal_text="", calls=None))]
fn new(py: Python<'_>, normal_text: &str, calls: Option<Vec<Py<PyToolCallDelta>>>) -> Self {
let calls = calls
.unwrap_or_default()
.iter()
.map(|call| call.borrow(py).inner.clone())
.collect();
Self {
inner: ToolParserOutput {
normal_text: normal_text.to_owned(),
calls,
},
}
}
#[getter]
fn normal_text(&self) -> &str {
&self.inner.normal_text
}
#[getter]
fn calls(&self) -> Vec<PyToolCallDelta> {
self.inner
.calls
.iter()
.cloned()
.map(|inner| PyToolCallDelta { inner })
.collect()
}
fn append(&mut self, other: PyRef<'_, PyToolParserOutput>) {
self.inner.append(other.inner.clone());
}
fn coalesce_calls(&self) -> Self {
Self::from_inner(self.inner.clone().coalesce_calls())
}
}
#[pyclass(name = "ToolParser", module = "vllm._rust_tool_parser", unsendable)]
struct PyToolParser {
parser: Box<dyn ToolParser>,
}
impl PyToolParser {
fn from_name(name: &str, tools: &[Tool]) -> PyResult<Self> {
let parser = match name {
"deepseek_v3" => DeepSeekV3ToolParser::create(tools),
"deepseek_v31" => DeepSeekV31ToolParser::create(tools),
"deepseek_v32" => DeepSeekV32ToolParser::create(tools),
"deepseek_v4" => DeepSeekV4ToolParser::create(tools),
"gemma4" => Gemma4ToolParser::create(tools),
"glm45" => Glm45MoeToolParser::create(tools),
"glm47" => Glm47MoeToolParser::create(tools),
"hermes" => HermesToolParser::create(tools),
"hy_v3" => HyV3ToolParser::create(tools),
"kimi_k2" => KimiK2ToolParser::create(tools),
"llama3_json" | "llama4_json" => Llama3JsonToolParser::create(tools),
"minimax_m2" => MinimaxM2ToolParser::create(tools),
"minimax_m3" => MinimaxM3ToolParser::create(tools),
"mistral" => MistralToolParser::create(tools),
"qwen3_xml" => Qwen3XmlToolParser::create(tools),
"qwen3_coder" => Qwen3CoderToolParser::create(tools),
_ => {
return Err(PyValueError::new_err(format!(
"unsupported tool parser `{name}`"
)));
}
}
.map_err(|error| PyValueError::new_err(error.to_report_string()))?;
Ok(Self { parser })
}
fn parse_into_output(&mut self, chunk: &str, output: &mut PyToolParserOutput) -> PyResult<()> {
self.parser
.parse_into(chunk, &mut output.inner)
.map_err(|error| PyValueError::new_err(error.to_report_string()))
}
}
#[pymethods]
impl PyToolParser {
#[new]
fn new(py: Python<'_>, parser_name: &str, tools: Vec<Py<PyTool>>) -> PyResult<Self> {
let tools = tools.iter().map(|tool| tool.borrow(py).inner.clone()).collect::<Vec<_>>();
Self::from_name(parser_name, &tools)
}
fn parse_into(
&mut self,
chunk: &str,
mut output: PyRefMut<'_, PyToolParserOutput>,
) -> PyResult<()> {
self.parse_into_output(chunk, &mut output)
}
fn finish(&mut self) -> PyResult<PyToolParserOutput> {
self.parser
.finish()
.map(PyToolParserOutput::from_inner)
.map_err(|error| PyValueError::new_err(error.to_report_string()))
}
fn reset(&mut self) -> String {
self.parser.reset()
}
fn preserve_special_tokens(&self) -> bool {
self.parser.preserve_special_tokens()
}
}
#[pymodule]
fn _rust_tool_parser(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyTool>()?;
m.add_class::<PyToolCallDelta>()?;
m.add_class::<PyToolParserOutput>()?;
m.add_class::<PyToolParser>()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
const NS: &str = "]<]minimax[>[";
fn with_python<R>(f: impl for<'py> FnOnce(Python<'py>) -> R) -> R {
Python::initialize();
Python::attach(f)
}
fn tool_schema() -> Value {
json!({
"type": "object",
"properties": {
"user_id": {"type": "integer"},
"shipping": {
"type": "object",
"properties": {
"city": {"type": "string"},
"zip": {"type": "integer"}
}
}
}
})
}
fn build_call() -> String {
format!(
"{NS}<tool_call>\n\
{NS}<invoke name=\"create_order\">\
{NS}<user_id>42{NS}</user_id>\
{NS}<shipping>\
{NS}<city>Singapore{NS}</city>\
{NS}<zip>018956{NS}</zip>\
{NS}</shipping>\
{NS}</invoke>\n\
{NS}</tool_call>"
)
}
fn make_py_tool(py: Python<'_>) -> PyResult<Py<PyTool>> {
let parameters = pythonize(py, &tool_schema()).map_err(|error| {
PyValueError::new_err(format!(
"failed to convert test schema from JSON to Python: {error}"
))
})?;
Py::new(
py,
PyTool::new(
"create_order".to_owned(),
Some("Create an order".to_owned()),
&parameters,
None,
)?,
)
}
#[test]
fn tool_round_trips_typed_fields() {
with_python(|py| {
let tool = make_py_tool(py)?;
let borrowed = tool.borrow(py);
assert_eq!(borrowed.name(), "create_order");
assert_eq!(borrowed.description(), Some("Create an order"));
assert_eq!(borrowed.strict(), None);
let parameters = borrowed.parameters(py)?;
let parameters = depythonize::<Value>(parameters.bind(py))?;
assert_eq!(parameters, tool_schema());
PyResult::Ok(())
})
.unwrap();
}
#[test]
fn output_append_and_coalesce_calls() {
with_python(|py| {
let first = Py::new(
py,
PyToolCallDelta::new(0, Some("create_order".to_owned()), "{\"a\"".to_owned()),
)?;
let second = Py::new(py, PyToolCallDelta::new(0, None, ":1}".to_owned()))?;
let mut output = PyToolParserOutput::new(py, "text", Some(vec![first]));
let other = Py::new(py, PyToolParserOutput::new(py, "", Some(vec![second])))?;
output.append(other.borrow(py));
let coalesced = output.coalesce_calls();
assert_eq!(coalesced.normal_text(), "text");
let calls = coalesced.calls();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].tool_index(), 0);
assert_eq!(calls[0].name(), Some("create_order"));
assert_eq!(calls[0].arguments(), "{\"a\":1}");
PyResult::Ok(())
})
.unwrap();
}
#[test]
fn parser_parse_finish_and_preserve_special_tokens() {
with_python(|py| {
let tool = make_py_tool(py)?;
let mut parser = PyToolParser::new(py, "minimax_m3", vec![tool])?;
assert!(!parser.preserve_special_tokens());
let mut output = PyToolParserOutput::new(py, "", None);
parser.parse_into_output(&build_call(), &mut output)?;
let finish = Py::new(py, parser.finish()?)?;
output.append(finish.borrow(py));
let output = output.coalesce_calls();
assert_eq!(output.normal_text(), "");
let calls = output.calls();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name(), Some("create_order"));
assert_eq!(
serde_json::from_str::<Value>(calls[0].arguments()).unwrap(),
json!({
"user_id": 42,
"shipping": {
"city": "Singapore",
"zip": 18956
}
})
);
assert_eq!(parser.reset(), "");
PyResult::Ok(())
})
.unwrap();
}
#[test]
fn parser_errors_for_unknown_name() {
with_python(|py| {
let tool = make_py_tool(py)?;
let error = match PyToolParser::new(py, "missing", vec![tool]) {
Ok(_) => panic!("missing parser name unexpectedly succeeded"),
Err(error) => error,
};
let message = format!("{error}");
assert!(message.contains("unsupported tool parser `missing`"));
PyResult::Ok(())
})
.unwrap();
}
}
+91 -21
View File
@@ -36,6 +36,8 @@ ROOT_DIR = Path(__file__).parent
logger = logging.getLogger(__name__)
PRECOMPILED_RUST_FRONTEND_PATH = ROOT_DIR / "vllm" / "vllm-rs"
PRECOMPILED_RUST_EXTENSION_GLOB = "_rust_*.so"
PRECOMPILED_RUST_EXTENSION_MEMBER_REGEX = re.compile(r"vllm/_rust_[^/]*\.so$")
# cannot import envs directly because it depends on vllm,
# which is not installed yet
@@ -54,6 +56,59 @@ def should_require_rust_frontend() -> bool:
return value.lower() not in ("", "0", "false", "no")
# Rust frontend binary, built via setuptools-rust and installed into the
# package directory alongside the Python modules.
# TODO: we may use `RustBin` to directly install it into `bin` directory, but this
# requires extra work on using precompiled binaries.
rust_extensions = [
RustExtension(
target="vllm.vllm-rs",
path="rust/src/cmd/Cargo.toml",
args=["--bin", "vllm-rs"],
features=["native-tls-vendored"],
binding=Binding.Exec,
optional=not should_require_rust_frontend(),
),
RustExtension(
target="vllm._rust_tool_parser",
path="rust/src/tool-parser/python/Cargo.toml",
features=["extension-module"],
binding=Binding.PyO3,
optional=not should_require_rust_frontend(),
),
]
def get_precompiled_rust_extension_paths() -> list[Path]:
return sorted((ROOT_DIR / "vllm").glob(PRECOMPILED_RUST_EXTENSION_GLOB))
def get_expected_rust_extension_module_names() -> list[str]:
"""Return configured PyO3 Rust extension module names under ``vllm``."""
module_names = []
for rust_extension in rust_extensions:
if rust_extension.binding != Binding.PyO3:
continue
for target_name in rust_extension.target.values():
if target_name.startswith("vllm._rust_"):
module_names.append(target_name.rsplit(".", 1)[-1])
return module_names
def get_missing_precompiled_rust_extension_modules() -> list[str]:
missing = []
for module_name in get_expected_rust_extension_module_names():
if not list((ROOT_DIR / "vllm").glob(f"{module_name}*.so")):
missing.append(module_name)
return missing
def has_precompiled_rust_extensions() -> bool:
return not get_missing_precompiled_rust_extension_modules()
if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
VLLM_TARGET_DEVICE = "cpu"
@@ -421,19 +476,33 @@ class precompiled_build_ext(build_ext):
class precompiled_build_rust(build_rust):
"""Skips local Rust builds when the precompiled wheel already ships vllm-rs."""
"""Skips local Rust builds when all precompiled Rust artifacts are present."""
def run(self) -> None:
if PRECOMPILED_RUST_FRONTEND_PATH.exists():
if (
PRECOMPILED_RUST_FRONTEND_PATH.exists()
and has_precompiled_rust_extensions()
):
logger.info(
"Skipping local Rust build: using precompiled %s",
"Skipping local Rust build: using precompiled %s and %s",
PRECOMPILED_RUST_FRONTEND_PATH,
get_precompiled_rust_extension_paths(),
)
return
missing = []
if not PRECOMPILED_RUST_FRONTEND_PATH.exists():
missing.append(str(PRECOMPILED_RUST_FRONTEND_PATH))
missing_rust_extensions = get_missing_precompiled_rust_extension_modules()
if missing_rust_extensions:
missing.extend(
str(ROOT_DIR / "vllm" / f"{module_name}*.so")
for module_name in missing_rust_extensions
)
logger.warning(
"Precompiled wheel did not provide %s; falling back to local Rust build.",
PRECOMPILED_RUST_FRONTEND_PATH,
"Precompiled wheel did not provide all Rust artifacts (%s); "
"falling back to local Rust build.",
", ".join(missing),
)
super().run()
@@ -756,6 +825,14 @@ class precompiled_wheel_utils:
if member.filename in exact_members:
file_members.append(member)
continue
if (
extract_rust_frontend
and PRECOMPILED_RUST_EXTENSION_MEMBER_REGEX.match(
member.filename
)
):
file_members.append(member)
continue
if not extract_extensions:
continue
@@ -1127,6 +1204,10 @@ if PRECOMPILED_RUST_FRONTEND_PATH.exists():
vllm_files = package_data.setdefault("vllm", [])
if "vllm-rs" not in vllm_files:
vllm_files.append("vllm-rs")
vllm_files = package_data.setdefault("vllm", [])
for rust_extension_path in get_precompiled_rust_extension_paths():
if rust_extension_path.name not in vllm_files:
vllm_files.append(rust_extension_path.name)
if _no_device():
ext_modules = []
@@ -1139,24 +1220,13 @@ else:
if USE_PRECOMPILED_EXTENSIONS
else cmake_build_ext,
}
if USE_PRECOMPILED_RUST_FRONTEND or PRECOMPILED_RUST_FRONTEND_PATH.exists():
if (
USE_PRECOMPILED_RUST_FRONTEND
or PRECOMPILED_RUST_FRONTEND_PATH.exists()
or has_precompiled_rust_extensions()
):
cmdclass["build_rust"] = precompiled_build_rust
# Rust frontend binary, built via setuptools-rust and installed into the
# package directory alongside the Python modules.
# TODO: we may use `RustBin` to directly install it into `bin` directory, but this
# requires extra work on using precompiled binaries.
rust_extensions = [
RustExtension(
target="vllm.vllm-rs",
path="rust/src/cmd/Cargo.toml",
args=["--bin", "vllm-rs"],
features=["native-tls-vendored"],
binding=Binding.Exec,
optional=not should_require_rust_frontend(),
),
]
setup(
# static metadata should rather go in pyproject.toml
version=get_vllm_version(),
@@ -257,5 +257,5 @@ def test_streaming_nested_tool_call(parser):
assert json.loads(tool_calls[0]["arguments"]) == json.loads(
parser.streamed_args_for_tool[0]
)
assert parser.prev_tool_call_arr[0]["arguments"]["items"][1]["qty"] == 5
assert json.loads(parser.prev_tool_call_arr[0]["arguments"])["items"][1]["qty"] == 5
assert results[-1].content == ""
-3
View File
@@ -1,3 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# mypy: ignore-errors
-73
View File
@@ -1,73 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# mypy: ignore-errors
# ruff: noqa
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar
from typing_extensions import Self
ParserInputs = TypeVar("ParserInputs", bound=Sequence)
ParserOutput = TypeVar("ParserOutput")
TextFn = Callable[[], str]
FunctionCallDict = Dict[Literal["name", "arguments"], str]
class Parser(ABC, Generic[ParserInputs, ParserOutput]):
@abstractmethod
def update(self, inputs: ParserInputs):
pass
def parse(self, inputs: ParserInputs) -> ParserOutput:
self.update(inputs)
final = self.get_final()
assert final is not None
return final
@abstractmethod
def get_delta(self) -> Optional[ParserOutput]:
pass
@abstractmethod
def get_final(self) -> Optional[ParserOutput]:
pass
@abstractmethod
def reset(self):
pass
def stop(self):
pass
def _get_sub_parsers(self) -> Optional[Dict[str, Self]]:
"""should only be used for debugging."""
pass
def stringify_function_calls(self, function_calls: List[FunctionCallDict]) -> str:
raise NotImplementedError()
class PatternMismatched(Exception):
def __init__(
self,
*,
offset: int,
expected: Any,
actual: Any,
reason: str,
context_fn: Optional[TextFn] = None,
):
self.offset = offset
self.expected = expected
self.actual = actual
self.reason = reason
self.context_fn = context_fn
def __str__(self):
main_message = f"at offset {self.offset}, {self.reason}, expected {self.expected!r} but got {self.actual!r}"
if self.context_fn is None:
return main_message
else:
return f"around context:\n{self.context_fn()}\n\n{main_message}"
-340
View File
@@ -1,340 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# mypy: ignore-errors
# ruff: noqa
import json
import math
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Iterator, List, Optional, Set
import jsonschema
import jsonschema.exceptions
from typing_extensions import Self
# fmt: off
NULL_STRINGS = (
"null", "Null", "NULL",
# It's bad, especially when we need enum string `none` etc.
# "none", "null", "nil",
# "None", "Null", "Nil",
# "NONE", "NULL", "NIL",
)
# fmt: on
@dataclass(slots=True)
class ConversionResult:
value: Any
confidence: int
class AtomDataType(int, Enum):
none = auto()
string = auto()
integer = auto()
float = auto()
boolean = auto()
# complex types:
array = auto()
object = auto()
other = auto()
@classmethod
def from_string(cls, string: Optional[str]) -> Self:
if string is None:
return cls.string
string = string.lower()
if string == "none" or string == "null" or string == "nil":
return cls.none
elif string == "str" or string == "string" or string == "text":
return cls.string
elif string == "int" or string == "integer":
return cls.integer
elif string == "float" or string == "number":
return cls.float
elif string == "bool" or string == "boolean":
return cls.boolean
elif string == "array" or string == "list":
return cls.array
elif (
string == "object"
or string == "dict"
or string == "dictionary"
or string == "map"
):
return cls.object
else:
return cls.other
def is_complex_type(self) -> bool:
return self == self.array or self == self.object or self == self.other
@classmethod
def from_example(cls, example: Any) -> Self:
if example is None:
return cls.none
elif isinstance(example, str):
return cls.string
# NOTE: Order matters. True is instance of int.
elif isinstance(example, bool):
return cls.boolean
elif isinstance(example, int):
return cls.integer
elif isinstance(example, float):
return cls.float
elif isinstance(example, list):
return cls.array
elif isinstance(example, dict):
return cls.object
else:
return cls.other
@classmethod
def iter_candidates_from_schema(cls, input: Any) -> Iterator[Self]:
if isinstance(input, dict):
if "type" in input:
type_value = input["type"]
if isinstance(type_value, str):
yield cls.from_string(type_value)
elif isinstance(type_value, list):
for each in type_value:
yield cls.from_string(each)
else:
yield cls.other
elif ("properties" in input and isinstance(input["properties"], dict)) or (
"additionalProperties" in input
and input["additionalProperties"] is not False
):
yield cls.object
elif "items" in input or "prefixItems" in input:
yield cls.array
else:
nothing_found = True
if "const" in input:
nothing_found = False
yield cls.from_example(input["const"])
elif (
"enum" in input
and isinstance(input["enum"], list)
and input["enum"]
):
for each in input["enum"]:
nothing_found = False
yield cls.from_example(each)
else:
for choice_field in ("anyOf", "oneOf", "allOf"):
if choice_field in input:
choices = input[choice_field]
if isinstance(choices, list):
for choice in choices:
for each in cls.iter_candidates_from_schema(choice):
nothing_found = False
yield each
if nothing_found:
yield cls.string
else:
yield cls.string
def convert_with_confidence(self, string: str) -> Optional[ConversionResult]:
if self == AtomDataType.none:
stripped = string.strip()
if stripped:
if len(stripped) == 4 and stripped.lower() == "null":
return ConversionResult(value=None, confidence=10)
return ConversionResult(value=None, confidence=1)
else:
return ConversionResult(value=None, confidence=3)
elif self == AtomDataType.string:
return ConversionResult(value=string, confidence=2)
elif self == AtomDataType.integer:
try:
return ConversionResult(value=int(string), confidence=10)
except ValueError:
return None
elif self == AtomDataType.float:
try:
value = float(string)
except ValueError:
return None
# NOTE: Python is evil.
# In Python, we have
# json.dumps(math.nan) -> "NaN"
# json.dumps(math.inf) -> "Infinity"
# json.dumps(-math.inf) -> "-Infinity"
# However, in other languages, you cannot deserialize them back.
# No perfect solution here.
if math.isnan(value):
value = "NaN"
elif math.isinf(value):
value = "Infinity" if value > 0 else "-Infinity"
elif value.is_integer():
value = int(value)
return ConversionResult(value=value, confidence=10)
elif self == AtomDataType.boolean:
stripped = string.strip()
if len(stripped) > 5:
return ConversionResult(value=False, confidence=1)
lower = stripped.lower()
if lower == "true":
return ConversionResult(value=True, confidence=10)
elif lower == "false":
return ConversionResult(value=False, confidence=10)
elif lower == "yes" or lower == "on" or lower == "1":
return ConversionResult(value=True, confidence=9)
elif lower == "no" or lower == "off" or lower == "0":
return ConversionResult(value=False, confidence=9)
else:
return ConversionResult(value=False, confidence=1)
else:
try:
return ConversionResult(value=json.loads(string), confidence=10)
except json.JSONDecodeError:
return None
@dataclass(slots=True)
class FunctionCallParameterDataType:
candidates: List[AtomDataType]
schema: Any
streaming: bool
@classmethod
def get_schema_of_parameter(cls, schema: Any, parameter_name: str) -> Self:
"""
$ref etc. not supported.
"""
if not isinstance(schema, dict):
return cls(candidates=[AtomDataType.string], schema=None, streaming=True)
property = None
parameters = schema.get("parameters")
if isinstance(parameters, dict):
properties = parameters.get("properties")
if isinstance(properties, dict):
property = properties.get(parameter_name)
return cls.from_property(property)
@classmethod
def from_property(cls, property: Any) -> Self:
candidate_set: Set[AtomDataType] = set()
candidates: List[AtomDataType] = []
# Guaranteed to be non-empty.
for each in AtomDataType.iter_candidates_from_schema(property):
if each not in candidate_set:
candidate_set.add(each)
candidates.append(each)
# The parameter is streaming if and only if there is only one candidate, string.
# No matter how complicated the schema is, if it can only be a string,
streaming = len(candidates) == 1 and candidates[0] == AtomDataType.string
return cls(candidates=candidates, schema=property, streaming=streaming)
def get_data_type_of_property(self, key: str) -> Optional[Self]:
if isinstance(self.schema, dict):
property = self.get_property(self.schema, key)
if property:
return self.from_property(property)
property_choices = []
for choice_field in ("anyOf", "oneOf", "allOf"):
if choice_field in self.schema:
choices = self.schema[choice_field]
if isinstance(choices, list):
for choice in choices:
property = self.get_property(choice, key)
if property:
property_choices.append(property)
if property_choices:
return self.from_property(
{
"anyOf": property_choices,
}
)
@staticmethod
def get_property(schema: Any, key: str) -> Any:
if "properties" in schema:
properties = schema["properties"]
if isinstance(properties, dict):
property = properties.get(key)
if property:
return property
if "additionalProperties" in schema:
additional_properties = schema["additionalProperties"]
if isinstance(additional_properties, dict):
return additional_properties
def get_data_type_of_item(self, *, index: int = 0) -> Optional[Self]:
if isinstance(self.schema, dict):
item = self.get_item(self.schema, index=index)
if item:
return self.from_property(item)
item_choices = []
for choice_field in ("anyOf", "oneOf", "allOf"):
if choice_field in self.schema:
choices = self.schema[choice_field]
if isinstance(choices, list):
for choice in choices:
item = self.get_item(choice, index=index)
if item:
item_choices.append(item)
if item_choices:
return self.from_property(
{
"anyOf": item_choices,
}
)
@staticmethod
def get_item(schema: Any, *, index: int = 0) -> Any:
items = schema.get("items")
if items:
return items
if "prefixItems" in schema:
prefix_items = schema["prefixItems"]
additional_items = schema.get("additionalItems")
if isinstance(prefix_items, list):
if index < len(prefix_items):
return prefix_items[index]
elif additional_items:
return additional_items
elif len(prefix_items):
return prefix_items[-1]
def convert(self, string: str, *, always_nullable: bool = False) -> Any:
if always_nullable:
stripped = string.strip()
if len(stripped) == 4 and stripped.lower() == "null":
return None
if not string:
if AtomDataType.object in self.candidates:
return {}
elif AtomDataType.array in self.candidates:
return []
elif AtomDataType.none in self.candidates:
return None
else:
return ""
converted_list: List[ConversionResult] = []
has_schema = bool(self.schema)
for candidate in self.candidates:
converted = candidate.convert_with_confidence(string)
if converted is None:
continue
if has_schema:
try:
jsonschema.validate(schema=self.schema, instance=converted.value)
# Ensure it beats the invalid ones.
converted.confidence += 10
except jsonschema.exceptions.SchemaError:
# The schema is invalid, ignore it.
has_schema = False
except Exception:
# Validation failed.
pass
converted_list.append(converted)
if len(converted_list) == 1:
return converted_list[0].value
elif len(converted_list):
return max(converted_list, key=lambda x: x.confidence).value
else:
return string
-9
View File
@@ -1,9 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# mypy: ignore-errors
# ruff: noqa
from typing import Callable, Optional, Tuple
IsPendingAndText = Tuple[bool, str]
CountConsumedTokensFn = Callable[[Optional[IsPendingAndText]], int]
-463
View File
@@ -1,463 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# mypy: ignore-errors
# ruff: noqa
import functools
import json
from abc import ABC, abstractmethod
from collections import deque
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
Any,
Callable,
Deque,
Dict,
Generator,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)
from typing_extensions import Self
from .data_type import NULL_STRINGS, FunctionCallParameterDataType
from .base import Parser, ParserOutput, PatternMismatched
from .decoder import CountConsumedTokensFn
OutputKey = Union[str, Callable[[Any], Dict]]
class GeneratorParser(Parser[str, dict], ABC):
"""
TODO:
1. 改成接 StrOrSpecialToken.
2. 报错时支持 row:col.
3. 需要 self.generator.close() 吗?
"""
@abstractmethod
def _process(self) -> Generator[ParserOutput, None, None]:
pass
def __init__(self):
self._count_consumed_tokens_fn: Optional[CountConsumedTokensFn] = None
self._full_text: List[str] = []
self._input_length = 0
self._input_buffer: Deque[str] = deque()
self._delta: Optional[Dict[str, Any]] = None
self._final: Optional[Dict[str, Any]] = None
self._generator = self._process()
next(self._generator)
def __del__(self):
# Necessary when GC is disabled (gc.disable()),
# otherwise the cyclic reference will never be collected.
if hasattr(self, "_generator"):
self._generator.close()
del self._generator
def reset(self):
self._full_text.clear()
self._input_length = 0
self._input_buffer.clear()
self._delta = None
self._final = None
# don't worry, a closed generator can be closed again.
self._generator.close()
self._generator = self._process()
next(self._generator)
def update(self, input: str):
self._full_text.append(input)
total = len(input)
self._input_length += total
self._input_buffer.extend(input)
# at most #input times,
# but early stop if the input buffer is already empty.
for _ in range(total):
if not self._input_buffer:
break
try:
next(self._generator)
except StopIteration:
if self._input_buffer:
raise self._error(
expected=None,
actual=self._input_buffer[0],
reason="pattern exhausted",
)
def _count_consumed_tokens(self) -> Optional[int]:
if self._count_consumed_tokens_fn is None:
return
total_tokens_count = self._count_consumed_tokens_fn(None)
if self._input_buffer:
pending_text_count = len(self._input_buffer)
consumed_text_count = self._input_length - pending_text_count
# If total is 0, we cannot use `consumed = total - pending`.
# Have to tokenize the consumed text, maybe again.
is_pending = (total_tokens_count > 0) and (
pending_text_count <= consumed_text_count
)
if is_pending:
text = "".join(self._input_buffer)
else:
text = "".join(self._full_text)[:consumed_text_count]
return self._count_consumed_tokens_fn((is_pending, text))
elif total_tokens_count:
return total_tokens_count
else:
# Only for non-streaming mode, should be called only once.
# In streaming mode, we should feed in tokens instead of text,
# thus `total_tokens_count` must be positive.
text = "".join(self._full_text)
return self._count_consumed_tokens_fn((False, text))
def get_delta(self) -> Optional[Dict[str, Any]]:
if self._delta:
delta = _get_dict_with_joint_strings(self._delta)
self._delta = None
return delta
def get_final(self) -> Optional[Dict[str, Any]]:
return _get_dict_with_joint_strings(self._final)
def _peek(self, index: int) -> Generator[None, None, str]:
while index >= len(self._input_buffer):
yield
return self._input_buffer[index]
def _read_one(self) -> str:
# intentionally unchecked, will raise IndexError if the input buffer is empty.
return self._input_buffer.popleft()
def _consume(self, length: int):
for _ in range(length):
self._input_buffer.popleft()
def _try(self, fn, *args, **kwargs):
try:
result = yield from fn(*args, **kwargs)
return Tried(successful=True, result=result)
except PatternMismatched as e:
# Clear traceback to break traceback → frame → f_locals → self
# circular reference chain. The traceback is not needed for
# backtracking logic, only offset/expected/actual matter.
e.__traceback__ = None
return Tried(successful=False, error=e)
def _append(self, key: OutputKey, value: Any):
if isinstance(key, str):
delta = {key: value}
else:
delta = key(value)
self._append_delta(delta)
def _append_delta(self, delta: Dict[str, Any]):
self._delta = _merge_dicts(delta, self._delta)
self._final = _merge_dicts(delta, self._final)
def _get_context(self, offset: int) -> str:
# NOTE: `self._error` is frequently triggered with backtracking,
# thus we cannot afford to join the full text every time.
# A partial function should work, which is only called when the error is stringified.
full_text = "".join(self._full_text)
return f"{full_text[:offset]}💥{full_text[offset:]}"
def _error(self, *, expected: Any, actual: Any, reason: str) -> PatternMismatched:
offset = self._input_length - len(self._input_buffer)
# Use weakref to avoid context_fn → partial → self circular reference.
# _full_text is a plain list, capturing it directly avoids holding self.
full_text_ref = self._full_text
return PatternMismatched(
offset=offset,
expected=expected,
actual=actual,
reason=reason,
context_fn=functools.partial(
_get_context_standalone, full_text_ref, offset
),
)
def _take_any(
self,
*,
until: Optional[Union[str, Tuple[str, ...]]] = None,
key: Optional[OutputKey] = None,
should_consume_suffix: bool = True,
) -> Generator[str, None, None]:
if until is None:
# Never ends without `until`, thus no need to collect the values.
while True:
yield from self._peek(0)
self._append(key, self._read_one())
else:
values = []
while True:
tried = yield from self._literal(
until,
should_consume=should_consume_suffix,
should_raise=False,
)
if tried is not None:
break
value = self._read_one()
values.append(value)
if key is not None:
self._append(key, value)
return "".join(values)
def _take_data_type_as_json(
self,
*,
until: Union[str, Tuple[str, ...]],
key: OutputKey,
data_type: FunctionCallParameterDataType,
always_nullable: bool = False,
should_consume_suffix: bool = True,
):
"""
Unlike `.take_any`, `until` and `key` is required here, because
- it is unnecessary and a little bit annoying to maintain the streaming state
for most data types other than `str`.
- if you don't need the output, just use `.take_any`.
NOTE: if there are various acceptable data types, we choose the FIRST convertible
one, indicating, if the first choice is `str`, the following choices will be
IGNORED, as everything could be a string.
TODO: if there are acceptable data types before `str`, we should peek until they
are excluded and then choose `str`, instead of choosing at the end.
"""
if not data_type.streaming:
value = yield from self._take_any(
until=until, should_consume_suffix=should_consume_suffix
)
value = data_type.convert(value, always_nullable=always_nullable)
self._append(key, json_dumps(value))
return value
else:
if always_nullable:
# very tricky.
# even if the data type is string, we still need to check if it's null.
tried = yield from self._literal(
_string_cartesian_product(NULL_STRINGS, until),
should_consume=should_consume_suffix,
should_raise=False,
)
if tried is not None:
if not should_consume_suffix:
# NOTE: length of "null" (case-insensitive) is 4.
self._consume(4)
self._append(key, "null")
return None
values = []
self._append(key, '"')
while True:
tried = yield from self._literal(
until, should_consume=should_consume_suffix, should_raise=False
)
if tried is not None:
break
value = self._read_one()
values.append(value)
self._append(key, json_dumps(value)[1:-1])
self._append(key, '"')
return "".join(values)
def _literal(
self,
target: Union[str, Tuple[str, ...]],
*,
should_raise: bool = True,
should_consume: bool = True,
) -> Generator[Optional[str], None, None]:
"""
NOTE: stops at the first matched target, even if it is a substring of another later target.
"""
if isinstance(target, str):
target = (target,)
char_index = 0
target_index = 0
target_index_max = len(target) - 1
while True:
ith = yield from self._peek(char_index)
if ith != target[target_index][char_index]:
already_matched_and_ith = target[target_index][:char_index] + ith
prefix_matching_result = PrefixMatchingResult.mismatch
target_index += 1
while target_index <= target_index_max:
prefix_matching_result = PrefixMatchingResult.check(
target[target_index], already_matched_and_ith
)
if prefix_matching_result != PrefixMatchingResult.mismatch:
break
target_index += 1
if prefix_matching_result == PrefixMatchingResult.substring_of_prefix:
break
elif prefix_matching_result == PrefixMatchingResult.mismatch:
if should_raise:
if len(target) == 1:
raise self._error(
expected=target[0][char_index:],
actual=ith,
reason=f"matching {target[0]!r}",
)
else:
raise self._error(
expected=target,
actual=ith,
reason="matching any of the list",
)
else:
return
char_index += 1
if char_index == len(target[target_index]):
break
if should_consume:
self._consume(len(target[target_index]))
return target[target_index]
def _whitespace(self) -> Iterator[None]:
total = 0
while True:
peek = yield from self._peek(total)
if peek.isspace():
total += 1
else:
self._consume(total)
break
@dataclass(slots=True)
class Tried:
successful: bool
result: Any = None
error: Optional[PatternMismatched] = None
def default_tool_call_output_key(tool_call_index: int, function: Dict) -> Dict:
return {
"tool_calls": [
{
"index": tool_call_index,
"function": function,
}
]
}
def _get_context_standalone(full_text_parts: List[str], offset: int) -> str:
"""Standalone version of _get_context that doesn't capture self."""
full_text = "".join(full_text_parts)
return f"{full_text[:offset]}💥{full_text[offset:]}"
def _merge_dicts(
src: Dict[str, Any],
dst: Optional[Dict[str, Any]],
whitelist: Optional[Set[str]] = None,
) -> Dict[str, Any]:
if dst is None:
return deepcopy(src)
for key, val_src in src.items():
if whitelist is not None and key in whitelist:
continue
val_dst = dst.get(key)
if val_dst is None:
dst[key] = deepcopy(val_src)
elif isinstance(val_src, dict):
dst[key] = _merge_dicts(val_src, val_dst)
elif isinstance(val_src, list):
# NOTE: don't just `val_dst.extend(val_src)` here, as we need a carefully constructed delta object.
# Assume that all elements have `index: int`, indicating its position in the final merged list.
if len(val_src) == len(val_dst) and all(
val_src[i]["index"] == val_dst[i]["index"] for i in range(len(val_src))
):
for i in range(len(val_src)):
_merge_dicts(val_src[i], val_dst[i], whitelist={"index"})
elif len(val_src):
index_to_item = {item["index"]: item for item in val_dst}
for item_src in val_src:
idx = item_src["index"]
item_dst = index_to_item.get(idx)
if item_dst is None:
index_to_item[idx] = deepcopy(item_src)
else:
_merge_dicts(item_src, item_dst, whitelist={"index"})
dst[key] = sorted(
index_to_item.values(), key=lambda item: item["index"]
)
elif isinstance(val_src, str):
if isinstance(val_dst, list):
dst[key].append(val_src)
else:
# NOTE: to avoid O(n²) string concatenation, we store one string as a list of substrings.
# Later, join the substrings back by `_get_dict_with_joint_strings`.
dst[key] = [val_dst, val_src]
else:
raise TypeError(
f"key {key} has unsupported type {type(val_src)} of {val_src!r} against {val_dst!r}"
)
return dst
def _get_dict_with_joint_strings(
data: Optional[Dict[str, Any]],
) -> Optional[Dict[str, Any]]:
if data is None:
return
updated = {}
for key, val in data.items():
if isinstance(val, list):
if val:
if isinstance(val[0], str):
updated[key] = "".join(val)
elif isinstance(val[0], dict):
updated[key] = [_get_dict_with_joint_strings(_) for _ in val]
else:
updated[key] = val
else:
updated[key] = val
elif isinstance(val, dict):
updated[key] = _get_dict_with_joint_strings(val)
else:
updated[key] = val
return updated
def json_dumps(data: Any) -> str:
return json.dumps(data, ensure_ascii=False)
class PrefixMatchingResult(int, Enum):
startswith = auto()
substring_of_prefix = auto()
mismatch = auto()
@classmethod
def check(cls, string: str, prefix: str) -> Self:
if not prefix:
return cls.startswith
for i in range(len(prefix)):
if i >= len(string):
return cls.substring_of_prefix
if string[i] != prefix[i]:
return cls.mismatch
return cls.startswith
@functools.cache
def _string_cartesian_product(
prefix: Tuple[str, ...], suffix: Union[str, Tuple[str, ...]]
) -> Tuple[str, ...]:
if isinstance(suffix, str):
return tuple(_ + suffix for _ in prefix)
else:
return tuple(_prefix + _suffix for _prefix in prefix for _suffix in suffix)
-435
View File
@@ -1,435 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# mypy: ignore-errors
# ruff: noqa
import json
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generator, List, Optional, Union
from typing_extensions import Self
from .data_type import AtomDataType, FunctionCallParameterDataType
from .base import FunctionCallDict
from .generator import (
GeneratorParser,
default_tool_call_output_key,
json_dumps,
)
class M3TextParser(GeneratorParser):
"""
Parser for MiniMax M3 models.
M3 uses a namespace token `]<]minimax[>[` as delimiter before each tag.
Parameters use actual XML tag names (not `<parameter name="...">`), and can be nested.
Complex arguments, as nested XML tags, are buffered and emitted as complete JSON;
Simple arguments, are streamed character-by-character if possible.
Example raw output::
]<]minimax[>[<tool_call>
]<]minimax[>[<invoke name="func1">]<]minimax[>[<p1>value1]<]minimax[>[</p1>]<]minimax[>[<p2>]<]minimax[>[<item>]<]minimax[>[<k>val]<]minimax[>[</k>]<]minimax[>[</item>]<]minimax[>[</p2>]<]minimax[>[</invoke>
]<]minimax[>[</tool_call>
"""
def __init__(
self,
*,
with_reasoning: bool = True,
reasoning_prefix: str = "<mm:think>",
reasoning_suffix: str = "</mm:think>",
functions: Optional[Dict] = None,
tool_call_xml_tag_name: str = "tool_call",
tool_call_namespace_token: str = "]<]minimax[>[",
always_nullable: bool = True,
reasoning_field: str = "reasoning",
content_field: str = "content",
tool_call_output_key: Callable[
[int, Dict], Dict
] = default_tool_call_output_key,
):
self.with_reasoning = with_reasoning
self._reasoning_tokens: Optional[int] = None
self._reasoning_prefix = reasoning_prefix
self._reasoning_suffix = reasoning_suffix
self._reasoning_suffix_without_newline = reasoning_suffix.lstrip()
self._functions = functions
self._tool_call_namespace_token = tool_call_namespace_token
self._tool_call_start = (
f"{self._tool_call_namespace_token}<{tool_call_xml_tag_name}>"
)
self._tool_call_end = (
f"{self._tool_call_namespace_token}</{tool_call_xml_tag_name}>"
)
self._invoke_prefix = f'{self._tool_call_namespace_token}<invoke name="'
self._invoke_suffix = '">'
self._end_of_invoke = f"{self._tool_call_namespace_token}</invoke>"
self._parameter_prefix = f"{self._tool_call_namespace_token}<"
self._parameter_suffix = f"{self._tool_call_namespace_token}</"
self._always_nullable = always_nullable
self._reasoning_field = reasoning_field
self._content_field = content_field
self._tool_call_output_key = tool_call_output_key
super().__init__()
def _get_function(self, function_name: str) -> Optional[Dict]:
if isinstance(self._functions, dict):
return self._functions.get(function_name)
def count_reasoning_tokens(self) -> Optional[int]:
if self.with_reasoning:
if self._reasoning_tokens is None:
return self._count_consumed_tokens()
else:
return self._reasoning_tokens
def _process(self) -> Generator[dict, None, None]:
if self.with_reasoning:
if self._reasoning_prefix:
with_reasoning = yield from self._literal(
self._reasoning_prefix, should_raise=False
)
if with_reasoning:
yield from self._take_any(
until=self._reasoning_suffix, key=self._reasoning_field
)
self._reasoning_tokens = self._count_consumed_tokens()
else:
self._reasoning_tokens = 0
# If reasoning is disabled, we may find the suffix at the beginning without the prefix.
yield from self._literal(
self._reasoning_suffix_without_newline,
should_raise=False,
)
else:
yield from self._take_any(
until=self._reasoning_suffix, key=self._reasoning_field
)
self._reasoning_tokens = self._count_consumed_tokens()
if not self._functions:
yield from self._take_any(key=self._content_field)
else:
yield from self._take_any(
until=self._tool_call_start,
key=self._content_field,
should_consume_suffix=False,
)
yield from self._literal(self._tool_call_start)
yield from self._literal("\n", should_raise=False)
# NOTE: Only ONE `<tool_call>` block is supported by design.
# Multiple parallel calls must share a single wrapper and use
# multiple `<invoke>` tags inside it. A second `<tool_call>` after
# the first `</tool_call>` will cause `update()` to raise
# PatternMismatched (pattern exhausted).
tool_call_index = 0
while True:
if tool_call_index:
yield from self._literal("\n", should_raise=False)
tried = yield from self._literal(
(self._invoke_prefix, self._tool_call_end),
should_raise=False,
)
if tried is None or tried == self._tool_call_end:
break
function_name = yield from self._take_any(until=self._invoke_suffix)
self._append_delta(
self._tool_call_output_key(
tool_call_index, {"name": function_name, "arguments": "{"}
)
)
function = self._get_function(function_name)
is_first_parameter = True
while True:
tried = yield from self._literal(
(self._end_of_invoke, self._parameter_prefix),
should_raise=False,
)
if tried == self._end_of_invoke:
self._append_delta(
self._tool_call_output_key(
tool_call_index, {"arguments": "}"}
)
)
break
if tried is None:
# Consume and ignore, hoping it may recover after this invoke ends.
yield from self._take_any(until=self._end_of_invoke)
self._append_delta(
self._tool_call_output_key(
tool_call_index, {"arguments": "}"}
)
)
break
parameter_name = yield from self._take_any(until=">")
parameter_name_to_arguments = "{}{}: ".format(
"" if is_first_parameter else ", ",
json_dumps(parameter_name),
)
self._append_delta(
self._tool_call_output_key(
tool_call_index,
{"arguments": parameter_name_to_arguments},
)
)
parameter_suffix = "{}{}>".format(
self._parameter_suffix, parameter_name
)
parameter_data_type = (
FunctionCallParameterDataType.get_schema_of_parameter(
function, parameter_name
)
)
tried = yield from self._literal(
self._parameter_prefix,
should_raise=False,
should_consume=False,
)
if tried is not None:
param_body_str = yield from self._take_any(
until=parameter_suffix
)
if param_body_str:
# nested XML -> object
# NOTE: The namespace token has the highest semantic
# priority. Once the model emits `]<]minimax[>[<` here,
# we MUST treat the body as nested XML, even if the
# schema says this parameter should be a primitive.
# The model is asserting "this is a JSON level
# transition" — schema mismatches are reported back via
# the agent loop, not silently rewritten here.
param = self._parse_parameter(
param_body_str, parameter_data_type
)
else:
param = parameter_data_type.convert("")
self._append_delta(
self._tool_call_output_key(
tool_call_index,
{"arguments": json_dumps(param)},
)
)
else:
# no more nested XML -> string / number / boolean
yield from self._take_data_type_as_json(
until=parameter_suffix,
key=lambda value: self._tool_call_output_key(
tool_call_index, {"arguments": value}
),
data_type=parameter_data_type,
always_nullable=False,
should_consume_suffix=True,
)
is_first_parameter = False
tool_call_index += 1
def _parse_parameter(
self, body: str, parameter_data_type: FunctionCallParameterDataType
) -> dict:
chunks = body.split(self._tool_call_namespace_token)
# NOTE: Array detection is intentionally a strict "schema says array
# AND first child is <item>" check. We do NOT promote a uniform
# `<x><x>...` body to an array on schema mismatch — leave it as
# `{"x": [...]}` so the agent loop can spot and correct the model.
if (
AtomDataType.array in parameter_data_type.candidates
and len(chunks) > 1
and chunks[1].startswith("<item>")
):
root = []
else:
root = {}
stack: List[_StackItem] = [
_StackItem(tag=None, value=root, texts=None, data_type=parameter_data_type)
]
# Ignore the first chunk inside the parameter.
# It should be empty, since we've tried `self._parameter_prefix` and failed before entering this function.
for chunk_index in range(1, len(chunks)):
chunk = chunks[chunk_index]
# There are 7 = 3 + 3 + 1 non-empty categories of chunks.
if chunk.startswith("</"):
gt_offset = chunk.find(">", 2)
if gt_offset == -1:
# 1. `</tag`
tag = chunk[2:]
value = None
elif gt_offset == len(chunk) - 1:
# 2. `</tag>`
tag = chunk[2:-1]
value = None
else:
# 3. `</tag>value`
tag = chunk[2:gt_offset]
value = chunk[gt_offset + 1 :]
while len(stack) > 1:
item = stack.pop()
stack[-1].append(item)
if item.tag == tag:
break
if value:
stack[-1].append_text(value)
elif chunk.startswith("<"):
gt_offset = chunk.find(">", 1)
if gt_offset == -1:
# 4. `<tag`
tag = chunk[1:]
value = None
elif gt_offset == len(chunk) - 1:
# 5. `<tag>`
tag = chunk[1:-1]
value = None
else:
# 6. `<tag>value`
tag = chunk[1:gt_offset]
value = chunk[gt_offset + 1 :]
sub_data_type = stack[-1].get_data_type_of_property(tag)
if (
sub_data_type
and AtomDataType.array in sub_data_type.candidates
and len(chunks) > chunk_index + 1
and chunks[chunk_index + 1].startswith("<item>")
):
sub = []
elif sub_data_type and AtomDataType.object in sub_data_type.candidates:
sub = {}
else:
sub = None
stack.append(
_StackItem(
tag=tag,
value=sub,
texts=[value] if value else None,
data_type=sub_data_type,
)
)
elif chunk:
# 7. `value`
stack[-1].append_text(chunk)
while len(stack) > 1:
item = stack.pop()
stack[-1].append(item)
return stack[0].get_value()
def stringify_function_calls(self, function_calls: List[FunctionCallDict]) -> str:
if not function_calls:
return ""
parts = [self._tool_call_start + "\n"]
for function_call in function_calls:
parts.append(
self._invoke_prefix + function_call["name"] + self._invoke_suffix
)
arguments = function_call["arguments"]
if isinstance(arguments, str):
arguments = json.loads(arguments)
parts.append(self._stringify_parameter(arguments))
parts.append(self._end_of_invoke + "\n")
parts.append(self._tool_call_end)
return "".join(parts)
def _stringify_parameter(self, parameter: Any) -> str:
# NOTE: null values are simply ignored.
# This is limited due to the training philosophy of MiniMax M3.
# In the training process, the model will NOT see any null values in tool calls.
# Thus, we have to skip them here, in order to avoid OOD.
# This cost is unwillingly accepted:
# `["a", null, "c"]` becomes `<items>a</items><items>c</items>`,
# even the size of the array is changed.
if isinstance(parameter, dict):
return "".join(
f"{self._tool_call_namespace_token}<{key}>{self._stringify_parameter(value)}{self._tool_call_namespace_token}</{key}>"
for key, value in parameter.items()
if value is not None
)
elif isinstance(parameter, list):
return "".join(
f"{self._tool_call_namespace_token}<item>{self._stringify_parameter(value)}{self._tool_call_namespace_token}</item>"
for value in parameter
if value is not None
)
elif isinstance(parameter, str):
return parameter
elif parameter is None:
# should be unreachable
return ""
else:
return json.dumps(parameter, ensure_ascii=False)
@dataclass(slots=True)
class _StackItem:
tag: Optional[str]
value: Optional[Union[Dict, List]]
texts: Optional[List[str]]
data_type: Optional[FunctionCallParameterDataType]
def get_value(self) -> Any:
if self.value is None:
if self.texts:
value = "".join(self.texts)
else:
value = ""
if self.data_type:
value = self.data_type.convert(value)
return value
elif self.texts and isinstance(self.value, dict):
extra_text_key = "$text"
while extra_text_key in self.value:
extra_text_key = "$" + extra_text_key
self.value[extra_text_key] = "".join(self.texts)
return self.value
else:
return self.value
def append(self, item: Self) -> None:
if self.value is None:
self.value = {item.tag: item.get_value()}
elif isinstance(self.value, dict):
if item.tag in self.value:
# NOTE: Duplicate tag inside an object is collapsed into a
# list to preserve all values, even if the schema declares
# the key as a singleton. We don't drop the data and we
# don't try to "fix" the schema mismatch silently — the agent
# loop should surface the inconsistency back to the model.
value = self.value[item.tag]
if isinstance(value, list):
value.append(item.get_value())
else:
self.value[item.tag] = [value, item.get_value()]
else:
self.value[item.tag] = item.get_value()
elif isinstance(self.value, list):
# We expect `item.tag` to be `"item"`, but if it's not, we should still accept it.
self.value.append(item.get_value())
else:
raise NotImplementedError()
def append_text(self, value: str) -> None:
if isinstance(self.value, list):
if self.data_type:
data_type = self.data_type.get_data_type_of_item(index=len(self.value))
if data_type:
value = data_type.convert(value)
self.value.append(value)
elif self.texts is None:
self.texts = [value]
else:
self.texts.append(value)
def get_data_type_of_property(
self, tag: str
) -> Optional[FunctionCallParameterDataType]:
if self.data_type:
if isinstance(self.value, list):
# We expect `tag` to be `"item"`, but if it's not, we should still accept it.
return self.data_type.get_data_type_of_item(index=len(self.value))
elif isinstance(self.value, dict):
return self.data_type.get_data_type_of_property(tag)
else:
return None
+6 -298
View File
@@ -1,311 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from contextlib import suppress
from typing import Any
from openai.types.responses.function_tool import FunctionTool
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers._llm_nom.m3_text import M3TextParser
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__)
from vllm.tool_parsers.rust_tool_parser import RustToolParser
class MinimaxM3ToolParser(ToolParser):
"""Adapter from the vendored MiniMax M3 parser to vLLM ToolParser.
class MinimaxM3ToolParser(RustToolParser):
"""Adapter from the Rust MiniMax M3 parser to vLLM ToolParser.
The real M3 grammar lives in ``_llm_nom.m3_text.M3TextParser``. This
class keeps only the vLLM-specific bridge work:
- convert vLLM tool definitions into the function schema shape expected by
the vendored parser;
- translate parser ``content`` / ``tool_calls`` deltas into vLLM protocol
objects; and
- maintain vLLM streaming bookkeeping used by finish-reason handling.
The real M3 grammar lives in the Rust tool-parser crate. This class only
configures the generic Rust bridge with the MiniMax M3 parser name.
M3 is not M2 with renamed tags: it prefixes each structural tag with the
MiniMax namespace marker, allows multiple ``<invoke>`` tags in one wrapper,
and represents nested arguments with parameter-name XML tags.
"""
# M3 emits its own XML-like tool-call format from the chat template. For
# required/named tool_choice, do not let the serving layer force JSON guided
# output; parse the M3 syntax through this parser instead.
supports_required_and_named = False
rust_parser_name = "minimax_m3"
tool_call_start_token = "]<]minimax[>[<tool_call>"
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self._parser: M3TextParser | None = None
self._error: Exception | None = None
self._tool_call_ids: dict[int, str] = {}
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
logger.debug(
"vLLM successfully imported tool parser %s", self.__class__.__name__
)
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
"""Adjust generation options for MiniMax M3 tool-call syntax.
Required/named tool choice must skip ``super().adjust_request()``
because the base implementation would install JSON structured output
constraints. M3 needs to preserve and generate its namespace-tagged
syntax, so we only ensure special-token text is not stripped.
"""
if request.tools:
tool_choice = getattr(request, "tool_choice", None)
if tool_choice == "required" or isinstance(
tool_choice, ChatCompletionNamedToolChoiceParam
):
if hasattr(request, "skip_special_tokens"):
request.skip_special_tokens = False
return request
request = super().adjust_request(request)
if (
request.tools
and getattr(request, "tool_choice", None) != "none"
and hasattr(request, "skip_special_tokens")
):
request.skip_special_tokens = False
return request
def _functions(self) -> dict[str, dict[str, Any]] | None:
"""Build the function map consumed by ``M3TextParser``."""
if not self.tools:
return None
functions: dict[str, dict[str, Any]] = {}
for tool in self.tools:
if isinstance(tool, FunctionTool):
name = tool.name
parameters = tool.parameters
elif isinstance(tool, ChatCompletionToolsParam):
name = tool.function.name
parameters = tool.function.parameters
else:
continue
functions[name] = {"parameters": parameters}
return functions
def _new_parser(self) -> M3TextParser:
"""Create a fresh vendored parser with the current tool schemas."""
return M3TextParser(
with_reasoning=False,
reasoning_prefix="",
functions=self._functions(),
)
def _get_parser(self) -> M3TextParser:
if self._parser is None:
self._parser = self._new_parser()
return self._parser
def _reset_streaming_state(self) -> None:
"""Reset parser state for a new request on a reused parser instance."""
self._parser = self._new_parser()
self._error = None
self._tool_call_ids.clear()
self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear()
self.current_tool_id = -1
self.current_tool_name_sent = False
def _ensure_tool_state(self, index: int) -> None:
"""Grow vLLM streaming state arrays to contain ``index``."""
while len(self.prev_tool_call_arr) <= index:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= index:
self.streamed_args_for_tool.append("")
def _record_delta(
self, index: int, name: str | None, arguments: str | None
) -> str | None:
"""Mirror a vendored-parser delta into vLLM streaming bookkeeping.
``prev_tool_call_arr`` and ``streamed_args_for_tool`` are read later by
the chat serving layer to decide the final ``tool_calls`` finish reason
and to flush any remaining argument bytes.
"""
tool_call_id = None
self._ensure_tool_state(index)
if name is not None:
tool_call_id = make_tool_call_id()
self._tool_call_ids[index] = tool_call_id
self.prev_tool_call_arr[index] = {"name": name, "arguments": {}}
self.current_tool_name_sent = True
if arguments is not None:
self.streamed_args_for_tool[index] += arguments
with suppress(json.JSONDecodeError):
self.prev_tool_call_arr[index]["arguments"] = json.loads(
self.streamed_args_for_tool[index]
)
self.current_tool_id = index
return tool_call_id
def _delta_message_from_parser_delta(
self, parser_delta: dict[str, Any] | None
) -> DeltaMessage | None:
"""Translate one ``M3TextParser`` delta into a vLLM ``DeltaMessage``."""
if parser_delta is None:
return None
normal_text = parser_delta.get("content") or None
tool_calls: list[DeltaToolCall] = []
for tool_call in parser_delta.get("tool_calls", []):
func = tool_call.get("function", {})
index = tool_call.get("index", 0)
name = func.get("name")
arguments = func.get("arguments")
if name is None and arguments is None:
continue
tool_call_id = self._record_delta(index, name, arguments)
tool_calls.append(
DeltaToolCall(
index=index,
id=tool_call_id,
type="function" if name is not None else None,
function=DeltaFunctionCall(
name=name,
arguments=arguments,
),
)
)
if normal_text is None and not tool_calls:
return None
return DeltaMessage(content=normal_text, tool_calls=tool_calls)
def _parse_complete(self, model_output: str) -> dict[str, Any] | None:
"""Parse complete model output with a throwaway parser instance."""
parser = self._new_parser()
try:
parser.update(model_output)
except Exception:
logger.exception("Error parsing MiniMax M3 tool call output.")
return None
return parser.get_delta() or parser.get_final()
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (non-streaming)."""
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
parsed = self._parse_complete(model_output)
if parsed is None:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
tool_calls: list[ToolCall] = []
self.prev_tool_call_arr.clear()
for parsed_tool_call in parsed.get("tool_calls", []):
func = parsed_tool_call.get("function", {})
name = func.get("name")
arguments = func.get("arguments", "{}")
if name is None:
continue
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(name=name, arguments=arguments),
)
)
try:
args = json.loads(arguments)
except json.JSONDecodeError:
args = arguments
self.prev_tool_call_arr.append({"name": name, "arguments": args})
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
content = parsed.get("content") or None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
current_token_ids: Sequence[int], # pylint: disable=unused-argument
delta_token_ids: Sequence[int],
request: ChatCompletionRequest, # pylint: disable=unused-argument
) -> DeltaMessage | None:
"""Extract tool calls from streaming model output.
``M3TextParser`` owns the incremental buffer, so this adapter feeds only
the newest text delta. It returns an empty final content delta on EOS
after a tool call so the serving layer reaches its finish-reason path.
"""
if not previous_text:
self._reset_streaming_state()
if self._error is not None:
return None
try:
self._get_parser().update(delta_text)
except Exception as error:
self._error = error
logger.exception("Error parsing MiniMax M3 streaming tool call output.")
parser_delta = self._get_parser().get_delta()
delta_message = self._delta_message_from_parser_delta(parser_delta)
if delta_message is not None:
return delta_message
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
return DeltaMessage(content="")
return None
+329
View File
@@ -0,0 +1,329 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
from openai.types.responses.function_tool import FunctionTool
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__)
def _rust_tool_parser_module() -> Any:
try:
from vllm import _rust_tool_parser
except ImportError as exc:
raise RuntimeError(
"Rust tool parsing requires the vllm._rust_tool_parser PyO3 "
"extension. Rebuild vLLM with Rust extensions enabled."
) from exc
return _rust_tool_parser
class RustToolParser(ToolParser):
"""Adapter from an opaque Rust parser to the vLLM ToolParser API.
Subclasses provide only model-specific configuration: the exact Rust parser
name and an optional tool-call start marker for fast complete-output
rejection.
This class keeps the vLLM-specific bridge work:
- convert vLLM tool definitions into the Rust ``Tool`` shape;
- translate typed Rust parser outputs into vLLM protocol objects; and
- maintain vLLM streaming bookkeeping used by finish-reason handling.
The parser grammar and incremental parser state stay in Rust.
"""
# Rust-backed parsers are opaque to Python by default. Do not use vLLM's
# standard JSON required/named handling; let the Rust parser consume the
# model's native tool-call syntax.
supports_required_and_named = False
rust_parser_name: str
tool_call_start_token: str | None = None
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self._parser: Any | None = None
self._error: Exception | None = None
self._finished = False
self._tool_call_ids: dict[int, str] = {}
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
logger.debug(
"vLLM successfully imported tool parser %s", self.__class__.__name__
)
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
"""Adjust request options without installing Python-side constraints.
Rust-backed parsers are treated as source-of-truth opaque parsers. The
bridge intentionally avoids ``super().adjust_request()`` so Python does
not install JSON schema guidance or structural-tag constraints that may
conflict with the Rust parser's native grammar.
"""
if self._get_parser().preserve_special_tokens():
request.skip_special_tokens = False
return request
def _rust_tools(self) -> list[Any]:
"""Build Rust ``Tool`` objects from vLLM tool definitions."""
if not self.tools:
return []
tools: list[Any] = []
for tool in self.tools:
if isinstance(tool, FunctionTool):
name = tool.name
description = tool.description
parameters = tool.parameters or {}
strict = getattr(tool, "strict", None)
elif isinstance(tool, ChatCompletionToolsParam):
name = tool.function.name
description = tool.function.description
parameters = tool.function.parameters or {}
strict = getattr(tool.function, "strict", None)
else:
continue
tools.append(
_rust_tool_parser_module().Tool(name, description, parameters, strict)
)
return tools
def _new_parser(self) -> Any:
"""Create a fresh Rust parser with the current tool schemas."""
return _rust_tool_parser_module().ToolParser(
self.rust_parser_name, self._rust_tools()
)
def _get_parser(self) -> Any:
if self._parser is None:
self._parser = self._new_parser()
return self._parser
def _reset_streaming_state(self) -> None:
"""Reset parser state for a new request on a reused parser instance."""
self._parser = self._new_parser()
self._error = None
self._finished = False
self._tool_call_ids.clear()
self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear()
self.current_tool_id = -1
self.current_tool_name_sent = False
def _ensure_tool_state(self, index: int) -> None:
"""Grow vLLM streaming state arrays to contain ``index``."""
while len(self.prev_tool_call_arr) <= index:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= index:
self.streamed_args_for_tool.append("")
def _record_delta(
self, index: int, name: str | None, arguments: str | None
) -> str | None:
"""Mirror a Rust parser delta into vLLM streaming bookkeeping.
``prev_tool_call_arr`` and ``streamed_args_for_tool`` are read later by
the chat serving layer to decide the final ``tool_calls`` finish reason
and to flush any remaining argument bytes.
"""
tool_call_id = None
self._ensure_tool_state(index)
if name is not None:
tool_call_id = make_tool_call_id()
self._tool_call_ids[index] = tool_call_id
self.prev_tool_call_arr[index] = {"name": name, "arguments": {}}
self.current_tool_name_sent = True
if arguments is not None:
self.streamed_args_for_tool[index] += arguments
self.prev_tool_call_arr[index]["arguments"] = (
self.streamed_args_for_tool[index]
)
self.current_tool_id = index
return tool_call_id
def _delta_message_from_parser_output(
self, parser_output: Any | None
) -> DeltaMessage | None:
"""Translate one Rust parser output into a vLLM ``DeltaMessage``."""
if parser_output is None:
return None
normal_text = parser_output.normal_text or None
tool_calls: list[DeltaToolCall] = []
for tool_call in parser_output.calls:
index = tool_call.tool_index
name = tool_call.name
arguments: str | None = tool_call.arguments
if name is None and arguments is None:
continue
tool_call_id = self._record_delta(index, name, arguments)
tool_calls.append(
DeltaToolCall(
index=index,
id=tool_call_id,
type="function" if name is not None else None,
function=DeltaFunctionCall(
name=name,
arguments=arguments,
),
)
)
if normal_text is None and not tool_calls:
return None
return DeltaMessage(content=normal_text, tool_calls=tool_calls)
def _parse_complete(self, model_output: str) -> Any | None:
"""Parse complete model output with a throwaway Rust parser instance."""
parser = self._new_parser()
output = _rust_tool_parser_module().ToolParserOutput()
try:
parser.parse_into(model_output, output)
output.append(parser.finish())
except Exception:
logger.exception(
"Error parsing %s tool call output.", self.rust_parser_name
)
return None
return output.coalesce_calls()
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (non-streaming)."""
if (
self.tool_call_start_token is not None
and self.tool_call_start_token not in model_output
):
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
parsed = self._parse_complete(model_output)
if parsed is None:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
tool_calls: list[ToolCall] = []
self.prev_tool_call_arr.clear()
for parsed_tool_call in parsed.calls:
name = parsed_tool_call.name
arguments = parsed_tool_call.arguments or "{}"
if name is None:
continue
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(name=name, arguments=arguments),
)
)
self.prev_tool_call_arr.append({"name": name, "arguments": arguments})
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
content = parsed.normal_text or None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
current_token_ids: Sequence[int], # pylint: disable=unused-argument
delta_token_ids: Sequence[int],
request: ChatCompletionRequest, # pylint: disable=unused-argument
) -> DeltaMessage | None:
"""Extract tool calls from streaming model output.
The Rust parser owns the incremental buffer, so this adapter feeds only
the newest text delta. On EOS, it calls ``finish()`` once to flush any
complete buffered tool call. It returns an empty final content delta
after a tool call so the serving layer reaches its finish-reason path.
"""
if not previous_text:
self._reset_streaming_state()
if self._error is not None:
return None
parser_output = _rust_tool_parser_module().ToolParserOutput()
try:
self._get_parser().parse_into(delta_text, parser_output)
except Exception as error:
self._error = error
logger.exception(
"Error parsing %s streaming tool call output.",
self.rust_parser_name,
)
delta_message = self._delta_message_from_parser_output(parser_output)
if delta_message is not None:
return delta_message
if not delta_text and delta_token_ids and not self._finished:
try:
finish_output = self._get_parser().finish()
self._finished = True
except Exception as error:
self._error = error
logger.exception(
"Error finishing %s streaming tool parser.",
self.rust_parser_name,
)
finish_output = None
delta_message = self._delta_message_from_parser_output(finish_output)
if delta_message is not None:
return delta_message
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
return DeltaMessage(content="")
return None