mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Rust Frontend] Fix several hf chat template rendering issues (#44311)
Signed-off-by: Bugen Zhao <i@bugenzhao.com>
This commit is contained in:
Generated
+1
@@ -5622,6 +5622,7 @@ dependencies = [
|
||||
"expect-test",
|
||||
"futures",
|
||||
"half",
|
||||
"indexmap 2.13.0",
|
||||
"itertools 0.14.0",
|
||||
"llm-multimodal",
|
||||
"minijinja",
|
||||
|
||||
+2
-1
@@ -43,6 +43,7 @@ half = { version = "2.7.1", features = ["bytemuck"] }
|
||||
hex = "0.4.3"
|
||||
hf-hub = { version = "0.5.0", features = ["tokio"] }
|
||||
http-body = "1.0.1"
|
||||
indexmap = "2.13.0"
|
||||
itertools = "0.14.0"
|
||||
libc = "0.2.177"
|
||||
llm-multimodal = { git = "https://github.com/vllm-project/llm-multimodal", rev = "5b558989844d1c7af3e43d0f604069ffd9c06320" }
|
||||
@@ -69,7 +70,7 @@ rustc-hash = "1.1.0"
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
serde-json-fmt = "0.1.0"
|
||||
serde_default = "0.2.0"
|
||||
serde_json = { version = "1.0.145", features = ["arbitrary_precision", "preserve_order"] }
|
||||
serde_json = { version = "1.0.145", features = ["preserve_order"] }
|
||||
serde_repr = "0.1.20"
|
||||
serde_tuple = "1.1.3"
|
||||
serde_with = "3.18.0"
|
||||
|
||||
@@ -10,6 +10,7 @@ asynk-strim-attr.workspace = true
|
||||
easy-ext.workspace = true
|
||||
futures.workspace = true
|
||||
half.workspace = true
|
||||
indexmap.workspace = true
|
||||
itertools.workspace = true
|
||||
llm-multimodal.workspace = true
|
||||
minijinja.workspace = true
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::Value as JsonValue;
|
||||
use thiserror_ext::AsReport as _;
|
||||
use tracing::{info, trace, warn};
|
||||
use vllm_text::Prompt;
|
||||
@@ -13,6 +13,7 @@ use self::format::{
|
||||
ChatTemplateContentFormat, ChatTemplateContentFormatOption as ContentFormatOption,
|
||||
};
|
||||
use self::template::{CompiledChatTemplate, TemplateContext};
|
||||
use self::value::{TemplateValue, to_template_value};
|
||||
use super::{ChatRenderer, RenderedPrompt};
|
||||
use crate::error::Result;
|
||||
use crate::request::{ChatContent, ChatContentPart, ChatMessage, ChatRequest};
|
||||
@@ -24,6 +25,7 @@ mod error;
|
||||
mod format;
|
||||
mod template;
|
||||
mod tojson;
|
||||
mod value;
|
||||
|
||||
pub use template::{load_chat_template, resolve_chat_template};
|
||||
|
||||
@@ -38,7 +40,7 @@ pub struct MultimodalRenderInfo {
|
||||
/// state.
|
||||
pub struct HfChatRenderer {
|
||||
default_template: Option<CompiledChatTemplate>,
|
||||
default_template_kwargs: HashMap<String, Value>,
|
||||
default_template_kwargs: HashMap<String, JsonValue>,
|
||||
content_format: ContentFormatOption,
|
||||
special_tokens: Option<HfSpecialTokens>,
|
||||
multimodal: Option<MultimodalRenderInfo>,
|
||||
@@ -48,7 +50,7 @@ impl HfChatRenderer {
|
||||
/// Create a renderer from the given template string.
|
||||
pub fn new(
|
||||
template: Option<String>,
|
||||
default_template_kwargs: HashMap<String, Value>,
|
||||
default_template_kwargs: HashMap<String, JsonValue>,
|
||||
content_format: ContentFormatOption,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
@@ -245,7 +247,7 @@ struct TemplateToolCall {
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TemplateToolFunction {
|
||||
name: String,
|
||||
arguments: Value,
|
||||
arguments: TemplateValue,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -259,7 +261,7 @@ pub(super) struct TemplateTool {
|
||||
struct TemplateToolDefinition {
|
||||
name: String,
|
||||
description: Option<String>,
|
||||
parameters: Value,
|
||||
parameters: TemplateValue,
|
||||
strict: Option<bool>,
|
||||
}
|
||||
|
||||
@@ -345,13 +347,14 @@ fn to_template_tool_calls(
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
for tool_call in content.tool_calls() {
|
||||
let arguments = serde_json::from_str::<Value>(&tool_call.arguments).map_err(|error| {
|
||||
let arguments = serde_json::from_str(&tool_call.arguments).map_err(|error| {
|
||||
Error::ChatTemplate(format!(
|
||||
"assistant tool call `{}` has invalid JSON arguments: {}",
|
||||
tool_call.id,
|
||||
error.as_report()
|
||||
))
|
||||
})?;
|
||||
let arguments = to_template_value(arguments);
|
||||
|
||||
tool_calls.push(TemplateToolCall {
|
||||
id: tool_call.id.clone(),
|
||||
@@ -434,7 +437,7 @@ fn to_template_tools(tools: &[ChatTool]) -> Vec<TemplateTool> {
|
||||
function: TemplateToolDefinition {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
parameters: to_template_value(tool.parameters.clone()),
|
||||
strict: tool.strict,
|
||||
},
|
||||
})
|
||||
@@ -909,6 +912,29 @@ mod tests {
|
||||
assert_eq!(rendered, "get_weather|Paris|call_1|Sunny");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_template_tool_call_argument_items_method_is_not_shadowed_by_field() {
|
||||
let request = sample_request(vec![ChatMessage::assistant_blocks(vec![
|
||||
AssistantContentBlock::ToolCall(crate::AssistantToolCall {
|
||||
id: "call_1".to_string(),
|
||||
name: "add".to_string(),
|
||||
arguments: r#"{"items":"operands","x":2,"y":1.0}"#.to_string(),
|
||||
}),
|
||||
])]);
|
||||
|
||||
let rendered = render(
|
||||
Some(
|
||||
"{%- set arguments = messages[0].tool_calls[0].function.arguments -%}
|
||||
{%- for key, value in arguments.items() -%}{{ key }}={{ value }};{%- endfor -%}
|
||||
|{{ arguments['items'] }}",
|
||||
),
|
||||
&request,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(rendered, "items=operands;x=2;y=1.0;|operands");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn qwen35_template_renders_prefilled_reasoning_start_when_thinking_enabled() {
|
||||
let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]);
|
||||
|
||||
@@ -208,11 +208,27 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tojson_preserves_arbitrary_precision_number_spelling() {
|
||||
fn tojson_uses_standard_serde_json_number_spelling() {
|
||||
let payload = serde_json::from_str(r#"{"x":2,"y":1.00}"#).unwrap();
|
||||
let rendered = render("{{ payload|tojson }}", payload);
|
||||
|
||||
assert_eq!(rendered, "{\"x\": 2, \"y\": 1.00}");
|
||||
// TODO: we cannot preserve the original number precision by enabling `serde_json`'s
|
||||
// `arbitrary_precision` feature, otherwise the following test
|
||||
// `serialized_json_numbers_do_not_leak_serde_private_representation` will fail.
|
||||
// See issue: https://github.com/mitsuhiko/minijinja/issues/641
|
||||
assert_eq!(rendered, "{\"x\": 2, \"y\": 1.0}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialized_json_numbers_do_not_leak_serde_private_representation() {
|
||||
let payload: serde_json::Value = serde_json::from_str(r#"{"x":2,"y":1.00}"#).unwrap();
|
||||
let rendered = render("{{ payload }}", payload);
|
||||
|
||||
// TODO: we cannot preserve the original number precision by enabling `serde_json`'s
|
||||
// `arbitrary_precision` feature, otherwise this will fail.
|
||||
// See issue: https://github.com/mitsuhiko/minijinja/issues/641
|
||||
assert!(!rendered.contains("$serde_json::private::Number"));
|
||||
assert_eq!(rendered, r#"{"x": 2, "y": 1.0}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use minijinja::value::{Enumerator, Object, ObjectExt, ObjectRepr};
|
||||
use minijinja::{Error as TemplateError, ErrorKind as TemplateErrorKind, State};
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
/// A wrapper around `minijinja::Value` that can be constructed with `to_template_value` and used
|
||||
/// as a value in the chat template.
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub(super) struct TemplateValue(minijinja::Value);
|
||||
|
||||
pub(super) fn to_template_value(value: JsonValue) -> TemplateValue {
|
||||
TemplateValue(match value {
|
||||
JsonValue::Array(values) => values
|
||||
.into_iter()
|
||||
.map(to_template_value)
|
||||
.map(|value| value.0)
|
||||
.collect::<minijinja::Value>(),
|
||||
JsonValue::Object(values) => minijinja::Value::from_object(TemplateMap(
|
||||
values
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key, to_template_value(value).0))
|
||||
.collect(),
|
||||
)),
|
||||
// For primitive values, directly convert them to `minijinja::Value` using `from_serialize`.
|
||||
value => minijinja::Value::from_serialize(value),
|
||||
})
|
||||
}
|
||||
|
||||
/// A custom map type that always returns `UnknownMethod` for method calls, so that pycompat can
|
||||
/// always handle dict methods through the unknown-method callback.
|
||||
///
|
||||
/// Use `IndexMap` to preserve the original key order when iterating.
|
||||
///
|
||||
/// MiniJinja's default map can resolve a same-named field before Python dict methods. HF templates
|
||||
/// commonly call `dict.items()`, which would fail if the map had an `items` field.
|
||||
/// See issue: https://github.com/mitsuhiko/minijinja/issues/903
|
||||
#[derive(Debug)]
|
||||
struct TemplateMap(IndexMap<String, minijinja::Value>);
|
||||
|
||||
impl Object for TemplateMap {
|
||||
fn repr(self: &Arc<Self>) -> ObjectRepr {
|
||||
ObjectRepr::Map
|
||||
}
|
||||
|
||||
fn get_value(self: &Arc<Self>, key: &minijinja::Value) -> Option<minijinja::Value> {
|
||||
self.0.get(key.as_str()?).cloned()
|
||||
}
|
||||
|
||||
fn get_value_by_str(self: &Arc<Self>, key: &str) -> Option<minijinja::Value> {
|
||||
self.0.get(key).cloned()
|
||||
}
|
||||
|
||||
fn enumerate(self: &Arc<Self>) -> Enumerator {
|
||||
self.mapped_rev_enumerator(|this| {
|
||||
Box::new(this.0.keys().map(|key| minijinja::Value::from(key.as_str())))
|
||||
})
|
||||
}
|
||||
|
||||
fn enumerator_len(self: &Arc<Self>) -> Option<usize> {
|
||||
Some(self.0.len())
|
||||
}
|
||||
|
||||
fn call_method(
|
||||
self: &Arc<Self>,
|
||||
_state: &State<'_, '_>,
|
||||
_method: &str,
|
||||
_args: &[minijinja::Value],
|
||||
) -> std::result::Result<minijinja::Value, TemplateError> {
|
||||
// Always return `UnknownMethod` for method calls,
|
||||
// so that pycompat can handle dict methods through the unknown-method callback.
|
||||
Err(TemplateError::from(TemplateErrorKind::UnknownMethod))
|
||||
}
|
||||
}
|
||||
@@ -233,7 +233,7 @@ async fn run_roundtrip_tool_call_mix(case: RoundtripCase) -> Result<()> {
|
||||
"roundtrip-reasoning-tools",
|
||||
vec![ChatMessage::text(
|
||||
ChatRole::User,
|
||||
"Check Shanghai weather and add 1.00 plus 2.",
|
||||
"Check Shanghai weather and add 1.0 plus 2.",
|
||||
)],
|
||||
test_tools(),
|
||||
Some(true), // always enable thinking in this fixture
|
||||
@@ -261,9 +261,10 @@ async fn run_roundtrip_tool_call_mix(case: RoundtripCase) -> Result<()> {
|
||||
AssistantContentBlock::ToolCall(AssistantToolCall {
|
||||
id: "functions.add:1".to_string(),
|
||||
name: "add".to_string(),
|
||||
// Intentionally use a non-lexical order of keys and a different number
|
||||
// formatting style to verify text-level fidelity of the roundtrip.
|
||||
arguments: r#"{"y":1.00,"x":2}"#.to_string(),
|
||||
// Intentionally use a non-lexical order of keys to verify text-level
|
||||
// fidelity of the roundtrip where JSON formatting remains stable. The
|
||||
// `items` key also exercises templates that call `arguments.items()`.
|
||||
arguments: r#"{"y":1.0,"x":2,"items":["left","right"]}"#.to_string(),
|
||||
}),
|
||||
],
|
||||
},
|
||||
@@ -291,7 +292,7 @@ async fn run_roundtrip_tool_call_mix(case: RoundtripCase) -> Result<()> {
|
||||
assert_eq!(tool_calls[1].name, "add");
|
||||
assert_eq!(
|
||||
tool_calls[1].arguments,
|
||||
expected_arguments(&case, r#"{"y": 1.00, "x": 2}"#)?,
|
||||
expected_arguments(&case, r#"{"y": 1.0, "x": 2, "items": ["left", "right"]}"#)?,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
@@ -585,9 +586,13 @@ fn test_tools() -> Vec<ChatTool> {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"y": { "type": "number" },
|
||||
"x": { "type": "number" }
|
||||
"x": { "type": "number" },
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["y", "x"]
|
||||
"required": ["y", "x", "items"]
|
||||
}),
|
||||
strict: None,
|
||||
},
|
||||
|
||||
@@ -34,26 +34,23 @@ const INTERNLM2_CONFIG: JsonToolCallConfig = JsonToolCallConfig {
|
||||
/// This Rust port intentionally diverges from
|
||||
/// `vllm/tool_parsers/internlm2_tool_parser.py` in two user-visible ways:
|
||||
///
|
||||
/// - **Parallel tool calls are supported.** Python silently drops every
|
||||
/// `<|action_start|>` block after the first (`current_tool_id > 0` returns
|
||||
/// an empty delta); this parser emits every well-formed block with
|
||||
/// incrementing `tool_index`. Models that legitimately emit multiple action
|
||||
/// blocks therefore produce more tool calls under Rust than under Python.
|
||||
/// - **End-marker bytes inside JSON string values are preserved.** Python
|
||||
/// does `action.split("<|action_end|>")[0]` which truncates regardless of
|
||||
/// JSON context; this parser scans matched braces and quotes so a literal
|
||||
/// `<|action_end|>` inside an arguments string is forwarded intact.
|
||||
/// - **Parallel tool calls are supported.** Python silently drops every `<|action_start|>` block
|
||||
/// after the first (`current_tool_id > 0` returns an empty delta); this parser emits every
|
||||
/// well-formed block with incrementing `tool_index`. Models that legitimately emit multiple
|
||||
/// action blocks therefore produce more tool calls under Rust than under Python.
|
||||
/// - **End-marker bytes inside JSON string values are preserved.** Python does
|
||||
/// `action.split("<|action_end|>")[0]` which truncates regardless of JSON context; this parser
|
||||
/// scans matched braces and quotes so a literal `<|action_end|>` inside an arguments string is
|
||||
/// forwarded intact.
|
||||
/// - **Only whitespace is allowed before the `{`.** Python's non-streaming
|
||||
/// `action[action.find("{"):]` drops any bytes before the first `{`, but
|
||||
/// its streaming path has no equivalent and the model format always emits
|
||||
/// `<|plugin|>{...`; this parser allows only whitespace there, matching the
|
||||
/// other JSON parsers in this crate.
|
||||
/// - **Truncated tool calls error rather than silently dropping.** Python's
|
||||
/// streaming wrapper swallows mid-stream errors with `except Exception:
|
||||
/// return None` (logging a traceback) while its non-streaming path raises
|
||||
/// `JSONDecodeError`; this parser returns an `incomplete InternLM2 tool
|
||||
/// call` error from `finish()`, matching the other JSON parsers and Python's
|
||||
/// non-streaming behavior.
|
||||
/// `action[action.find("{"):]` drops any bytes before the first `{`, but its streaming path has
|
||||
/// no equivalent and the model format always emits `<|plugin|>{...`; this parser allows only
|
||||
/// whitespace there, matching the other JSON parsers in this crate.
|
||||
/// - **Truncated tool calls error rather than silently dropping.** Python's streaming wrapper
|
||||
/// swallows mid-stream errors with `except Exception: return None` (logging a traceback) while
|
||||
/// its non-streaming path raises `JSONDecodeError`; this parser returns an `incomplete InternLM2
|
||||
/// tool call` error from `finish()`, matching the other JSON parsers and Python's non-streaming
|
||||
/// behavior.
|
||||
///
|
||||
/// # Known unaddressed divergences (TODO)
|
||||
///
|
||||
@@ -63,21 +60,18 @@ const INTERNLM2_CONFIG: JsonToolCallConfig = JsonToolCallConfig {
|
||||
/// Qwen as well. If a real-world InternLM2 deployment hits one of these,
|
||||
/// prioritize the corresponding fix.
|
||||
///
|
||||
/// - **Arguments value type.** The shared core requires the arguments value
|
||||
/// to be a JSON object (`take_json_object` rejects anything not starting
|
||||
/// with `{`). Python's `json.dumps(action_dict.get("parameters", ...))`
|
||||
/// accepts `null`, arrays, strings, and numbers and round-trips them
|
||||
/// verbatim. Models that legitimately emit `"parameters":null` will hard-
|
||||
/// - **Arguments value type.** The shared core requires the arguments value to be a JSON object
|
||||
/// (`take_json_object` rejects anything not starting with `{`). Python's
|
||||
/// `json.dumps(action_dict.get("parameters", ...))` accepts `null`, arrays, strings, and numbers
|
||||
/// and round-trips them verbatim. Models that legitimately emit `"parameters":null` will hard-
|
||||
/// fail under Rust.
|
||||
/// - **Unknown arguments key.** Python falls back to `{}` via
|
||||
/// `action_dict.get("parameters", action_dict.get("arguments", {}))` when
|
||||
/// neither key is present; the Rust header parser raises
|
||||
/// `parsing failed: invalid InternLM2` for any unrecognized key. A model
|
||||
/// that emits a typo (e.g. `"params"`) breaks the whole response.
|
||||
/// - **Field order independence.** The header parser requires the JSON keys
|
||||
/// to appear in the order `name` then arguments key. Python's
|
||||
/// `json.loads` + `dict.get` is order-independent, so a model emitting
|
||||
/// `{"parameters":{...},"name":"foo"}` parses in Python but fails in Rust.
|
||||
/// - **Unknown arguments key.** Python falls back to `{}` via `action_dict.get("parameters",
|
||||
/// action_dict.get("arguments", {}))` when neither key is present; the Rust header parser raises
|
||||
/// `parsing failed: invalid InternLM2` for any unrecognized key. A model that emits a typo (e.g.
|
||||
/// `"params"`) breaks the whole response.
|
||||
/// - **Field order independence.** The header parser requires the JSON keys to appear in the order
|
||||
/// `name` then arguments key. Python's `json.loads` + `dict.get` is order-independent, so a model
|
||||
/// emitting `{"parameters":{...},"name":"foo"}` parses in Python but fails in Rust.
|
||||
pub struct Internlm2ToolParser {
|
||||
inner: JsonToolCallParser,
|
||||
}
|
||||
|
||||
@@ -501,15 +501,21 @@ mod tests {
|
||||
|
||||
assert_eq!(converted_number_text(¶ms, "5"), "5");
|
||||
assert_eq!(converted_number_text(¶ms, "5.0"), "5.0");
|
||||
assert_eq!(converted_number_text(¶ms, "5.00"), "5.00");
|
||||
assert_eq!(converted_number_text(¶ms, "1e0"), "1e+0");
|
||||
assert_eq!(converted_number_text(¶ms, "5."), "5.0");
|
||||
assert_eq!(converted_number_text(¶ms, "+1"), "1");
|
||||
assert_eq!(converted_number_text(¶ms, "+1.0"), "1.0");
|
||||
assert_eq!(
|
||||
converted_number_text(¶ms, "9223372036854775807.5"),
|
||||
"9223372036854775807.5"
|
||||
);
|
||||
|
||||
// TODO: we cannot preserve the original number precision by enabling `serde_json`'s
|
||||
// `arbitrary_precision` feature, otherwise the test
|
||||
// `serialized_json_numbers_do_not_leak_serde_private_representation` will fail.
|
||||
// See issue: https://github.com/mitsuhiko/minijinja/issues/641
|
||||
|
||||
// assert_eq!(converted_number_text(¶ms, "5.00"), "5.00");
|
||||
// assert_eq!(converted_number_text(¶ms, "1e0"), "1e+0");
|
||||
// assert_eq!(
|
||||
// converted_number_text(¶ms, "9223372036854775807.5"),
|
||||
// "9223372036854775807.5"
|
||||
// );
|
||||
}
|
||||
|
||||
fn converted_number_text(params: &ToolSchema, value: &str) -> String {
|
||||
|
||||
Reference in New Issue
Block a user