[Rust Frontend] Fix several hf chat template rendering issues (#44311)

Signed-off-by: Bugen Zhao <i@bugenzhao.com>
This commit is contained in:
Bugen Zhao
2026-06-03 16:04:43 +08:00
committed by GitHub
parent 6550ff12f2
commit 449be4f934
9 changed files with 183 additions and 56 deletions
+1
View File
@@ -5622,6 +5622,7 @@ dependencies = [
"expect-test",
"futures",
"half",
"indexmap 2.13.0",
"itertools 0.14.0",
"llm-multimodal",
"minijinja",
+2 -1
View File
@@ -43,6 +43,7 @@ half = { version = "2.7.1", features = ["bytemuck"] }
hex = "0.4.3"
hf-hub = { version = "0.5.0", features = ["tokio"] }
http-body = "1.0.1"
indexmap = "2.13.0"
itertools = "0.14.0"
libc = "0.2.177"
llm-multimodal = { git = "https://github.com/vllm-project/llm-multimodal", rev = "5b558989844d1c7af3e43d0f604069ffd9c06320" }
@@ -69,7 +70,7 @@ rustc-hash = "1.1.0"
serde = { version = "1.0.228", features = ["derive"] }
serde-json-fmt = "0.1.0"
serde_default = "0.2.0"
serde_json = { version = "1.0.145", features = ["arbitrary_precision", "preserve_order"] }
serde_json = { version = "1.0.145", features = ["preserve_order"] }
serde_repr = "0.1.20"
serde_tuple = "1.1.3"
serde_with = "3.18.0"
+1
View File
@@ -10,6 +10,7 @@ asynk-strim-attr.workspace = true
easy-ext.workspace = true
futures.workspace = true
half.workspace = true
indexmap.workspace = true
itertools.workspace = true
llm-multimodal.workspace = true
minijinja.workspace = true
+33 -7
View File
@@ -1,7 +1,7 @@
use std::collections::HashMap;
use serde::Serialize;
use serde_json::Value;
use serde_json::Value as JsonValue;
use thiserror_ext::AsReport as _;
use tracing::{info, trace, warn};
use vllm_text::Prompt;
@@ -13,6 +13,7 @@ use self::format::{
ChatTemplateContentFormat, ChatTemplateContentFormatOption as ContentFormatOption,
};
use self::template::{CompiledChatTemplate, TemplateContext};
use self::value::{TemplateValue, to_template_value};
use super::{ChatRenderer, RenderedPrompt};
use crate::error::Result;
use crate::request::{ChatContent, ChatContentPart, ChatMessage, ChatRequest};
@@ -24,6 +25,7 @@ mod error;
mod format;
mod template;
mod tojson;
mod value;
pub use template::{load_chat_template, resolve_chat_template};
@@ -38,7 +40,7 @@ pub struct MultimodalRenderInfo {
/// state.
pub struct HfChatRenderer {
default_template: Option<CompiledChatTemplate>,
default_template_kwargs: HashMap<String, Value>,
default_template_kwargs: HashMap<String, JsonValue>,
content_format: ContentFormatOption,
special_tokens: Option<HfSpecialTokens>,
multimodal: Option<MultimodalRenderInfo>,
@@ -48,7 +50,7 @@ impl HfChatRenderer {
/// Create a renderer from the given template string.
pub fn new(
template: Option<String>,
default_template_kwargs: HashMap<String, Value>,
default_template_kwargs: HashMap<String, JsonValue>,
content_format: ContentFormatOption,
) -> Result<Self> {
Ok(Self {
@@ -245,7 +247,7 @@ struct TemplateToolCall {
#[derive(Debug, Serialize)]
struct TemplateToolFunction {
name: String,
arguments: Value,
arguments: TemplateValue,
}
#[derive(Debug, Serialize)]
@@ -259,7 +261,7 @@ pub(super) struct TemplateTool {
struct TemplateToolDefinition {
name: String,
description: Option<String>,
parameters: Value,
parameters: TemplateValue,
strict: Option<bool>,
}
@@ -345,13 +347,14 @@ fn to_template_tool_calls(
let mut tool_calls = Vec::new();
for tool_call in content.tool_calls() {
let arguments = serde_json::from_str::<Value>(&tool_call.arguments).map_err(|error| {
let arguments = serde_json::from_str(&tool_call.arguments).map_err(|error| {
Error::ChatTemplate(format!(
"assistant tool call `{}` has invalid JSON arguments: {}",
tool_call.id,
error.as_report()
))
})?;
let arguments = to_template_value(arguments);
tool_calls.push(TemplateToolCall {
id: tool_call.id.clone(),
@@ -434,7 +437,7 @@ fn to_template_tools(tools: &[ChatTool]) -> Vec<TemplateTool> {
function: TemplateToolDefinition {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
parameters: to_template_value(tool.parameters.clone()),
strict: tool.strict,
},
})
@@ -909,6 +912,29 @@ mod tests {
assert_eq!(rendered, "get_weather|Paris|call_1|Sunny");
}
#[test]
fn chat_template_tool_call_argument_items_method_is_not_shadowed_by_field() {
let request = sample_request(vec![ChatMessage::assistant_blocks(vec![
AssistantContentBlock::ToolCall(crate::AssistantToolCall {
id: "call_1".to_string(),
name: "add".to_string(),
arguments: r#"{"items":"operands","x":2,"y":1.0}"#.to_string(),
}),
])]);
let rendered = render(
Some(
"{%- set arguments = messages[0].tool_calls[0].function.arguments -%}
{%- for key, value in arguments.items() -%}{{ key }}={{ value }};{%- endfor -%}
|{{ arguments['items'] }}",
),
&request,
)
.unwrap();
assert_eq!(rendered, "items=operands;x=2;y=1.0;|operands");
}
#[test]
fn qwen35_template_renders_prefilled_reasoning_start_when_thinking_enabled() {
let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]);
+18 -2
View File
@@ -208,11 +208,27 @@ mod tests {
}
#[test]
fn tojson_preserves_arbitrary_precision_number_spelling() {
fn tojson_uses_standard_serde_json_number_spelling() {
let payload = serde_json::from_str(r#"{"x":2,"y":1.00}"#).unwrap();
let rendered = render("{{ payload|tojson }}", payload);
assert_eq!(rendered, "{\"x\": 2, \"y\": 1.00}");
// TODO: we cannot preserve the original number precision by enabling `serde_json`'s
// `arbitrary_precision` feature, otherwise the following test
// `serialized_json_numbers_do_not_leak_serde_private_representation` will fail.
// See issue: https://github.com/mitsuhiko/minijinja/issues/641
assert_eq!(rendered, "{\"x\": 2, \"y\": 1.0}");
}
#[test]
fn serialized_json_numbers_do_not_leak_serde_private_representation() {
let payload: serde_json::Value = serde_json::from_str(r#"{"x":2,"y":1.00}"#).unwrap();
let rendered = render("{{ payload }}", payload);
// TODO: we cannot preserve the original number precision by enabling `serde_json`'s
// `arbitrary_precision` feature, otherwise this will fail.
// See issue: https://github.com/mitsuhiko/minijinja/issues/641
assert!(!rendered.contains("$serde_json::private::Number"));
assert_eq!(rendered, r#"{"x": 2, "y": 1.0}"#);
}
#[test]
+77
View File
@@ -0,0 +1,77 @@
use std::sync::Arc;
use indexmap::IndexMap;
use minijinja::value::{Enumerator, Object, ObjectExt, ObjectRepr};
use minijinja::{Error as TemplateError, ErrorKind as TemplateErrorKind, State};
use serde::Serialize;
use serde_json::Value as JsonValue;
/// A wrapper around `minijinja::Value` that can be constructed with `to_template_value` and used
/// as a value in the chat template.
#[derive(Debug, Serialize)]
#[serde(transparent)]
pub(super) struct TemplateValue(minijinja::Value);
pub(super) fn to_template_value(value: JsonValue) -> TemplateValue {
TemplateValue(match value {
JsonValue::Array(values) => values
.into_iter()
.map(to_template_value)
.map(|value| value.0)
.collect::<minijinja::Value>(),
JsonValue::Object(values) => minijinja::Value::from_object(TemplateMap(
values
.into_iter()
.map(|(key, value)| (key, to_template_value(value).0))
.collect(),
)),
// For primitive values, directly convert them to `minijinja::Value` using `from_serialize`.
value => minijinja::Value::from_serialize(value),
})
}
/// A custom map type that always returns `UnknownMethod` for method calls, so that pycompat can
/// always handle dict methods through the unknown-method callback.
///
/// Use `IndexMap` to preserve the original key order when iterating.
///
/// MiniJinja's default map can resolve a same-named field before Python dict methods. HF templates
/// commonly call `dict.items()`, which would fail if the map had an `items` field.
/// See issue: https://github.com/mitsuhiko/minijinja/issues/903
#[derive(Debug)]
struct TemplateMap(IndexMap<String, minijinja::Value>);
impl Object for TemplateMap {
fn repr(self: &Arc<Self>) -> ObjectRepr {
ObjectRepr::Map
}
fn get_value(self: &Arc<Self>, key: &minijinja::Value) -> Option<minijinja::Value> {
self.0.get(key.as_str()?).cloned()
}
fn get_value_by_str(self: &Arc<Self>, key: &str) -> Option<minijinja::Value> {
self.0.get(key).cloned()
}
fn enumerate(self: &Arc<Self>) -> Enumerator {
self.mapped_rev_enumerator(|this| {
Box::new(this.0.keys().map(|key| minijinja::Value::from(key.as_str())))
})
}
fn enumerator_len(self: &Arc<Self>) -> Option<usize> {
Some(self.0.len())
}
fn call_method(
self: &Arc<Self>,
_state: &State<'_, '_>,
_method: &str,
_args: &[minijinja::Value],
) -> std::result::Result<minijinja::Value, TemplateError> {
// Always return `UnknownMethod` for method calls,
// so that pycompat can handle dict methods through the unknown-method callback.
Err(TemplateError::from(TemplateErrorKind::UnknownMethod))
}
}
+12 -7
View File
@@ -233,7 +233,7 @@ async fn run_roundtrip_tool_call_mix(case: RoundtripCase) -> Result<()> {
"roundtrip-reasoning-tools",
vec![ChatMessage::text(
ChatRole::User,
"Check Shanghai weather and add 1.00 plus 2.",
"Check Shanghai weather and add 1.0 plus 2.",
)],
test_tools(),
Some(true), // always enable thinking in this fixture
@@ -261,9 +261,10 @@ async fn run_roundtrip_tool_call_mix(case: RoundtripCase) -> Result<()> {
AssistantContentBlock::ToolCall(AssistantToolCall {
id: "functions.add:1".to_string(),
name: "add".to_string(),
// Intentionally use a non-lexical order of keys and a different number
// formatting style to verify text-level fidelity of the roundtrip.
arguments: r#"{"y":1.00,"x":2}"#.to_string(),
// Intentionally use a non-lexical order of keys to verify text-level
// fidelity of the roundtrip where JSON formatting remains stable. The
// `items` key also exercises templates that call `arguments.items()`.
arguments: r#"{"y":1.0,"x":2,"items":["left","right"]}"#.to_string(),
}),
],
},
@@ -291,7 +292,7 @@ async fn run_roundtrip_tool_call_mix(case: RoundtripCase) -> Result<()> {
assert_eq!(tool_calls[1].name, "add");
assert_eq!(
tool_calls[1].arguments,
expected_arguments(&case, r#"{"y": 1.00, "x": 2}"#)?,
expected_arguments(&case, r#"{"y": 1.0, "x": 2, "items": ["left", "right"]}"#)?,
);
assert_eq!(
@@ -585,9 +586,13 @@ fn test_tools() -> Vec<ChatTool> {
"type": "object",
"properties": {
"y": { "type": "number" },
"x": { "type": "number" }
"x": { "type": "number" },
"items": {
"type": "array",
"items": { "type": "string" }
}
},
"required": ["y", "x"]
"required": ["y", "x", "items"]
}),
strict: None,
},
+27 -33
View File
@@ -34,26 +34,23 @@ const INTERNLM2_CONFIG: JsonToolCallConfig = JsonToolCallConfig {
/// This Rust port intentionally diverges from
/// `vllm/tool_parsers/internlm2_tool_parser.py` in two user-visible ways:
///
/// - **Parallel tool calls are supported.** Python silently drops every
/// `<|action_start|>` block after the first (`current_tool_id > 0` returns
/// an empty delta); this parser emits every well-formed block with
/// incrementing `tool_index`. Models that legitimately emit multiple action
/// blocks therefore produce more tool calls under Rust than under Python.
/// - **End-marker bytes inside JSON string values are preserved.** Python
/// does `action.split("<|action_end|>")[0]` which truncates regardless of
/// JSON context; this parser scans matched braces and quotes so a literal
/// `<|action_end|>` inside an arguments string is forwarded intact.
/// - **Parallel tool calls are supported.** Python silently drops every `<|action_start|>` block
/// after the first (`current_tool_id > 0` returns an empty delta); this parser emits every
/// well-formed block with incrementing `tool_index`. Models that legitimately emit multiple
/// action blocks therefore produce more tool calls under Rust than under Python.
/// - **End-marker bytes inside JSON string values are preserved.** Python does
/// `action.split("<|action_end|>")[0]` which truncates regardless of JSON context; this parser
/// scans matched braces and quotes so a literal `<|action_end|>` inside an arguments string is
/// forwarded intact.
/// - **Only whitespace is allowed before the `{`.** Python's non-streaming
/// `action[action.find("{"):]` drops any bytes before the first `{`, but
/// its streaming path has no equivalent and the model format always emits
/// `<|plugin|>{...`; this parser allows only whitespace there, matching the
/// other JSON parsers in this crate.
/// - **Truncated tool calls error rather than silently dropping.** Python's
/// streaming wrapper swallows mid-stream errors with `except Exception:
/// return None` (logging a traceback) while its non-streaming path raises
/// `JSONDecodeError`; this parser returns an `incomplete InternLM2 tool
/// call` error from `finish()`, matching the other JSON parsers and Python's
/// non-streaming behavior.
/// `action[action.find("{"):]` drops any bytes before the first `{`, but its streaming path has
/// no equivalent and the model format always emits `<|plugin|>{...`; this parser allows only
/// whitespace there, matching the other JSON parsers in this crate.
/// - **Truncated tool calls error rather than silently dropping.** Python's streaming wrapper
/// swallows mid-stream errors with `except Exception: return None` (logging a traceback) while
/// its non-streaming path raises `JSONDecodeError`; this parser returns an `incomplete InternLM2
/// tool call` error from `finish()`, matching the other JSON parsers and Python's non-streaming
/// behavior.
///
/// # Known unaddressed divergences (TODO)
///
@@ -63,21 +60,18 @@ const INTERNLM2_CONFIG: JsonToolCallConfig = JsonToolCallConfig {
/// Qwen as well. If a real-world InternLM2 deployment hits one of these,
/// prioritize the corresponding fix.
///
/// - **Arguments value type.** The shared core requires the arguments value
/// to be a JSON object (`take_json_object` rejects anything not starting
/// with `{`). Python's `json.dumps(action_dict.get("parameters", ...))`
/// accepts `null`, arrays, strings, and numbers and round-trips them
/// verbatim. Models that legitimately emit `"parameters":null` will hard-
/// - **Arguments value type.** The shared core requires the arguments value to be a JSON object
/// (`take_json_object` rejects anything not starting with `{`). Python's
/// `json.dumps(action_dict.get("parameters", ...))` accepts `null`, arrays, strings, and numbers
/// and round-trips them verbatim. Models that legitimately emit `"parameters":null` will hard-
/// fail under Rust.
/// - **Unknown arguments key.** Python falls back to `{}` via
/// `action_dict.get("parameters", action_dict.get("arguments", {}))` when
/// neither key is present; the Rust header parser raises
/// `parsing failed: invalid InternLM2` for any unrecognized key. A model
/// that emits a typo (e.g. `"params"`) breaks the whole response.
/// - **Field order independence.** The header parser requires the JSON keys
/// to appear in the order `name` then arguments key. Python's
/// `json.loads` + `dict.get` is order-independent, so a model emitting
/// `{"parameters":{...},"name":"foo"}` parses in Python but fails in Rust.
/// - **Unknown arguments key.** Python falls back to `{}` via `action_dict.get("parameters",
/// action_dict.get("arguments", {}))` when neither key is present; the Rust header parser raises
/// `parsing failed: invalid InternLM2` for any unrecognized key. A model that emits a typo (e.g.
/// `"params"`) breaks the whole response.
/// - **Field order independence.** The header parser requires the JSON keys to appear in the order
/// `name` then arguments key. Python's `json.loads` + `dict.get` is order-independent, so a model
/// emitting `{"parameters":{...},"name":"foo"}` parses in Python but fails in Rust.
pub struct Internlm2ToolParser {
inner: JsonToolCallParser,
}
+12 -6
View File
@@ -501,15 +501,21 @@ mod tests {
assert_eq!(converted_number_text(&params, "5"), "5");
assert_eq!(converted_number_text(&params, "5.0"), "5.0");
assert_eq!(converted_number_text(&params, "5.00"), "5.00");
assert_eq!(converted_number_text(&params, "1e0"), "1e+0");
assert_eq!(converted_number_text(&params, "5."), "5.0");
assert_eq!(converted_number_text(&params, "+1"), "1");
assert_eq!(converted_number_text(&params, "+1.0"), "1.0");
assert_eq!(
converted_number_text(&params, "9223372036854775807.5"),
"9223372036854775807.5"
);
// TODO: we cannot preserve the original number precision by enabling `serde_json`'s
// `arbitrary_precision` feature, otherwise the test
// `serialized_json_numbers_do_not_leak_serde_private_representation` will fail.
// See issue: https://github.com/mitsuhiko/minijinja/issues/641
// assert_eq!(converted_number_text(&params, "5.00"), "5.00");
// assert_eq!(converted_number_text(&params, "1e0"), "1e+0");
// assert_eq!(
// converted_number_text(&params, "9223372036854775807.5"),
// "9223372036854775807.5"
// );
}
fn converted_number_text(params: &ToolSchema, value: &str) -> String {