mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
bridge rust tool parser to python
Signed-off-by: Bugen Zhao <i@bugenzhao.com>
This commit is contained in:
Generated
+86
@@ -3458,6 +3458,75 @@ version = "0.1.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f"
|
||||
|
||||
[[package]]
|
||||
name = "pyo3"
|
||||
version = "0.28.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"once_cell",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-build-config"
|
||||
version = "0.28.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e"
|
||||
dependencies = [
|
||||
"target-lexicon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-ffi"
|
||||
version = "0.28.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"pyo3-build-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros"
|
||||
version = "0.28.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros-backend"
|
||||
version = "0.28.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pythonize"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b79f670c9626c8b651c0581011b57b6ba6970bb69faf01a7c4c0cfc81c43f95"
|
||||
dependencies = [
|
||||
"pyo3",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qoi"
|
||||
version = "0.4.1"
|
||||
@@ -4669,6 +4738,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.13.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca"
|
||||
|
||||
[[package]]
|
||||
name = "task-local"
|
||||
version = "0.1.1"
|
||||
@@ -5901,6 +5976,17 @@ dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vllm-tool-parser-py"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"pyo3",
|
||||
"pythonize",
|
||||
"serde_json",
|
||||
"thiserror-ext",
|
||||
"vllm-tool-parser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.5.0"
|
||||
|
||||
@@ -12,6 +12,7 @@ members = [
|
||||
"src/text",
|
||||
"src/tokenizer",
|
||||
"src/tool-parser",
|
||||
"src/tool-parser/python",
|
||||
]
|
||||
resolver = "3"
|
||||
|
||||
@@ -60,6 +61,8 @@ prometheus-client = "0.24.0"
|
||||
prometheus-client-derive-encode = "0.5.0"
|
||||
prost = "0.14.3"
|
||||
prost-types = "0.14.3"
|
||||
pyo3 = "0.28.3"
|
||||
pythonize = "0.28.0"
|
||||
rand = "0.9.2"
|
||||
reasoning-parser = "1.2.2"
|
||||
reqwest = { version = "0.12.8", default-features = false, features = ["rustls-tls"] }
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
[package]
|
||||
name = "vllm-tool-parser-py"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "_rust_tool_parser"
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[features]
|
||||
default = []
|
||||
extension-module = ["pyo3/extension-module"]
|
||||
|
||||
[dependencies]
|
||||
pyo3.workspace = true
|
||||
pythonize = { workspace = true, features = ["serde_json"] }
|
||||
serde_json.workspace = true
|
||||
thiserror-ext.workspace = true
|
||||
vllm-tool-parser.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
@@ -0,0 +1,400 @@
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyAny, PyModule};
|
||||
use pythonize::{depythonize, pythonize};
|
||||
use serde_json::Value;
|
||||
use thiserror_ext::AsReport as _;
|
||||
use vllm_tool_parser::{
|
||||
DeepSeekV3ToolParser, DeepSeekV4ToolParser, DeepSeekV31ToolParser, DeepSeekV32ToolParser,
|
||||
Gemma4ToolParser, Glm45MoeToolParser, Glm47MoeToolParser, HermesToolParser, HyV3ToolParser,
|
||||
KimiK2ToolParser, Llama3JsonToolParser, MinimaxM2ToolParser, MinimaxM3ToolParser,
|
||||
MistralToolParser, Qwen3CoderToolParser, Qwen3XmlToolParser, Tool, ToolCallDelta, ToolParser,
|
||||
ToolParserOutput,
|
||||
};
|
||||
|
||||
#[pyclass(name = "Tool", module = "vllm._rust_tool_parser", skip_from_py_object)]
|
||||
#[derive(Clone)]
|
||||
struct PyTool {
|
||||
inner: Tool,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyTool {
|
||||
#[new]
|
||||
#[pyo3(signature = (name, description, parameters, strict=None))]
|
||||
fn new(
|
||||
name: String,
|
||||
description: Option<String>,
|
||||
parameters: &Bound<'_, PyAny>,
|
||||
strict: Option<bool>,
|
||||
) -> PyResult<Self> {
|
||||
let parameters = depythonize::<Value>(parameters).map_err(|error| {
|
||||
PyValueError::new_err(format!(
|
||||
"failed to convert tool parameters from Python to JSON: {error}"
|
||||
))
|
||||
})?;
|
||||
Ok(Self {
|
||||
inner: Tool {
|
||||
name,
|
||||
description,
|
||||
parameters,
|
||||
strict,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn name(&self) -> &str {
|
||||
&self.inner.name
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn description(&self) -> Option<&str> {
|
||||
self.inner.description.as_deref()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn parameters(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
|
||||
pythonize(py, &self.inner.parameters).map(Bound::unbind).map_err(|error| {
|
||||
PyValueError::new_err(format!(
|
||||
"failed to convert tool parameters from JSON to Python: {error}"
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn strict(&self) -> Option<bool> {
|
||||
self.inner.strict
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(
|
||||
name = "ToolCallDelta",
|
||||
module = "vllm._rust_tool_parser",
|
||||
skip_from_py_object
|
||||
)]
|
||||
#[derive(Clone)]
|
||||
struct PyToolCallDelta {
|
||||
inner: ToolCallDelta,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyToolCallDelta {
|
||||
#[new]
|
||||
#[pyo3(signature = (tool_index, name, arguments))]
|
||||
fn new(tool_index: usize, name: Option<String>, arguments: String) -> Self {
|
||||
Self {
|
||||
inner: ToolCallDelta {
|
||||
tool_index,
|
||||
name,
|
||||
arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn tool_index(&self) -> usize {
|
||||
self.inner.tool_index
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn name(&self) -> Option<&str> {
|
||||
self.inner.name.as_deref()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn arguments(&self) -> &str {
|
||||
&self.inner.arguments
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(
|
||||
name = "ToolParserOutput",
|
||||
module = "vllm._rust_tool_parser",
|
||||
skip_from_py_object
|
||||
)]
|
||||
#[derive(Clone)]
|
||||
struct PyToolParserOutput {
|
||||
inner: ToolParserOutput,
|
||||
}
|
||||
|
||||
impl PyToolParserOutput {
|
||||
fn from_inner(inner: ToolParserOutput) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyToolParserOutput {
|
||||
#[new]
|
||||
#[pyo3(signature = (normal_text="", calls=None))]
|
||||
fn new(py: Python<'_>, normal_text: &str, calls: Option<Vec<Py<PyToolCallDelta>>>) -> Self {
|
||||
let calls = calls
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.map(|call| call.borrow(py).inner.clone())
|
||||
.collect();
|
||||
Self {
|
||||
inner: ToolParserOutput {
|
||||
normal_text: normal_text.to_owned(),
|
||||
calls,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn normal_text(&self) -> &str {
|
||||
&self.inner.normal_text
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn calls(&self) -> Vec<PyToolCallDelta> {
|
||||
self.inner
|
||||
.calls
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|inner| PyToolCallDelta { inner })
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn append(&mut self, other: PyRef<'_, PyToolParserOutput>) {
|
||||
self.inner.append(other.inner.clone());
|
||||
}
|
||||
|
||||
fn coalesce_calls(&self) -> Self {
|
||||
Self::from_inner(self.inner.clone().coalesce_calls())
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(name = "ToolParser", module = "vllm._rust_tool_parser", unsendable)]
|
||||
struct PyToolParser {
|
||||
parser: Box<dyn ToolParser>,
|
||||
}
|
||||
|
||||
impl PyToolParser {
|
||||
fn from_name(name: &str, tools: &[Tool]) -> PyResult<Self> {
|
||||
let parser = match name {
|
||||
"deepseek_v3" => DeepSeekV3ToolParser::create(tools),
|
||||
"deepseek_v31" => DeepSeekV31ToolParser::create(tools),
|
||||
"deepseek_v32" => DeepSeekV32ToolParser::create(tools),
|
||||
"deepseek_v4" => DeepSeekV4ToolParser::create(tools),
|
||||
"gemma4" => Gemma4ToolParser::create(tools),
|
||||
"glm45" => Glm45MoeToolParser::create(tools),
|
||||
"glm47" => Glm47MoeToolParser::create(tools),
|
||||
"hermes" => HermesToolParser::create(tools),
|
||||
"hy_v3" => HyV3ToolParser::create(tools),
|
||||
"kimi_k2" => KimiK2ToolParser::create(tools),
|
||||
"llama3_json" | "llama4_json" => Llama3JsonToolParser::create(tools),
|
||||
"minimax_m2" => MinimaxM2ToolParser::create(tools),
|
||||
"minimax_m3" => MinimaxM3ToolParser::create(tools),
|
||||
"mistral" => MistralToolParser::create(tools),
|
||||
"qwen3_xml" => Qwen3XmlToolParser::create(tools),
|
||||
"qwen3_coder" => Qwen3CoderToolParser::create(tools),
|
||||
_ => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"unsupported tool parser `{name}`"
|
||||
)));
|
||||
}
|
||||
}
|
||||
.map_err(|error| PyValueError::new_err(error.to_report_string()))?;
|
||||
|
||||
Ok(Self { parser })
|
||||
}
|
||||
|
||||
fn parse_into_output(&mut self, chunk: &str, output: &mut PyToolParserOutput) -> PyResult<()> {
|
||||
self.parser
|
||||
.parse_into(chunk, &mut output.inner)
|
||||
.map_err(|error| PyValueError::new_err(error.to_report_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyToolParser {
|
||||
#[new]
|
||||
fn new(py: Python<'_>, parser_name: &str, tools: Vec<Py<PyTool>>) -> PyResult<Self> {
|
||||
let tools = tools.iter().map(|tool| tool.borrow(py).inner.clone()).collect::<Vec<_>>();
|
||||
Self::from_name(parser_name, &tools)
|
||||
}
|
||||
|
||||
fn parse_into(
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
mut output: PyRefMut<'_, PyToolParserOutput>,
|
||||
) -> PyResult<()> {
|
||||
self.parse_into_output(chunk, &mut output)
|
||||
}
|
||||
|
||||
fn finish(&mut self) -> PyResult<PyToolParserOutput> {
|
||||
self.parser
|
||||
.finish()
|
||||
.map(PyToolParserOutput::from_inner)
|
||||
.map_err(|error| PyValueError::new_err(error.to_report_string()))
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> String {
|
||||
self.parser.reset()
|
||||
}
|
||||
|
||||
fn preserve_special_tokens(&self) -> bool {
|
||||
self.parser.preserve_special_tokens()
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn _rust_tool_parser(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyTool>()?;
|
||||
m.add_class::<PyToolCallDelta>()?;
|
||||
m.add_class::<PyToolParserOutput>()?;
|
||||
m.add_class::<PyToolParser>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
const NS: &str = "]<]minimax[>[";
|
||||
|
||||
fn with_python<R>(f: impl for<'py> FnOnce(Python<'py>) -> R) -> R {
|
||||
Python::initialize();
|
||||
Python::attach(f)
|
||||
}
|
||||
|
||||
fn tool_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user_id": {"type": "integer"},
|
||||
"shipping": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"zip": {"type": "integer"}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn build_call() -> String {
|
||||
format!(
|
||||
"{NS}<tool_call>\n\
|
||||
{NS}<invoke name=\"create_order\">\
|
||||
{NS}<user_id>42{NS}</user_id>\
|
||||
{NS}<shipping>\
|
||||
{NS}<city>Singapore{NS}</city>\
|
||||
{NS}<zip>018956{NS}</zip>\
|
||||
{NS}</shipping>\
|
||||
{NS}</invoke>\n\
|
||||
{NS}</tool_call>"
|
||||
)
|
||||
}
|
||||
|
||||
fn make_py_tool(py: Python<'_>) -> PyResult<Py<PyTool>> {
|
||||
let parameters = pythonize(py, &tool_schema()).map_err(|error| {
|
||||
PyValueError::new_err(format!(
|
||||
"failed to convert test schema from JSON to Python: {error}"
|
||||
))
|
||||
})?;
|
||||
Py::new(
|
||||
py,
|
||||
PyTool::new(
|
||||
"create_order".to_owned(),
|
||||
Some("Create an order".to_owned()),
|
||||
¶meters,
|
||||
None,
|
||||
)?,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_round_trips_typed_fields() {
|
||||
with_python(|py| {
|
||||
let tool = make_py_tool(py)?;
|
||||
let borrowed = tool.borrow(py);
|
||||
assert_eq!(borrowed.name(), "create_order");
|
||||
assert_eq!(borrowed.description(), Some("Create an order"));
|
||||
assert_eq!(borrowed.strict(), None);
|
||||
|
||||
let parameters = borrowed.parameters(py)?;
|
||||
let parameters = depythonize::<Value>(parameters.bind(py))?;
|
||||
assert_eq!(parameters, tool_schema());
|
||||
PyResult::Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_append_and_coalesce_calls() {
|
||||
with_python(|py| {
|
||||
let first = Py::new(
|
||||
py,
|
||||
PyToolCallDelta::new(0, Some("create_order".to_owned()), "{\"a\"".to_owned()),
|
||||
)?;
|
||||
let second = Py::new(py, PyToolCallDelta::new(0, None, ":1}".to_owned()))?;
|
||||
let mut output = PyToolParserOutput::new(py, "text", Some(vec![first]));
|
||||
let other = Py::new(py, PyToolParserOutput::new(py, "", Some(vec![second])))?;
|
||||
output.append(other.borrow(py));
|
||||
|
||||
let coalesced = output.coalesce_calls();
|
||||
assert_eq!(coalesced.normal_text(), "text");
|
||||
let calls = coalesced.calls();
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].tool_index(), 0);
|
||||
assert_eq!(calls[0].name(), Some("create_order"));
|
||||
assert_eq!(calls[0].arguments(), "{\"a\":1}");
|
||||
PyResult::Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parser_parse_finish_and_preserve_special_tokens() {
|
||||
with_python(|py| {
|
||||
let tool = make_py_tool(py)?;
|
||||
let mut parser = PyToolParser::new(py, "minimax_m3", vec![tool])?;
|
||||
assert!(!parser.preserve_special_tokens());
|
||||
|
||||
let mut output = PyToolParserOutput::new(py, "", None);
|
||||
parser.parse_into_output(&build_call(), &mut output)?;
|
||||
let finish = Py::new(py, parser.finish()?)?;
|
||||
output.append(finish.borrow(py));
|
||||
let output = output.coalesce_calls();
|
||||
|
||||
assert_eq!(output.normal_text(), "");
|
||||
let calls = output.calls();
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name(), Some("create_order"));
|
||||
assert_eq!(
|
||||
serde_json::from_str::<Value>(calls[0].arguments()).unwrap(),
|
||||
json!({
|
||||
"user_id": 42,
|
||||
"shipping": {
|
||||
"city": "Singapore",
|
||||
"zip": 18956
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
assert_eq!(parser.reset(), "");
|
||||
PyResult::Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parser_errors_for_unknown_name() {
|
||||
with_python(|py| {
|
||||
let tool = make_py_tool(py)?;
|
||||
let error = match PyToolParser::new(py, "missing", vec![tool]) {
|
||||
Ok(_) => panic!("missing parser name unexpectedly succeeded"),
|
||||
Err(error) => error,
|
||||
};
|
||||
let message = format!("{error}");
|
||||
assert!(message.contains("unsupported tool parser `missing`"));
|
||||
PyResult::Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,8 @@ ROOT_DIR = Path(__file__).parent
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRECOMPILED_RUST_FRONTEND_PATH = ROOT_DIR / "vllm" / "vllm-rs"
|
||||
PRECOMPILED_RUST_EXTENSION_GLOB = "_rust_*.so"
|
||||
PRECOMPILED_RUST_EXTENSION_MEMBER_REGEX = re.compile(r"vllm/_rust_[^/]*\.so$")
|
||||
|
||||
# cannot import envs directly because it depends on vllm,
|
||||
# which is not installed yet
|
||||
@@ -54,6 +56,59 @@ def should_require_rust_frontend() -> bool:
|
||||
return value.lower() not in ("", "0", "false", "no")
|
||||
|
||||
|
||||
# Rust frontend binary, built via setuptools-rust and installed into the
|
||||
# package directory alongside the Python modules.
|
||||
# TODO: we may use `RustBin` to directly install it into `bin` directory, but this
|
||||
# requires extra work on using precompiled binaries.
|
||||
rust_extensions = [
|
||||
RustExtension(
|
||||
target="vllm.vllm-rs",
|
||||
path="rust/src/cmd/Cargo.toml",
|
||||
args=["--bin", "vllm-rs"],
|
||||
features=["native-tls-vendored"],
|
||||
binding=Binding.Exec,
|
||||
optional=not should_require_rust_frontend(),
|
||||
),
|
||||
RustExtension(
|
||||
target="vllm._rust_tool_parser",
|
||||
path="rust/src/tool-parser/python/Cargo.toml",
|
||||
features=["extension-module"],
|
||||
binding=Binding.PyO3,
|
||||
optional=not should_require_rust_frontend(),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_precompiled_rust_extension_paths() -> list[Path]:
|
||||
return sorted((ROOT_DIR / "vllm").glob(PRECOMPILED_RUST_EXTENSION_GLOB))
|
||||
|
||||
|
||||
def get_expected_rust_extension_module_names() -> list[str]:
|
||||
"""Return configured PyO3 Rust extension module names under ``vllm``."""
|
||||
module_names = []
|
||||
for rust_extension in rust_extensions:
|
||||
if rust_extension.binding != Binding.PyO3:
|
||||
continue
|
||||
|
||||
for target_name in rust_extension.target.values():
|
||||
if target_name.startswith("vllm._rust_"):
|
||||
module_names.append(target_name.rsplit(".", 1)[-1])
|
||||
|
||||
return module_names
|
||||
|
||||
|
||||
def get_missing_precompiled_rust_extension_modules() -> list[str]:
|
||||
missing = []
|
||||
for module_name in get_expected_rust_extension_module_names():
|
||||
if not list((ROOT_DIR / "vllm").glob(f"{module_name}*.so")):
|
||||
missing.append(module_name)
|
||||
return missing
|
||||
|
||||
|
||||
def has_precompiled_rust_extensions() -> bool:
|
||||
return not get_missing_precompiled_rust_extension_modules()
|
||||
|
||||
|
||||
if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu":
|
||||
logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS")
|
||||
VLLM_TARGET_DEVICE = "cpu"
|
||||
@@ -421,19 +476,33 @@ class precompiled_build_ext(build_ext):
|
||||
|
||||
|
||||
class precompiled_build_rust(build_rust):
|
||||
"""Skips local Rust builds when the precompiled wheel already ships vllm-rs."""
|
||||
"""Skips local Rust builds when all precompiled Rust artifacts are present."""
|
||||
|
||||
def run(self) -> None:
|
||||
if PRECOMPILED_RUST_FRONTEND_PATH.exists():
|
||||
if (
|
||||
PRECOMPILED_RUST_FRONTEND_PATH.exists()
|
||||
and has_precompiled_rust_extensions()
|
||||
):
|
||||
logger.info(
|
||||
"Skipping local Rust build: using precompiled %s",
|
||||
"Skipping local Rust build: using precompiled %s and %s",
|
||||
PRECOMPILED_RUST_FRONTEND_PATH,
|
||||
get_precompiled_rust_extension_paths(),
|
||||
)
|
||||
return
|
||||
|
||||
missing = []
|
||||
if not PRECOMPILED_RUST_FRONTEND_PATH.exists():
|
||||
missing.append(str(PRECOMPILED_RUST_FRONTEND_PATH))
|
||||
missing_rust_extensions = get_missing_precompiled_rust_extension_modules()
|
||||
if missing_rust_extensions:
|
||||
missing.extend(
|
||||
str(ROOT_DIR / "vllm" / f"{module_name}*.so")
|
||||
for module_name in missing_rust_extensions
|
||||
)
|
||||
logger.warning(
|
||||
"Precompiled wheel did not provide %s; falling back to local Rust build.",
|
||||
PRECOMPILED_RUST_FRONTEND_PATH,
|
||||
"Precompiled wheel did not provide all Rust artifacts (%s); "
|
||||
"falling back to local Rust build.",
|
||||
", ".join(missing),
|
||||
)
|
||||
super().run()
|
||||
|
||||
@@ -756,6 +825,14 @@ class precompiled_wheel_utils:
|
||||
if member.filename in exact_members:
|
||||
file_members.append(member)
|
||||
continue
|
||||
if (
|
||||
extract_rust_frontend
|
||||
and PRECOMPILED_RUST_EXTENSION_MEMBER_REGEX.match(
|
||||
member.filename
|
||||
)
|
||||
):
|
||||
file_members.append(member)
|
||||
continue
|
||||
|
||||
if not extract_extensions:
|
||||
continue
|
||||
@@ -1127,6 +1204,10 @@ if PRECOMPILED_RUST_FRONTEND_PATH.exists():
|
||||
vllm_files = package_data.setdefault("vllm", [])
|
||||
if "vllm-rs" not in vllm_files:
|
||||
vllm_files.append("vllm-rs")
|
||||
vllm_files = package_data.setdefault("vllm", [])
|
||||
for rust_extension_path in get_precompiled_rust_extension_paths():
|
||||
if rust_extension_path.name not in vllm_files:
|
||||
vllm_files.append(rust_extension_path.name)
|
||||
|
||||
if _no_device():
|
||||
ext_modules = []
|
||||
@@ -1139,24 +1220,13 @@ else:
|
||||
if USE_PRECOMPILED_EXTENSIONS
|
||||
else cmake_build_ext,
|
||||
}
|
||||
if USE_PRECOMPILED_RUST_FRONTEND or PRECOMPILED_RUST_FRONTEND_PATH.exists():
|
||||
if (
|
||||
USE_PRECOMPILED_RUST_FRONTEND
|
||||
or PRECOMPILED_RUST_FRONTEND_PATH.exists()
|
||||
or has_precompiled_rust_extensions()
|
||||
):
|
||||
cmdclass["build_rust"] = precompiled_build_rust
|
||||
|
||||
# Rust frontend binary, built via setuptools-rust and installed into the
|
||||
# package directory alongside the Python modules.
|
||||
# TODO: we may use `RustBin` to directly install it into `bin` directory, but this
|
||||
# requires extra work on using precompiled binaries.
|
||||
rust_extensions = [
|
||||
RustExtension(
|
||||
target="vllm.vllm-rs",
|
||||
path="rust/src/cmd/Cargo.toml",
|
||||
args=["--bin", "vllm-rs"],
|
||||
features=["native-tls-vendored"],
|
||||
binding=Binding.Exec,
|
||||
optional=not should_require_rust_frontend(),
|
||||
),
|
||||
]
|
||||
|
||||
setup(
|
||||
# static metadata should rather go in pyproject.toml
|
||||
version=get_vllm_version(),
|
||||
|
||||
@@ -257,5 +257,5 @@ def test_streaming_nested_tool_call(parser):
|
||||
assert json.loads(tool_calls[0]["arguments"]) == json.loads(
|
||||
parser.streamed_args_for_tool[0]
|
||||
)
|
||||
assert parser.prev_tool_call_arr[0]["arguments"]["items"][1]["qty"] == 5
|
||||
assert json.loads(parser.prev_tool_call_arr[0]["arguments"])["items"][1]["qty"] == 5
|
||||
assert results[-1].content == ""
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# mypy: ignore-errors
|
||||
@@ -1,73 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# mypy: ignore-errors
|
||||
# ruff: noqa
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
ParserInputs = TypeVar("ParserInputs", bound=Sequence)
|
||||
ParserOutput = TypeVar("ParserOutput")
|
||||
TextFn = Callable[[], str]
|
||||
FunctionCallDict = Dict[Literal["name", "arguments"], str]
|
||||
|
||||
|
||||
class Parser(ABC, Generic[ParserInputs, ParserOutput]):
|
||||
@abstractmethod
|
||||
def update(self, inputs: ParserInputs):
|
||||
pass
|
||||
|
||||
def parse(self, inputs: ParserInputs) -> ParserOutput:
|
||||
self.update(inputs)
|
||||
final = self.get_final()
|
||||
assert final is not None
|
||||
return final
|
||||
|
||||
@abstractmethod
|
||||
def get_delta(self) -> Optional[ParserOutput]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_final(self) -> Optional[ParserOutput]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def _get_sub_parsers(self) -> Optional[Dict[str, Self]]:
|
||||
"""should only be used for debugging."""
|
||||
pass
|
||||
|
||||
def stringify_function_calls(self, function_calls: List[FunctionCallDict]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PatternMismatched(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
offset: int,
|
||||
expected: Any,
|
||||
actual: Any,
|
||||
reason: str,
|
||||
context_fn: Optional[TextFn] = None,
|
||||
):
|
||||
self.offset = offset
|
||||
self.expected = expected
|
||||
self.actual = actual
|
||||
self.reason = reason
|
||||
self.context_fn = context_fn
|
||||
|
||||
def __str__(self):
|
||||
main_message = f"at offset {self.offset}, {self.reason}, expected {self.expected!r} but got {self.actual!r}"
|
||||
if self.context_fn is None:
|
||||
return main_message
|
||||
else:
|
||||
return f"around context:\n{self.context_fn()}\n\n{main_message}"
|
||||
@@ -1,340 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# mypy: ignore-errors
|
||||
# ruff: noqa
|
||||
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Iterator, List, Optional, Set
|
||||
|
||||
import jsonschema
|
||||
import jsonschema.exceptions
|
||||
from typing_extensions import Self
|
||||
|
||||
# fmt: off
|
||||
NULL_STRINGS = (
|
||||
"null", "Null", "NULL",
|
||||
# It's bad, especially when we need enum string `none` etc.
|
||||
# "none", "null", "nil",
|
||||
# "None", "Null", "Nil",
|
||||
# "NONE", "NULL", "NIL",
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConversionResult:
|
||||
value: Any
|
||||
confidence: int
|
||||
|
||||
|
||||
class AtomDataType(int, Enum):
|
||||
none = auto()
|
||||
string = auto()
|
||||
integer = auto()
|
||||
float = auto()
|
||||
boolean = auto()
|
||||
# complex types:
|
||||
array = auto()
|
||||
object = auto()
|
||||
other = auto()
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string: Optional[str]) -> Self:
|
||||
if string is None:
|
||||
return cls.string
|
||||
string = string.lower()
|
||||
if string == "none" or string == "null" or string == "nil":
|
||||
return cls.none
|
||||
elif string == "str" or string == "string" or string == "text":
|
||||
return cls.string
|
||||
elif string == "int" or string == "integer":
|
||||
return cls.integer
|
||||
elif string == "float" or string == "number":
|
||||
return cls.float
|
||||
elif string == "bool" or string == "boolean":
|
||||
return cls.boolean
|
||||
elif string == "array" or string == "list":
|
||||
return cls.array
|
||||
elif (
|
||||
string == "object"
|
||||
or string == "dict"
|
||||
or string == "dictionary"
|
||||
or string == "map"
|
||||
):
|
||||
return cls.object
|
||||
else:
|
||||
return cls.other
|
||||
|
||||
def is_complex_type(self) -> bool:
|
||||
return self == self.array or self == self.object or self == self.other
|
||||
|
||||
@classmethod
|
||||
def from_example(cls, example: Any) -> Self:
|
||||
if example is None:
|
||||
return cls.none
|
||||
elif isinstance(example, str):
|
||||
return cls.string
|
||||
# NOTE: Order matters. True is instance of int.
|
||||
elif isinstance(example, bool):
|
||||
return cls.boolean
|
||||
elif isinstance(example, int):
|
||||
return cls.integer
|
||||
elif isinstance(example, float):
|
||||
return cls.float
|
||||
elif isinstance(example, list):
|
||||
return cls.array
|
||||
elif isinstance(example, dict):
|
||||
return cls.object
|
||||
else:
|
||||
return cls.other
|
||||
|
||||
@classmethod
|
||||
def iter_candidates_from_schema(cls, input: Any) -> Iterator[Self]:
|
||||
if isinstance(input, dict):
|
||||
if "type" in input:
|
||||
type_value = input["type"]
|
||||
if isinstance(type_value, str):
|
||||
yield cls.from_string(type_value)
|
||||
elif isinstance(type_value, list):
|
||||
for each in type_value:
|
||||
yield cls.from_string(each)
|
||||
else:
|
||||
yield cls.other
|
||||
elif ("properties" in input and isinstance(input["properties"], dict)) or (
|
||||
"additionalProperties" in input
|
||||
and input["additionalProperties"] is not False
|
||||
):
|
||||
yield cls.object
|
||||
elif "items" in input or "prefixItems" in input:
|
||||
yield cls.array
|
||||
else:
|
||||
nothing_found = True
|
||||
if "const" in input:
|
||||
nothing_found = False
|
||||
yield cls.from_example(input["const"])
|
||||
elif (
|
||||
"enum" in input
|
||||
and isinstance(input["enum"], list)
|
||||
and input["enum"]
|
||||
):
|
||||
for each in input["enum"]:
|
||||
nothing_found = False
|
||||
yield cls.from_example(each)
|
||||
else:
|
||||
for choice_field in ("anyOf", "oneOf", "allOf"):
|
||||
if choice_field in input:
|
||||
choices = input[choice_field]
|
||||
if isinstance(choices, list):
|
||||
for choice in choices:
|
||||
for each in cls.iter_candidates_from_schema(choice):
|
||||
nothing_found = False
|
||||
yield each
|
||||
if nothing_found:
|
||||
yield cls.string
|
||||
else:
|
||||
yield cls.string
|
||||
|
||||
def convert_with_confidence(self, string: str) -> Optional[ConversionResult]:
|
||||
if self == AtomDataType.none:
|
||||
stripped = string.strip()
|
||||
if stripped:
|
||||
if len(stripped) == 4 and stripped.lower() == "null":
|
||||
return ConversionResult(value=None, confidence=10)
|
||||
return ConversionResult(value=None, confidence=1)
|
||||
else:
|
||||
return ConversionResult(value=None, confidence=3)
|
||||
elif self == AtomDataType.string:
|
||||
return ConversionResult(value=string, confidence=2)
|
||||
elif self == AtomDataType.integer:
|
||||
try:
|
||||
return ConversionResult(value=int(string), confidence=10)
|
||||
except ValueError:
|
||||
return None
|
||||
elif self == AtomDataType.float:
|
||||
try:
|
||||
value = float(string)
|
||||
except ValueError:
|
||||
return None
|
||||
# NOTE: Python is evil.
|
||||
# In Python, we have
|
||||
# json.dumps(math.nan) -> "NaN"
|
||||
# json.dumps(math.inf) -> "Infinity"
|
||||
# json.dumps(-math.inf) -> "-Infinity"
|
||||
# However, in other languages, you cannot deserialize them back.
|
||||
# No perfect solution here.
|
||||
if math.isnan(value):
|
||||
value = "NaN"
|
||||
elif math.isinf(value):
|
||||
value = "Infinity" if value > 0 else "-Infinity"
|
||||
elif value.is_integer():
|
||||
value = int(value)
|
||||
return ConversionResult(value=value, confidence=10)
|
||||
elif self == AtomDataType.boolean:
|
||||
stripped = string.strip()
|
||||
if len(stripped) > 5:
|
||||
return ConversionResult(value=False, confidence=1)
|
||||
lower = stripped.lower()
|
||||
if lower == "true":
|
||||
return ConversionResult(value=True, confidence=10)
|
||||
elif lower == "false":
|
||||
return ConversionResult(value=False, confidence=10)
|
||||
elif lower == "yes" or lower == "on" or lower == "1":
|
||||
return ConversionResult(value=True, confidence=9)
|
||||
elif lower == "no" or lower == "off" or lower == "0":
|
||||
return ConversionResult(value=False, confidence=9)
|
||||
else:
|
||||
return ConversionResult(value=False, confidence=1)
|
||||
else:
|
||||
try:
|
||||
return ConversionResult(value=json.loads(string), confidence=10)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FunctionCallParameterDataType:
|
||||
candidates: List[AtomDataType]
|
||||
schema: Any
|
||||
streaming: bool
|
||||
|
||||
@classmethod
|
||||
def get_schema_of_parameter(cls, schema: Any, parameter_name: str) -> Self:
|
||||
"""
|
||||
$ref etc. not supported.
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return cls(candidates=[AtomDataType.string], schema=None, streaming=True)
|
||||
property = None
|
||||
parameters = schema.get("parameters")
|
||||
if isinstance(parameters, dict):
|
||||
properties = parameters.get("properties")
|
||||
if isinstance(properties, dict):
|
||||
property = properties.get(parameter_name)
|
||||
return cls.from_property(property)
|
||||
|
||||
@classmethod
|
||||
def from_property(cls, property: Any) -> Self:
|
||||
candidate_set: Set[AtomDataType] = set()
|
||||
candidates: List[AtomDataType] = []
|
||||
# Guaranteed to be non-empty.
|
||||
for each in AtomDataType.iter_candidates_from_schema(property):
|
||||
if each not in candidate_set:
|
||||
candidate_set.add(each)
|
||||
candidates.append(each)
|
||||
# The parameter is streaming if and only if there is only one candidate, string.
|
||||
# No matter how complicated the schema is, if it can only be a string,
|
||||
streaming = len(candidates) == 1 and candidates[0] == AtomDataType.string
|
||||
return cls(candidates=candidates, schema=property, streaming=streaming)
|
||||
|
||||
def get_data_type_of_property(self, key: str) -> Optional[Self]:
|
||||
if isinstance(self.schema, dict):
|
||||
property = self.get_property(self.schema, key)
|
||||
if property:
|
||||
return self.from_property(property)
|
||||
property_choices = []
|
||||
for choice_field in ("anyOf", "oneOf", "allOf"):
|
||||
if choice_field in self.schema:
|
||||
choices = self.schema[choice_field]
|
||||
if isinstance(choices, list):
|
||||
for choice in choices:
|
||||
property = self.get_property(choice, key)
|
||||
if property:
|
||||
property_choices.append(property)
|
||||
if property_choices:
|
||||
return self.from_property(
|
||||
{
|
||||
"anyOf": property_choices,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_property(schema: Any, key: str) -> Any:
|
||||
if "properties" in schema:
|
||||
properties = schema["properties"]
|
||||
if isinstance(properties, dict):
|
||||
property = properties.get(key)
|
||||
if property:
|
||||
return property
|
||||
if "additionalProperties" in schema:
|
||||
additional_properties = schema["additionalProperties"]
|
||||
if isinstance(additional_properties, dict):
|
||||
return additional_properties
|
||||
|
||||
def get_data_type_of_item(self, *, index: int = 0) -> Optional[Self]:
|
||||
if isinstance(self.schema, dict):
|
||||
item = self.get_item(self.schema, index=index)
|
||||
if item:
|
||||
return self.from_property(item)
|
||||
item_choices = []
|
||||
for choice_field in ("anyOf", "oneOf", "allOf"):
|
||||
if choice_field in self.schema:
|
||||
choices = self.schema[choice_field]
|
||||
if isinstance(choices, list):
|
||||
for choice in choices:
|
||||
item = self.get_item(choice, index=index)
|
||||
if item:
|
||||
item_choices.append(item)
|
||||
if item_choices:
|
||||
return self.from_property(
|
||||
{
|
||||
"anyOf": item_choices,
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_item(schema: Any, *, index: int = 0) -> Any:
|
||||
items = schema.get("items")
|
||||
if items:
|
||||
return items
|
||||
if "prefixItems" in schema:
|
||||
prefix_items = schema["prefixItems"]
|
||||
additional_items = schema.get("additionalItems")
|
||||
if isinstance(prefix_items, list):
|
||||
if index < len(prefix_items):
|
||||
return prefix_items[index]
|
||||
elif additional_items:
|
||||
return additional_items
|
||||
elif len(prefix_items):
|
||||
return prefix_items[-1]
|
||||
|
||||
def convert(self, string: str, *, always_nullable: bool = False) -> Any:
|
||||
if always_nullable:
|
||||
stripped = string.strip()
|
||||
if len(stripped) == 4 and stripped.lower() == "null":
|
||||
return None
|
||||
if not string:
|
||||
if AtomDataType.object in self.candidates:
|
||||
return {}
|
||||
elif AtomDataType.array in self.candidates:
|
||||
return []
|
||||
elif AtomDataType.none in self.candidates:
|
||||
return None
|
||||
else:
|
||||
return ""
|
||||
converted_list: List[ConversionResult] = []
|
||||
has_schema = bool(self.schema)
|
||||
for candidate in self.candidates:
|
||||
converted = candidate.convert_with_confidence(string)
|
||||
if converted is None:
|
||||
continue
|
||||
if has_schema:
|
||||
try:
|
||||
jsonschema.validate(schema=self.schema, instance=converted.value)
|
||||
# Ensure it beats the invalid ones.
|
||||
converted.confidence += 10
|
||||
except jsonschema.exceptions.SchemaError:
|
||||
# The schema is invalid, ignore it.
|
||||
has_schema = False
|
||||
except Exception:
|
||||
# Validation failed.
|
||||
pass
|
||||
converted_list.append(converted)
|
||||
if len(converted_list) == 1:
|
||||
return converted_list[0].value
|
||||
elif len(converted_list):
|
||||
return max(converted_list, key=lambda x: x.confidence).value
|
||||
else:
|
||||
return string
|
||||
@@ -1,9 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# mypy: ignore-errors
|
||||
# ruff: noqa
|
||||
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
IsPendingAndText = Tuple[bool, str]
|
||||
CountConsumedTokensFn = Callable[[Optional[IsPendingAndText]], int]
|
||||
@@ -1,463 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# mypy: ignore-errors
|
||||
# ruff: noqa
|
||||
|
||||
import functools
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from .data_type import NULL_STRINGS, FunctionCallParameterDataType
|
||||
from .base import Parser, ParserOutput, PatternMismatched
|
||||
from .decoder import CountConsumedTokensFn
|
||||
|
||||
OutputKey = Union[str, Callable[[Any], Dict]]
|
||||
|
||||
|
||||
class GeneratorParser(Parser[str, dict], ABC):
|
||||
"""
|
||||
TODO:
|
||||
1. 改成接 StrOrSpecialToken.
|
||||
2. 报错时支持 row:col.
|
||||
3. 需要 self.generator.close() 吗?
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _process(self) -> Generator[ParserOutput, None, None]:
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
self._count_consumed_tokens_fn: Optional[CountConsumedTokensFn] = None
|
||||
self._full_text: List[str] = []
|
||||
self._input_length = 0
|
||||
self._input_buffer: Deque[str] = deque()
|
||||
self._delta: Optional[Dict[str, Any]] = None
|
||||
self._final: Optional[Dict[str, Any]] = None
|
||||
self._generator = self._process()
|
||||
next(self._generator)
|
||||
|
||||
def __del__(self):
|
||||
# Necessary when GC is disabled (gc.disable()),
|
||||
# otherwise the cyclic reference will never be collected.
|
||||
if hasattr(self, "_generator"):
|
||||
self._generator.close()
|
||||
del self._generator
|
||||
|
||||
def reset(self):
|
||||
self._full_text.clear()
|
||||
self._input_length = 0
|
||||
self._input_buffer.clear()
|
||||
self._delta = None
|
||||
self._final = None
|
||||
# don't worry, a closed generator can be closed again.
|
||||
self._generator.close()
|
||||
self._generator = self._process()
|
||||
next(self._generator)
|
||||
|
||||
def update(self, input: str):
|
||||
self._full_text.append(input)
|
||||
total = len(input)
|
||||
self._input_length += total
|
||||
self._input_buffer.extend(input)
|
||||
# at most #input times,
|
||||
# but early stop if the input buffer is already empty.
|
||||
for _ in range(total):
|
||||
if not self._input_buffer:
|
||||
break
|
||||
try:
|
||||
next(self._generator)
|
||||
except StopIteration:
|
||||
if self._input_buffer:
|
||||
raise self._error(
|
||||
expected=None,
|
||||
actual=self._input_buffer[0],
|
||||
reason="pattern exhausted",
|
||||
)
|
||||
|
||||
def _count_consumed_tokens(self) -> Optional[int]:
|
||||
if self._count_consumed_tokens_fn is None:
|
||||
return
|
||||
total_tokens_count = self._count_consumed_tokens_fn(None)
|
||||
if self._input_buffer:
|
||||
pending_text_count = len(self._input_buffer)
|
||||
consumed_text_count = self._input_length - pending_text_count
|
||||
# If total is 0, we cannot use `consumed = total - pending`.
|
||||
# Have to tokenize the consumed text, maybe again.
|
||||
is_pending = (total_tokens_count > 0) and (
|
||||
pending_text_count <= consumed_text_count
|
||||
)
|
||||
if is_pending:
|
||||
text = "".join(self._input_buffer)
|
||||
else:
|
||||
text = "".join(self._full_text)[:consumed_text_count]
|
||||
return self._count_consumed_tokens_fn((is_pending, text))
|
||||
elif total_tokens_count:
|
||||
return total_tokens_count
|
||||
else:
|
||||
# Only for non-streaming mode, should be called only once.
|
||||
# In streaming mode, we should feed in tokens instead of text,
|
||||
# thus `total_tokens_count` must be positive.
|
||||
text = "".join(self._full_text)
|
||||
return self._count_consumed_tokens_fn((False, text))
|
||||
|
||||
def get_delta(self) -> Optional[Dict[str, Any]]:
|
||||
if self._delta:
|
||||
delta = _get_dict_with_joint_strings(self._delta)
|
||||
self._delta = None
|
||||
return delta
|
||||
|
||||
def get_final(self) -> Optional[Dict[str, Any]]:
|
||||
return _get_dict_with_joint_strings(self._final)
|
||||
|
||||
def _peek(self, index: int) -> Generator[None, None, str]:
|
||||
while index >= len(self._input_buffer):
|
||||
yield
|
||||
return self._input_buffer[index]
|
||||
|
||||
def _read_one(self) -> str:
|
||||
# intentionally unchecked, will raise IndexError if the input buffer is empty.
|
||||
return self._input_buffer.popleft()
|
||||
|
||||
def _consume(self, length: int):
|
||||
for _ in range(length):
|
||||
self._input_buffer.popleft()
|
||||
|
||||
def _try(self, fn, *args, **kwargs):
|
||||
try:
|
||||
result = yield from fn(*args, **kwargs)
|
||||
return Tried(successful=True, result=result)
|
||||
except PatternMismatched as e:
|
||||
# Clear traceback to break traceback → frame → f_locals → self
|
||||
# circular reference chain. The traceback is not needed for
|
||||
# backtracking logic, only offset/expected/actual matter.
|
||||
e.__traceback__ = None
|
||||
return Tried(successful=False, error=e)
|
||||
|
||||
def _append(self, key: OutputKey, value: Any):
|
||||
if isinstance(key, str):
|
||||
delta = {key: value}
|
||||
else:
|
||||
delta = key(value)
|
||||
self._append_delta(delta)
|
||||
|
||||
def _append_delta(self, delta: Dict[str, Any]):
|
||||
self._delta = _merge_dicts(delta, self._delta)
|
||||
self._final = _merge_dicts(delta, self._final)
|
||||
|
||||
def _get_context(self, offset: int) -> str:
|
||||
# NOTE: `self._error` is frequently triggered with backtracking,
|
||||
# thus we cannot afford to join the full text every time.
|
||||
# A partial function should work, which is only called when the error is stringified.
|
||||
full_text = "".join(self._full_text)
|
||||
return f"{full_text[:offset]}💥{full_text[offset:]}"
|
||||
|
||||
def _error(self, *, expected: Any, actual: Any, reason: str) -> PatternMismatched:
|
||||
offset = self._input_length - len(self._input_buffer)
|
||||
# Use weakref to avoid context_fn → partial → self circular reference.
|
||||
# _full_text is a plain list, capturing it directly avoids holding self.
|
||||
full_text_ref = self._full_text
|
||||
return PatternMismatched(
|
||||
offset=offset,
|
||||
expected=expected,
|
||||
actual=actual,
|
||||
reason=reason,
|
||||
context_fn=functools.partial(
|
||||
_get_context_standalone, full_text_ref, offset
|
||||
),
|
||||
)
|
||||
|
||||
def _take_any(
|
||||
self,
|
||||
*,
|
||||
until: Optional[Union[str, Tuple[str, ...]]] = None,
|
||||
key: Optional[OutputKey] = None,
|
||||
should_consume_suffix: bool = True,
|
||||
) -> Generator[str, None, None]:
|
||||
if until is None:
|
||||
# Never ends without `until`, thus no need to collect the values.
|
||||
while True:
|
||||
yield from self._peek(0)
|
||||
self._append(key, self._read_one())
|
||||
else:
|
||||
values = []
|
||||
while True:
|
||||
tried = yield from self._literal(
|
||||
until,
|
||||
should_consume=should_consume_suffix,
|
||||
should_raise=False,
|
||||
)
|
||||
if tried is not None:
|
||||
break
|
||||
value = self._read_one()
|
||||
values.append(value)
|
||||
if key is not None:
|
||||
self._append(key, value)
|
||||
return "".join(values)
|
||||
|
||||
def _take_data_type_as_json(
|
||||
self,
|
||||
*,
|
||||
until: Union[str, Tuple[str, ...]],
|
||||
key: OutputKey,
|
||||
data_type: FunctionCallParameterDataType,
|
||||
always_nullable: bool = False,
|
||||
should_consume_suffix: bool = True,
|
||||
):
|
||||
"""
|
||||
Unlike `.take_any`, `until` and `key` is required here, because
|
||||
- it is unnecessary and a little bit annoying to maintain the streaming state
|
||||
for most data types other than `str`.
|
||||
- if you don't need the output, just use `.take_any`.
|
||||
|
||||
NOTE: if there are various acceptable data types, we choose the FIRST convertible
|
||||
one, indicating, if the first choice is `str`, the following choices will be
|
||||
IGNORED, as everything could be a string.
|
||||
|
||||
TODO: if there are acceptable data types before `str`, we should peek until they
|
||||
are excluded and then choose `str`, instead of choosing at the end.
|
||||
"""
|
||||
if not data_type.streaming:
|
||||
value = yield from self._take_any(
|
||||
until=until, should_consume_suffix=should_consume_suffix
|
||||
)
|
||||
value = data_type.convert(value, always_nullable=always_nullable)
|
||||
self._append(key, json_dumps(value))
|
||||
return value
|
||||
else:
|
||||
if always_nullable:
|
||||
# very tricky.
|
||||
# even if the data type is string, we still need to check if it's null.
|
||||
tried = yield from self._literal(
|
||||
_string_cartesian_product(NULL_STRINGS, until),
|
||||
should_consume=should_consume_suffix,
|
||||
should_raise=False,
|
||||
)
|
||||
if tried is not None:
|
||||
if not should_consume_suffix:
|
||||
# NOTE: length of "null" (case-insensitive) is 4.
|
||||
self._consume(4)
|
||||
self._append(key, "null")
|
||||
return None
|
||||
values = []
|
||||
self._append(key, '"')
|
||||
while True:
|
||||
tried = yield from self._literal(
|
||||
until, should_consume=should_consume_suffix, should_raise=False
|
||||
)
|
||||
if tried is not None:
|
||||
break
|
||||
value = self._read_one()
|
||||
values.append(value)
|
||||
self._append(key, json_dumps(value)[1:-1])
|
||||
self._append(key, '"')
|
||||
return "".join(values)
|
||||
|
||||
def _literal(
|
||||
self,
|
||||
target: Union[str, Tuple[str, ...]],
|
||||
*,
|
||||
should_raise: bool = True,
|
||||
should_consume: bool = True,
|
||||
) -> Generator[Optional[str], None, None]:
|
||||
"""
|
||||
NOTE: stops at the first matched target, even if it is a substring of another later target.
|
||||
"""
|
||||
if isinstance(target, str):
|
||||
target = (target,)
|
||||
char_index = 0
|
||||
target_index = 0
|
||||
target_index_max = len(target) - 1
|
||||
while True:
|
||||
ith = yield from self._peek(char_index)
|
||||
if ith != target[target_index][char_index]:
|
||||
already_matched_and_ith = target[target_index][:char_index] + ith
|
||||
prefix_matching_result = PrefixMatchingResult.mismatch
|
||||
target_index += 1
|
||||
while target_index <= target_index_max:
|
||||
prefix_matching_result = PrefixMatchingResult.check(
|
||||
target[target_index], already_matched_and_ith
|
||||
)
|
||||
if prefix_matching_result != PrefixMatchingResult.mismatch:
|
||||
break
|
||||
target_index += 1
|
||||
if prefix_matching_result == PrefixMatchingResult.substring_of_prefix:
|
||||
break
|
||||
elif prefix_matching_result == PrefixMatchingResult.mismatch:
|
||||
if should_raise:
|
||||
if len(target) == 1:
|
||||
raise self._error(
|
||||
expected=target[0][char_index:],
|
||||
actual=ith,
|
||||
reason=f"matching {target[0]!r}",
|
||||
)
|
||||
else:
|
||||
raise self._error(
|
||||
expected=target,
|
||||
actual=ith,
|
||||
reason="matching any of the list",
|
||||
)
|
||||
else:
|
||||
return
|
||||
char_index += 1
|
||||
if char_index == len(target[target_index]):
|
||||
break
|
||||
if should_consume:
|
||||
self._consume(len(target[target_index]))
|
||||
return target[target_index]
|
||||
|
||||
def _whitespace(self) -> Iterator[None]:
|
||||
total = 0
|
||||
while True:
|
||||
peek = yield from self._peek(total)
|
||||
if peek.isspace():
|
||||
total += 1
|
||||
else:
|
||||
self._consume(total)
|
||||
break
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Tried:
|
||||
successful: bool
|
||||
result: Any = None
|
||||
error: Optional[PatternMismatched] = None
|
||||
|
||||
|
||||
def default_tool_call_output_key(tool_call_index: int, function: Dict) -> Dict:
|
||||
return {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": tool_call_index,
|
||||
"function": function,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _get_context_standalone(full_text_parts: List[str], offset: int) -> str:
|
||||
"""Standalone version of _get_context that doesn't capture self."""
|
||||
full_text = "".join(full_text_parts)
|
||||
return f"{full_text[:offset]}💥{full_text[offset:]}"
|
||||
|
||||
|
||||
def _merge_dicts(
|
||||
src: Dict[str, Any],
|
||||
dst: Optional[Dict[str, Any]],
|
||||
whitelist: Optional[Set[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if dst is None:
|
||||
return deepcopy(src)
|
||||
for key, val_src in src.items():
|
||||
if whitelist is not None and key in whitelist:
|
||||
continue
|
||||
val_dst = dst.get(key)
|
||||
if val_dst is None:
|
||||
dst[key] = deepcopy(val_src)
|
||||
elif isinstance(val_src, dict):
|
||||
dst[key] = _merge_dicts(val_src, val_dst)
|
||||
elif isinstance(val_src, list):
|
||||
# NOTE: don't just `val_dst.extend(val_src)` here, as we need a carefully constructed delta object.
|
||||
# Assume that all elements have `index: int`, indicating its position in the final merged list.
|
||||
if len(val_src) == len(val_dst) and all(
|
||||
val_src[i]["index"] == val_dst[i]["index"] for i in range(len(val_src))
|
||||
):
|
||||
for i in range(len(val_src)):
|
||||
_merge_dicts(val_src[i], val_dst[i], whitelist={"index"})
|
||||
elif len(val_src):
|
||||
index_to_item = {item["index"]: item for item in val_dst}
|
||||
for item_src in val_src:
|
||||
idx = item_src["index"]
|
||||
item_dst = index_to_item.get(idx)
|
||||
if item_dst is None:
|
||||
index_to_item[idx] = deepcopy(item_src)
|
||||
else:
|
||||
_merge_dicts(item_src, item_dst, whitelist={"index"})
|
||||
dst[key] = sorted(
|
||||
index_to_item.values(), key=lambda item: item["index"]
|
||||
)
|
||||
elif isinstance(val_src, str):
|
||||
if isinstance(val_dst, list):
|
||||
dst[key].append(val_src)
|
||||
else:
|
||||
# NOTE: to avoid O(n²) string concatenation, we store one string as a list of substrings.
|
||||
# Later, join the substrings back by `_get_dict_with_joint_strings`.
|
||||
dst[key] = [val_dst, val_src]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"key {key} has unsupported type {type(val_src)} of {val_src!r} against {val_dst!r}"
|
||||
)
|
||||
return dst
|
||||
|
||||
|
||||
def _get_dict_with_joint_strings(
|
||||
data: Optional[Dict[str, Any]],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if data is None:
|
||||
return
|
||||
updated = {}
|
||||
for key, val in data.items():
|
||||
if isinstance(val, list):
|
||||
if val:
|
||||
if isinstance(val[0], str):
|
||||
updated[key] = "".join(val)
|
||||
elif isinstance(val[0], dict):
|
||||
updated[key] = [_get_dict_with_joint_strings(_) for _ in val]
|
||||
else:
|
||||
updated[key] = val
|
||||
else:
|
||||
updated[key] = val
|
||||
elif isinstance(val, dict):
|
||||
updated[key] = _get_dict_with_joint_strings(val)
|
||||
else:
|
||||
updated[key] = val
|
||||
return updated
|
||||
|
||||
|
||||
def json_dumps(data: Any) -> str:
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
|
||||
|
||||
class PrefixMatchingResult(int, Enum):
|
||||
startswith = auto()
|
||||
substring_of_prefix = auto()
|
||||
mismatch = auto()
|
||||
|
||||
@classmethod
|
||||
def check(cls, string: str, prefix: str) -> Self:
|
||||
if not prefix:
|
||||
return cls.startswith
|
||||
for i in range(len(prefix)):
|
||||
if i >= len(string):
|
||||
return cls.substring_of_prefix
|
||||
if string[i] != prefix[i]:
|
||||
return cls.mismatch
|
||||
return cls.startswith
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _string_cartesian_product(
|
||||
prefix: Tuple[str, ...], suffix: Union[str, Tuple[str, ...]]
|
||||
) -> Tuple[str, ...]:
|
||||
if isinstance(suffix, str):
|
||||
return tuple(_ + suffix for _ in prefix)
|
||||
else:
|
||||
return tuple(_prefix + _suffix for _prefix in prefix for _suffix in suffix)
|
||||
@@ -1,435 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# mypy: ignore-errors
|
||||
# ruff: noqa
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from .data_type import AtomDataType, FunctionCallParameterDataType
|
||||
from .base import FunctionCallDict
|
||||
from .generator import (
|
||||
GeneratorParser,
|
||||
default_tool_call_output_key,
|
||||
json_dumps,
|
||||
)
|
||||
|
||||
|
||||
class M3TextParser(GeneratorParser):
|
||||
"""
|
||||
Parser for MiniMax M3 models.
|
||||
|
||||
M3 uses a namespace token `]<]minimax[>[` as delimiter before each tag.
|
||||
Parameters use actual XML tag names (not `<parameter name="...">`), and can be nested.
|
||||
Complex arguments, as nested XML tags, are buffered and emitted as complete JSON;
|
||||
Simple arguments, are streamed character-by-character if possible.
|
||||
|
||||
Example raw output::
|
||||
|
||||
]<]minimax[>[<tool_call>
|
||||
]<]minimax[>[<invoke name="func1">]<]minimax[>[<p1>value1]<]minimax[>[</p1>]<]minimax[>[<p2>]<]minimax[>[<item>]<]minimax[>[<k>val]<]minimax[>[</k>]<]minimax[>[</item>]<]minimax[>[</p2>]<]minimax[>[</invoke>
|
||||
]<]minimax[>[</tool_call>
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
with_reasoning: bool = True,
|
||||
reasoning_prefix: str = "<mm:think>",
|
||||
reasoning_suffix: str = "</mm:think>",
|
||||
functions: Optional[Dict] = None,
|
||||
tool_call_xml_tag_name: str = "tool_call",
|
||||
tool_call_namespace_token: str = "]<]minimax[>[",
|
||||
always_nullable: bool = True,
|
||||
reasoning_field: str = "reasoning",
|
||||
content_field: str = "content",
|
||||
tool_call_output_key: Callable[
|
||||
[int, Dict], Dict
|
||||
] = default_tool_call_output_key,
|
||||
):
|
||||
self.with_reasoning = with_reasoning
|
||||
self._reasoning_tokens: Optional[int] = None
|
||||
self._reasoning_prefix = reasoning_prefix
|
||||
self._reasoning_suffix = reasoning_suffix
|
||||
self._reasoning_suffix_without_newline = reasoning_suffix.lstrip()
|
||||
self._functions = functions
|
||||
self._tool_call_namespace_token = tool_call_namespace_token
|
||||
self._tool_call_start = (
|
||||
f"{self._tool_call_namespace_token}<{tool_call_xml_tag_name}>"
|
||||
)
|
||||
self._tool_call_end = (
|
||||
f"{self._tool_call_namespace_token}</{tool_call_xml_tag_name}>"
|
||||
)
|
||||
self._invoke_prefix = f'{self._tool_call_namespace_token}<invoke name="'
|
||||
self._invoke_suffix = '">'
|
||||
self._end_of_invoke = f"{self._tool_call_namespace_token}</invoke>"
|
||||
self._parameter_prefix = f"{self._tool_call_namespace_token}<"
|
||||
self._parameter_suffix = f"{self._tool_call_namespace_token}</"
|
||||
self._always_nullable = always_nullable
|
||||
self._reasoning_field = reasoning_field
|
||||
self._content_field = content_field
|
||||
self._tool_call_output_key = tool_call_output_key
|
||||
super().__init__()
|
||||
|
||||
def _get_function(self, function_name: str) -> Optional[Dict]:
|
||||
if isinstance(self._functions, dict):
|
||||
return self._functions.get(function_name)
|
||||
|
||||
def count_reasoning_tokens(self) -> Optional[int]:
|
||||
if self.with_reasoning:
|
||||
if self._reasoning_tokens is None:
|
||||
return self._count_consumed_tokens()
|
||||
else:
|
||||
return self._reasoning_tokens
|
||||
|
||||
def _process(self) -> Generator[dict, None, None]:
|
||||
if self.with_reasoning:
|
||||
if self._reasoning_prefix:
|
||||
with_reasoning = yield from self._literal(
|
||||
self._reasoning_prefix, should_raise=False
|
||||
)
|
||||
if with_reasoning:
|
||||
yield from self._take_any(
|
||||
until=self._reasoning_suffix, key=self._reasoning_field
|
||||
)
|
||||
self._reasoning_tokens = self._count_consumed_tokens()
|
||||
else:
|
||||
self._reasoning_tokens = 0
|
||||
# If reasoning is disabled, we may find the suffix at the beginning without the prefix.
|
||||
yield from self._literal(
|
||||
self._reasoning_suffix_without_newline,
|
||||
should_raise=False,
|
||||
)
|
||||
else:
|
||||
yield from self._take_any(
|
||||
until=self._reasoning_suffix, key=self._reasoning_field
|
||||
)
|
||||
self._reasoning_tokens = self._count_consumed_tokens()
|
||||
|
||||
if not self._functions:
|
||||
yield from self._take_any(key=self._content_field)
|
||||
else:
|
||||
yield from self._take_any(
|
||||
until=self._tool_call_start,
|
||||
key=self._content_field,
|
||||
should_consume_suffix=False,
|
||||
)
|
||||
yield from self._literal(self._tool_call_start)
|
||||
yield from self._literal("\n", should_raise=False)
|
||||
|
||||
# NOTE: Only ONE `<tool_call>` block is supported by design.
|
||||
# Multiple parallel calls must share a single wrapper and use
|
||||
# multiple `<invoke>` tags inside it. A second `<tool_call>` after
|
||||
# the first `</tool_call>` will cause `update()` to raise
|
||||
# PatternMismatched (pattern exhausted).
|
||||
tool_call_index = 0
|
||||
while True:
|
||||
if tool_call_index:
|
||||
yield from self._literal("\n", should_raise=False)
|
||||
tried = yield from self._literal(
|
||||
(self._invoke_prefix, self._tool_call_end),
|
||||
should_raise=False,
|
||||
)
|
||||
if tried is None or tried == self._tool_call_end:
|
||||
break
|
||||
|
||||
function_name = yield from self._take_any(until=self._invoke_suffix)
|
||||
self._append_delta(
|
||||
self._tool_call_output_key(
|
||||
tool_call_index, {"name": function_name, "arguments": "{"}
|
||||
)
|
||||
)
|
||||
function = self._get_function(function_name)
|
||||
is_first_parameter = True
|
||||
while True:
|
||||
tried = yield from self._literal(
|
||||
(self._end_of_invoke, self._parameter_prefix),
|
||||
should_raise=False,
|
||||
)
|
||||
if tried == self._end_of_invoke:
|
||||
self._append_delta(
|
||||
self._tool_call_output_key(
|
||||
tool_call_index, {"arguments": "}"}
|
||||
)
|
||||
)
|
||||
break
|
||||
if tried is None:
|
||||
# Consume and ignore, hoping it may recover after this invoke ends.
|
||||
yield from self._take_any(until=self._end_of_invoke)
|
||||
self._append_delta(
|
||||
self._tool_call_output_key(
|
||||
tool_call_index, {"arguments": "}"}
|
||||
)
|
||||
)
|
||||
break
|
||||
parameter_name = yield from self._take_any(until=">")
|
||||
parameter_name_to_arguments = "{}{}: ".format(
|
||||
"" if is_first_parameter else ", ",
|
||||
json_dumps(parameter_name),
|
||||
)
|
||||
self._append_delta(
|
||||
self._tool_call_output_key(
|
||||
tool_call_index,
|
||||
{"arguments": parameter_name_to_arguments},
|
||||
)
|
||||
)
|
||||
parameter_suffix = "{}{}>".format(
|
||||
self._parameter_suffix, parameter_name
|
||||
)
|
||||
parameter_data_type = (
|
||||
FunctionCallParameterDataType.get_schema_of_parameter(
|
||||
function, parameter_name
|
||||
)
|
||||
)
|
||||
tried = yield from self._literal(
|
||||
self._parameter_prefix,
|
||||
should_raise=False,
|
||||
should_consume=False,
|
||||
)
|
||||
if tried is not None:
|
||||
param_body_str = yield from self._take_any(
|
||||
until=parameter_suffix
|
||||
)
|
||||
if param_body_str:
|
||||
# nested XML -> object
|
||||
# NOTE: The namespace token has the highest semantic
|
||||
# priority. Once the model emits `]<]minimax[>[<` here,
|
||||
# we MUST treat the body as nested XML, even if the
|
||||
# schema says this parameter should be a primitive.
|
||||
# The model is asserting "this is a JSON level
|
||||
# transition" — schema mismatches are reported back via
|
||||
# the agent loop, not silently rewritten here.
|
||||
param = self._parse_parameter(
|
||||
param_body_str, parameter_data_type
|
||||
)
|
||||
else:
|
||||
param = parameter_data_type.convert("")
|
||||
self._append_delta(
|
||||
self._tool_call_output_key(
|
||||
tool_call_index,
|
||||
{"arguments": json_dumps(param)},
|
||||
)
|
||||
)
|
||||
else:
|
||||
# no more nested XML -> string / number / boolean
|
||||
yield from self._take_data_type_as_json(
|
||||
until=parameter_suffix,
|
||||
key=lambda value: self._tool_call_output_key(
|
||||
tool_call_index, {"arguments": value}
|
||||
),
|
||||
data_type=parameter_data_type,
|
||||
always_nullable=False,
|
||||
should_consume_suffix=True,
|
||||
)
|
||||
is_first_parameter = False
|
||||
tool_call_index += 1
|
||||
|
||||
def _parse_parameter(
|
||||
self, body: str, parameter_data_type: FunctionCallParameterDataType
|
||||
) -> dict:
|
||||
chunks = body.split(self._tool_call_namespace_token)
|
||||
# NOTE: Array detection is intentionally a strict "schema says array
|
||||
# AND first child is <item>" check. We do NOT promote a uniform
|
||||
# `<x><x>...` body to an array on schema mismatch — leave it as
|
||||
# `{"x": [...]}` so the agent loop can spot and correct the model.
|
||||
if (
|
||||
AtomDataType.array in parameter_data_type.candidates
|
||||
and len(chunks) > 1
|
||||
and chunks[1].startswith("<item>")
|
||||
):
|
||||
root = []
|
||||
else:
|
||||
root = {}
|
||||
stack: List[_StackItem] = [
|
||||
_StackItem(tag=None, value=root, texts=None, data_type=parameter_data_type)
|
||||
]
|
||||
|
||||
# Ignore the first chunk inside the parameter.
|
||||
# It should be empty, since we've tried `self._parameter_prefix` and failed before entering this function.
|
||||
for chunk_index in range(1, len(chunks)):
|
||||
chunk = chunks[chunk_index]
|
||||
# There are 7 = 3 + 3 + 1 non-empty categories of chunks.
|
||||
if chunk.startswith("</"):
|
||||
gt_offset = chunk.find(">", 2)
|
||||
if gt_offset == -1:
|
||||
# 1. `</tag`
|
||||
tag = chunk[2:]
|
||||
value = None
|
||||
elif gt_offset == len(chunk) - 1:
|
||||
# 2. `</tag>`
|
||||
tag = chunk[2:-1]
|
||||
value = None
|
||||
else:
|
||||
# 3. `</tag>value`
|
||||
tag = chunk[2:gt_offset]
|
||||
value = chunk[gt_offset + 1 :]
|
||||
while len(stack) > 1:
|
||||
item = stack.pop()
|
||||
stack[-1].append(item)
|
||||
if item.tag == tag:
|
||||
break
|
||||
if value:
|
||||
stack[-1].append_text(value)
|
||||
elif chunk.startswith("<"):
|
||||
gt_offset = chunk.find(">", 1)
|
||||
if gt_offset == -1:
|
||||
# 4. `<tag`
|
||||
tag = chunk[1:]
|
||||
value = None
|
||||
elif gt_offset == len(chunk) - 1:
|
||||
# 5. `<tag>`
|
||||
tag = chunk[1:-1]
|
||||
value = None
|
||||
else:
|
||||
# 6. `<tag>value`
|
||||
tag = chunk[1:gt_offset]
|
||||
value = chunk[gt_offset + 1 :]
|
||||
sub_data_type = stack[-1].get_data_type_of_property(tag)
|
||||
if (
|
||||
sub_data_type
|
||||
and AtomDataType.array in sub_data_type.candidates
|
||||
and len(chunks) > chunk_index + 1
|
||||
and chunks[chunk_index + 1].startswith("<item>")
|
||||
):
|
||||
sub = []
|
||||
elif sub_data_type and AtomDataType.object in sub_data_type.candidates:
|
||||
sub = {}
|
||||
else:
|
||||
sub = None
|
||||
stack.append(
|
||||
_StackItem(
|
||||
tag=tag,
|
||||
value=sub,
|
||||
texts=[value] if value else None,
|
||||
data_type=sub_data_type,
|
||||
)
|
||||
)
|
||||
elif chunk:
|
||||
# 7. `value`
|
||||
stack[-1].append_text(chunk)
|
||||
|
||||
while len(stack) > 1:
|
||||
item = stack.pop()
|
||||
stack[-1].append(item)
|
||||
|
||||
return stack[0].get_value()
|
||||
|
||||
def stringify_function_calls(self, function_calls: List[FunctionCallDict]) -> str:
|
||||
if not function_calls:
|
||||
return ""
|
||||
parts = [self._tool_call_start + "\n"]
|
||||
for function_call in function_calls:
|
||||
parts.append(
|
||||
self._invoke_prefix + function_call["name"] + self._invoke_suffix
|
||||
)
|
||||
arguments = function_call["arguments"]
|
||||
if isinstance(arguments, str):
|
||||
arguments = json.loads(arguments)
|
||||
parts.append(self._stringify_parameter(arguments))
|
||||
parts.append(self._end_of_invoke + "\n")
|
||||
parts.append(self._tool_call_end)
|
||||
return "".join(parts)
|
||||
|
||||
def _stringify_parameter(self, parameter: Any) -> str:
|
||||
# NOTE: null values are simply ignored.
|
||||
# This is limited due to the training philosophy of MiniMax M3.
|
||||
# In the training process, the model will NOT see any null values in tool calls.
|
||||
# Thus, we have to skip them here, in order to avoid OOD.
|
||||
# This cost is unwillingly accepted:
|
||||
# `["a", null, "c"]` becomes `<items>a</items><items>c</items>`,
|
||||
# even the size of the array is changed.
|
||||
if isinstance(parameter, dict):
|
||||
return "".join(
|
||||
f"{self._tool_call_namespace_token}<{key}>{self._stringify_parameter(value)}{self._tool_call_namespace_token}</{key}>"
|
||||
for key, value in parameter.items()
|
||||
if value is not None
|
||||
)
|
||||
elif isinstance(parameter, list):
|
||||
return "".join(
|
||||
f"{self._tool_call_namespace_token}<item>{self._stringify_parameter(value)}{self._tool_call_namespace_token}</item>"
|
||||
for value in parameter
|
||||
if value is not None
|
||||
)
|
||||
elif isinstance(parameter, str):
|
||||
return parameter
|
||||
elif parameter is None:
|
||||
# should be unreachable
|
||||
return ""
|
||||
else:
|
||||
return json.dumps(parameter, ensure_ascii=False)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _StackItem:
|
||||
tag: Optional[str]
|
||||
value: Optional[Union[Dict, List]]
|
||||
texts: Optional[List[str]]
|
||||
data_type: Optional[FunctionCallParameterDataType]
|
||||
|
||||
def get_value(self) -> Any:
|
||||
if self.value is None:
|
||||
if self.texts:
|
||||
value = "".join(self.texts)
|
||||
else:
|
||||
value = ""
|
||||
if self.data_type:
|
||||
value = self.data_type.convert(value)
|
||||
return value
|
||||
elif self.texts and isinstance(self.value, dict):
|
||||
extra_text_key = "$text"
|
||||
while extra_text_key in self.value:
|
||||
extra_text_key = "$" + extra_text_key
|
||||
self.value[extra_text_key] = "".join(self.texts)
|
||||
return self.value
|
||||
else:
|
||||
return self.value
|
||||
|
||||
def append(self, item: Self) -> None:
|
||||
if self.value is None:
|
||||
self.value = {item.tag: item.get_value()}
|
||||
elif isinstance(self.value, dict):
|
||||
if item.tag in self.value:
|
||||
# NOTE: Duplicate tag inside an object is collapsed into a
|
||||
# list to preserve all values, even if the schema declares
|
||||
# the key as a singleton. We don't drop the data and we
|
||||
# don't try to "fix" the schema mismatch silently — the agent
|
||||
# loop should surface the inconsistency back to the model.
|
||||
value = self.value[item.tag]
|
||||
if isinstance(value, list):
|
||||
value.append(item.get_value())
|
||||
else:
|
||||
self.value[item.tag] = [value, item.get_value()]
|
||||
else:
|
||||
self.value[item.tag] = item.get_value()
|
||||
elif isinstance(self.value, list):
|
||||
# We expect `item.tag` to be `"item"`, but if it's not, we should still accept it.
|
||||
self.value.append(item.get_value())
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def append_text(self, value: str) -> None:
|
||||
if isinstance(self.value, list):
|
||||
if self.data_type:
|
||||
data_type = self.data_type.get_data_type_of_item(index=len(self.value))
|
||||
if data_type:
|
||||
value = data_type.convert(value)
|
||||
self.value.append(value)
|
||||
elif self.texts is None:
|
||||
self.texts = [value]
|
||||
else:
|
||||
self.texts.append(value)
|
||||
|
||||
def get_data_type_of_property(
|
||||
self, tag: str
|
||||
) -> Optional[FunctionCallParameterDataType]:
|
||||
if self.data_type:
|
||||
if isinstance(self.value, list):
|
||||
# We expect `tag` to be `"item"`, but if it's not, we should still accept it.
|
||||
return self.data_type.get_data_type_of_item(index=len(self.value))
|
||||
elif isinstance(self.value, dict):
|
||||
return self.data_type.get_data_type_of_property(tag)
|
||||
else:
|
||||
return None
|
||||
@@ -1,311 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
|
||||
from openai.types.responses.function_tool import FunctionTool
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers._llm_nom.m3_text import M3TextParser
|
||||
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from vllm.tool_parsers.rust_tool_parser import RustToolParser
|
||||
|
||||
|
||||
class MinimaxM3ToolParser(ToolParser):
|
||||
"""Adapter from the vendored MiniMax M3 parser to vLLM ToolParser.
|
||||
class MinimaxM3ToolParser(RustToolParser):
|
||||
"""Adapter from the Rust MiniMax M3 parser to vLLM ToolParser.
|
||||
|
||||
The real M3 grammar lives in ``_llm_nom.m3_text.M3TextParser``. This
|
||||
class keeps only the vLLM-specific bridge work:
|
||||
- convert vLLM tool definitions into the function schema shape expected by
|
||||
the vendored parser;
|
||||
- translate parser ``content`` / ``tool_calls`` deltas into vLLM protocol
|
||||
objects; and
|
||||
- maintain vLLM streaming bookkeeping used by finish-reason handling.
|
||||
The real M3 grammar lives in the Rust tool-parser crate. This class only
|
||||
configures the generic Rust bridge with the MiniMax M3 parser name.
|
||||
|
||||
M3 is not M2 with renamed tags: it prefixes each structural tag with the
|
||||
MiniMax namespace marker, allows multiple ``<invoke>`` tags in one wrapper,
|
||||
and represents nested arguments with parameter-name XML tags.
|
||||
"""
|
||||
|
||||
# M3 emits its own XML-like tool-call format from the chat template. For
|
||||
# required/named tool_choice, do not let the serving layer force JSON guided
|
||||
# output; parse the M3 syntax through this parser instead.
|
||||
supports_required_and_named = False
|
||||
|
||||
rust_parser_name = "minimax_m3"
|
||||
tool_call_start_token = "]<]minimax[>[<tool_call>"
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||||
super().__init__(tokenizer, tools)
|
||||
self._parser: M3TextParser | None = None
|
||||
self._error: Exception | None = None
|
||||
self._tool_call_ids: dict[int, str] = {}
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"vLLM successfully imported tool parser %s", self.__class__.__name__
|
||||
)
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> ChatCompletionRequest | ResponsesRequest:
|
||||
"""Adjust generation options for MiniMax M3 tool-call syntax.
|
||||
|
||||
Required/named tool choice must skip ``super().adjust_request()``
|
||||
because the base implementation would install JSON structured output
|
||||
constraints. M3 needs to preserve and generate its namespace-tagged
|
||||
syntax, so we only ensure special-token text is not stripped.
|
||||
"""
|
||||
if request.tools:
|
||||
tool_choice = getattr(request, "tool_choice", None)
|
||||
if tool_choice == "required" or isinstance(
|
||||
tool_choice, ChatCompletionNamedToolChoiceParam
|
||||
):
|
||||
if hasattr(request, "skip_special_tokens"):
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
request = super().adjust_request(request)
|
||||
if (
|
||||
request.tools
|
||||
and getattr(request, "tool_choice", None) != "none"
|
||||
and hasattr(request, "skip_special_tokens")
|
||||
):
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def _functions(self) -> dict[str, dict[str, Any]] | None:
|
||||
"""Build the function map consumed by ``M3TextParser``."""
|
||||
if not self.tools:
|
||||
return None
|
||||
|
||||
functions: dict[str, dict[str, Any]] = {}
|
||||
for tool in self.tools:
|
||||
if isinstance(tool, FunctionTool):
|
||||
name = tool.name
|
||||
parameters = tool.parameters
|
||||
elif isinstance(tool, ChatCompletionToolsParam):
|
||||
name = tool.function.name
|
||||
parameters = tool.function.parameters
|
||||
else:
|
||||
continue
|
||||
functions[name] = {"parameters": parameters}
|
||||
return functions
|
||||
|
||||
def _new_parser(self) -> M3TextParser:
|
||||
"""Create a fresh vendored parser with the current tool schemas."""
|
||||
return M3TextParser(
|
||||
with_reasoning=False,
|
||||
reasoning_prefix="",
|
||||
functions=self._functions(),
|
||||
)
|
||||
|
||||
def _get_parser(self) -> M3TextParser:
|
||||
if self._parser is None:
|
||||
self._parser = self._new_parser()
|
||||
return self._parser
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
"""Reset parser state for a new request on a reused parser instance."""
|
||||
self._parser = self._new_parser()
|
||||
self._error = None
|
||||
self._tool_call_ids.clear()
|
||||
self.prev_tool_call_arr.clear()
|
||||
self.streamed_args_for_tool.clear()
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
|
||||
def _ensure_tool_state(self, index: int) -> None:
|
||||
"""Grow vLLM streaming state arrays to contain ``index``."""
|
||||
while len(self.prev_tool_call_arr) <= index:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
def _record_delta(
|
||||
self, index: int, name: str | None, arguments: str | None
|
||||
) -> str | None:
|
||||
"""Mirror a vendored-parser delta into vLLM streaming bookkeeping.
|
||||
|
||||
``prev_tool_call_arr`` and ``streamed_args_for_tool`` are read later by
|
||||
the chat serving layer to decide the final ``tool_calls`` finish reason
|
||||
and to flush any remaining argument bytes.
|
||||
"""
|
||||
tool_call_id = None
|
||||
self._ensure_tool_state(index)
|
||||
|
||||
if name is not None:
|
||||
tool_call_id = make_tool_call_id()
|
||||
self._tool_call_ids[index] = tool_call_id
|
||||
self.prev_tool_call_arr[index] = {"name": name, "arguments": {}}
|
||||
self.current_tool_name_sent = True
|
||||
|
||||
if arguments is not None:
|
||||
self.streamed_args_for_tool[index] += arguments
|
||||
with suppress(json.JSONDecodeError):
|
||||
self.prev_tool_call_arr[index]["arguments"] = json.loads(
|
||||
self.streamed_args_for_tool[index]
|
||||
)
|
||||
self.current_tool_id = index
|
||||
|
||||
return tool_call_id
|
||||
|
||||
def _delta_message_from_parser_delta(
|
||||
self, parser_delta: dict[str, Any] | None
|
||||
) -> DeltaMessage | None:
|
||||
"""Translate one ``M3TextParser`` delta into a vLLM ``DeltaMessage``."""
|
||||
if parser_delta is None:
|
||||
return None
|
||||
|
||||
normal_text = parser_delta.get("content") or None
|
||||
tool_calls: list[DeltaToolCall] = []
|
||||
for tool_call in parser_delta.get("tool_calls", []):
|
||||
func = tool_call.get("function", {})
|
||||
index = tool_call.get("index", 0)
|
||||
name = func.get("name")
|
||||
arguments = func.get("arguments")
|
||||
if name is None and arguments is None:
|
||||
continue
|
||||
|
||||
tool_call_id = self._record_delta(index, name, arguments)
|
||||
tool_calls.append(
|
||||
DeltaToolCall(
|
||||
index=index,
|
||||
id=tool_call_id,
|
||||
type="function" if name is not None else None,
|
||||
function=DeltaFunctionCall(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if normal_text is None and not tool_calls:
|
||||
return None
|
||||
return DeltaMessage(content=normal_text, tool_calls=tool_calls)
|
||||
|
||||
def _parse_complete(self, model_output: str) -> dict[str, Any] | None:
|
||||
"""Parse complete model output with a throwaway parser instance."""
|
||||
parser = self._new_parser()
|
||||
try:
|
||||
parser.update(model_output)
|
||||
except Exception:
|
||||
logger.exception("Error parsing MiniMax M3 tool call output.")
|
||||
return None
|
||||
return parser.get_delta() or parser.get_final()
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""Extract tool calls from complete model output (non-streaming)."""
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
|
||||
parsed = self._parse_complete(model_output)
|
||||
if parsed is None:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
self.prev_tool_call_arr.clear()
|
||||
for parsed_tool_call in parsed.get("tool_calls", []):
|
||||
func = parsed_tool_call.get("function", {})
|
||||
name = func.get("name")
|
||||
arguments = func.get("arguments", "{}")
|
||||
if name is None:
|
||||
continue
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=name, arguments=arguments),
|
||||
)
|
||||
)
|
||||
try:
|
||||
args = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
args = arguments
|
||||
self.prev_tool_call_arr.append({"name": name, "arguments": args})
|
||||
|
||||
if not tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
|
||||
content = parsed.get("content") or None
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
current_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest, # pylint: disable=unused-argument
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract tool calls from streaming model output.
|
||||
|
||||
``M3TextParser`` owns the incremental buffer, so this adapter feeds only
|
||||
the newest text delta. It returns an empty final content delta on EOS
|
||||
after a tool call so the serving layer reaches its finish-reason path.
|
||||
"""
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
|
||||
if self._error is not None:
|
||||
return None
|
||||
|
||||
try:
|
||||
self._get_parser().update(delta_text)
|
||||
except Exception as error:
|
||||
self._error = error
|
||||
logger.exception("Error parsing MiniMax M3 streaming tool call output.")
|
||||
|
||||
parser_delta = self._get_parser().get_delta()
|
||||
delta_message = self._delta_message_from_parser_delta(parser_delta)
|
||||
if delta_message is not None:
|
||||
return delta_message
|
||||
|
||||
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,329 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from openai.types.responses.function_tool import FunctionTool
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _rust_tool_parser_module() -> Any:
|
||||
try:
|
||||
from vllm import _rust_tool_parser
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"Rust tool parsing requires the vllm._rust_tool_parser PyO3 "
|
||||
"extension. Rebuild vLLM with Rust extensions enabled."
|
||||
) from exc
|
||||
return _rust_tool_parser
|
||||
|
||||
|
||||
class RustToolParser(ToolParser):
|
||||
"""Adapter from an opaque Rust parser to the vLLM ToolParser API.
|
||||
|
||||
Subclasses provide only model-specific configuration: the exact Rust parser
|
||||
name and an optional tool-call start marker for fast complete-output
|
||||
rejection.
|
||||
|
||||
This class keeps the vLLM-specific bridge work:
|
||||
- convert vLLM tool definitions into the Rust ``Tool`` shape;
|
||||
- translate typed Rust parser outputs into vLLM protocol objects; and
|
||||
- maintain vLLM streaming bookkeeping used by finish-reason handling.
|
||||
|
||||
The parser grammar and incremental parser state stay in Rust.
|
||||
"""
|
||||
|
||||
# Rust-backed parsers are opaque to Python by default. Do not use vLLM's
|
||||
# standard JSON required/named handling; let the Rust parser consume the
|
||||
# model's native tool-call syntax.
|
||||
supports_required_and_named = False
|
||||
|
||||
rust_parser_name: str
|
||||
tool_call_start_token: str | None = None
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||||
super().__init__(tokenizer, tools)
|
||||
self._parser: Any | None = None
|
||||
self._error: Exception | None = None
|
||||
self._finished = False
|
||||
self._tool_call_ids: dict[int, str] = {}
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"vLLM successfully imported tool parser %s", self.__class__.__name__
|
||||
)
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> ChatCompletionRequest | ResponsesRequest:
|
||||
"""Adjust request options without installing Python-side constraints.
|
||||
|
||||
Rust-backed parsers are treated as source-of-truth opaque parsers. The
|
||||
bridge intentionally avoids ``super().adjust_request()`` so Python does
|
||||
not install JSON schema guidance or structural-tag constraints that may
|
||||
conflict with the Rust parser's native grammar.
|
||||
"""
|
||||
if self._get_parser().preserve_special_tokens():
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def _rust_tools(self) -> list[Any]:
|
||||
"""Build Rust ``Tool`` objects from vLLM tool definitions."""
|
||||
if not self.tools:
|
||||
return []
|
||||
|
||||
tools: list[Any] = []
|
||||
for tool in self.tools:
|
||||
if isinstance(tool, FunctionTool):
|
||||
name = tool.name
|
||||
description = tool.description
|
||||
parameters = tool.parameters or {}
|
||||
strict = getattr(tool, "strict", None)
|
||||
elif isinstance(tool, ChatCompletionToolsParam):
|
||||
name = tool.function.name
|
||||
description = tool.function.description
|
||||
parameters = tool.function.parameters or {}
|
||||
strict = getattr(tool.function, "strict", None)
|
||||
else:
|
||||
continue
|
||||
tools.append(
|
||||
_rust_tool_parser_module().Tool(name, description, parameters, strict)
|
||||
)
|
||||
return tools
|
||||
|
||||
def _new_parser(self) -> Any:
|
||||
"""Create a fresh Rust parser with the current tool schemas."""
|
||||
return _rust_tool_parser_module().ToolParser(
|
||||
self.rust_parser_name, self._rust_tools()
|
||||
)
|
||||
|
||||
def _get_parser(self) -> Any:
|
||||
if self._parser is None:
|
||||
self._parser = self._new_parser()
|
||||
return self._parser
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
"""Reset parser state for a new request on a reused parser instance."""
|
||||
self._parser = self._new_parser()
|
||||
self._error = None
|
||||
self._finished = False
|
||||
self._tool_call_ids.clear()
|
||||
self.prev_tool_call_arr.clear()
|
||||
self.streamed_args_for_tool.clear()
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
|
||||
def _ensure_tool_state(self, index: int) -> None:
|
||||
"""Grow vLLM streaming state arrays to contain ``index``."""
|
||||
while len(self.prev_tool_call_arr) <= index:
|
||||
self.prev_tool_call_arr.append({})
|
||||
while len(self.streamed_args_for_tool) <= index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
def _record_delta(
|
||||
self, index: int, name: str | None, arguments: str | None
|
||||
) -> str | None:
|
||||
"""Mirror a Rust parser delta into vLLM streaming bookkeeping.
|
||||
|
||||
``prev_tool_call_arr`` and ``streamed_args_for_tool`` are read later by
|
||||
the chat serving layer to decide the final ``tool_calls`` finish reason
|
||||
and to flush any remaining argument bytes.
|
||||
"""
|
||||
tool_call_id = None
|
||||
self._ensure_tool_state(index)
|
||||
|
||||
if name is not None:
|
||||
tool_call_id = make_tool_call_id()
|
||||
self._tool_call_ids[index] = tool_call_id
|
||||
self.prev_tool_call_arr[index] = {"name": name, "arguments": {}}
|
||||
self.current_tool_name_sent = True
|
||||
|
||||
if arguments is not None:
|
||||
self.streamed_args_for_tool[index] += arguments
|
||||
self.prev_tool_call_arr[index]["arguments"] = (
|
||||
self.streamed_args_for_tool[index]
|
||||
)
|
||||
self.current_tool_id = index
|
||||
|
||||
return tool_call_id
|
||||
|
||||
def _delta_message_from_parser_output(
|
||||
self, parser_output: Any | None
|
||||
) -> DeltaMessage | None:
|
||||
"""Translate one Rust parser output into a vLLM ``DeltaMessage``."""
|
||||
if parser_output is None:
|
||||
return None
|
||||
|
||||
normal_text = parser_output.normal_text or None
|
||||
tool_calls: list[DeltaToolCall] = []
|
||||
for tool_call in parser_output.calls:
|
||||
index = tool_call.tool_index
|
||||
name = tool_call.name
|
||||
arguments: str | None = tool_call.arguments
|
||||
if name is None and arguments is None:
|
||||
continue
|
||||
|
||||
tool_call_id = self._record_delta(index, name, arguments)
|
||||
tool_calls.append(
|
||||
DeltaToolCall(
|
||||
index=index,
|
||||
id=tool_call_id,
|
||||
type="function" if name is not None else None,
|
||||
function=DeltaFunctionCall(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if normal_text is None and not tool_calls:
|
||||
return None
|
||||
return DeltaMessage(content=normal_text, tool_calls=tool_calls)
|
||||
|
||||
def _parse_complete(self, model_output: str) -> Any | None:
|
||||
"""Parse complete model output with a throwaway Rust parser instance."""
|
||||
parser = self._new_parser()
|
||||
output = _rust_tool_parser_module().ToolParserOutput()
|
||||
try:
|
||||
parser.parse_into(model_output, output)
|
||||
output.append(parser.finish())
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error parsing %s tool call output.", self.rust_parser_name
|
||||
)
|
||||
return None
|
||||
return output.coalesce_calls()
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""Extract tool calls from complete model output (non-streaming)."""
|
||||
if (
|
||||
self.tool_call_start_token is not None
|
||||
and self.tool_call_start_token not in model_output
|
||||
):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
|
||||
parsed = self._parse_complete(model_output)
|
||||
if parsed is None:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
self.prev_tool_call_arr.clear()
|
||||
for parsed_tool_call in parsed.calls:
|
||||
name = parsed_tool_call.name
|
||||
arguments = parsed_tool_call.arguments or "{}"
|
||||
if name is None:
|
||||
continue
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=name, arguments=arguments),
|
||||
)
|
||||
)
|
||||
self.prev_tool_call_arr.append({"name": name, "arguments": arguments})
|
||||
|
||||
if not tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
|
||||
content = parsed.normal_text or None
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
current_token_ids: Sequence[int], # pylint: disable=unused-argument
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest, # pylint: disable=unused-argument
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract tool calls from streaming model output.
|
||||
|
||||
The Rust parser owns the incremental buffer, so this adapter feeds only
|
||||
the newest text delta. On EOS, it calls ``finish()`` once to flush any
|
||||
complete buffered tool call. It returns an empty final content delta
|
||||
after a tool call so the serving layer reaches its finish-reason path.
|
||||
"""
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
|
||||
if self._error is not None:
|
||||
return None
|
||||
|
||||
parser_output = _rust_tool_parser_module().ToolParserOutput()
|
||||
try:
|
||||
self._get_parser().parse_into(delta_text, parser_output)
|
||||
except Exception as error:
|
||||
self._error = error
|
||||
logger.exception(
|
||||
"Error parsing %s streaming tool call output.",
|
||||
self.rust_parser_name,
|
||||
)
|
||||
|
||||
delta_message = self._delta_message_from_parser_output(parser_output)
|
||||
if delta_message is not None:
|
||||
return delta_message
|
||||
|
||||
if not delta_text and delta_token_ids and not self._finished:
|
||||
try:
|
||||
finish_output = self._get_parser().finish()
|
||||
self._finished = True
|
||||
except Exception as error:
|
||||
self._error = error
|
||||
logger.exception(
|
||||
"Error finishing %s streaming tool parser.",
|
||||
self.rust_parser_name,
|
||||
)
|
||||
finish_output = None
|
||||
delta_message = self._delta_message_from_parser_output(finish_output)
|
||||
if delta_message is not None:
|
||||
return delta_message
|
||||
|
||||
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
Reference in New Issue
Block a user