refactor & simplify

Signed-off-by: Bugen Zhao <i@bugenzhao.com>
This commit is contained in:
Bugen Zhao
2026-06-05 06:39:10 +00:00
parent 53275a22d6
commit 76c973e13c
2 changed files with 87 additions and 120 deletions
+86 -119
View File
@@ -1,22 +1,50 @@
//! Thin PyO3 bindings for `vllm_tool_parser`.
//!
//! This crate exposes the Rust tool parser trait and data shapes to Python
//! while keeping parser state, grammar, and schema-aware argument conversion in
//! Rust. Python callers should use this module as a typed bridge and keep any
//! vLLM protocol adaptation outside the binding.
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyModule};
use pythonize::{depythonize, pythonize};
use serde_json::Value;
use thiserror_ext::AsReport as _;
use vllm_tool_parser::{
DeepSeekV3ToolParser, DeepSeekV4ToolParser, DeepSeekV31ToolParser, DeepSeekV32ToolParser,
Gemma4ToolParser, Glm45MoeToolParser, Glm47MoeToolParser, HermesToolParser, HyV3ToolParser,
KimiK2ToolParser, Llama3JsonToolParser, MinimaxM2ToolParser, MinimaxM3ToolParser,
MistralToolParser, Qwen3CoderToolParser, Qwen3XmlToolParser, Tool, ToolCallDelta, ToolParser,
ToolParserOutput,
};
use vllm_tool_parser::{Tool, ToolCallDelta, ToolParser, ToolParserOutput};
macro_rules! tool_parser_factory {
($($parser:ident),+ $(,)?) => {
fn create_tool_parser(
name: &str,
tools: &[Tool],
) -> PyResult<Box<dyn ToolParser>> {
match name {
$(
stringify!($parser) => {
<vllm_tool_parser::$parser as ToolParser>::create(tools)
}
)+
_ => {
return Err(PyValueError::new_err(format!(
"unsupported tool parser `{name}`"
)));
}
}
.map_err(|error| PyValueError::new_err(error.to_report_string()))
}
};
}
// Export a tool parser to Python by registering it here.
tool_parser_factory! {
DeepSeekV4ToolParser,
MinimaxM3ToolParser,
}
#[pyclass(name = "Tool", module = "vllm._rust_tool_parser", skip_from_py_object)]
#[derive(Clone)]
struct PyTool {
inner: Tool,
}
struct PyTool(Tool);
#[pymethods]
impl PyTool {
@@ -33,29 +61,27 @@ impl PyTool {
"failed to convert tool parameters from Python to JSON: {error}"
))
})?;
Ok(Self {
inner: Tool {
name,
description,
parameters,
strict,
},
})
Ok(Self(Tool {
name,
description,
parameters,
strict,
}))
}
#[getter]
fn name(&self) -> &str {
&self.inner.name
&self.0.name
}
#[getter]
fn description(&self) -> Option<&str> {
self.inner.description.as_deref()
self.0.description.as_deref()
}
#[getter]
fn parameters(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
pythonize(py, &self.inner.parameters).map(Bound::unbind).map_err(|error| {
pythonize(py, &self.0.parameters).map(Bound::unbind).map_err(|error| {
PyValueError::new_err(format!(
"failed to convert tool parameters from JSON to Python: {error}"
))
@@ -64,7 +90,7 @@ impl PyTool {
#[getter]
fn strict(&self) -> Option<bool> {
self.inner.strict
self.0.strict
}
}
@@ -74,37 +100,33 @@ impl PyTool {
skip_from_py_object
)]
#[derive(Clone)]
struct PyToolCallDelta {
inner: ToolCallDelta,
}
struct PyToolCallDelta(ToolCallDelta);
#[pymethods]
impl PyToolCallDelta {
#[new]
#[pyo3(signature = (tool_index, name, arguments))]
fn new(tool_index: usize, name: Option<String>, arguments: String) -> Self {
Self {
inner: ToolCallDelta {
tool_index,
name,
arguments,
},
}
Self(ToolCallDelta {
tool_index,
name,
arguments,
})
}
#[getter]
fn tool_index(&self) -> usize {
self.inner.tool_index
self.0.tool_index
}
#[getter]
fn name(&self) -> Option<&str> {
self.inner.name.as_deref()
self.0.name.as_deref()
}
#[getter]
fn arguments(&self) -> &str {
&self.inner.arguments
&self.0.arguments
}
}
@@ -114,96 +136,47 @@ impl PyToolCallDelta {
skip_from_py_object
)]
#[derive(Clone)]
struct PyToolParserOutput {
inner: ToolParserOutput,
}
impl PyToolParserOutput {
fn from_inner(inner: ToolParserOutput) -> Self {
Self { inner }
}
}
struct PyToolParserOutput(ToolParserOutput);
#[pymethods]
impl PyToolParserOutput {
#[new]
#[pyo3(signature = (normal_text="", calls=None))]
fn new(py: Python<'_>, normal_text: &str, calls: Option<Vec<Py<PyToolCallDelta>>>) -> Self {
let calls = calls
.unwrap_or_default()
.iter()
.map(|call| call.borrow(py).inner.clone())
.collect();
Self {
inner: ToolParserOutput {
normal_text: normal_text.to_owned(),
calls,
},
}
let calls =
calls.unwrap_or_default().iter().map(|call| call.borrow(py).0.clone()).collect();
Self(ToolParserOutput {
normal_text: normal_text.to_owned(),
calls,
})
}
#[getter]
fn normal_text(&self) -> &str {
&self.inner.normal_text
&self.0.normal_text
}
#[getter]
fn calls(&self) -> Vec<PyToolCallDelta> {
self.inner
.calls
.iter()
.cloned()
.map(|inner| PyToolCallDelta { inner })
.collect()
self.0.calls.iter().cloned().map(PyToolCallDelta).collect()
}
fn append(&mut self, other: PyRef<'_, PyToolParserOutput>) {
self.inner.append(other.inner.clone());
self.0.append(other.0.clone());
}
fn coalesce_calls(&self) -> Self {
Self::from_inner(self.inner.clone().coalesce_calls())
Self(self.0.clone().coalesce_calls())
}
}
#[pyclass(name = "ToolParser", module = "vllm._rust_tool_parser", unsendable)]
struct PyToolParser {
parser: Box<dyn ToolParser>,
}
struct PyToolParser(Box<dyn ToolParser>);
impl PyToolParser {
fn from_name(name: &str, tools: &[Tool]) -> PyResult<Self> {
let parser = match name {
"deepseek_v3" => DeepSeekV3ToolParser::create(tools),
"deepseek_v31" => DeepSeekV31ToolParser::create(tools),
"deepseek_v32" => DeepSeekV32ToolParser::create(tools),
"deepseek_v4" => DeepSeekV4ToolParser::create(tools),
"gemma4" => Gemma4ToolParser::create(tools),
"glm45" => Glm45MoeToolParser::create(tools),
"glm47" => Glm47MoeToolParser::create(tools),
"hermes" => HermesToolParser::create(tools),
"hy_v3" => HyV3ToolParser::create(tools),
"kimi_k2" => KimiK2ToolParser::create(tools),
"llama3_json" | "llama4_json" => Llama3JsonToolParser::create(tools),
"minimax_m2" => MinimaxM2ToolParser::create(tools),
"minimax_m3" => MinimaxM3ToolParser::create(tools),
"mistral" => MistralToolParser::create(tools),
"qwen3_xml" => Qwen3XmlToolParser::create(tools),
"qwen3_coder" => Qwen3CoderToolParser::create(tools),
_ => {
return Err(PyValueError::new_err(format!(
"unsupported tool parser `{name}`"
)));
}
}
.map_err(|error| PyValueError::new_err(error.to_report_string()))?;
Ok(Self { parser })
}
fn parse_into_output(&mut self, chunk: &str, output: &mut PyToolParserOutput) -> PyResult<()> {
self.parser
.parse_into(chunk, &mut output.inner)
self.0
.parse_into(chunk, &mut output.0)
.map_err(|error| PyValueError::new_err(error.to_report_string()))
}
}
@@ -212,8 +185,8 @@ impl PyToolParser {
impl PyToolParser {
#[new]
fn new(py: Python<'_>, parser_name: &str, tools: Vec<Py<PyTool>>) -> PyResult<Self> {
let tools = tools.iter().map(|tool| tool.borrow(py).inner.clone()).collect::<Vec<_>>();
Self::from_name(parser_name, &tools)
let tools = tools.iter().map(|tool| tool.borrow(py).0.clone()).collect::<Vec<_>>();
create_tool_parser(parser_name, &tools).map(Self)
}
fn parse_into(
@@ -225,18 +198,18 @@ impl PyToolParser {
}
fn finish(&mut self) -> PyResult<PyToolParserOutput> {
self.parser
self.0
.finish()
.map(PyToolParserOutput::from_inner)
.map(PyToolParserOutput)
.map_err(|error| PyValueError::new_err(error.to_report_string()))
}
fn reset(&mut self) -> String {
self.parser.reset()
self.0.reset()
}
fn preserve_special_tokens(&self) -> bool {
self.parser.preserve_special_tokens()
self.0.preserve_special_tokens()
}
}
@@ -254,8 +227,6 @@ mod tests {
use super::*;
use serde_json::json;
const NS: &str = "]<]minimax[>[";
fn with_python<R>(f: impl for<'py> FnOnce(Python<'py>) -> R) -> R {
Python::initialize();
Python::attach(f)
@@ -278,17 +249,13 @@ mod tests {
}
fn build_call() -> String {
format!(
"{NS}<tool_call>\n\
{NS}<invoke name=\"create_order\">\
{NS}<user_id>42{NS}</user_id>\
{NS}<shipping>\
{NS}<city>Singapore{NS}</city>\
{NS}<zip>018956{NS}</zip>\
{NS}</shipping>\
{NS}</invoke>\n\
{NS}</tool_call>"
)
r#"<DSMLtool_calls>
<DSMLinvoke name="create_order">
<DSMLparameter name="user_id" string="false">42</DSMLparameter>
<DSMLparameter name="shipping" string="false">{"city":"Singapore","zip":18956}</DSMLparameter>
</DSMLinvoke>
</DSMLtool_calls>"#
.to_owned()
}
fn make_py_tool(py: Python<'_>) -> PyResult<Py<PyTool>> {
@@ -353,8 +320,8 @@ mod tests {
fn parser_parse_finish_and_preserve_special_tokens() {
with_python(|py| {
let tool = make_py_tool(py)?;
let mut parser = PyToolParser::new(py, "minimax_m3", vec![tool])?;
assert!(!parser.preserve_special_tokens());
let mut parser = PyToolParser::new(py, "DeepSeekV4ToolParser", vec![tool])?;
assert!(parser.preserve_special_tokens());
let mut output = PyToolParserOutput::new(py, "", None);
parser.parse_into_output(&build_call(), &mut output)?;
+1 -1
View File
@@ -15,5 +15,5 @@ class MinimaxM3ToolParser(RustToolParser):
and represents nested arguments with parameter-name XML tags.
"""
rust_parser_name = "minimax_m3"
rust_parser_name = "MinimaxM3ToolParser"
tool_call_start_token = "]<]minimax[>[<tool_call>"