From 39910f2b25aacc09f5e7f166cdf0030b19f8b9e8 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 22 May 2026 08:21:48 +0800 Subject: [PATCH] [Rust Frontend] Move code from `vllm-frontend-rs` (#43283) Signed-off-by: Bugen Zhao Signed-off-by: Nick Hill Signed-off-by: Eric Curtin Signed-off-by: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Signed-off-by: Will.hou <1205157517@qq.com> Signed-off-by: Will.hou Co-authored-by: Nick Hill Co-authored-by: Eric Curtin Co-authored-by: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Co-authored-by: Will.hou <1205157517@qq.com> Co-authored-by: Will.hou Please see https://github.com/Inferact/vllm-frontend-rs for full original commit history. --- .buildkite/image_build/image_build.sh | 7 - .buildkite/image_build/image_build_cpu.sh | 6 - .../image_build/image_build_cpu_arm64.sh | 6 - .../scripts/run-rust-frontend-cargo-ci.sh | 156 + .buildkite/test_areas/rust_frontend.yaml | 2 +- .../test_areas/rust_frontend_cargo.yaml | 30 + .gitmodules | 3 - .pre-commit-config.yaml | 26 + build_rust.sh | 5 +- docker/Dockerfile | 16 +- docker/Dockerfile.cpu | 16 +- docker/Dockerfile.nightly_torch | 16 +- docker/Dockerfile.rocm | 19 +- docker/Dockerfile.xpu | 16 +- pyproject.toml | 5 +- requirements/xpu.txt | 1 + rust | 1 - rust/.gitattributes | 2 + rust/.gitignore | 3 + rust/AGENTS.md | 37 + rust/CLAUDE.md | 4 + rust/Cargo.lock | 6611 +++++++++++++++++ rust/Cargo.toml | 129 + rust/README.md | 89 + rust/proto/vllm_grpc.proto | 196 + rust/rustfmt.toml | 3 + rust/rustfmt.unstable.toml | 20 + rust/src/chat/Cargo.toml | 52 + rust/src/chat/examples/README.md | 39 + .../examples/external_engine_chat_qwen.rs | 178 + rust/src/chat/src/backend/hf.rs | 308 + rust/src/chat/src/backend/mod.rs | 86 + rust/src/chat/src/error.rs | 80 + rust/src/chat/src/event.rs | 183 + rust/src/chat/src/lib.rs | 249 + rust/src/chat/src/multimodal.rs | 775 ++ rust/src/chat/src/multimodal/tensor.rs | 342 + rust/src/chat/src/output/default/mod.rs | 166 + rust/src/chat/src/output/default/reasoning.rs | 504 ++ rust/src/chat/src/output/default/tool.rs | 625 ++ rust/src/chat/src/output/harmony/mod.rs | 430 ++ rust/src/chat/src/output/harmony/tests.rs | 351 + rust/src/chat/src/output/mod.rs | 135 + rust/src/chat/src/output/structured.rs | 508 ++ rust/src/chat/src/parser/mod.rs | 107 + rust/src/chat/src/parser/reasoning/mod.rs | 120 + rust/src/chat/src/parser/reasoning/tests.rs | 61 + rust/src/chat/src/parser/tool/mod.rs | 140 + rust/src/chat/src/parser/tool/tests.rs | 152 + .../src/renderer/deepseek_v32/encoding.rs | 555 ++ .../deepseek_v32/fixtures/test_input.json | 149 + .../fixtures/test_input_search_w_date.json | 732 ++ .../fixtures/test_input_search_wo_date.json | 533 ++ .../fixtures/test_output_search_w_date.txt | 2455 ++++++ .../fixtures/test_output_search_wo_date.txt | 1069 +++ .../fixtures/test_output_vllm_parity.txt | 112 + .../src/chat/src/renderer/deepseek_v32/mod.rs | 31 + .../chat/src/renderer/deepseek_v32/tests.rs | 422 ++ .../chat/src/renderer/deepseek_v4/encoding.rs | 558 ++ .../deepseek_v4/fixtures/test_input_1.json | 81 + .../deepseek_v4/fixtures/test_input_2.json | 24 + .../deepseek_v4/fixtures/test_output_1.txt | 36 + .../deepseek_v4/fixtures/test_output_2.txt | 1 + rust/src/chat/src/renderer/deepseek_v4/mod.rs | 30 + .../chat/src/renderer/deepseek_v4/tests.rs | 369 + rust/src/chat/src/renderer/hf/error.rs | 15 + rust/src/chat/src/renderer/hf/format.rs | 400 + rust/src/chat/src/renderer/hf/mod.rs | 970 +++ rust/src/chat/src/renderer/hf/template.rs | 316 + rust/src/chat/src/renderer/hf/tojson.rs | 277 + rust/src/chat/src/renderer/mod.rs | 31 + rust/src/chat/src/renderer/selection.rs | 109 + rust/src/chat/src/request.rs | 662 ++ rust/src/chat/src/stream.rs | 237 + rust/src/chat/tests/chat.rs | 1671 +++++ rust/src/chat/tests/templates/qwen3.jinja | 89 + rust/src/chat/tests/templates/qwen35.jinja | 154 + .../tests/templates/vllm_examples/README.md | 6 + .../vllm_examples/template_alpaca.jinja | 29 + .../vllm_examples/template_baichuan.jinja | 13 + .../vllm_examples/template_chatglm.jinja | 18 + .../vllm_examples/template_chatglm2.jinja | 18 + .../vllm_examples/template_chatml.jinja | 2 + .../vllm_examples/template_falcon.jinja | 15 + .../vllm_examples/template_falcon_180b.jinja | 17 + .../vllm_examples/template_inkbot.jinja | 30 + .../vllm_examples/template_teleflm.jinja | 12 + .../tool_chat_template_deepseekr1.jinja | 92 + .../tool_chat_template_deepseekv3.jinja | 96 + .../tool_chat_template_deepseekv31.jinja | 91 + .../tool_chat_template_functiongemma.jinja | 54 + .../tool_chat_template_gemma3_pythonic.jinja | 123 + .../tool_chat_template_gemma4.jinja | 331 + .../tool_chat_template_glm4.jinja | 54 + .../tool_chat_template_granite.jinja | 36 + .../tool_chat_template_granite_20b_fc.jinja | 130 + .../tool_chat_template_hermes.jinja | 130 + .../tool_chat_template_hunyuan_a13b.jinja | 113 + .../tool_chat_template_internlm2_tool.jinja | 60 + .../tool_chat_template_llama3.1_json.jinja | 120 + .../tool_chat_template_llama3.2_json.jinja | 133 + ...tool_chat_template_llama3.2_pythonic.jinja | 98 + .../tool_chat_template_llama4_json.jinja | 116 + .../tool_chat_template_llama4_pythonic.jinja | 111 + .../tool_chat_template_minimax_m1.jinja | 91 + .../tool_chat_template_mistral.jinja | 86 + .../tool_chat_template_mistral3.jinja | 126 + .../tool_chat_template_mistral_parallel.jinja | 93 + .../tool_chat_template_phi4_mini.jinja | 62 + .../tool_chat_template_qwen3coder.jinja | 117 + .../tool_chat_template_toolace.jinja | 65 + .../tool_chat_template_xlam_llama.jinja | 77 + .../tool_chat_template_xlam_qwen.jinja | 66 + rust/src/cmd/Cargo.toml | 39 + rust/src/cmd/examples/README.md | 44 + rust/src/cmd/src/cli.rs | 434 ++ rust/src/cmd/src/cli/tests.rs | 905 +++ rust/src/cmd/src/cli/unsupported.rs | 660 ++ rust/src/cmd/src/logging.rs | 328 + rust/src/cmd/src/main.rs | 189 + rust/src/engine-core-client/Cargo.toml | 49 + .../src/engine-core-client/examples/README.md | 53 + .../examples/external_engine_logprobs.rs | 190 + .../examples/external_engine_utility_call.rs | 143 + rust/src/engine-core-client/src/client.rs | 673 ++ rust/src/engine-core-client/src/client/imp.rs | 410 + .../engine-core-client/src/client/state.rs | 670 ++ .../engine-core-client/src/client/stream.rs | 153 + .../src/coordinator/bootstrap.rs | 91 + .../src/coordinator/external.rs | 167 + .../src/coordinator/handle.rs | 123 + .../src/coordinator/inproc.rs | 204 + .../engine-core-client/src/coordinator/mod.rs | 7 + rust/src/engine-core-client/src/error.rs | 88 + rust/src/engine-core-client/src/lib.rs | 18 + rust/src/engine-core-client/src/metrics.rs | 131 + .../src/protocol/classified_outputs.rs | 251 + .../engine-core-client/src/protocol/dtype.rs | 40 + .../src/protocol/handshake.rs | 90 + .../src/protocol/logprobs.rs | 330 + .../src/protocol/logprobs/array.rs | 306 + .../src/protocol/logprobs/tests.rs | 302 + .../src/protocol/logprobs/wire.rs | 31 + .../engine-core-client/src/protocol/mod.rs | 756 ++ .../src/protocol/multimodal.rs | 282 + .../engine-core-client/src/protocol/stats.rs | 192 + .../engine-core-client/src/protocol/tensor.rs | 255 + rust/src/engine-core-client/src/test_utils.rs | 295 + .../engine-core-client/src/tests/client.rs | 2559 +++++++ rust/src/engine-core-client/src/tests/mod.rs | 1 + .../src/tests/python_compat.py | 356 + rust/src/engine-core-client/src/transport.rs | 606 ++ rust/src/llm/Cargo.toml | 38 + rust/src/llm/examples/README.md | 29 + .../src/llm/examples/external_engine_smoke.rs | 144 + rust/src/llm/src/error.rs | 12 + rust/src/llm/src/lib.rs | 100 + rust/src/llm/src/log_stats.rs | 199 + rust/src/llm/src/output.rs | 346 + rust/src/llm/src/request.rs | 201 + rust/src/llm/src/request_metrics.rs | 391 + rust/src/llm/tests/generate.rs | 749 ++ rust/src/managed-engine/Cargo.toml | 18 + rust/src/managed-engine/src/cli.rs | 349 + rust/src/managed-engine/src/lib.rs | 4 + rust/src/managed-engine/src/process.rs | 263 + rust/src/metrics/Cargo.toml | 11 + rust/src/metrics/src/api_server.rs | 78 + rust/src/metrics/src/lib.rs | 76 + rust/src/metrics/src/request.rs | 298 + rust/src/metrics/src/scheduler.rs | 272 + rust/src/reasoning-parser/Cargo.toml | 12 + rust/src/reasoning-parser/src/cohere_cmd.rs | 45 + rust/src/reasoning-parser/src/deepseek_r1.rs | 44 + rust/src/reasoning-parser/src/delimited.rs | 161 + rust/src/reasoning-parser/src/gemma4.rs | 273 + rust/src/reasoning-parser/src/kimi.rs | 39 + rust/src/reasoning-parser/src/lib.rs | 129 + rust/src/reasoning-parser/src/qwen3.rs | 43 + rust/src/reasoning-parser/src/tests.rs | 161 + rust/src/server/Cargo.toml | 57 + rust/src/server/build.rs | 12 + rust/src/server/examples/README.md | 39 + .../examples/external_engine_openai_qwen.rs | 215 + rust/src/server/src/config.rs | 110 + rust/src/server/src/error.rs | 74 + rust/src/server/src/grpc/convert.rs | 679 ++ rust/src/server/src/grpc/mod.rs | 158 + rust/src/server/src/grpc/tests.rs | 662 ++ rust/src/server/src/lib.rs | 237 + rust/src/server/src/listener.rs | 135 + rust/src/server/src/middleware/load.rs | 110 + rust/src/server/src/middleware/metrics.rs | 79 + rust/src/server/src/middleware/mod.rs | 5 + rust/src/server/src/routes.rs | 68 + rust/src/server/src/routes/cache.rs | 56 + rust/src/server/src/routes/collective_rpc.rs | 51 + rust/src/server/src/routes/health.rs | 14 + .../server/src/routes/http_client_tests.rs | 392 + .../server/src/routes/inference/generate.rs | 215 + .../src/routes/inference/generate/convert.rs | 112 + .../src/routes/inference/generate/types.rs | 57 + .../src/routes/inference/generate/validate.rs | 84 + rust/src/server/src/routes/inference/mod.rs | 3 + rust/src/server/src/routes/load.rs | 18 + rust/src/server/src/routes/metrics.rs | 26 + .../src/routes/openai/chat_completions.rs | 1046 +++ .../routes/openai/chat_completions/convert.rs | 992 +++ .../routes/openai/chat_completions/types.rs | 600 ++ .../openai/chat_completions/validate.rs | 410 + .../server/src/routes/openai/completions.rs | 571 ++ .../src/routes/openai/completions/convert.rs | 352 + .../src/routes/openai/completions/types.rs | 253 + .../src/routes/openai/completions/validate.rs | 178 + rust/src/server/src/routes/openai/mod.rs | 8 + rust/src/server/src/routes/openai/models.rs | 24 + .../src/routes/openai/utils/logprobs.rs | 256 + .../src/server/src/routes/openai/utils/mod.rs | 4 + .../routes/openai/utils/structured_outputs.rs | 132 + .../server/src/routes/openai/utils/types.rs | 425 ++ .../src/routes/openai/utils/validated_json.rs | 55 + rust/src/server/src/routes/sleep.rs | 78 + rust/src/server/src/routes/tests.rs | 3679 +++++++++ rust/src/server/src/state.rs | 120 + rust/src/server/src/utils.rs | 106 + rust/src/text/Cargo.toml | 34 + rust/src/text/src/backend/hf/config.rs | 373 + rust/src/text/src/backend/hf/mod.rs | 120 + rust/src/text/src/backend/hf/model_files.rs | 459 ++ rust/src/text/src/backend/mod.rs | 47 + rust/src/text/src/error.rs | 30 + rust/src/text/src/lib.rs | 149 + rust/src/text/src/lower.rs | 739 ++ rust/src/text/src/output/decoded.rs | 607 ++ rust/src/text/src/output/logprobs.rs | 236 + rust/src/text/src/output/mod.rs | 323 + rust/src/text/src/request.rs | 197 + rust/src/tokenizer/Cargo.toml | 35 + rust/src/tokenizer/benches/hf.rs | 118 + rust/src/tokenizer/benches/tiktoken.rs | 117 + rust/src/tokenizer/src/byte_level_decode.rs | 119 + rust/src/tokenizer/src/error.rs | 9 + rust/src/tokenizer/src/hf.rs | 344 + rust/src/tokenizer/src/incremental.rs | 359 + rust/src/tokenizer/src/lib.rs | 62 + rust/src/tokenizer/src/tekken.rs | 62 + rust/src/tokenizer/src/tiktoken.rs | 1013 +++ rust/src/tool-parser/Cargo.toml | 75 + rust/src/tool-parser/benches/deepseek_v3.rs | 130 + rust/src/tool-parser/benches/deepseek_v31.rs | 128 + rust/src/tool-parser/benches/deepseek_v32.rs | 113 + rust/src/tool-parser/benches/gemma4.rs | 125 + rust/src/tool-parser/benches/glm45_moe.rs | 210 + rust/src/tool-parser/benches/kimi_k2.rs | 149 + rust/src/tool-parser/benches/llama3_json.rs | 128 + rust/src/tool-parser/benches/minimax_m2.rs | 136 + rust/src/tool-parser/benches/qwen3_coder.rs | 138 + rust/src/tool-parser/benches/qwen3_xml.rs | 125 + rust/src/tool-parser/benches/utils/mod.rs | 49 + .../src/deepseek_dsml/deepseek_v32.rs | 439 ++ .../src/deepseek_dsml/deepseek_v4.rs | 133 + rust/src/tool-parser/src/deepseek_dsml/mod.rs | 278 + .../src/deepseek_json/deepseek_v3.rs | 235 + .../src/deepseek_json/deepseek_v31.rs | 239 + rust/src/tool-parser/src/deepseek_json/mod.rs | 308 + rust/src/tool-parser/src/error.rs | 13 + rust/src/tool-parser/src/gemma4.rs | 653 ++ rust/src/tool-parser/src/glm_xml/glm45_moe.rs | 43 + rust/src/tool-parser/src/glm_xml/glm47_moe.rs | 135 + rust/src/tool-parser/src/glm_xml/mod.rs | 435 ++ rust/src/tool-parser/src/json/hermes.rs | 221 + rust/src/tool-parser/src/json/llama.rs | 487 ++ rust/src/tool-parser/src/json/mistral.rs | 240 + rust/src/tool-parser/src/json/mod.rs | 465 ++ rust/src/tool-parser/src/json/qwen.rs | 279 + rust/src/tool-parser/src/kimi_k2.rs | 560 ++ rust/src/tool-parser/src/lib.rs | 141 + rust/src/tool-parser/src/minimax_m2.rs | 519 ++ rust/src/tool-parser/src/parameters.rs | 508 ++ rust/src/tool-parser/src/qwen_coder.rs | 632 ++ rust/src/tool-parser/src/test_utils.rs | 115 + rust/src/tool-parser/src/tests.rs | 97 + rust/src/tool-parser/src/utils.rs | 581 ++ tools/pre_commit/rust-check.sh | 47 + 284 files changed, 72569 insertions(+), 90 deletions(-) create mode 100755 .buildkite/scripts/run-rust-frontend-cargo-ci.sh create mode 100644 .buildkite/test_areas/rust_frontend_cargo.yaml delete mode 100644 .gitmodules delete mode 160000 rust create mode 100644 rust/.gitattributes create mode 100644 rust/.gitignore create mode 100644 rust/AGENTS.md create mode 100644 rust/CLAUDE.md create mode 100644 rust/Cargo.lock create mode 100644 rust/Cargo.toml create mode 100644 rust/README.md create mode 100644 rust/proto/vllm_grpc.proto create mode 100644 rust/rustfmt.toml create mode 100644 rust/rustfmt.unstable.toml create mode 100644 rust/src/chat/Cargo.toml create mode 100644 rust/src/chat/examples/README.md create mode 100644 rust/src/chat/examples/external_engine_chat_qwen.rs create mode 100644 rust/src/chat/src/backend/hf.rs create mode 100644 rust/src/chat/src/backend/mod.rs create mode 100644 rust/src/chat/src/error.rs create mode 100644 rust/src/chat/src/event.rs create mode 100644 rust/src/chat/src/lib.rs create mode 100644 rust/src/chat/src/multimodal.rs create mode 100644 rust/src/chat/src/multimodal/tensor.rs create mode 100644 rust/src/chat/src/output/default/mod.rs create mode 100644 rust/src/chat/src/output/default/reasoning.rs create mode 100644 rust/src/chat/src/output/default/tool.rs create mode 100644 rust/src/chat/src/output/harmony/mod.rs create mode 100644 rust/src/chat/src/output/harmony/tests.rs create mode 100644 rust/src/chat/src/output/mod.rs create mode 100644 rust/src/chat/src/output/structured.rs create mode 100644 rust/src/chat/src/parser/mod.rs create mode 100644 rust/src/chat/src/parser/reasoning/mod.rs create mode 100644 rust/src/chat/src/parser/reasoning/tests.rs create mode 100644 rust/src/chat/src/parser/tool/mod.rs create mode 100644 rust/src/chat/src/parser/tool/tests.rs create mode 100644 rust/src/chat/src/renderer/deepseek_v32/encoding.rs create mode 100644 rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input.json create mode 100644 rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input_search_w_date.json create mode 100644 rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input_search_wo_date.json create mode 100644 rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_search_w_date.txt create mode 100644 rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_search_wo_date.txt create mode 100644 rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_vllm_parity.txt create mode 100644 rust/src/chat/src/renderer/deepseek_v32/mod.rs create mode 100644 rust/src/chat/src/renderer/deepseek_v32/tests.rs create mode 100644 rust/src/chat/src/renderer/deepseek_v4/encoding.rs create mode 100644 rust/src/chat/src/renderer/deepseek_v4/fixtures/test_input_1.json create mode 100644 rust/src/chat/src/renderer/deepseek_v4/fixtures/test_input_2.json create mode 100644 rust/src/chat/src/renderer/deepseek_v4/fixtures/test_output_1.txt create mode 100644 rust/src/chat/src/renderer/deepseek_v4/fixtures/test_output_2.txt create mode 100644 rust/src/chat/src/renderer/deepseek_v4/mod.rs create mode 100644 rust/src/chat/src/renderer/deepseek_v4/tests.rs create mode 100644 rust/src/chat/src/renderer/hf/error.rs create mode 100644 rust/src/chat/src/renderer/hf/format.rs create mode 100644 rust/src/chat/src/renderer/hf/mod.rs create mode 100644 rust/src/chat/src/renderer/hf/template.rs create mode 100644 rust/src/chat/src/renderer/hf/tojson.rs create mode 100644 rust/src/chat/src/renderer/mod.rs create mode 100644 rust/src/chat/src/renderer/selection.rs create mode 100644 rust/src/chat/src/request.rs create mode 100644 rust/src/chat/src/stream.rs create mode 100644 rust/src/chat/tests/chat.rs create mode 100644 rust/src/chat/tests/templates/qwen3.jinja create mode 100644 rust/src/chat/tests/templates/qwen35.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/README.md create mode 100644 rust/src/chat/tests/templates/vllm_examples/template_alpaca.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/template_baichuan.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/template_chatglm.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/template_chatglm2.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/template_chatml.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/template_falcon.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/template_falcon_180b.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/template_inkbot.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/template_teleflm.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekr1.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekv3.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekv31.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_functiongemma.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_gemma3_pythonic.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_gemma4.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_glm4.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_granite.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_granite_20b_fc.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_hermes.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_hunyuan_a13b.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_internlm2_tool.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.1_json.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.2_json.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.2_pythonic.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama4_json.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama4_pythonic.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_minimax_m1.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral3.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral_parallel.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_phi4_mini.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_qwen3coder.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_toolace.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_xlam_llama.jinja create mode 100644 rust/src/chat/tests/templates/vllm_examples/tool_chat_template_xlam_qwen.jinja create mode 100644 rust/src/cmd/Cargo.toml create mode 100644 rust/src/cmd/examples/README.md create mode 100644 rust/src/cmd/src/cli.rs create mode 100644 rust/src/cmd/src/cli/tests.rs create mode 100644 rust/src/cmd/src/cli/unsupported.rs create mode 100644 rust/src/cmd/src/logging.rs create mode 100644 rust/src/cmd/src/main.rs create mode 100644 rust/src/engine-core-client/Cargo.toml create mode 100644 rust/src/engine-core-client/examples/README.md create mode 100644 rust/src/engine-core-client/examples/external_engine_logprobs.rs create mode 100644 rust/src/engine-core-client/examples/external_engine_utility_call.rs create mode 100644 rust/src/engine-core-client/src/client.rs create mode 100644 rust/src/engine-core-client/src/client/imp.rs create mode 100644 rust/src/engine-core-client/src/client/state.rs create mode 100644 rust/src/engine-core-client/src/client/stream.rs create mode 100644 rust/src/engine-core-client/src/coordinator/bootstrap.rs create mode 100644 rust/src/engine-core-client/src/coordinator/external.rs create mode 100644 rust/src/engine-core-client/src/coordinator/handle.rs create mode 100644 rust/src/engine-core-client/src/coordinator/inproc.rs create mode 100644 rust/src/engine-core-client/src/coordinator/mod.rs create mode 100644 rust/src/engine-core-client/src/error.rs create mode 100644 rust/src/engine-core-client/src/lib.rs create mode 100644 rust/src/engine-core-client/src/metrics.rs create mode 100644 rust/src/engine-core-client/src/protocol/classified_outputs.rs create mode 100644 rust/src/engine-core-client/src/protocol/dtype.rs create mode 100644 rust/src/engine-core-client/src/protocol/handshake.rs create mode 100644 rust/src/engine-core-client/src/protocol/logprobs.rs create mode 100644 rust/src/engine-core-client/src/protocol/logprobs/array.rs create mode 100644 rust/src/engine-core-client/src/protocol/logprobs/tests.rs create mode 100644 rust/src/engine-core-client/src/protocol/logprobs/wire.rs create mode 100644 rust/src/engine-core-client/src/protocol/mod.rs create mode 100644 rust/src/engine-core-client/src/protocol/multimodal.rs create mode 100644 rust/src/engine-core-client/src/protocol/stats.rs create mode 100644 rust/src/engine-core-client/src/protocol/tensor.rs create mode 100644 rust/src/engine-core-client/src/test_utils.rs create mode 100644 rust/src/engine-core-client/src/tests/client.rs create mode 100644 rust/src/engine-core-client/src/tests/mod.rs create mode 100755 rust/src/engine-core-client/src/tests/python_compat.py create mode 100644 rust/src/engine-core-client/src/transport.rs create mode 100644 rust/src/llm/Cargo.toml create mode 100644 rust/src/llm/examples/README.md create mode 100644 rust/src/llm/examples/external_engine_smoke.rs create mode 100644 rust/src/llm/src/error.rs create mode 100644 rust/src/llm/src/lib.rs create mode 100644 rust/src/llm/src/log_stats.rs create mode 100644 rust/src/llm/src/output.rs create mode 100644 rust/src/llm/src/request.rs create mode 100644 rust/src/llm/src/request_metrics.rs create mode 100644 rust/src/llm/tests/generate.rs create mode 100644 rust/src/managed-engine/Cargo.toml create mode 100644 rust/src/managed-engine/src/cli.rs create mode 100644 rust/src/managed-engine/src/lib.rs create mode 100644 rust/src/managed-engine/src/process.rs create mode 100644 rust/src/metrics/Cargo.toml create mode 100644 rust/src/metrics/src/api_server.rs create mode 100644 rust/src/metrics/src/lib.rs create mode 100644 rust/src/metrics/src/request.rs create mode 100644 rust/src/metrics/src/scheduler.rs create mode 100644 rust/src/reasoning-parser/Cargo.toml create mode 100644 rust/src/reasoning-parser/src/cohere_cmd.rs create mode 100644 rust/src/reasoning-parser/src/deepseek_r1.rs create mode 100644 rust/src/reasoning-parser/src/delimited.rs create mode 100644 rust/src/reasoning-parser/src/gemma4.rs create mode 100644 rust/src/reasoning-parser/src/kimi.rs create mode 100644 rust/src/reasoning-parser/src/lib.rs create mode 100644 rust/src/reasoning-parser/src/qwen3.rs create mode 100644 rust/src/reasoning-parser/src/tests.rs create mode 100644 rust/src/server/Cargo.toml create mode 100644 rust/src/server/build.rs create mode 100644 rust/src/server/examples/README.md create mode 100644 rust/src/server/examples/external_engine_openai_qwen.rs create mode 100644 rust/src/server/src/config.rs create mode 100644 rust/src/server/src/error.rs create mode 100644 rust/src/server/src/grpc/convert.rs create mode 100644 rust/src/server/src/grpc/mod.rs create mode 100644 rust/src/server/src/grpc/tests.rs create mode 100644 rust/src/server/src/lib.rs create mode 100644 rust/src/server/src/listener.rs create mode 100644 rust/src/server/src/middleware/load.rs create mode 100644 rust/src/server/src/middleware/metrics.rs create mode 100644 rust/src/server/src/middleware/mod.rs create mode 100644 rust/src/server/src/routes.rs create mode 100644 rust/src/server/src/routes/cache.rs create mode 100644 rust/src/server/src/routes/collective_rpc.rs create mode 100644 rust/src/server/src/routes/health.rs create mode 100644 rust/src/server/src/routes/http_client_tests.rs create mode 100644 rust/src/server/src/routes/inference/generate.rs create mode 100644 rust/src/server/src/routes/inference/generate/convert.rs create mode 100644 rust/src/server/src/routes/inference/generate/types.rs create mode 100644 rust/src/server/src/routes/inference/generate/validate.rs create mode 100644 rust/src/server/src/routes/inference/mod.rs create mode 100644 rust/src/server/src/routes/load.rs create mode 100644 rust/src/server/src/routes/metrics.rs create mode 100644 rust/src/server/src/routes/openai/chat_completions.rs create mode 100644 rust/src/server/src/routes/openai/chat_completions/convert.rs create mode 100644 rust/src/server/src/routes/openai/chat_completions/types.rs create mode 100644 rust/src/server/src/routes/openai/chat_completions/validate.rs create mode 100644 rust/src/server/src/routes/openai/completions.rs create mode 100644 rust/src/server/src/routes/openai/completions/convert.rs create mode 100644 rust/src/server/src/routes/openai/completions/types.rs create mode 100644 rust/src/server/src/routes/openai/completions/validate.rs create mode 100644 rust/src/server/src/routes/openai/mod.rs create mode 100644 rust/src/server/src/routes/openai/models.rs create mode 100644 rust/src/server/src/routes/openai/utils/logprobs.rs create mode 100644 rust/src/server/src/routes/openai/utils/mod.rs create mode 100644 rust/src/server/src/routes/openai/utils/structured_outputs.rs create mode 100644 rust/src/server/src/routes/openai/utils/types.rs create mode 100644 rust/src/server/src/routes/openai/utils/validated_json.rs create mode 100644 rust/src/server/src/routes/sleep.rs create mode 100644 rust/src/server/src/routes/tests.rs create mode 100644 rust/src/server/src/state.rs create mode 100644 rust/src/server/src/utils.rs create mode 100644 rust/src/text/Cargo.toml create mode 100644 rust/src/text/src/backend/hf/config.rs create mode 100644 rust/src/text/src/backend/hf/mod.rs create mode 100644 rust/src/text/src/backend/hf/model_files.rs create mode 100644 rust/src/text/src/backend/mod.rs create mode 100644 rust/src/text/src/error.rs create mode 100644 rust/src/text/src/lib.rs create mode 100644 rust/src/text/src/lower.rs create mode 100644 rust/src/text/src/output/decoded.rs create mode 100644 rust/src/text/src/output/logprobs.rs create mode 100644 rust/src/text/src/output/mod.rs create mode 100644 rust/src/text/src/request.rs create mode 100644 rust/src/tokenizer/Cargo.toml create mode 100644 rust/src/tokenizer/benches/hf.rs create mode 100644 rust/src/tokenizer/benches/tiktoken.rs create mode 100644 rust/src/tokenizer/src/byte_level_decode.rs create mode 100644 rust/src/tokenizer/src/error.rs create mode 100644 rust/src/tokenizer/src/hf.rs create mode 100644 rust/src/tokenizer/src/incremental.rs create mode 100644 rust/src/tokenizer/src/lib.rs create mode 100644 rust/src/tokenizer/src/tekken.rs create mode 100644 rust/src/tokenizer/src/tiktoken.rs create mode 100644 rust/src/tool-parser/Cargo.toml create mode 100644 rust/src/tool-parser/benches/deepseek_v3.rs create mode 100644 rust/src/tool-parser/benches/deepseek_v31.rs create mode 100644 rust/src/tool-parser/benches/deepseek_v32.rs create mode 100644 rust/src/tool-parser/benches/gemma4.rs create mode 100644 rust/src/tool-parser/benches/glm45_moe.rs create mode 100644 rust/src/tool-parser/benches/kimi_k2.rs create mode 100644 rust/src/tool-parser/benches/llama3_json.rs create mode 100644 rust/src/tool-parser/benches/minimax_m2.rs create mode 100644 rust/src/tool-parser/benches/qwen3_coder.rs create mode 100644 rust/src/tool-parser/benches/qwen3_xml.rs create mode 100644 rust/src/tool-parser/benches/utils/mod.rs create mode 100644 rust/src/tool-parser/src/deepseek_dsml/deepseek_v32.rs create mode 100644 rust/src/tool-parser/src/deepseek_dsml/deepseek_v4.rs create mode 100644 rust/src/tool-parser/src/deepseek_dsml/mod.rs create mode 100644 rust/src/tool-parser/src/deepseek_json/deepseek_v3.rs create mode 100644 rust/src/tool-parser/src/deepseek_json/deepseek_v31.rs create mode 100644 rust/src/tool-parser/src/deepseek_json/mod.rs create mode 100644 rust/src/tool-parser/src/error.rs create mode 100644 rust/src/tool-parser/src/gemma4.rs create mode 100644 rust/src/tool-parser/src/glm_xml/glm45_moe.rs create mode 100644 rust/src/tool-parser/src/glm_xml/glm47_moe.rs create mode 100644 rust/src/tool-parser/src/glm_xml/mod.rs create mode 100644 rust/src/tool-parser/src/json/hermes.rs create mode 100644 rust/src/tool-parser/src/json/llama.rs create mode 100644 rust/src/tool-parser/src/json/mistral.rs create mode 100644 rust/src/tool-parser/src/json/mod.rs create mode 100644 rust/src/tool-parser/src/json/qwen.rs create mode 100644 rust/src/tool-parser/src/kimi_k2.rs create mode 100644 rust/src/tool-parser/src/lib.rs create mode 100644 rust/src/tool-parser/src/minimax_m2.rs create mode 100644 rust/src/tool-parser/src/parameters.rs create mode 100644 rust/src/tool-parser/src/qwen_coder.rs create mode 100644 rust/src/tool-parser/src/test_utils.rs create mode 100644 rust/src/tool-parser/src/tests.rs create mode 100644 rust/src/tool-parser/src/utils.rs create mode 100755 tools/pre_commit/rust-check.sh diff --git a/.buildkite/image_build/image_build.sh b/.buildkite/image_build/image_build.sh index f93e86d9ec0..10c03c3e177 100755 --- a/.buildkite/image_build/image_build.sh +++ b/.buildkite/image_build/image_build.sh @@ -223,13 +223,6 @@ echo "CACHE_FROM_MAIN: ${CACHE_FROM_MAIN}" check_and_skip_if_image_exists -# The rust frontend lives in a git submodule under rust/. Buildkite's default -# checkout does not recurse submodules, and the Dockerfile only sees what's in -# the build context, so initialize the submodule here before invoking bake. -echo "--- :git: Initializing git submodules" -git submodule sync --recursive -git submodule update --init --recursive - echo "--- :docker: Setting up Docker buildx bake" echo "Target: ${TARGET}" echo "vLLM bake file: ${VLLM_BAKE_FILE_PATH}" diff --git a/.buildkite/image_build/image_build_cpu.sh b/.buildkite/image_build/image_build_cpu.sh index a8f76d5b265..035f070ab89 100755 --- a/.buildkite/image_build/image_build_cpu.sh +++ b/.buildkite/image_build/image_build_cpu.sh @@ -21,12 +21,6 @@ else exit 0 fi -# The rust frontend lives in a git submodule under rust/. Buildkite's default -# checkout does not recurse submodules, and the Dockerfile only sees what's in -# the build context, so initialize the submodule here before building. -git submodule sync --recursive -git submodule update --init --recursive - # build docker build --file docker/Dockerfile.cpu \ --build-arg max_jobs=16 \ diff --git a/.buildkite/image_build/image_build_cpu_arm64.sh b/.buildkite/image_build/image_build_cpu_arm64.sh index 9ee0b353aa8..b561e2c2e46 100755 --- a/.buildkite/image_build/image_build_cpu_arm64.sh +++ b/.buildkite/image_build/image_build_cpu_arm64.sh @@ -21,12 +21,6 @@ else exit 0 fi -# The rust frontend lives in a git submodule under rust/. Buildkite's default -# checkout does not recurse submodules, and the Dockerfile only sees what's in -# the build context, so initialize the submodule here before building. -git submodule sync --recursive -git submodule update --init --recursive - # build docker build --file docker/Dockerfile.cpu \ --build-arg max_jobs=16 \ diff --git a/.buildkite/scripts/run-rust-frontend-cargo-ci.sh b/.buildkite/scripts/run-rust-frontend-cargo-ci.sh new file mode 100755 index 00000000000..6ce9b5200c4 --- /dev/null +++ b/.buildkite/scripts/run-rust-frontend-cargo-ci.sh @@ -0,0 +1,156 @@ +#!/usr/bin/env bash +set -euo pipefail + +MODE="${1:-}" + +if [[ "$MODE" != "style-clippy" && "$MODE" != "test" ]]; then + echo "Usage: $0 {style-clippy|test}" >&2 + exit 2 +fi + +ROOT_DIR="$(git rev-parse --show-toplevel)" +cd "$ROOT_DIR" + +export CARGO_TERM_COLOR="${CARGO_TERM_COLOR:-always}" +export CARGO_HOME="${CARGO_HOME:-$HOME/.cargo}" +export RUSTUP_HOME="${RUSTUP_HOME:-$HOME/.rustup}" +export PATH="$CARGO_HOME/bin:$PATH" + +log_section() { + echo "--- $*" +} + +install_protoc() { + if command -v protoc >/dev/null 2>&1; then + return + fi + + local version="${PROTOC_VERSION:-31.1}" + local arch + case "$(uname -m)" in + x86_64) + arch="x86_64" + ;; + aarch64|arm64) + arch="aarch_64" + ;; + *) + echo "Unsupported protoc architecture: $(uname -m)" >&2 + return 1 + ;; + esac + + local url="https://github.com/protocolbuffers/protobuf/releases/download/v${version}/protoc-${version}-linux-${arch}.zip" + local tmp_dir + tmp_dir="$(mktemp -d)" + + log_section "Installing protoc ${version}" + curl -L --proto '=https' --tlsv1.2 -sSf "$url" -o "$tmp_dir/protoc.zip" + mkdir -p "$CARGO_HOME/bin" + unzip -q "$tmp_dir/protoc.zip" bin/protoc 'include/*' -d "$CARGO_HOME" + chmod +x "$CARGO_HOME/bin/protoc" + rm -rf "$tmp_dir" +} + +rust_toolchain() { + awk -F '"' '/channel[[:space:]]*=/ { print $2; exit }' rust-toolchain.toml +} + +install_rust_toolchain() { + log_section "Installing Rust toolchain" + if ! command -v rustup >/dev/null 2>&1; then + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs \ + | sh -s -- -y --profile minimal --default-toolchain none + fi + + local toolchain + toolchain="$(rust_toolchain)" + rustup toolchain install "$toolchain" --profile minimal --component rustfmt,clippy + rustup component add --toolchain "$toolchain" rustfmt clippy +} + +install_cargo_binstall() { + if command -v cargo-binstall >/dev/null 2>&1; then + return + fi + + log_section "Installing cargo-binstall" + curl -L --proto '=https' --tlsv1.2 -sSf \ + https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh \ + | bash +} + +install_cargo_sort() { + if command -v cargo-sort >/dev/null 2>&1; then + return + fi + + log_section "Installing cargo-sort" + install_cargo_binstall + cargo binstall --no-confirm cargo-sort +} + +install_cargo_nextest() { + if command -v cargo-nextest >/dev/null 2>&1; then + return + fi + + log_section "Installing cargo-nextest" + install_cargo_binstall + cargo binstall --no-confirm --secure cargo-nextest +} + +install_uv() { + if command -v uv >/dev/null 2>&1; then + return + fi + + log_section "Installing uv" + curl -LsSf --proto '=https' --tlsv1.2 https://astral.sh/uv/install.sh \ + | env UV_INSTALL_DIR="$CARGO_HOME/bin" sh +} + +run_style_clippy() { + install_cargo_sort + + log_section "Checking Rust formatting" + cargo fmt --manifest-path rust/Cargo.toml --all -- --check + + log_section "Checking Cargo.toml ordering" + cargo sort --workspace --check rust + + log_section "Running clippy" + cargo clippy \ + --manifest-path rust/Cargo.toml \ + --workspace \ + --all-targets \ + --all-features \ + --locked \ + -- \ + -D warnings +} + +run_tests() { + install_uv + install_cargo_nextest + + log_section "Running cargo nextest" + cargo nextest run \ + --manifest-path rust/Cargo.toml \ + --workspace \ + --all-features \ + --locked \ + --no-fail-fast +} + +install_protoc +install_rust_toolchain + +case "$MODE" in + style-clippy) + run_style_clippy + ;; + test) + run_tests + ;; +esac diff --git a/.buildkite/test_areas/rust_frontend.yaml b/.buildkite/test_areas/rust_frontend.yaml index f0d2c2f9fb9..f750d58be58 100644 --- a/.buildkite/test_areas/rust_frontend.yaml +++ b/.buildkite/test_areas/rust_frontend.yaml @@ -1,4 +1,4 @@ -group: Rust Frontend +group: Rust Frontend E2E depends_on: - image-build steps: diff --git a/.buildkite/test_areas/rust_frontend_cargo.yaml b/.buildkite/test_areas/rust_frontend_cargo.yaml new file mode 100644 index 00000000000..06f9eb9c245 --- /dev/null +++ b/.buildkite/test_areas/rust_frontend_cargo.yaml @@ -0,0 +1,30 @@ +group: Rust Frontend Cargo +depends_on: [] +steps: +- label: Rust Frontend Cargo Style + Clippy + key: rust-frontend-cargo-style-clippy + depends_on: [] + timeout_in_minutes: 30 + device: cpu-medium + no_plugin: true + source_file_dependencies: + - rust/ + - rust-toolchain.toml + - .buildkite/test_areas/rust_frontend_cargo.yaml + - .buildkite/scripts/run-rust-frontend-cargo-ci.sh + commands: + - .buildkite/scripts/run-rust-frontend-cargo-ci.sh style-clippy + +- label: Rust Frontend Cargo Tests + key: rust-frontend-cargo-tests + depends_on: [] + timeout_in_minutes: 30 + device: cpu-medium + no_plugin: true + source_file_dependencies: + - rust/ + - rust-toolchain.toml + - .buildkite/test_areas/rust_frontend_cargo.yaml + - .buildkite/scripts/run-rust-frontend-cargo-ci.sh + commands: + - .buildkite/scripts/run-rust-frontend-cargo-ci.sh test diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 79d557ecd69..00000000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "rust"] - path = rust - url = https://github.com/Inferact/vllm-frontend-rs.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f1e0afebf21..c6658ff735e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -256,6 +256,32 @@ repos: entry: python tools/pre_commit/check_boolean_context_manager.py language: python types: [python] + # Rust hooks. These shell out to `cargo`; tools/pre_commit/rust-check.sh + # skips with a warning when cargo is not installed. + - id: rust-cargo-autoinherit + name: Rust - Normalize Cargo manifests with autoinherit + entry: tools/pre_commit/rust-check.sh autoinherit --prefer-simple-dotted + language: script + pass_filenames: false + require_serial: true + stages: [pre-commit] # Only run locally as Buildkite will cover this + files: ^rust/(Cargo\.toml|src/.*/Cargo\.toml)$ + - id: rust-cargo-sort + name: Rust - Sort Cargo manifest sections + entry: tools/pre_commit/rust-check.sh sort --workspace + language: script + pass_filenames: false + require_serial: true + stages: [pre-commit] # Only run locally as Buildkite will cover this + files: ^rust/(Cargo\.toml|src/.*/Cargo\.toml)$ + - id: rust-cargo-fmt + name: Rust - Format code + entry: tools/pre_commit/rust-check.sh fmt + language: script + pass_filenames: false + require_serial: true + stages: [pre-commit] # Only run locally as Buildkite will cover this + files: ^rust/.*(\.rs|Cargo\.toml|rustfmt\.toml)$ # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/build_rust.sh b/build_rust.sh index fb4a589de4c..98871ec8abc 100755 --- a/build_rust.sh +++ b/build_rust.sh @@ -9,10 +9,10 @@ set -euo pipefail REPO_ROOT="$(cd "$(dirname "$0")" && pwd)" RUST_DIR="$REPO_ROOT/rust" -TARGET_PATH="$REPO_ROOT/vllm/vllm-rs" +TARGET_PATH="${VLLM_RS_TARGET_PATH:-$REPO_ROOT/vllm/vllm-rs}" # Read the required toolchain from rust-toolchain.toml. -TOOLCHAIN=$(grep '^channel' "$RUST_DIR/rust-toolchain.toml" | sed 's/.*= *"\(.*\)"/\1/') +TOOLCHAIN=$(grep '^channel' "$REPO_ROOT/rust-toolchain.toml" | sed 's/.*= *"\(.*\)"/\1/') # Ensure rustup and the required toolchain are available. if ! command -v rustup &>/dev/null; then @@ -39,5 +39,6 @@ cargo +"$TOOLCHAIN" build "${PROFILE_ARGS[@]}" \ --bin vllm-rs \ --features native-tls-vendored +mkdir -p "$(dirname "$TARGET_PATH")" cp "$RUST_DIR/target/$PROFILE_DIR/vllm-rs" "$TARGET_PATH" echo "Installed vllm-rs to $TARGET_PATH" diff --git a/docker/Dockerfile b/docker/Dockerfile index 7929ce8e92d..6b6c4bdfba5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -256,7 +256,7 @@ RUN if [ "${BUILD_OS}" = "manylinux" ]; then \ COPY tools/install_protoc.sh /tmp/install_protoc.sh RUN /tmp/install_protoc.sh && rm /tmp/install_protoc.sh -# Install rustup; the toolchain itself is pinned by rust/rust-toolchain.toml. +# Install rustup; the toolchain itself is pinned by rust-toolchain.toml. RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | \ sh -s -- -y --profile minimal --default-toolchain none ENV PATH="/root/.cargo/bin:${PATH}" @@ -265,14 +265,8 @@ WORKDIR /workspace # Copy only the rust workspace — the binary is the sole artifact we need. COPY rust rust - -# Fail loudly if the rust submodule was not initialized on the host before -# `docker build`. Without this check, cargo would emit a confusing error. -RUN if [ ! -f rust/Cargo.toml ]; then \ - echo "ERROR: rust/ submodule is not initialized."; \ - echo "Run 'git submodule update --init --recursive' on the host before building."; \ - exit 1; \ - fi +COPY rust-toolchain.toml rust-toolchain.toml +COPY build_rust.sh build_rust.sh # Cap cargo parallelism to avoid exhausting the CI host's open-file limit # (rustc spawns enough concurrent processes to hit RLIMIT_NOFILE otherwise). @@ -284,9 +278,7 @@ ENV CARGO_BUILD_JOBS=4 RUN --mount=type=cache,target=/root/.cargo/registry \ --mount=type=cache,target=/root/.cargo/git \ --mount=type=cache,target=/workspace/rust/target \ - cd rust \ - && cargo build --release --bin vllm-rs --features native-tls-vendored \ - && cp target/release/vllm-rs /workspace/vllm-rs + VLLM_RS_TARGET_PATH=/workspace/vllm-rs bash build_rust.sh #################### RUST BUILD IMAGE #################### #################### CSRC BUILD IMAGE #################### diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 86a076976d3..8571a9c5d6d 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -96,7 +96,7 @@ RUN apt-get update -y \ COPY tools/install_protoc.sh /tmp/install_protoc.sh RUN /tmp/install_protoc.sh && rm /tmp/install_protoc.sh -# Install rustup; the toolchain itself is pinned by rust/rust-toolchain.toml. +# Install rustup; the toolchain itself is pinned by rust-toolchain.toml. RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | \ sh -s -- -y --profile minimal --default-toolchain none ENV PATH="/root/.cargo/bin:${PATH}" @@ -105,14 +105,8 @@ WORKDIR /workspace # Copy only the rust workspace — the binary is the sole artifact we need. COPY rust rust - -# Fail loudly if the rust submodule was not initialized on the host before -# `docker build`. Without this check, cargo would emit a confusing error. -RUN if [ ! -f rust/Cargo.toml ]; then \ - echo "ERROR: rust/ submodule is not initialized."; \ - echo "Run 'git submodule update --init --recursive' on the host before building."; \ - exit 1; \ - fi +COPY rust-toolchain.toml rust-toolchain.toml +COPY build_rust.sh build_rust.sh # Cap cargo parallelism to avoid exhausting the CI host's open-file limit # (rustc spawns enough concurrent processes to hit RLIMIT_NOFILE otherwise). @@ -124,9 +118,7 @@ ENV CARGO_BUILD_JOBS=4 RUN --mount=type=cache,target=/root/.cargo/registry \ --mount=type=cache,target=/root/.cargo/git \ --mount=type=cache,target=/workspace/rust/target \ - cd rust \ - && cargo build --release --bin vllm-rs --features native-tls-vendored \ - && cp target/release/vllm-rs /workspace/vllm-rs + VLLM_RS_TARGET_PATH=/workspace/vllm-rs bash build_rust.sh ######################### BUILD IMAGE ######################### FROM base AS vllm-build diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index 3aa328a5fc3..0d5a9cc5f83 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -108,7 +108,7 @@ RUN apt-get update -y \ COPY tools/install_protoc.sh /tmp/install_protoc.sh RUN /tmp/install_protoc.sh && rm /tmp/install_protoc.sh -# Install rustup; the toolchain itself is pinned by rust/rust-toolchain.toml. +# Install rustup; the toolchain itself is pinned by rust-toolchain.toml. RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | \ sh -s -- -y --profile minimal --default-toolchain none ENV PATH="/root/.cargo/bin:${PATH}" @@ -116,14 +116,8 @@ ENV PATH="/root/.cargo/bin:${PATH}" WORKDIR /workspace COPY rust rust - -# Fail loudly if the rust submodule was not initialized on the host before -# `docker build`. -RUN if [ ! -f rust/Cargo.toml ]; then \ - echo "ERROR: rust/ submodule is not initialized."; \ - echo "Run 'git submodule update --init --recursive' on the host before building."; \ - exit 1; \ - fi +COPY rust-toolchain.toml rust-toolchain.toml +COPY build_rust.sh build_rust.sh # Cap cargo parallelism to avoid exhausting the CI host's open-file limit # (rustc spawns enough concurrent processes to hit RLIMIT_NOFILE otherwise). @@ -132,9 +126,7 @@ ENV CARGO_BUILD_JOBS=4 RUN --mount=type=cache,target=/root/.cargo/registry \ --mount=type=cache,target=/root/.cargo/git \ --mount=type=cache,target=/workspace/rust/target \ - cd rust \ - && cargo build --release --bin vllm-rs --features native-tls-vendored \ - && cp target/release/vllm-rs /workspace/vllm-rs + VLLM_RS_TARGET_PATH=/workspace/vllm-rs bash build_rust.sh #################### RUST BUILD IMAGE #################### #################### WHEEL BUILD IMAGE #################### diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index a391c192c80..b2342200a68 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -106,7 +106,6 @@ ONBUILD RUN git clone ${VLLM_REPO} \ && cd vllm \ && git fetch -v --prune -- origin ${VLLM_BRANCH} \ && git checkout FETCH_HEAD \ - && git submodule update --init --recursive \ && if [ ${VLLM_REPO} != "https://github.com/vllm-project/vllm.git" ] ; then \ git remote add upstream "https://github.com/vllm-project/vllm.git" \ && git fetch upstream ; fi @@ -120,16 +119,7 @@ FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm FROM fetch_vllm AS rust-build ARG COMMON_WORKDIR -# Fail loudly if the rust submodule was not initialized on the host before -# `docker build`. The rust frontend source is brought in via the fetch_vllm -# stage, so an uninitialized submodule would otherwise produce a confusing -# cargo failure. -RUN if [ ! -f ${COMMON_WORKDIR}/vllm/rust/Cargo.toml ]; then \ - echo "ERROR: rust/ submodule is not initialized."; \ - echo "Run 'git submodule update --init --recursive' on the host before building."; \ - exit 1; \ - fi - +# protoc is used by tonic-build/prost-build. RUN apt-get update -q -y && apt-get install -q -y --no-install-recommends \ ca-certificates curl unzip \ && rm -rf /var/lib/apt/lists/* @@ -137,7 +127,7 @@ RUN apt-get update -q -y && apt-get install -q -y --no-install-recommends \ COPY tools/install_protoc.sh /tmp/install_protoc.sh RUN /tmp/install_protoc.sh && rm /tmp/install_protoc.sh -# Install rustup; the toolchain itself is pinned by rust/rust-toolchain.toml. +# Install rustup; the toolchain itself is pinned by rust-toolchain.toml. RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | \ sh -s -- -y --profile minimal --default-toolchain none ENV PATH="/root/.cargo/bin:${PATH}" @@ -150,9 +140,8 @@ ENV CARGO_BUILD_JOBS=4 # so it persists into the image layer for later COPY --from=rust-build. RUN --mount=type=cache,target=/root/.cargo/registry \ --mount=type=cache,target=/root/.cargo/git \ - cd ${COMMON_WORKDIR}/vllm/rust \ - && cargo build --release --bin vllm-rs --features native-tls-vendored \ - && cp target/release/vllm-rs /tmp/vllm-rs + cd ${COMMON_WORKDIR}/vllm \ + && VLLM_RS_TARGET_PATH=/tmp/vllm-rs bash build_rust.sh # ----------------------- # vLLM build stages diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 5432f85feb0..ef05b4aa2e5 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -12,7 +12,7 @@ RUN apt-get update -y \ COPY tools/install_protoc.sh /tmp/install_protoc.sh RUN /tmp/install_protoc.sh && rm /tmp/install_protoc.sh -# Install rustup; the toolchain itself is pinned by rust/rust-toolchain.toml. +# Install rustup; the toolchain itself is pinned by rust-toolchain.toml. RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | \ sh -s -- -y --profile minimal --default-toolchain none ENV PATH="/root/.cargo/bin:${PATH}" @@ -20,14 +20,8 @@ ENV PATH="/root/.cargo/bin:${PATH}" WORKDIR /workspace COPY rust rust - -# Fail loudly if the rust submodule was not initialized on the host before -# `docker build`. -RUN if [ ! -f rust/Cargo.toml ]; then \ - echo "ERROR: rust/ submodule is not initialized."; \ - echo "Run 'git submodule update --init --recursive' on the host before building."; \ - exit 1; \ - fi +COPY rust-toolchain.toml rust-toolchain.toml +COPY build_rust.sh build_rust.sh # Cap cargo parallelism to avoid exhausting the CI host's open-file limit # (rustc spawns enough concurrent processes to hit RLIMIT_NOFILE otherwise). @@ -36,9 +30,7 @@ ENV CARGO_BUILD_JOBS=4 RUN --mount=type=cache,target=/root/.cargo/registry \ --mount=type=cache,target=/root/.cargo/git \ --mount=type=cache,target=/workspace/rust/target \ - cd rust \ - && cargo build --release --bin vllm-rs --features native-tls-vendored \ - && cp target/release/vllm-rs /workspace/vllm-rs + VLLM_RS_TARGET_PATH=/workspace/vllm-rs bash build_rust.sh FROM intel/deep-learning-essentials:2025.3.2-0-devel-ubuntu24.04 AS vllm-base diff --git a/pyproject.toml b/pyproject.toml index ebf4c51af91..faba7f3df3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,7 +127,10 @@ extend-exclude = ["tests/models/fixtures/*", "tests/prompts/*", "tests/tokenizer "vllm/third_party/*", "vllm/entrypoints/serve/instrumentator/static/*", "tests/entrypoints/speech_to_text/transcription/test_transcription_validation.py", "docs/governance/process.md", "docs/assets/contributing/vllm_bench_serve_timeline.html", - "tests/v1/engine/test_fast_incdec_prefix_err.py", ".git/*", "csrc/cpu/sgl-kernels/*"] + "tests/v1/engine/test_fast_incdec_prefix_err.py", ".git/*", "csrc/cpu/sgl-kernels/*", + "rust/src/chat/src/renderer/deepseek_v32/fixtures/*", + "rust/src/tool-parser/src/gemma4.rs", "rust/src/text/src/output/decoded.rs", + "rust/src/tokenizer/src/incremental.rs", "rust/src/reasoning-parser/src/tests.rs"] ignore-hidden = false [tool.typos.default] diff --git a/requirements/xpu.txt b/requirements/xpu.txt index a25d4e5e0f8..105b8cec372 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -5,6 +5,7 @@ ray>=2.9 cmake>=3.26.1 packaging>=24.2 setuptools-scm>=8 +setuptools-rust>=1.9.0 setuptools>=77.0.3,<81.0.0 setuptools-rust>=1.9.0 wheel diff --git a/rust b/rust deleted file mode 160000 index ad6771ac093..00000000000 --- a/rust +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ad6771ac093420632253ca43bfc81007beb2d6ce diff --git a/rust/.gitattributes b/rust/.gitattributes new file mode 100644 index 00000000000..2f7173592d7 --- /dev/null +++ b/rust/.gitattributes @@ -0,0 +1,2 @@ +src/chat/src/renderer/deepseek_v32/fixtures/** linguist-generated=true +src/chat/tests/templates/** linguist-vendored=true diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 00000000000..14f70eb7d43 --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,3 @@ +/target +AGENTS.override.md +.vscode diff --git a/rust/AGENTS.md b/rust/AGENTS.md new file mode 100644 index 00000000000..37117c7490b --- /dev/null +++ b/rust/AGENTS.md @@ -0,0 +1,37 @@ +# Alternative Frontend to vLLM Engine in Rust + +This project aims to implement an alternative frontend to the vLLM Engine in Rust, providing a more efficient and robust interface for interacting with the engine. Currently it's still in the very early stage and is actively evolving. + +## Coding Styles + +- Always use workspace dependencies for Cargo crates. +- Prefer splitting code into multiple smaller modules and files for better organization and readability, rather than putting everything in a single file. +- When refactoring or reconstructing code, always preserve the original comments and documentation VERBATIM, if applicable. +- If not specified, default to writing concise Rust documentation and comments that match the style of the existing codebase when generating code. +- When migrating code from Python or any other language, preserve the original documentation comments whenever they still make sense in the Rust code. +- Although you might be asked to only implement or migrate minimal functionality at the beginning, you should still leave necessary `TODO` comments in the code for the future improvements of the lacked features, so that it's easier for the next iteration to build upon the existing codebase. +- When writing parsers with `winnow`: + - Prefer a declarative parser shape over imperative step-by-step parsing, as long as it's more readable and maintainable. + - Prefer tuple-based parser composition over calling `parse_next` one parser at a time. + - Prefer built-in combinators and token parsers before adding local helpers. + - Add short documentation comments like `Parse a ..` to all local parser/combinator functions. + - Reuse existing utilities from `utils` module as much as possible, and add new ones there if needed. +- Rust error handling: + - Never call `to_string()` directly on an error value. + - Use `ToReportString` or `AsReport` by `thiserror-ext` instead. + - For `Error` variants that are primarily free-form text, prefer a struct variant with a `message: String` field. `thiserror_ext::Macro` will auto-derive `foo!(...)` and `bail_foo!(...)` helper macros from that shape. + - Use `foo!(...)` when you need to construct an error value inside an expression, such as `Err(foo!(...))`, `.ok_or_else(|| foo!(...))`, or `Err::<(), _>(foo!(...))?`. + - Use `bail_foo!(...)` only in statement positions where you want to exit the current `Result`-returning function immediately. Prefer it over `return Err(foo!(...))` in those cases. + - If a variant has extra structured fields, prefer the generated macro form `foo!(field = value, "message")` rather than manually writing `Error::Foo { ... }`. +- Since the project is still in early stage, it's fine to break API and make non-backwards-compatible changes as needed. +- Currently the project is only targeting Unix-like platforms, so it's fine to use Unix-specific APIs without extra compatibility layers like `cfg(unix)` + +## Testing + +- Prefer snapshot testing with the `expect-test` crate over writing multiple `assert_eq!` statements on individual fields. Use `expect_test::expect![[...]].assert_debug_eq(...)` to snapshot the `Debug` output of the entire struct. + - Write `expect![[""]]` as a placeholder first, then run `UPDATE_EXPECT=1 cargo test` to auto-fill the snapshot content. + - For values containing non-deterministic data (e.g., UUIDs), set them to a fixed value like `""` before snapshotting. +- In tests, avoid hand-writing full request struct literals when only a few fields matter. Prefer test fixtures such as `for_test()` with struct update syntax, so newly added fields do not force mechanical edits across many tests. +- Prefer deterministic synchronization in async and integration tests, such as channels, barriers, explicit handshakes, or observable state transitions, instead of `sleep`-based timing assumptions. + - Use `sleep` only as a last resort when there is no better observable synchronization point. +- Always run test with `cargo nextest run` instead of `cargo test`, if available, as it's much faster. diff --git a/rust/CLAUDE.md b/rust/CLAUDE.md new file mode 100644 index 00000000000..f96fd34544b --- /dev/null +++ b/rust/CLAUDE.md @@ -0,0 +1,4 @@ +# CLAUDE.md + +First, check @AGENTS.override.md if exists. +Then, follow instructions in @AGENTS.md. diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 00000000000..16329d1fb3a --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,6611 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "getrandom 0.3.4", + "once_cell", + "serde", + "version_check", + "zerocopy", +] + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "aligned" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4508988c62edf04abd8d92897fca0c2995d907ce1dfeaf369dac3716a40685" +dependencies = [ + "as-slice", +] + +[[package]] +name = "aligned-vec" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc890384c8602f339876ded803c97ad529f3842aba97f6392b3dba0dd171769b" +dependencies = [ + "equator", +] + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse 0.2.7", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstream" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" +dependencies = [ + "anstyle", + "anstyle-parse 1.0.0", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-parse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" + +[[package]] +name = "arc-swap" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" +dependencies = [ + "rustversion", +] + +[[package]] +name = "arg_enum_proc_macro" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "as-slice" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "516b6b4f0e40d50dcda9365d53964ec74560ad4284da2e7fc97122cd83174516" +dependencies = [ + "stable_deref_trait", +] + +[[package]] +name = "async-io" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "456b8a8feb6f42d237746d4b3e9a178494627745c3c56c6ea55d92ba50d026fc" +dependencies = [ + "autocfg", + "cfg-if", + "concurrent-queue", + "futures-io", + "futures-lite", + "parking", + "polling", + "rustix", + "slab", + "windows-sys 0.61.2", +] + +[[package]] +name = "async-openai" +version = "0.33.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc48c3deb4ad9a2ee8c8e364c79eb0f74e69e17ed7e883d55988b90ea44fe986" +dependencies = [ + "async-openai-macros", + "backoff", + "base64 0.22.1", + "bytes", + "derive_builder", + "eventsource-stream", + "futures", + "getrandom 0.3.4", + "hex", + "hmac", + "rand 0.9.2", + "reqwest", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "serde_urlencoded", + "sha2", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tokio-tungstenite", + "tokio-util", + "tracing", + "url", +] + +[[package]] +name = "async-openai-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81872a8e595e8ceceab71c6ba1f9078e313b452a1e31934e6763ef5d308705e4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "asynchronous-codec" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a860072022177f903e59730004fb5dc13db9275b79bb2aef7ba8ce831956c233" +dependencies = [ + "bytes", + "futures-sink", + "futures-util", + "memchr", + "pin-project-lite", +] + +[[package]] +name = "asynk-strim" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52697735bdaac441a29391a9e97102c74c6ef0f9b60a40cf109b1b404e29d2f6" +dependencies = [ + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "asynk-strim-attr" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6ccb67be092524ce594e599332719f1cd6d64dcaed8d46f1e8726d466c10bcb" +dependencies = [ + "asynk-strim", + "asynk-strim-attr-macro", + "futures-core", +] + +[[package]] +name = "asynk-strim-attr-macro" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34e40a2181bb16fb68e25c49c8b3e25bbb9a808bf8f9f83bc596ac4ad70c86a1" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "av-scenechange" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f321d77c20e19b92c39e7471cf986812cbb46659d2af674adc4331ef3f18394" +dependencies = [ + "aligned", + "anyhow", + "arg_enum_proc_macro", + "arrayvec", + "log", + "num-rational", + "num-traits", + "pastey", + "rayon", + "thiserror 2.0.18", + "v_frame", + "y4m", +] + +[[package]] +name = "av1-grain" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cfddb07216410377231960af4fcab838eaa12e013417781b78bd95ee22077f8" +dependencies = [ + "anyhow", + "arrayvec", + "log", + "nom 8.0.0", + "num-rational", + "v_frame", +] + +[[package]] +name = "avif-serialize" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "375082f007bd67184fb9c0374614b29f9aaa604ec301635f72338bb65386a53d" +dependencies = [ + "arrayvec", +] + +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.17", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec 0.6.3", +] + +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec 0.8.0", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + +[[package]] +name = "bit_field" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e4b40c7323adcfc0a41c4b88143ed58346ff65a288fc144329c5c45e05d70c6" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bitstream-io" +version = "4.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eff00be299a18769011411c9def0d827e8f2d7bf0c3dbf53633147a8867fd1f" +dependencies = [ + "no_std_io2", +] + +[[package]] +name = "blake3" +version = "1.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0aa83c34e62843d924f905e0f5c866eb1dd6545fc4d719e803d9ba6030371fce" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", + "cpufeatures 0.3.0", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bstr" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" +dependencies = [ + "memchr", + "regex-automata", + "serde", +] + +[[package]] +name = "built" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4ad8f11f288f48ca24471bbd51ac257aaeaaa07adae295591266b792902ae64" + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "byteorder-lite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +dependencies = [ + "serde", +] + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" +dependencies = [ + "rustversion", +] + +[[package]] +name = "cc" +version = "1.2.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +dependencies = [ + "find-msvc-tools", + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstream 0.6.21", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "compact_str" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "serde", + "static_assertions", +] + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + +[[package]] +name = "console" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.61.2", +] + +[[package]] +name = "constant_time_eq" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" + +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + +[[package]] +name = "cookie_store" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15b2c103cf610ec6cae3da84a766285b42fd16aad564758459e6ecf128c75206" +dependencies = [ + "cookie", + "document-features", + "idna", + "indexmap 2.13.0", + "log", + "serde", + "serde_derive", + "serde_json", + "time", + "url", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "daachorse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63b7ef7a4be509357f4804d0a22e830daddb48f19fd604e4ad32ddce04a94c36" + +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core 0.23.0", + "darling_macro 0.23.0", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.117", +] + +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.117", +] + +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core 0.20.11", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core 0.23.0", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "dary_heap" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04" +dependencies = [ + "serde", +] + +[[package]] +name = "der" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + +[[package]] +name = "deranged" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" +dependencies = [ + "powerfmt", + "serde_core", +] + +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.117", +] + +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", + "unicode-xid", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "dirs" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.61.2", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "dissimilar" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aeda16ab4059c5fd2a83f2b9c9e9c981327b18aa8e3b313f7e6563799d4f093e" + +[[package]] +name = "document-features" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" +dependencies = [ + "litrs", +] + +[[package]] +name = "dtoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c3cf4824e2d5f025c7b531afcb2325364084a16806f6d47fbc1f5fbd9960590" + +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] +name = "easy-ext" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8072bec12b909b65aec01fa6518f387cfbf3427d4475409ad622898cd347522c" + +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "enum-as-inner" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0359ee92f81184d7985519e474bda2a5738476334edd3746c9b1265c067afe70" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "enum-ordinalize" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a1091a7bb1f8f2c4b28f1fe2cef4980ca2d410a3d727d67ecc3178c9b0800f0" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "env_filter" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" +dependencies = [ + "anstream 1.0.0", + "anstyle", + "env_filter", + "jiff", + "log", +] + +[[package]] +name = "equator" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4711b213838dfee0117e3be6ac926007d7f433d7bbe33595975d4190cb07e6fc" +dependencies = [ + "equator-macro", +] + +[[package]] +name = "equator-macro" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom 7.1.3", + "pin-project-lite", +] + +[[package]] +name = "expect-test" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63af43ff4431e848fb47472a920f14fa71c24de13255a5692e93d4e90302acb0" +dependencies = [ + "dissimilar", + "once_cell", +] + +[[package]] +name = "exr" +version = "1.74.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4300e043a56aa2cb633c01af81ca8f699a321879a7854d3896a0ba89056363be" +dependencies = [ + "bit_field", + "half", + "lebe", + "miniz_oxide", + "rayon-core", + "smallvec", + "zune-inflate", +] + +[[package]] +name = "fancy-regex" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" +dependencies = [ + "bit-set 0.5.3", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "fancy-regex" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72cf461f865c862bb7dc573f643dd6a2b6842f7c30b07882b56bd148cc2761b8" +dependencies = [ + "bit-set 0.8.0", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "fast_image_resize" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12dd43e5011e8d8411a3215a0d57a2ec5c68282fb90eb5d7221fab0113442174" +dependencies = [ + "bytemuck", + "cfg-if", + "document-features", + "image", + "num-traits", + "thiserror 2.0.18", +] + +[[package]] +name = "fastokens" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "796a262ed47d1458a4b40d0ed831c927e6f54d5b9c1de2683bb4ac9b04f4c7cc" +dependencies = [ + "daachorse", + "fancy-regex 0.17.0", + "hf-hub 0.4.3", + "icu_normalizer", + "memchr", + "pcre2", + "rayon", + "serde", + "serde_json", + "strum", + "thiserror 2.0.18", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "fax" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05de7d48f37cd6730705cbca900770cab77a89f413d23e100ad7fad7795a0ab" +dependencies = [ + "fax_derive", +] + +[[package]] +name = "fax_derive" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0aca10fb742cb43f9e7bb8467c91aa9bcb8e3ffbc6a6f7389bb93ffc920577d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "fdeflate" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "fslock" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04412b8935272e3a9bae6f48c7bfff74c2911f60525404edfdd28e49884c3bfb" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-lite" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad" +dependencies = [ + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getopts" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe4fbac503b8d1f88e6676011885f34b7174f46e59956bba534ba83abded4df" +dependencies = [ + "unicode-width", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "r-efi 5.3.0", + "wasip2", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + +[[package]] +name = "gif" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee8cfcc411d9adbbaba82fb72661cc1bcca13e8bba98b364e62b2dba8f960159" +dependencies = [ + "color_quant", + "weezl", +] + +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap 2.13.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "zerocopy", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hf-hub" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97" +dependencies = [ + "dirs", + "http", + "indicatif 0.17.11", + "libc", + "log", + "rand 0.9.2", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.18", + "ureq 2.12.1", + "windows-sys 0.60.2", +] + +[[package]] +name = "hf-hub" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef3982638978efa195ff11b305f51f1f22f4f0a6cabee7af79b383ebee6a213" +dependencies = [ + "dirs", + "futures", + "http", + "indicatif 0.18.4", + "libc", + "log", + "native-tls", + "num_cpus", + "rand 0.9.2", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "ureq 3.3.0", + "windows-sys 0.61.2", +] + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] +name = "hound" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots 1.0.6", +] + +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "system-configuration", + "tokio", + "tower-service", + "tracing", + "windows-registry", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "image" +version = "0.25.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85ab80394333c02fe689eaf900ab500fbd0c2213da414687ebf995a65d5a6104" +dependencies = [ + "bytemuck", + "byteorder-lite", + "color_quant", + "exr", + "gif", + "image-webp", + "moxcms", + "num-traits", + "png", + "qoi", + "ravif", + "rayon", + "rgb", + "tiff", + "zune-core", + "zune-jpeg", +] + +[[package]] +name = "image-webp" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" +dependencies = [ + "byteorder-lite", + "quick-error", +] + +[[package]] +name = "imgref" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c5cedc30da3a610cac6b4ba17597bdf7152cf974e8aab3afb3d54455e371c8" + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", + "serde", +] + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console 0.15.11", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + +[[package]] +name = "indicatif" +version = "0.18.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" +dependencies = [ + "console 0.16.2", + "portable-atomic", + "unicode-width", + "unit-prefix", + "web-time", +] + +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "interpolate_name" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "iri-string" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "is-macro" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d57a3e447e24c22647738e4607f1df1e0ec6f72e16182c4cd199f647cdfb0e4" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "jiff" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "lalrpop-util" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507460a910eb7b32ee961886ff48539633b788a36b65692b95f225b844c82553" + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "lebe" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a79a3332a6609480d7d0c9eab957bca6b455b91bb84e66d19f5ff66294b85b8" + +[[package]] +name = "libc" +version = "0.2.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" + +[[package]] +name = "libfuzzer-sys" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f12a681b7dd8ce12bff52488013ba614b869148d54dd79836ab85aafdd53f08d" +dependencies = [ + "arbitrary", + "cc", +] + +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + +[[package]] +name = "libredox" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" +dependencies = [ + "libc", +] + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + +[[package]] +name = "litrs" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" + +[[package]] +name = "llm-multimodal" +version = "1.5.0" +source = "git+https://github.com/vllm-project/llm-multimodal?rev=5b558989844d1c7af3e43d0f604069ffd9c06320#5b558989844d1c7af3e43d0f604069ffd9c06320" +dependencies = [ + "base64 0.22.1", + "blake3", + "bytes", + "fast_image_resize", + "image", + "ndarray 0.17.2", + "once_cell", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "url", +] + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "loop9" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fae87c125b03c1d2c0150c90365d7d6bcc53fb73a9acaef207d2d065860f062" +dependencies = [ + "imgref", +] + +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + +[[package]] +name = "malachite" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fbdf9cb251732db30a7200ebb6ae5d22fe8e11397364416617d2c2cf0c51cb5" +dependencies = [ + "malachite-base", + "malachite-nz", + "malachite-q", +] + +[[package]] +name = "malachite-base" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ea0ed76adf7defc1a92240b5c36d5368cfe9251640dcce5bd2d0b7c1fd87aeb" +dependencies = [ + "hashbrown 0.14.5", + "itertools 0.11.0", + "libm", + "ryu", +] + +[[package]] +name = "malachite-bigint" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d149aaa2965d70381709d9df4c7ee1fc0de1c614a4efc2ee356f5e43d68749f8" +dependencies = [ + "derive_more", + "malachite", + "num-integer", + "num-traits", + "paste", +] + +[[package]] +name = "malachite-nz" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34a79feebb2bc9aa7762047c8e5495269a367da6b5a90a99882a0aeeac1841f7" +dependencies = [ + "itertools 0.11.0", + "libm", + "malachite-base", +] + +[[package]] +name = "malachite-q" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f235d5747b1256b47620f5640c2a17a88c7569eebdf27cd9cb130e1a619191" +dependencies = [ + "itertools 0.11.0", + "malachite-base", + "malachite-nz", +] + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "maybe-rayon" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" +dependencies = [ + "cfg-if", + "rayon", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "memo-map" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + +[[package]] +name = "minijinja" +version = "2.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "328251e58ad8e415be6198888fc207502727dc77945806421ab34f35bf012e7d" +dependencies = [ + "memo-map", + "serde", + "serde_json", +] + +[[package]] +name = "minijinja-contrib" +version = "2.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c6302e47d2b51f9fc978268ff7f5a014de5caa2ad48440309fd10ee711480d7" +dependencies = [ + "minijinja", + "serde", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "moxcms" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb85c154ba489f01b25c0d36ae69a87e4a1c73a72631fc6c0eb6dde34a73e44b" +dependencies = [ + "num-traits", + "pxfm", +] + +[[package]] +name = "multimap" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" + +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "new_debug_unreachable" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" + +[[package]] +name = "no_std_io2" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b51ed7824b6e07d354605f4abb3d9d300350701299da96642ee084f5ce631550" +dependencies = [ + "memchr", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + +[[package]] +name = "noop_proc_macro" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" + +[[package]] +name = "num-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "num_threads" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9" +dependencies = [ + "libc", +] + +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "onig" +version = "6.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" +dependencies = [ + "bitflags", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc" +dependencies = [ + "cc", + "pkg-config", +] + +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "openai-harmony" +version = "0.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e77e82af451fc95deeb728a40b84db8ee82d341e136c268de415123a560b9b72" +dependencies = [ + "anyhow", + "base64 0.22.1", + "bstr", + "clap", + "fancy-regex 0.13.0", + "futures", + "image", + "regex", + "reqwest", + "rustc-hash 1.1.0", + "serde", + "serde_json", + "serde_with", + "sha1", + "sha2", + "thiserror 2.0.18", +] + +[[package]] +name = "openai-protocol" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b8d41ed865b7a26b6b2d2a519b774460e5ddc50eeeb4ac2a2409c8817ec2de9" +dependencies = [ + "bitflags", + "chrono", + "rand 0.9.2", + "schemars 0.8.22", + "serde", + "serde_json", + "serde_with", + "tokio", + "tracing", + "url", + "validator", +] + +[[package]] +name = "openssl" +version = "0.10.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "openssl-src" +version = "300.5.5+3.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f1787d533e03597a7934fd0a765f0d28e94ecc5fb7789f8053b1e699a56f709" +dependencies = [ + "cc", +] + +[[package]] +name = "openssl-sys" +version = "0.9.112" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +dependencies = [ + "cc", + "libc", + "openssl-src", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pastey" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" + +[[package]] +name = "pcre2" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e970b0fcce0c7ee6ef662744ff711f21ccd6f11b7cf03cd187a80e89797fc67" +dependencies = [ + "libc", + "log", + "pcre2-sys", +] + +[[package]] +name = "pcre2-sys" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18b9073c1a2549bd409bf4a32c94d903bb1a09bf845bc306ae148897fa0760a4" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + +[[package]] +name = "pem-rfc7468" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6305423e0e7738146434843d1694d621cce767262b2a86910beab705e4493d9" +dependencies = [ + "base64ct", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "petgraph" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.5", + "indexmap 2.13.0", +] + +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared", + "rand 0.8.5", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project" +version = "1.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "png" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61" +dependencies = [ + "bitflags", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + +[[package]] +name = "polling" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218" +dependencies = [ + "cfg-if", + "concurrent-queue", + "hermit-abi", + "pin-project-lite", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.117", +] + +[[package]] +name = "primal-check" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" +dependencies = [ + "num-integer", +] + +[[package]] +name = "proc-macro-crate" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" +dependencies = [ + "toml_edit", +] + +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "profiling" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773" +dependencies = [ + "profiling-procmacros", +] + +[[package]] +name = "profiling-procmacros" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b" +dependencies = [ + "quote", + "syn 2.0.117", +] + +[[package]] +name = "prometheus-client" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4500adecd7af8e0e9f4dbce15cfee07ce913fbf6ad605cc468b83f2d531ee94" +dependencies = [ + "dtoa", + "itoa", + "parking_lot", + "prometheus-client-derive-encode", +] + +[[package]] +name = "prometheus-client-derive-encode" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9adf1691c04c0a5ff46ff8f262b58beb07b0dbb61f96f9f54f6cbd82106ed87f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "prost" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" +dependencies = [ + "heck", + "itertools 0.14.0", + "log", + "multimap", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "pulldown-cmark", + "pulldown-cmark-to-cmark", + "regex", + "syn 2.0.117", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +dependencies = [ + "anyhow", + "itertools 0.14.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "prost-types" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +dependencies = [ + "prost", +] + +[[package]] +name = "pulldown-cmark" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad" +dependencies = [ + "bitflags", + "memchr", + "unicase", +] + +[[package]] +name = "pulldown-cmark-to-cmark" +version = "22.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50793def1b900256624a709439404384204a5dc3a6ec580281bfaac35e882e90" +dependencies = [ + "pulldown-cmark", +] + +[[package]] +name = "pxfm" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0c5ccf5294c6ccd63a74f1565028353830a9c2f5eb0c682c355c471726a6e3f" + +[[package]] +name = "qoi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "quick-error" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" + +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash 2.1.1", + "rustls", + "socket2", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash 2.1.1", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "rav1e" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43b6dd56e85d9483277cde964fd1bdb0428de4fec5ebba7540995639a21cb32b" +dependencies = [ + "aligned-vec", + "arbitrary", + "arg_enum_proc_macro", + "arrayvec", + "av-scenechange", + "av1-grain", + "bitstream-io", + "built", + "cfg-if", + "interpolate_name", + "itertools 0.14.0", + "libc", + "libfuzzer-sys", + "log", + "maybe-rayon", + "new_debug_unreachable", + "noop_proc_macro", + "num-derive", + "num-traits", + "paste", + "profiling", + "rand 0.9.2", + "rand_chacha 0.9.0", + "simd_helpers", + "thiserror 2.0.18", + "v_frame", + "wasm-bindgen", +] + +[[package]] +name = "ravif" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e52310197d971b0f5be7fe6b57530dcd27beb35c1b013f29d66c1ad73fbbcc45" +dependencies = [ + "avif-serialize", + "imgref", + "loop9", + "quick-error", + "rav1e", + "rayon", + "rgb", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f" +dependencies = [ + "either", + "itertools 0.14.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "realfft" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f821338fddb99d089116342c46e9f1fbf3828dba077674613e734e01d6ea8677" +dependencies = [ + "rustfft", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_users" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" +dependencies = [ + "getrandom 0.2.17", + "libredox", + "thiserror 2.0.18", +] + +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "js-sys", + "log", + "mime", + "mime_guess", + "native-tls", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-native-certs", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-native-tls", + "tokio-rustls", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "webpki-roots 1.0.6", +] + +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime", + "nom 7.1.3", + "pin-project-lite", + "reqwest", + "thiserror 1.0.69", +] + +[[package]] +name = "rgb" +version = "0.8.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b34b781b31e5d73e9fbc8689c70551fd1ade9a19e3e28cfec8580a79290cc4" + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "riptoken" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "196b37c4dd48f99b51e8aeaa3d0c343df62c7bb5b66cd6735aa072524fbe8665" +dependencies = [ + "fancy-regex 0.17.0", + "rayon", + "regex", + "regex-automata", + "rustc-hash 2.1.1", +] + +[[package]] +name = "rmp" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" +dependencies = [ + "rmp", + "serde", +] + +[[package]] +name = "rmpv" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a4e1d4b9b938a26d2996af33229f0ca0956c652c1375067f0b45291c1df8417" +dependencies = [ + "rmp", + "serde", + "serde_bytes", +] + +[[package]] +name = "rubato" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5258099699851cfd0082aeb645feb9c084d9a5e1f1b8d5372086b989fc5e56a1" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "realfft", +] + +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rustfft" +version = "6.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21db5f9893e91f41798c88680037dba611ca6674703c1a18601b01a72c8adb89" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustpython-ast" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cdaf8ee5c1473b993b398c174641d3aa9da847af36e8d5eb8291930b72f31a5" +dependencies = [ + "is-macro", + "malachite-bigint", + "rustpython-parser-core", + "static_assertions", +] + +[[package]] +name = "rustpython-parser" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "868f724daac0caf9bd36d38caf45819905193a901e8f1c983345a68e18fb2abb" +dependencies = [ + "anyhow", + "is-macro", + "itertools 0.11.0", + "lalrpop-util", + "log", + "malachite-bigint", + "num-traits", + "phf", + "phf_codegen", + "rustc-hash 1.1.0", + "rustpython-ast", + "rustpython-parser-core", + "tiny-keccak", + "unic-emoji-char", + "unic-ucd-ident", + "unicode_names2", +] + +[[package]] +name = "rustpython-parser-core" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4b6c12fa273825edc7bccd9a734f0ad5ba4b8a2f4da5ff7efe946f066d0f4ad" +dependencies = [ + "is-macro", + "memchr", + "rustpython-parser-vendored", +] + +[[package]] +name = "rustpython-parser-vendored" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04fcea49a4630a3a5d940f4d514dc4f575ed63c14c3e3ed07146634aed7f67a6" +dependencies = [ + "memchr", + "once_cell", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "saa" +version = "5.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16c7f49c9d5caa3bf4b3106900484b447b9253fe99670ceb81cb6cb5027855e1" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "scc" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc" +dependencies = [ + "sdd 3.0.10", +] + +[[package]] +name = "scc" +version = "3.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45bb5ce9efd4a6e7b0f86c2697fe4c1d78d1f4e6d988c54b752d577cafe22fe8" +dependencies = [ + "saa", + "sdd 4.7.3", +] + +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "schemars" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd191f9397d57d581cddd31014772520aa448f65ef991055d7f61582c65165f" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 2.0.117", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "sdd" +version = "3.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" + +[[package]] +name = "sdd" +version = "4.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b21a75f5913ab130e4b369fb8693be25f29b983e2ecad4279df9bfa5dd8aaf3e" + +[[package]] +name = "secrecy" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" +dependencies = [ + "serde", + "zeroize", +] + +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde-json-fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a33b7a5f52a26d520099339add40c48baf2e5ada194c8cc1b18cafa2b5e419" +dependencies = [ + "serde", + "serde_json", + "smartstring", +] + +[[package]] +name = "serde_bytes" +version = "0.11.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" +dependencies = [ + "serde", + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_default" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "486b028b311aaaea83e0ba65a3e6e3cbef381e74e9d0bd6263faefd1fb503c1d" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "indexmap 2.13.0", + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "serde_tuple" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6af196b9c06f0aa5555ab980c01a2527b0f67517da8d68b1731b9d4764846a6f" +dependencies = [ + "serde", + "serde_tuple_macros", +] + +[[package]] +name = "serde_tuple_macros" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3a1e7d2eadec84deabd46ae061bf480a91a6bce74d25dad375bd656f2e19d8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_with" +version = "3.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" +dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.13.0", + "schemars 0.9.0", + "schemars 1.2.1", + "serde_core", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" +dependencies = [ + "darling 0.23.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "serial_test" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "911bd979bf1070a3f3aa7b691a3b3e9968f339ceeec89e08c280a8a22207a32f" +dependencies = [ + "fslock", + "futures-executor", + "futures-util", + "log", + "once_cell", + "parking_lot", + "scc 2.4.0", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "digest", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "digest", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + +[[package]] +name = "simd_helpers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95890f873bec569a0362c235787f3aca6e1e887302ba4840839bcc6459c42da6" +dependencies = [ + "quote", +] + +[[package]] +name = "siphasher" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" +dependencies = [ + "autocfg", + "static_assertions", + "version_check", +] + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom 7.1.3", + "serde", + "unicode-segmentation", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "subenum" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3d08fe7078c57309d5c3d938e50eba95ba1d33b9c3a101a8465fc6861a5416" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "system-configuration" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" +dependencies = [ + "bitflags", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "task-local" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2972044a9e5e448a506a7ff6f0d03b566d8ef4cd6918a58fc59835a0f8666626" +dependencies = [ + "pin-project-lite", +] + +[[package]] +name = "tekken-rs" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49623843103837f53f7ebe8cfafc19ccff28ff0e15e7c4b9f6ad21e36fbfde3a" +dependencies = [ + "anyhow", + "base64 0.22.1", + "env_logger", + "hound", + "log", + "ndarray 0.16.1", + "regex", + "rubato", + "rustc-hash 1.1.0", + "rustfft", + "serde", + "serde_json", + "thiserror 2.0.18", + "tiktoken-rs 0.7.0", +] + +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fb7e61141f4141832ca9aad63c3c90023843f944a1975460abdacc64d03f534" +dependencies = [ + "thiserror 2.0.18", + "thiserror-ext-derive", +] + +[[package]] +name = "thiserror-ext-derive" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b5042dd3b562d1d57711be902006a0003fa2781b81d5b2bec07416be31586ff" +dependencies = [ + "either", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tiff" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b63feaf3343d35b6ca4d50483f94843803b0f51634937cc2ec519fc32232bc52" +dependencies = [ + "fax", + "flate2", + "half", + "quick-error", + "weezl", + "zune-jpeg", +] + +[[package]] +name = "tiktoken-rs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25563eeba904d770acf527e8b370fe9a5547bacd20ff84a0b6c3bc41288e5625" +dependencies = [ + "anyhow", + "base64 0.22.1", + "bstr", + "fancy-regex 0.13.0", + "lazy_static", + "regex", + "rustc-hash 1.1.0", +] + +[[package]] +name = "tiktoken-rs" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a19830747d9034cd9da43a60eaa8e552dfda7712424aebf187b7a60126bae0d" +dependencies = [ + "anyhow", + "base64 0.22.1", + "bstr", + "fancy-regex 0.13.0", + "lazy_static", + "regex", + "rustc-hash 1.1.0", +] + +[[package]] +name = "time" +version = "0.3.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +dependencies = [ + "deranged", + "itoa", + "libc", + "num-conv", + "num_threads", + "powerfmt", + "serde_core", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" + +[[package]] +name = "time-macros" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "tinyvec" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokenizers" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b238e22d44a15349529690fb07bd645cf58149a1b1e44d6cb5bd1641ff1a6223" +dependencies = [ + "ahash", + "aho-corasick", + "compact_str", + "dary_heap", + "derive_builder", + "esaxx-rs", + "getrandom 0.3.4", + "indicatif 0.18.4", + "itertools 0.14.0", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.9.2", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 2.0.18", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-io", + "futures-sink", + "futures-util", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "toml_datetime" +version = "1.1.1+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.25.11+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b59c4d22ed448339746c59b905d24568fcbb3ab65a500494f7b8c3e97739f2b" +dependencies = [ + "indexmap 2.13.0", + "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" +dependencies = [ + "winnow", +] + +[[package]] +name = "tonic" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" +dependencies = [ + "async-trait", + "axum", + "base64 0.22.1", + "bytes", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "socket2", + "sync_wrapper", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-build" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1882ac3bf5ef12877d7ed57aad87e75154c11931c2ba7e6cde5e22d63522c734" +dependencies = [ + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "tonic-prost" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" +dependencies = [ + "bytes", + "prost", + "tonic", +] + +[[package]] +name = "tonic-prost-build" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3144df636917574672e93d0f56d7edec49f90305749c668df5101751bb8f95a" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build", + "prost-types", + "quote", + "syn 2.0.117", + "tempfile", + "tonic-build", +] + +[[package]] +name = "tool-parser" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bdcf0aa96d42cfc2ecc8e7b3c10598b9c1a6052f996b5ab574dec72f483d87c" +dependencies = [ + "async-trait", + "num-traits", + "openai-protocol", + "parking_lot", + "regex", + "rustpython-parser", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tracing", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "indexmap 2.13.0", + "pin-project-lite", + "slab", + "sync_wrapper", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-futures" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" +dependencies = [ + "futures", + "futures-task", + "pin-project", + "tracing", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "trait-set" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b79e2e9c9ab44c6d7c20d5976961b47e8f49ac199154daa514b77cd1ab536625" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "log", + "rand 0.9.2", + "thiserror 2.0.18", + "utf-8", +] + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unic-char-property" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8c57a407d9b6fa02b4795eb81c5b6652060a15a7903ea981f3d723e6c0be221" +dependencies = [ + "unic-char-range", +] + +[[package]] +name = "unic-char-range" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0398022d5f700414f6b899e10b8348231abf9173fa93144cbc1a43b9793c1fbc" + +[[package]] +name = "unic-common" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d7ff825a6a654ee85a63e80f92f054f904f21e7d12da4e22f9834a4aaa35bc" + +[[package]] +name = "unic-emoji-char" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b07221e68897210270a38bde4babb655869637af0f69407f96053a34f76494d" +dependencies = [ + "unic-char-property", + "unic-char-range", + "unic-ucd-version", +] + +[[package]] +name = "unic-ucd-ident" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e230a37c0381caa9219d67cf063aa3a375ffed5bf541a452db16e744bdab6987" +dependencies = [ + "unic-char-property", + "unic-char-range", + "unic-ucd-version", +] + +[[package]] +name = "unic-ucd-version" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96bd2f2237fe450fcd0a1d2f5f4e91711124f7857ba2e964247776ebeeb7b0c4" +dependencies = [ + "unic-common", +] + +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + +[[package]] +name = "unicode-segmentation" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da36089a805484bcccfffe0739803392c8298778a2d2f09febf76fac5ad9025b" + +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + +[[package]] +name = "unicode_names2" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1673eca9782c84de5f81b82e4109dcfb3611c8ba0d52930ec4a9478f547b2dd" +dependencies = [ + "phf", + "unicode_names2_generator", +] + +[[package]] +name = "unicode_names2_generator" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b91e5b84611016120197efd7dc93ef76774f4e084cd73c9fb3ea4a86c570c56e" +dependencies = [ + "getopts", + "log", + "phf_codegen", + "rand 0.8.5", +] + +[[package]] +name = "unit-prefix" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81e544489bf3d8ef66c953931f56617f423cd4b5494be343d9b9d3dda037b9a3" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "socks", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "ureq" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" +dependencies = [ + "base64 0.22.1", + "cookie_store", + "der", + "flate2", + "log", + "native-tls", + "percent-encoding", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "socks", + "ureq-proto", + "utf8-zero", + "webpki-root-certs", + "webpki-roots 1.0.6", +] + +[[package]] +name = "ureq-proto" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8-zero" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "uuid" +version = "1.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" +dependencies = [ + "getrandom 0.4.2", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "v_frame" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "666b7727c8875d6ab5db9533418d7c764233ac9c0cff1d469aec8fa127597be2" +dependencies = [ + "aligned-vec", + "num-traits", + "wasm-bindgen", +] + +[[package]] +name = "validator" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43fb22e1a008ece370ce08a3e9e4447a910e92621bb49b85d6e48a45397e7cfa" +dependencies = [ + "idna", + "once_cell", + "regex", + "serde", + "serde_derive", + "serde_json", + "url", + "validator_derive", +] + +[[package]] +name = "validator_derive" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7df16e474ef958526d1205f6dda359fdfab79d9aa6d54bafcb92dcd07673dca" +dependencies = [ + "darling 0.20.11", + "once_cell", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "vllm-chat" +version = "0.1.0" +dependencies = [ + "anyhow", + "asynk-strim-attr", + "bytes", + "clap", + "easy-ext", + "expect-test", + "futures", + "half", + "itertools 0.14.0", + "llm-multimodal", + "minijinja", + "minijinja-contrib", + "openai-harmony", + "reqwest", + "rmp-serde", + "serde", + "serde-json-fmt", + "serde_json", + "serde_with", + "serial_test", + "subenum", + "tempfile", + "thiserror 2.0.18", + "thiserror-ext", + "tokio", + "tracing", + "tracing-subscriber", + "trait-set", + "uuid", + "vllm-engine-core-client", + "vllm-llm", + "vllm-reasoning-parser", + "vllm-text", + "vllm-tokenizer", + "vllm-tool-parser", + "zeromq", +] + +[[package]] +name = "vllm-cmd" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "educe", + "expect-test", + "itertools 0.14.0", + "native-tls", + "serde", + "serde_json", + "serde_with", + "thiserror-ext", + "time", + "tokio", + "tokio-util", + "tracing", + "tracing-subscriber", + "uuid", + "vllm-engine-core-client", + "vllm-managed-engine", + "vllm-server", +] + +[[package]] +name = "vllm-engine-core-client" +version = "0.1.0" +dependencies = [ + "anyhow", + "arc-swap", + "bytemuck", + "byteorder", + "bytes", + "clap", + "easy-ext", + "enum-as-inner", + "expect-test", + "futures", + "half", + "hex", + "itertools 0.14.0", + "parking_lot", + "rmp-serde", + "rmpv", + "serde", + "serde_default", + "serde_json", + "serde_repr", + "serde_tuple", + "serde_with", + "task-local", + "tempfile", + "thiserror 2.0.18", + "thiserror-ext", + "tokio", + "tokio-util", + "tracing", + "tracing-subscriber", + "vllm-metrics", + "zeromq", +] + +[[package]] +name = "vllm-llm" +version = "0.1.0" +dependencies = [ + "anyhow", + "bytes", + "clap", + "easy-ext", + "enum-as-inner", + "expect-test", + "futures", + "rmp-serde", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-util", + "tracing", + "tracing-subscriber", + "uuid", + "vllm-engine-core-client", + "vllm-metrics", + "zeromq", +] + +[[package]] +name = "vllm-managed-engine" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "expect-test", + "libc", + "tokio", + "tracing", +] + +[[package]] +name = "vllm-metrics" +version = "0.1.0" +dependencies = [ + "prometheus-client", +] + +[[package]] +name = "vllm-reasoning-parser" +version = "0.1.0" +dependencies = [ + "thiserror 2.0.18", + "vllm-tokenizer", +] + +[[package]] +name = "vllm-server" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-openai", + "asynk-strim-attr", + "axum", + "bytes", + "clap", + "expect-test", + "futures", + "http-body", + "itertools 0.14.0", + "libc", + "llm-multimodal", + "prost", + "prost-types", + "rmp-serde", + "rmpv", + "serde", + "serde_json", + "serde_with", + "serial_test", + "socket2", + "thiserror-ext", + "tokio", + "tokio-stream", + "tokio-util", + "tonic", + "tonic-prost", + "tonic-prost-build", + "tower", + "tower-http", + "tracing", + "tracing-futures", + "tracing-subscriber", + "uuid", + "validator", + "vllm-chat", + "vllm-engine-core-client", + "vllm-llm", + "vllm-metrics", + "vllm-text", + "zeromq", +] + +[[package]] +name = "vllm-text" +version = "0.1.0" +dependencies = [ + "anyhow", + "asynk-strim-attr", + "easy-ext", + "enum-as-inner", + "expect-test", + "futures", + "hf-hub 0.5.0", + "itertools 0.14.0", + "serde", + "serde_json", + "serde_with", + "tempfile", + "thiserror 2.0.18", + "thiserror-ext", + "tokio", + "tracing", + "trait-set", + "vllm-engine-core-client", + "vllm-llm", + "vllm-tokenizer", +] + +[[package]] +name = "vllm-tokenizer" +version = "0.1.0" +dependencies = [ + "base64 0.22.1", + "criterion", + "fastokens", + "hf-hub 0.5.0", + "riptoken", + "rustc-hash 1.1.0", + "serde", + "serde_json", + "tekken-rs", + "tempfile", + "thiserror 2.0.18", + "thiserror-ext", + "tiktoken-rs 0.9.1", + "tokenizers", + "tracing", +] + +[[package]] +name = "vllm-tool-parser" +version = "0.1.0" +dependencies = [ + "criterion", + "expect-test", + "futures", + "openai-protocol", + "serde", + "serde_json", + "thiserror 2.0.18", + "thiserror-ext", + "tool-parser", + "winnow", +] + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9c5522b3a28661442748e09d40924dfb9ca614b21c00d3fd135720e48b67db8" +dependencies = [ + "cfg-if", + "futures-util", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn 2.0.117", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap 2.13.0", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap 2.13.0", + "semver", +] + +[[package]] +name = "web-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "854ba17bb104abfb26ba36da9729addc7ce7f06f5c0f90f3c391f8461cca21f9" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-root-certs" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.6", +] + +[[package]] +name = "webpki-roots" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "weezl" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" + +[[package]] +name = "win_uds" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd30a1a28a3799479cbf4e17284a220ea9ff6bad098a9d0224543a5d1efe1da" +dependencies = [ + "async-io", + "futures-io", + "socket2", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "winnow" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" +dependencies = [ + "memchr", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap 2.13.0", + "prettyplease", + "syn 2.0.117", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.117", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap 2.13.0", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap 2.13.0", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "y4m" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5a4b21e1a62b67a2970e6831bc091d7b87e119e7f9791aef9702e3bef04448" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.8.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zeromq" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efb2c254fd8f366755335c9e43b865f8484fe3bd717d65ffe7c3f28852863030" +dependencies = [ + "async-trait", + "asynchronous-codec", + "bytes", + "crossbeam-queue", + "futures", + "log", + "num-traits", + "once_cell", + "parking_lot", + "rand 0.9.2", + "regex", + "scc 3.6.9", + "thiserror 1.0.69", + "tokio", + "tokio-util", + "uuid", + "win_uds", +] + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zune-core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9" + +[[package]] +name = "zune-inflate" +version = "0.2.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" +dependencies = [ + "simd-adler32", +] + +[[package]] +name = "zune-jpeg" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296" +dependencies = [ + "zune-core", +] diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 00000000000..89e582dbc56 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,129 @@ +[workspace] +members = [ + "src/chat", + "src/cmd", + "src/engine-core-client", + "src/llm", + "src/managed-engine", + "src/metrics", + "src/reasoning-parser", + "src/server", + "src/text", + "src/tokenizer", + "src/tool-parser", +] +resolver = "3" + +[workspace.package] +version = "0.1.0" +edition = "2024" +license = "Apache-2.0" + +[workspace.dependencies] +anyhow = "1.0.100" +arc-swap = "1.9.0" +async-openai = "0.33.1" +async-trait = "0.1.89" +asynk-strim-attr = "0.1.0" +axum = "0.8.8" +base64 = "0.22.1" +bytemuck = { version = "1.25.0", features = ["extern_crate_alloc"] } +byteorder = "1.5.0" +bytes = "1.11.1" +clap = { version = "4.5.38", features = ["derive", "env"] } +criterion = "0.5.1" +easy-ext = "1.0.3" +educe = "0.6.0" +enum-as-inner = "0.7.0" +expect-test = "1.5.1" +fastokens = "0.2.0" +futures = "0.3.31" +half = { version = "2.7.1", features = ["bytemuck"] } +hex = "0.4.3" +hf-hub = { version = "0.5.0", features = ["tokio"] } +http-body = "1.0.1" +itertools = "0.14.0" +libc = "0.2.177" +llm-multimodal = { git = "https://github.com/vllm-project/llm-multimodal", rev = "5b558989844d1c7af3e43d0f604069ffd9c06320" } +minijinja = { version = "2.0", features = ["unstable_machinery", "json", "builtins", "loader", "loop_controls"] } +minijinja-contrib = { version = "2.0", features = ["pycompat"] } +native-tls-vendored = { package = "native-tls", version = "0.2.18", features = ["vendored"] } +ndarray = { version = "0.16.1", features = ["serde"] } +openai-harmony = "0.0.8" +openai-protocol = "1.6.0" +parking_lot = "0.12.5" +prometheus-client = "0.24.0" +prometheus-client-derive-encode = "0.5.0" +prost = "0.14.3" +prost-types = "0.14.3" +reasoning-parser = "1.2.2" +reqwest = { version = "0.12.8", default-features = false, features = ["rustls-tls"] } +riptoken = { version = "0.3.0", default-features = false } +rmp-serde = "1.3.1" +rmpv = { version = "1.3.1", features = ["with-serde"] } +rustc-hash = "1.1.0" +serde = { version = "1.0.228", features = ["derive"] } +serde-json-fmt = "0.1.0" +serde_default = "0.2.0" +serde_json = "1.0.145" +serde_repr = "0.1.20" +serde_tuple = "1.1.3" +serde_with = "3.18.0" +serial_test = "3.2.0" +socket2 = "0.6.3" +subenum = "1.1.3" +task-local = "0.1.1" +tekken = { package = "tekken-rs", version = "0.1.1", default-features = false } +tempfile = "3.23.0" +thiserror = "2.0.16" +thiserror-ext = "0.3.0" +tiktoken-rs = "0.9.1" +time = { version = "0.3.47", features = ["formatting", "local-offset", "macros"] } +tokenizers = "0.22.0" +tokio = { version = "1.47.1", features = [ + "macros", + "net", + "rt-multi-thread", + "sync", + "time", +] } +tokio-stream = "0.1" +tokio-util = { version = "0.7.18", features = ["rt"] } +tonic = "0.14.5" +tonic-build = "0.14.5" +tonic-prost = "0.14.5" +tonic-prost-build = "0.14.5" +tool-parser = "1.2.0" +tower = { version = "0.5.3", features = ["util"] } +tower-http = { version = "0.6.8", features = ["trace"] } +tracing = { version = "0.1.44", features = ["release_max_level_debug"] } +tracing-futures = { version = "0.2.5", features = ["futures-03"] } +tracing-subscriber = { version = "0.3.20", features = ["env-filter", "fmt"] } +trait-set = "0.3.0" +uuid = { version = "1.22.0", features = ["v4"] } +validator = { version = "0.20.0", features = ["derive"] } +vllm-chat = { path = "src/chat" } +vllm-engine-core-client = { path = "src/engine-core-client" } +vllm-llm = { path = "src/llm" } +vllm-managed-engine = { path = "src/managed-engine" } +vllm-metrics = { path = "src/metrics" } +vllm-reasoning-parser = { path = "src/reasoning-parser" } +vllm-server = { path = "src/server" } +vllm-text = { path = "src/text" } +vllm-tokenizer = { path = "src/tokenizer" } +vllm-tool-parser = { path = "src/tool-parser" } +winnow = "1.0.2" +zeromq = { version = "0.6.0", default-features = false, features = [ + "tokio-runtime", + "all-transport", +] } + +[workspace.lints.clippy] +too_many_arguments = "allow" + +[profile.dev] +panic = "abort" + +[profile.release] +lto = "thin" +panic = "abort" diff --git a/rust/README.md b/rust/README.md new file mode 100644 index 00000000000..679a7f0966e --- /dev/null +++ b/rust/README.md @@ -0,0 +1,89 @@ +# vllm-frontend-rs + +This is a Rust drop-in alternative frontend for vLLM. The current goal is to rebuild the northbound serving layer in Rust while still talking to the core Python vLLM engine process(es) via ZMQ over the existing engine boundary. + +It should still be considered experimental, and is not feature-complete. We are working to add more functionality from the python front-end. + +See for the original commit history before it was moved into the main vllm repo. + +## Architecture + +The component is organized as a Cargo workspace with several crates, layered bottom-up: + +```text +┌─────────────────────────────────┐ +│ vllm-cmd / vllm-rs │ CLI entrypoint: +│ │ Python vLLM frontend subprocess +│ │ Rust managed-engine serve mode +├─────────────────────────────────┤ +│ vllm-server │ OpenAI-compatible HTTP API (axum) +├─────────────────────────────────┤ +│ vllm-chat │ Chat completions: template rendering, +│ │ structured assistant events, +│ │ reasoning & tool parsing +├─────────────────────────────────┤ +│ vllm-text │ Tokenizer & incremental detokenizer +├─────────────────────────────────┤ +│ vllm-llm │ Thin token-in/token-out facade over +│ │ the engine client +├─────────────────────────────────┤ +│ vllm-engine-core-client │ ZMQ transport + MessagePack protocol +│ │ for the headless vLLM engine +└─────────────────────────────────┘ +``` + +`vllm-rs` integrates into Python `vllm` as a Rust frontend subprocess. +Python owns process startup and launches the Rust API server as a Python-supervised worker, while +passing the inherited listening socket and transport addresses into `vllm-rs`. + +For example: + +```bash +VLLM_USE_RUST_FRONTEND=1 vllm serve Qwen/Qwen3-0.6B +``` + +### External Engine + +`vllm-rs serve` can be run standalone with `--data-parallel-size-local 0` when the Python engines +are started elsewhere and this node should run only the Rust frontend. The frontend still uses +the global `--data-parallel-size` to determine how many engines it expects to join the shared handshake. + +```bash +vllm serve Qwen/Qwen3-0.6B \ + --headless \ + --data-parallel-address 127.0.0.1 \ + --data-parallel-rpc-port 62100 \ + --data-parallel-size 1 \ + --data-parallel-size-local 1 +``` + +Then start the Rust frontend-only server: + +```bash +vllm-rs serve Qwen/Qwen3-0.6B \ + --data-parallel-address 127.0.0.1 \ + --data-parallel-rpc-port 62100 \ + --data-parallel-size 1 \ + --data-parallel-size-local 0 +``` + +To build the `vllm-rs` in isolation: + +```bash +# from the local checkout +cargo install --path src/cmd --bin vllm-rs +``` + +### Example Request + +After either startup path, you can use any OpenAI-compatible client: + +```bash +curl http://127.0.0.1:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-0.6B", + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "stream": true + }' +``` diff --git a/rust/proto/vllm_grpc.proto b/rust/proto/vllm_grpc.proto new file mode 100644 index 00000000000..56c5f36442d --- /dev/null +++ b/rust/proto/vllm_grpc.proto @@ -0,0 +1,196 @@ +syntax = "proto3"; +package vllm; + +import "google/protobuf/struct.proto"; + + +service Generate { + // Generates text given a prompt + rpc Generate (GenerateRequest) returns (GenerateResponse) {} + // Generates text given a prompt, streaming the outputs + rpc GenerateStream (GenerateRequest) returns (stream GenerateResponse) {} +} + +// ====================================================================================== +// Generate Request +// ====================================================================================== + +message GenerateRequest { + string request_id = 1; + string model = 2; + + oneof prompt { + string text = 3; + TokenIds token_ids = 4; + } + + // Temperature, defaults to model-specific default or 0 + optional float temperature = 5; + // Parameters controlling random sampling, not applicable if temperature == 0 + RandomSampling sampling = 6; + // Parameters for conditionally penalizing/boosting + // candidate tokens during decoding + DecodingParameters decoding = 7; + // Parameters controlling when generation should stop + StoppingCriteria stopping = 8; + // Flags to control what is returned in the response + ResponseOptions response = 9; + // Parameters controlling KV cache/distribution + KVCacheParameters kv = 10; + + // Truncate prompt tokens; default (0) means no truncation + uint32 truncate_prompt_tokens = 11; + + int32 priority = 12; +} + +message RandomSampling { + uint32 num_sequences = 1; // "n", default (0) means 1 + uint32 top_k = 2; // 0 means default + float top_p = 3; // 0 means default + float min_p = 4; // 0 means default + optional int64 seed = 5; +} + +message DecodingParameters { + // Penalties + float presence_penalty = 1; // Default (0.0) means no penalty + float frequency_penalty = 2; // Default (0.0) means no penalty + float repetition_penalty = 3; // Default (0.0) means no penalty + map logit_bias = 4; + repeated uint32 allowed_token_ids = 5; + + message StringChoices { + repeated string choices = 1; + } + + // Control structured outputs + oneof structured_output { + string json = 6; + string regex = 7; + StringChoices choice = 8; + string grammar = 9; + bool json_object = 10; + string structural_tag = 11; + } +} + +message StoppingCriteria { + // Default (0) is currently 20 + uint32 max_new_tokens = 1; + // Default (0) means no minimum + uint32 min_new_tokens = 2; + + repeated uint32 stop_token_ids = 3; + repeated string stop_strings = 4; + bool include_stop_strings = 5; + + bool ignore_eos = 6; +} + +message ResponseOptions { + // Prompt options + bool prompt_token_ids = 1; + bool prompt_logprobs = 2; + optional CandidateTokens prompt_candidates = 3; + + // Output options; output_text defaults to true + optional bool output_text = 4; + bool output_token_ids = 5; + bool output_logprobs = 6; + optional CandidateTokens output_candidates = 7; +} + +message KVCacheParameters { + bool bypass_prefix_cache = 1; + string cache_salt = 2; + + // KV Connector transfer parameters + google.protobuf.Struct kv_transfer_params = 3; +} + +// Controls which extra candidate tokens at each position should be returned +message CandidateTokens { + oneof select { + uint32 top_n = 1; + TokenIds token_ids = 2; + bool all = 3; + } +} + +// ====================================================================================== +// Generate Response +// ====================================================================================== + +message GenerateResponse { + // Only present in first response + optional PromptInfo prompt_info = 1; + SequenceOutput outputs = 2; +} + +message SequenceOutput { + // Index of output sequence for num_sequences > 1. + uint32 index = 1; + + string text = 2; + uint32 num_tokens = 3; // Number of tokens in this chunk + repeated uint32 token_ids = 4; // If requested + repeated float logprobs = 5; // If requested + repeated uint32 ranks = 6; // If logprobs were requested + repeated CandidateTokenInfo candidate_tokens = 7; // If requested + + // Only present in final output for this sequence + optional FinishInfo finish_info = 8; +} + +// Prompt info, returned in the first response +message PromptInfo { + uint32 num_prompt_tokens = 1; + repeated uint32 token_ids = 2; // If requested + repeated float logprobs = 3; // If requested + repeated uint32 ranks = 4; // If logprobs were requested + repeated CandidateTokenInfo candidate_tokens = 5; +} + +// Finish info, returned in the final response +message FinishInfo { + uint32 num_output_tokens = 1; + + enum FinishReason { + NOT_FINISHED = 0; // Possibly more tokens to be streamed + LENGTH = 1; // Finished due to length constraint + STOP = 2; // Stop string/token or EOS encountered + ABORTED = 3; // Request aborted/cancelled + } + + FinishReason finish_reason = 2; + // One of these will be set when finish_reason == STOP + oneof stop_reason { + uint32 stop_token_id = 3; + uint32 eos_token_id = 4; + string stop_string = 5; + } + + google.protobuf.Struct kv_transfer_params = 6; + //uint64 seed = 7; +} + +// Info for candidate tokens other than the input/sampled +// token at a given position +message CandidateTokenInfo { + message TokenInfo { + uint32 id = 1; + float logprob = 2; + uint32 rank = 3; + // string text = 4; + // bytes token_bytes = 5; + } + // Candidate token infos at this position + repeated TokenInfo tokens = 1; +} + +// Token ids used for prompt +message TokenIds { + repeated uint32 ids = 1; +} + diff --git a/rust/rustfmt.toml b/rust/rustfmt.toml new file mode 100644 index 00000000000..e619a753fa0 --- /dev/null +++ b/rust/rustfmt.toml @@ -0,0 +1,3 @@ +style_edition = "2024" +chain_width = 80 +use_field_init_shorthand = true diff --git a/rust/rustfmt.unstable.toml b/rust/rustfmt.unstable.toml new file mode 100644 index 00000000000..4a0dda0a49f --- /dev/null +++ b/rust/rustfmt.unstable.toml @@ -0,0 +1,20 @@ +# Optional local formatting profile. CI and pre-commit use rustfmt.toml. +# Apply manually with: +# cargo +nightly fmt -- --config-path rustfmt.unstable.toml + +style_edition = "2024" +chain_width = 80 +comment_width = 100 +use_field_init_shorthand = true + +# Unstable features go here. +unstable_features = true + +format_code_in_doc_comments = true +format_macro_matchers = true +normalize_comments = true +normalize_doc_attributes = true +imports_granularity = "Module" +group_imports = "StdExternalCrate" +reorder_impl_items = true +wrap_comments = true diff --git a/rust/src/chat/Cargo.toml b/rust/src/chat/Cargo.toml new file mode 100644 index 00000000000..4bb339d43ba --- /dev/null +++ b/rust/src/chat/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "vllm-chat" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +anyhow.workspace = true +asynk-strim-attr.workspace = true +easy-ext.workspace = true +futures.workspace = true +half.workspace = true +itertools.workspace = true +llm-multimodal.workspace = true +minijinja.workspace = true +minijinja-contrib.workspace = true +openai-harmony.workspace = true +reqwest.workspace = true +serde.workspace = true +serde-json-fmt.workspace = true +serde_json.workspace = true +serde_with.workspace = true +subenum.workspace = true +thiserror.workspace = true +thiserror-ext.workspace = true +tokio.workspace = true +tracing.workspace = true +trait-set.workspace = true +uuid.workspace = true +vllm-engine-core-client.workspace = true +vllm-llm.workspace = true +vllm-reasoning-parser.workspace = true +vllm-text.workspace = true +vllm-tokenizer.workspace = true +vllm-tool-parser.workspace = true + +[dev-dependencies] +anyhow.workspace = true +bytes.workspace = true +clap.workspace = true +expect-test.workspace = true +rmp-serde.workspace = true +serial_test = { workspace = true, features = ["file_locks"] } +tempfile.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +uuid.workspace = true +vllm-engine-core-client = { workspace = true, features = ["test-util"] } +zeromq.workspace = true + +[lints] +workspace = true diff --git a/rust/src/chat/examples/README.md b/rust/src/chat/examples/README.md new file mode 100644 index 00000000000..9ae231da792 --- /dev/null +++ b/rust/src/chat/examples/README.md @@ -0,0 +1,39 @@ +# Chat Smoke Test + +Start a fresh headless `vllm` engine: + +```bash +source ../vllm/.venv/bin/activate +HF_HUB_OFFLINE=1 \ +VLLM_LOGGING_LEVEL=DEBUG \ +VLLM_CPU_KVCACHE_SPACE=2 \ +VLLM_HOST_IP=127.0.0.1 \ +VLLM_LOOPBACK_IP=127.0.0.1 \ +python3 -m vllm.entrypoints.cli.main serve Qwen/Qwen3-0.6B \ + --headless \ + --data-parallel-address 127.0.0.1 \ + --data-parallel-rpc-port 62100 \ + --data-parallel-size-local 1 \ + --max-model-len 512 \ + --dtype float16 +``` + +Run the Rust chat smoke test through the `vllm-chat` interface: + +```bash +cargo run -p vllm-chat --example external_engine_chat_qwen -- \ + --handshake-address tcp://127.0.0.1:62100 \ + --host 127.0.0.1 \ + --prompt 'What is the capital of France? Answer with one word.' +``` + +The example now defaults to `Qwen/Qwen3-0.6B`. The current `vllm-chat` +request model stays text-first and supports either plain string content or +OpenAI-style text blocks, while the output side now emits structured assistant +events and automatically separates reasoning blocks for supported models. Tool +use and multimodal inputs are still out of scope. It uses the Rust +`tokenizers` library for the tokenizer itself, plus standard Hugging Face +config files to load the chat template and EOS metadata. + +IMPORTANT: Restart `vllm` each time you run the smoke test. The current headless +engine cannot safely handle frontend reconnects after the client shuts down. diff --git a/rust/src/chat/examples/external_engine_chat_qwen.rs b/rust/src/chat/examples/external_engine_chat_qwen.rs new file mode 100644 index 00000000000..d99d672d5eb --- /dev/null +++ b/rust/src/chat/examples/external_engine_chat_qwen.rs @@ -0,0 +1,178 @@ +use std::time::Duration; + +use anyhow::{Context, Result, bail}; +use clap::Parser; +use futures::StreamExt as _; +use tracing_subscriber::EnvFilter; +use vllm_chat::{ + AssistantBlockKind, AssistantMessageExt as _, ChatEvent, ChatLlm, ChatMessage, ChatRequest, + ChatRole, SamplingParams, load_model_backends, +}; +use vllm_engine_core_client::{EngineCoreClient, EngineCoreClientConfig, TransportMode}; +use vllm_llm::Llm; +use vllm_text::TextLlm; + +#[derive(Debug, Parser)] +#[command(about = "Smoke-test the Rust chat facade against an external Qwen vLLM engine.")] +struct Args { + #[arg(long)] + handshake_address: String, + #[arg(long, default_value_t = 1)] + engine_count: usize, + #[arg(long, default_value = "Qwen/Qwen3-0.6B")] + model: String, + #[arg(long, default_value = "127.0.0.1")] + host: String, + #[arg(long, default_value_t = 30)] + ready_timeout_secs: u64, + #[arg(long)] + prompt: String, +} + +const CLIENT_INDEX: u32 = 0; +const OUTPUT_TIMEOUT_SECS: u64 = 120; + +fn unique_request_id() -> String { + format!("rust-chat-smoke-{}", uuid::Uuid::new_v4()) +} + +fn init_tracing() { + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("vllm_engine_core_client=debug")); + let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init(); +} + +#[tokio::main(flavor = "multi_thread")] +async fn main() -> Result<()> { + init_tracing(); + let args = Args::parse(); + let loaded = load_model_backends(&args.model, Default::default()) + .await + .with_context(|| format!("failed to load backends for {}", args.model))?; + let text_backend = loaded.text_backend; + let chat_backend = loaded.chat_backend; + + let ready_timeout = Duration::from_secs(args.ready_timeout_secs); + let output_timeout = Duration::from_secs(OUTPUT_TIMEOUT_SECS); + let request_id = unique_request_id(); + let client = EngineCoreClient::connect(EngineCoreClientConfig { + transport_mode: TransportMode::HandshakeOwner { + handshake_address: args.handshake_address.clone(), + advertised_host: args.host.clone(), + engine_count: args.engine_count, + ready_timeout, + local_input_address: None, + local_output_address: None, + }, + coordinator_mode: None, + model_name: args.model.clone(), + client_index: CLIENT_INDEX, + }) + .await + .context("failed to connect to external vLLM engine")?; + + println!("model={}", args.model); + println!("tokenizer_source=tokenizers + hf-hub"); + println!("chat_template_source=tokenizer_config.json or adjacent chat template file"); + println!("handshake_address={}", args.handshake_address); + println!("engine_count={}", args.engine_count); + println!("input_address={}", client.input_address()); + println!("output_address={}", client.output_address()); + println!("engine_identities={:x?}", client.engine_identities()); + + let llm = Llm::new(client); + let chat = ChatLlm::new(TextLlm::new(llm, text_backend), chat_backend); + + let request = ChatRequest { + messages: vec![ChatMessage::text(ChatRole::User, args.prompt.clone())], + sampling_params: SamplingParams { + temperature: Some(0.0), + ..Default::default() + }, + request_id: request_id.clone(), + ..ChatRequest::for_test() + }; + + println!("request_id={request_id}"); + println!("prompt={}", args.prompt); + + let mut stream = chat.chat(request).await.context("failed to submit chat request")?; + let output = tokio::time::timeout(output_timeout, async { + let mut final_reasoning = String::new(); + let mut final_text = String::new(); + let mut final_output_token_count = 0usize; + let mut finish_reason = None; + let mut saw_start = false; + let mut saw_stream_output = false; + + while let Some(event) = stream.next().await.transpose()? { + match event { + ChatEvent::Start { .. } => { + saw_start = true; + } + ChatEvent::BlockStart { kind, .. } => { + if saw_stream_output { + println!(); + } + match kind { + AssistantBlockKind::Reasoning => print!("[reasoning] "), + AssistantBlockKind::Text => print!("[answer] "), + AssistantBlockKind::ToolCall => {} + } + saw_stream_output = true; + } + ChatEvent::ToolCallStart { name, .. } => { + if saw_stream_output { + println!(); + } + print!("[tool:{name}] "); + saw_stream_output = true; + } + ChatEvent::LogprobsDelta { .. } => {} + ChatEvent::Done { + message, + output_token_count, + finish_reason: reason, + .. + } => { + final_reasoning = message.reasoning().unwrap_or_default(); + final_text = message.text(); + final_output_token_count = output_token_count; + finish_reason = Some(reason); + break; + } + ChatEvent::BlockDelta { kind, delta, .. } => match kind { + AssistantBlockKind::Reasoning | AssistantBlockKind::Text => { + print!("{delta}"); + } + AssistantBlockKind::ToolCall => {} + }, + ChatEvent::ToolCallArgumentsDelta { delta, .. } => print!("{delta}"), + ChatEvent::BlockEnd { .. } | ChatEvent::ToolCallEnd { .. } => {} + } + } + + println!(); + + if !saw_start { + bail!("chat stream ended without a start event"); + } + Ok::<_, anyhow::Error>(( + final_reasoning, + final_text, + final_output_token_count, + finish_reason, + )) + }) + .await + .context("timed out waiting for chat output")??; + + chat.shutdown().await.context("failed to shut down chat client")?; + + println!("final_reasoning={:?}", output.0); + println!("final_text={:?}", output.1); + println!("final_output_token_count={:?}", output.2); + println!("finish_reason={:?}", output.3); + + Ok(()) +} diff --git a/rust/src/chat/src/backend/hf.rs b/rust/src/chat/src/backend/hf.rs new file mode 100644 index 00000000000..6c3dddc8729 --- /dev/null +++ b/rust/src/chat/src/backend/hf.rs @@ -0,0 +1,308 @@ +use std::sync::Arc; + +use tracing::info; +use vllm_text::backend::hf::{HfTextBackend, ResolvedModelFiles, load_model_config}; +use vllm_text::tokenizer::DynTokenizer; +use vllm_text::{DynTextBackend, TextBackend as _}; + +use crate::backend::{ + ChatBackend, DynChatBackend, LoadModelBackendsOptions, LoadedModelBackends, + NewChatOutputProcessorOptions, +}; +use crate::error::Result; +use crate::multimodal::MultimodalModelInfo; +use crate::output::{ + DefaultChatOutputProcessor, HarmonyChatOutputProcessor, validate_harmony_parser_overrides, +}; +use crate::renderer::hf::{HfChatRenderer, MultimodalRenderInfo}; +use crate::renderer::{DeepSeekV4ChatRenderer, DeepSeekV32ChatRenderer, DynChatRenderer}; +use crate::request::ChatRequest; +use crate::{DynChatOutputProcessor, RendererSelection}; + +/// [`ChatBackend`] implementation built on Hugging Face model files. +pub struct HfChatBackend { + model_id: String, + model_type: String, + tokenizer: DynTokenizer, + chat_renderer: DynChatRenderer, + multimodal_model_info: Option, +} + +impl HfChatBackend { + /// Load the chat backend from resolved Hugging Face model files. + pub fn from_resolved_model_files( + files: ResolvedModelFiles, + model_id: String, + options: LoadModelBackendsOptions, + tokenizer: DynTokenizer, + ) -> Result { + let model_config = load_model_config(files.config_path.as_deref())?; + let model_type = model_config.model_type().unwrap_or_default(); + let multimodal_model_info = MultimodalModelInfo::from_paths( + model_id.clone(), + (!model_type.is_empty()).then_some(model_type.to_string()), + files.config_path.as_deref(), + files.preprocessor_config_path.as_deref(), + tokenizer.clone(), + )?; + let multimodal_render_info = resolve_multimodal_render_info(multimodal_model_info.as_ref()); + + let renderer = options.renderer.resolve(model_type); + let chat_renderer: DynChatRenderer = match renderer { + RendererSelection::Auto => unreachable!("renderer auto should be resolved above"), + RendererSelection::Hf => Arc::new(HfChatRenderer::load( + &files, + options, + multimodal_render_info, + )?), + RendererSelection::DeepSeekV32 => Arc::new(DeepSeekV32ChatRenderer::new()), + RendererSelection::DeepSeekV4 => Arc::new(DeepSeekV4ChatRenderer::new()), + }; + + info!( + model_id, + model_type, + %renderer, + "loaded chat backend with Hugging Face model files" + ); + + Ok(Self { + model_id, + model_type: model_type.to_string(), + tokenizer, + chat_renderer, + multimodal_model_info, + }) + } +} + +impl ChatBackend for HfChatBackend { + fn chat_renderer(&self) -> DynChatRenderer { + self.chat_renderer.clone() + } + + fn multimodal_model_info(&self) -> Option<&MultimodalModelInfo> { + self.multimodal_model_info.as_ref() + } + + fn new_chat_output_processor( + &self, + request: &mut ChatRequest, + options: NewChatOutputProcessorOptions<'_>, + ) -> Result { + if self.model_type == "gpt_oss" { + validate_harmony_parser_overrides(options.tool_call_parser, options.reasoning_parser)?; + return Ok(Box::new(HarmonyChatOutputProcessor::new(request)?)); + } + + Ok(Box::new(DefaultChatOutputProcessor::new( + request, + &self.model_id, + self.tokenizer.clone(), + options.tool_call_parser, + options.reasoning_parser, + )?)) + } +} + +/// Load the Hugging Face text and chat backends for the given model id. +pub(super) async fn load_model_backends( + model_id: &str, + options: LoadModelBackendsOptions, +) -> Result { + let files = ResolvedModelFiles::new(model_id).await?; + let text_backend = + HfTextBackend::from_resolved_model_files(files.clone(), model_id.to_string())?; + let tokenizer = text_backend.tokenizer(); + let text_backend: DynTextBackend = Arc::new(text_backend); + + let chat_backend: DynChatBackend = Arc::new(HfChatBackend::from_resolved_model_files( + files, + model_id.to_string(), + options, + tokenizer, + )?); + + Ok(LoadedModelBackends { + text_backend, + chat_backend, + }) +} + +fn resolve_multimodal_render_info( + info: Option<&MultimodalModelInfo>, +) -> Option { + info.map(|info| MultimodalRenderInfo { + placeholder_token: info.placeholder_token().to_string(), + }) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::path::PathBuf; + use std::sync::Arc; + + use tempfile::tempdir; + use vllm_text::backend::hf::TokenizerSource; + use vllm_text::tokenizer::{DynTokenizer, Tokenizer}; + + use super::HfChatBackend; + use crate::RendererSelection; + use crate::backend::{ChatBackend, LoadModelBackendsOptions}; + use crate::request::{ChatContent, ChatMessage, ChatRequest}; + + fn request_with_user_text(text: &str) -> ChatRequest { + ChatRequest { + request_id: "renderer-selection-test".to_string(), + messages: vec![ChatMessage::User { + content: ChatContent::Text(text.to_string()), + }], + ..ChatRequest::for_test() + } + } + + fn write_json(path: &std::path::Path, content: &str) { + std::fs::write(path, content).unwrap(); + } + + fn resolved_files( + config_json: &str, + tokenizer_config_json: &str, + ) -> vllm_text::backend::hf::ResolvedModelFiles { + let dir = tempdir().unwrap(); + let root = dir.keep(); + let config_path = root.join("config.json"); + let tokenizer_config_path = root.join("tokenizer_config.json"); + write_json(&config_path, config_json); + write_json(&tokenizer_config_path, tokenizer_config_json); + + vllm_text::backend::hf::ResolvedModelFiles { + tokenizer: TokenizerSource::HuggingFace(PathBuf::from("/tmp/unused-tokenizer.json")), + tokenizer_config_path: Some(tokenizer_config_path), + generation_config_path: None, + preprocessor_config_path: None, + chat_template_path: None, + config_path: Some(config_path), + } + } + + struct TestTokenizer; + + impl Tokenizer for TestTokenizer { + fn encode( + &self, + _text: &str, + _add_special_tokens: bool, + ) -> vllm_text::tokenizer::Result> { + Ok(Vec::new()) + } + + fn decode( + &self, + _token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_text::tokenizer::Result { + Ok(String::new()) + } + + fn token_to_id(&self, _token: &str) -> Option { + None + } + } + + fn test_tokenizer() -> DynTokenizer { + Arc::new(TestTokenizer) + } + + fn render_prompt( + renderer: RendererSelection, + config_json: &str, + tokenizer_config_json: &str, + ) -> String { + let backend = HfChatBackend::from_resolved_model_files( + resolved_files(config_json, tokenizer_config_json), + "test-model".to_string(), + LoadModelBackendsOptions { + renderer, + chat_template_content_format: Default::default(), + chat_template: None, + default_chat_template_kwargs: HashMap::new(), + }, + test_tokenizer(), + ) + .unwrap(); + + backend + .chat_renderer() + .render(&request_with_user_text("hello")) + .unwrap() + .prompt + .into_text() + .expect("renderer should return text prompt") + } + + #[test] + fn auto_uses_deepseek_renderer_for_deepseek_v32_model_type() { + let prompt = render_prompt( + RendererSelection::Auto, + r#"{"model_type":"deepseek_v32"}"#, + r#"{}"#, + ); + + assert_eq!( + prompt, + "<|begin▁of▁sentence|><|User|>hello<|Assistant|>" + ); + } + + #[test] + fn auto_uses_hf_renderer_for_other_model_types() { + let prompt = render_prompt( + RendererSelection::Auto, + r#"{"model_type":"qwen2"}"#, + r#"{"chat_template":"{{ messages[0].content }}"}"#, + ); + + assert_eq!(prompt, "hello"); + } + + #[test] + fn explicit_deepseek_renderer_overrides_generic_model_type() { + let prompt = render_prompt( + RendererSelection::DeepSeekV32, + r#"{"model_type":"qwen2"}"#, + r#"{"chat_template":"{{ messages[0].content }}"}"#, + ); + + assert_eq!( + prompt, + "<|begin▁of▁sentence|><|User|>hello<|Assistant|>" + ); + } + + #[test] + fn explicit_hf_renderer_overrides_deepseek_v32_model_type() { + let prompt = render_prompt( + RendererSelection::Hf, + r#"{"model_type":"deepseek_v32"}"#, + r#"{"chat_template":"{{ messages[0].content }}"}"#, + ); + + assert_eq!(prompt, "hello"); + } + + #[test] + fn auto_uses_nested_text_config_model_type() { + let prompt = render_prompt( + RendererSelection::Auto, + r#"{"text_config":{"model_type":"deepseek_v32","num_attention_heads":32}}"#, + r#"{}"#, + ); + + assert_eq!( + prompt, + "<|begin▁of▁sentence|><|User|>hello<|Assistant|>" + ); + } +} diff --git a/rust/src/chat/src/backend/mod.rs b/rust/src/chat/src/backend/mod.rs new file mode 100644 index 00000000000..f49ca673704 --- /dev/null +++ b/rust/src/chat/src/backend/mod.rs @@ -0,0 +1,86 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use serde_json::Value; +use vllm_text::{DynTextBackend, TextBackend}; + +use crate::error::Result; +use crate::multimodal::MultimodalModelInfo; +use crate::output::DynChatOutputProcessor; +use crate::renderer::DynChatRenderer; +use crate::request::ChatRequest; +use crate::{ChatTemplateContentFormatOption, ParserSelection, RendererSelection}; + +pub mod hf; + +/// Options for creating a new chat output processor. +pub struct NewChatOutputProcessorOptions<'a> { + pub tool_call_parser: &'a ParserSelection, + pub reasoning_parser: &'a ParserSelection, +} + +/// Minimal prompt-processing backend needed by `vllm-chat`. +pub trait ChatBackend: Send + Sync { + /// Return the renderer used for chat-prompt construction. + fn chat_renderer(&self) -> DynChatRenderer; + + /// Return model files/config needed for request-scoped multimodal + /// preprocessing, if supported. + fn multimodal_model_info(&self) -> Option<&MultimodalModelInfo> { + None + } + + /// Create a request-scoped output processor after request-level adjustments + /// are applied. + fn new_chat_output_processor( + &self, + request: &mut ChatRequest, + options: NewChatOutputProcessorOptions<'_>, + ) -> Result; +} + +/// Shared trait-object form of [`ChatBackend`]. +pub type DynChatBackend = Arc; + +/// Convenience trait for backends that can serve both raw text generation and +/// chat templating. +/// +/// This is mainly useful in tests and small examples, where one mock/backend +/// often implements both sides and callers want `ChatLlm` to wire the shared +/// object into `TextLlm` automatically. +pub trait ChatTextBackend: ChatBackend + TextBackend {} + +impl ChatTextBackend for T where T: ChatBackend + TextBackend + ?Sized {} + +/// Shared trait-object form of [`ChatTextBackend`]. +pub type DynChatTextBackend = Arc; + +/// Frontend-side chat backend loading options. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct LoadModelBackendsOptions { + /// Which chat renderer implementation to use. + pub renderer: RendererSelection, + /// How to serialize `message.content` when rendering the chat template. + pub chat_template_content_format: ChatTemplateContentFormatOption, + /// Optional server-default chat template override, provided either as an + /// inline template or as a path to a template file. + pub chat_template: Option, + /// Optional server-default keyword arguments merged into every + /// chat-template render before request-level `chat_template_kwargs`. + pub default_chat_template_kwargs: HashMap, +} + +/// Shared backends loaded from a model id. +pub struct LoadedModelBackends { + pub text_backend: DynTextBackend, + pub chat_backend: DynChatBackend, +} + +/// Load text and chat backends for the given model id. +pub async fn load_model_backends( + model_id: &str, + options: LoadModelBackendsOptions, +) -> Result { + // Currently, we only have HuggingFace backends. + hf::load_model_backends(model_id, options).await +} diff --git a/rust/src/chat/src/error.rs b/rust/src/chat/src/error.rs new file mode 100644 index 00000000000..25d8d015680 --- /dev/null +++ b/rust/src/chat/src/error.rs @@ -0,0 +1,80 @@ +use thiserror::Error; +use thiserror_ext::Macro; + +type BoxedError = Box; + +#[derive(Debug, Error, Macro)] +#[thiserror_ext(macro(path = "crate::error"))] +pub enum Error { + #[error("chat request must contain at least one message")] + EmptyMessages, + #[error("cannot continue the final message when the last message is not from the assistant")] + ContinueFinalAssistantWithoutFinalAssistant, + #[error("chat template is required but none was configured")] + MissingChatTemplate, + #[error("chat template error: {0}")] + ChatTemplate(String), + #[error("multimodal input is not supported by this chat renderer")] + UnsupportedMultimodalRenderer, + #[error("unsupported multimodal content: {0}")] + UnsupportedMultimodalContent(&'static str), + #[error("multimodal preprocessing error: {0}")] + Multimodal(#[message] String), + #[error("{kind} parsing is not available for model `{model_id}`")] + ParserUnavailableForModel { + kind: &'static str, + model_id: String, + }, + #[error("{kind} parsing is disabled by frontend configuration")] + ParserDisabled { kind: &'static str }, + #[error( + "{kind} parser `{name}` is not registered{}", + available_parser_hint(.available_names) + )] + ParserUnavailableByName { + kind: &'static str, + name: String, + available_names: Vec, + }, + #[error("failed to initialize {kind} parser `{name}`")] + ParserInitialization { + kind: &'static str, + name: String, + #[source] + error: BoxedError, + }, + #[error( + "gpt_oss uses native Harmony output parsing; generic {kind} parser override `{selection}` is not supported" + )] + HarmonyParserOverrideUnsupported { + kind: &'static str, + selection: String, + }, + #[error("harmony output parsing failed")] + HarmonyOutputParsing { + #[source] + error: BoxedError, + }, + #[error( + "this model's maximum context length is {max_model_len} tokens, \ + but the prompt contains {prompt_len} input tokens" + )] + PromptTooLong { max_model_len: u32, prompt_len: u32 }, + #[error("chat request stream `{request_id}` closed before terminal output")] + StreamClosedBeforeTerminalOutput { request_id: String }, + #[error("tool call stream state is inconsistent: {message}")] + ToolCallStreamInvariant { message: String }, + #[error(transparent)] + Text(#[from] vllm_text::Error), +} + +pub type Result = std::result::Result; + +/// Format the available-parser suffix used in user-facing error messages. +fn available_parser_hint(available_names: &[String]) -> String { + if available_names.is_empty() { + String::new() + } else { + format!(" (choose from: {})", available_names.join(", ")) + } +} diff --git a/rust/src/chat/src/event.rs b/rust/src/chat/src/event.rs new file mode 100644 index 00000000000..08603b43572 --- /dev/null +++ b/rust/src/chat/src/event.rs @@ -0,0 +1,183 @@ +use std::ops::Deref; +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use vllm_text::{DecodedLogprobs, DecodedPromptLogprobs}; + +use crate::FinishReason; + +/// One finalized assistant tool call. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct AssistantToolCall { + pub id: String, + pub name: String, + pub arguments: String, +} + +/// Semantic kind of one assistant output block. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum AssistantBlockKind { + /// Visible final-answer text. + Text, + /// Extracted reasoning content. + Reasoning, + /// One finalized tool call. + ToolCall, +} + +/// One structured assistant output block. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AssistantContentBlock { + /// Visible final-answer text. + Text { text: String }, + /// Extracted reasoning content. + Reasoning { text: String }, + /// One finalized tool call. + ToolCall(AssistantToolCall), +} + +impl AssistantContentBlock { + /// Return the semantic kind of this block. + pub fn kind(&self) -> AssistantBlockKind { + match self { + Self::Text { .. } => AssistantBlockKind::Text, + Self::Reasoning { .. } => AssistantBlockKind::Reasoning, + Self::ToolCall(..) => AssistantBlockKind::ToolCall, + } + } + + /// Return this block as one finalized tool call, if applicable. + pub fn as_tool_call(&self) -> Option<&AssistantToolCall> { + match self { + Self::ToolCall(call) => Some(call), + _ => None, + } + } +} + +#[easy_ext::ext(AssistantMessageExt)] +impl [AssistantContentBlock] { + /// Concatenate all visible final-answer text blocks. + pub fn text(&self) -> String { + self.iter() + .filter_map(|block| match block { + AssistantContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect() + } + + /// Concatenate all extracted reasoning blocks, if any. + pub fn reasoning(&self) -> Option { + Some( + self.iter() + .filter_map(|block| match block { + AssistantContentBlock::Reasoning { text } => Some(text.as_str()), + _ => None, + }) + .collect(), + ) + .filter(|s: &String| !s.is_empty()) + } + + /// Return whether this assistant message contains any non-empty reasoning + /// text blocks. + pub fn has_reasoning(&self) -> bool { + self.iter().any(|block| match block { + AssistantContentBlock::Reasoning { text } => !text.is_empty(), + _ => false, + }) + } + + /// Return finalized assistant tool calls in encounter order. + pub fn tool_calls(&self) -> impl Iterator { + self.iter().filter_map(AssistantContentBlock::as_tool_call) + } + + /// Return whether this assistant message contains any tool-call blocks. + pub fn has_tool_calls(&self) -> bool { + self.iter().any(|block| matches!(block, AssistantContentBlock::ToolCall(_))) + } +} + +/// Final structured assistant message assembled from the event stream. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct AssistantMessage { + pub content: Vec, +} + +impl Deref for AssistantMessage { + type Target = [AssistantContentBlock]; + + fn deref(&self) -> &Self::Target { + &self.content + } +} + +impl AssistantMessage { + /// Push one new block to the end of the message content. + pub(crate) fn push_block(&mut self, block: AssistantContentBlock) { + self.content.push(block); + } +} + +/// Streamed chat event emitted by [`crate::ChatEventStream`]. +#[derive(Debug, Clone, PartialEq)] +pub enum ChatEvent { + /// The request was accepted, streaming has started, and prompt metadata is + /// ready. + Start { + /// The actual prompt token IDs for this request. + prompt_token_ids: Arc<[u32]>, + /// Once-only prompt logprobs metadata, when requested. + prompt_logprobs: Option, + }, + /// A new assistant output block has started. + BlockStart { + index: usize, + kind: AssistantBlockKind, + }, + /// A newly observed delta for one open assistant output block. + BlockDelta { + index: usize, + kind: AssistantBlockKind, + delta: String, + }, + /// Per-decoded-update sample metadata: logprobs and/or output token IDs. + LogprobsDelta { + logprobs: Option, + token_ids: Vec, + }, + /// One assistant output block has ended. + BlockEnd { + index: usize, + block: AssistantContentBlock, + }, + /// One tool call has started. + ToolCallStart { + index: usize, + id: String, + name: String, + }, + /// One incremental tool-call arguments delta for the currently open tool + /// call. + ToolCallArgumentsDelta { index: usize, delta: String }, + /// One tool call has ended. + ToolCallEnd { + index: usize, + call: AssistantToolCall, + }, + /// Terminal event carrying the final assembled assistant message and finish + /// metadata. + Done { + message: AssistantMessage, + /// Number of prompt tokens actually sent to the engine after chat + /// template rendering and tokenization. + prompt_token_count: usize, + /// Number of output tokens generated. + output_token_count: usize, + finish_reason: FinishReason, + /// Connector-specific KV transfer parameters for disaggregated serving. + kv_transfer_params: Option, + }, +} diff --git a/rust/src/chat/src/lib.rs b/rust/src/chat/src/lib.rs new file mode 100644 index 00000000000..6afc4b35c5f --- /dev/null +++ b/rust/src/chat/src/lib.rs @@ -0,0 +1,249 @@ +//! Minimal chat facade above [`vllm_text`]. +//! +//! This crate keeps the northbound boundary intentionally small: +//! `messages -> rendered prompt -> tokenized prompt -> engine request -> +//! streamed structured assistant events`. The request side remains text-first, +//! while the response side can emit structured reasoning and final-answer +//! blocks. It is closer to vLLM's internal chat-rendering flow than to a full +//! OpenAI-compatible surface. + +pub use backend::hf::HfChatBackend; +pub use backend::{ + ChatBackend, ChatTextBackend, DynChatBackend, DynChatTextBackend, LoadModelBackendsOptions, + LoadedModelBackends, NewChatOutputProcessorOptions, load_model_backends, +}; +pub use error::{Error, Result}; +pub use event::{ + AssistantBlockKind, AssistantContentBlock, AssistantMessage, AssistantMessageExt, + AssistantToolCall, ChatEvent, +}; +use futures::{StreamExt, TryStreamExt as _}; +pub use output::{ + ChatOutputProcessor, DefaultChatOutputProcessor, DynChatOutputProcessor, + HarmonyChatOutputProcessor, +}; +pub use parser::ParserSelection; +pub use parser::reasoning::{ + ReasoningDelta, ReasoningError, ReasoningParser, ReasoningParserFactory, +}; +pub use parser::tool::{ToolParser, ToolParserError, ToolParserFactory}; +pub use renderer::hf::ChatTemplateContentFormatOption; +pub use renderer::{ + ChatRenderer, DeepSeekV4ChatRenderer, DeepSeekV32ChatRenderer, DynChatRenderer, RenderedPrompt, + RendererSelection, +}; +pub use request::{ + ChatContent, ChatContentPart, ChatMessage, ChatOptions, ChatRequest, ChatRole, ChatTool, + ChatToolChoice, GenerationPromptMode, ReasoningEffort, SamplingParams, +}; +pub use stream::{ChatEventStream, ChatEventStreamTrait, CollectedAssistantMessage}; +pub use vllm_llm::FinishReason; + +mod backend; +mod error; +mod event; +pub mod multimodal; +mod output; +mod parser; +mod renderer; +mod request; +mod stream; + +use vllm_engine_core_client::EngineCoreClient; +use vllm_engine_core_client::protocol::ModelDtype; +use vllm_llm::Llm; +use vllm_text::{TextLlm, TextRequest}; + +/// Validate explicit parser override names without starting request processing. +pub fn validate_parser_overrides( + tool_call_parser: &ParserSelection, + reasoning_parser: &ParserSelection, +) -> Result<()> { + let tool_parser_factory = ToolParserFactory::global(); + if let ParserSelection::Explicit(name) = tool_call_parser + && !tool_parser_factory.contains(name) + { + return Err(Error::ParserUnavailableByName { + kind: "tool", + name: name.clone(), + available_names: tool_parser_factory.list(), + }); + } + + let reasoning_parser_factory = ReasoningParserFactory::global(); + if let ParserSelection::Explicit(name) = reasoning_parser + && !reasoning_parser_factory.contains(name) + { + return Err(Error::ParserUnavailableByName { + kind: "reasoning", + name: name.clone(), + available_names: reasoning_parser_factory.list(), + }); + } + + Ok(()) +} + +/// Structured chat facade above [`TextLlm`]. +/// +/// This layer stays above raw text semantics: it takes care of chat-template +/// rendering, exposes structured assistant events, and adds chat-specific +/// request semantics such as tool calls. +pub struct ChatLlm { + text: TextLlm, + backend: DynChatBackend, + /// Effective model dtype reported by the engine. + model_dtype: Option, + /// Tool-call parser selection. + tool_call_parser: ParserSelection, + /// Reasoning parser selection. + reasoning_parser: ParserSelection, +} + +impl ChatLlm { + /// Create a new chat facade from a text-generation facade plus a chat + /// backend. + pub fn new(text: TextLlm, backend: DynChatBackend) -> Self { + let model_dtype = text.engine_core_client().model_dtype(); + + Self { + text, + backend, + model_dtype, + tool_call_parser: ParserSelection::Auto, + reasoning_parser: ParserSelection::Auto, + } + } + + /// Convenience constructor for one shared backend object that implements + /// both text and chat responsibilities. + pub fn from_shared_backend(llm: Llm, backend: DynChatTextBackend) -> Self { + let text = TextLlm::new(llm, backend.clone()); + Self::new(text, backend) + } + + /// Set tool-call parser selection. + pub fn with_tool_call_parser(mut self, selection: ParserSelection) -> Self { + self.tool_call_parser = selection; + self + } + + /// Set reasoning parser selection. + pub fn with_reasoning_parser(mut self, selection: ParserSelection) -> Self { + self.reasoning_parser = selection; + self + } + + /// Override the effective model dtype used for multimodal tensor encoding. + pub fn with_model_dtype(mut self, model_dtype: Option) -> Self { + self.model_dtype = model_dtype; + self + } + + /// Expose the underlying text facade for raw text-generation routes such as + /// `/v1/completions`. + pub fn text(&self) -> &TextLlm { + &self.text + } + + /// Return the model ID reported by the underlying text backend. + pub fn model_id(&self) -> &str { + self.text.model_id() + } + + /// Expose the underlying engine-core client for low-level utility/admin + /// calls. + pub fn engine_core_client(&self) -> &EngineCoreClient { + self.text.engine_core_client() + } + + /// Render, tokenize, and submit one chat request. + pub async fn chat(&self, mut request: ChatRequest) -> Result { + request.validate()?; + + let output_processor = self.backend.new_chat_output_processor( + &mut request, + NewChatOutputProcessorOptions { + tool_call_parser: &self.tool_call_parser, + reasoning_parser: &self.reasoning_parser, + }, + )?; + let rendered = self.backend.chat_renderer().render(&request)?; + + let (prompt, mm_features) = multimodal::finalize_rendered_prompt( + &request, + rendered, + self.backend.multimodal_model_info(), + self.model_dtype, + ) + .await?; + + let text_request = TextRequest { + request_id: request.request_id.clone(), + prompt, + mm_features, + sampling_params: request.sampling_params, + decode_options: request.decode_options, + intermediate: request.intermediate, + priority: request.priority, + cache_salt: request.cache_salt, + add_special_tokens: request.add_special_tokens, + data_parallel_rank: request.data_parallel_rank, + }; + let decoded_stream = self.text.generate(text_request).await?.map_err(Error::from).boxed(); + + let structured_stream = output_processor.process(decoded_stream)?; + + Ok(ChatEventStream::new(request.request_id, structured_stream)) + } + + /// Shut down the underlying LLM client and its background tasks. + pub async fn shutdown(self) -> Result<()> { + self.text.shutdown().await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use thiserror_ext::AsReport; + + use super::{ParserSelection, validate_parser_overrides}; + use crate::parser::reasoning::names; + + #[test] + fn validate_parser_overrides_accepts_registered_names() { + validate_parser_overrides( + &ParserSelection::Explicit("llama3_json".to_string()), + &ParserSelection::Explicit(names::QWEN3.to_string()), + ) + .unwrap(); + } + + #[test] + fn validate_parser_overrides_accepts_auto_and_none() { + validate_parser_overrides(&ParserSelection::Auto, &ParserSelection::None).unwrap(); + } + + #[test] + fn validate_parser_overrides_rejects_unknown_tool_parser() { + let error = validate_parser_overrides( + &ParserSelection::Explicit("definitely_missing_tool_parser".to_string()), + &ParserSelection::Auto, + ) + .unwrap_err(); + + expect_test::expect!["tool parser `definitely_missing_tool_parser` is not registered (choose from: deepseek_v3, deepseek_v31, deepseek_v32, deepseek_v4, gemma4, glm45, glm47, hermes, kimi_k2, llama3_json, llama4_json, minimax_m2, mistral, qwen3_coder, qwen3_xml)"].assert_eq(&error.to_report_string()); + } + + #[test] + fn validate_parser_overrides_rejects_unknown_reasoning_parser() { + let error = validate_parser_overrides( + &ParserSelection::Auto, + &ParserSelection::Explicit("definitely_missing_reasoning_parser".to_string()), + ) + .unwrap_err(); + + expect_test::expect!["reasoning parser `definitely_missing_reasoning_parser` is not registered (choose from: cohere_cmd, deepseek_r1, deepseek_v3, deepseek_v4, gemma4, glm45, kimi, kimi_k2, minimax_m2, nemotron_v3, qwen3, step3)"].assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/chat/src/multimodal.rs b/rust/src/chat/src/multimodal.rs new file mode 100644 index 00000000000..ad730405601 --- /dev/null +++ b/rust/src/chat/src/multimodal.rs @@ -0,0 +1,775 @@ +//! Chat-layer multimodal image preparation. +//! +//! This module owns the narrow image-only multimodal path for chat requests: +//! it extracts image parts from structured chat messages, fetches and +//! preprocesses them through `llm-multimodal`, expands rendered prompt +//! placeholders after tokenization, and builds the engine-facing +//! `MmFeatures` payload. +//! +//! Raw media stays above `vllm-text`; this module lowers it into token IDs and +//! opaque tensor payloads before the request is handed to text generation. + +use std::collections::{HashMap, HashSet}; +use std::fs; +use std::path::Path; +use std::sync::{Arc, LazyLock, Once}; + +use itertools::izip; +use llm_multimodal::{ + AsyncMultiModalTracker, FieldLayout, ImagePreProcessor, ImageProcessorRegistry, MediaConnector, + MediaConnectorConfig, MediaContentPart, Modality, ModelMetadata, ModelProcessorSpec, + ModelRegistry, PreProcessorConfig, PreprocessedImages, PromptReplacement, TokenResolver, + TrackedMedia, +}; +use tracing::warn; +use vllm_engine_core_client::protocol::ModelDtype; +use vllm_engine_core_client::protocol::multimodal::{ + MmBatchedField, MmFeatureSpec, MmFeatures, MmField, MmFieldElem, MmFlatField, MmKwargsItem, + MmSharedField, MmSlice, PlaceholderRange, SliceSpec, +}; +use vllm_engine_core_client::protocol::tensor::WireTensor; +use vllm_text::Prompt; +use vllm_text::tokenizer::{DynTokenizer, Tokenizer}; + +use crate::error::{Error, Result, bail_multimodal, multimodal}; +use crate::renderer::RenderedPrompt; +use crate::request::{ChatContent, ChatContentPart, ChatMessage, ChatRequest}; + +mod tensor; + +/// Resolved multimodal support for one loaded model. +#[derive(Clone)] +pub struct MultimodalModelInfo { + context: MultimodalModelContext, + spec: ResolvedMultimodalSpec, + image_processor: ResolvedImageProcessor, + media_connector: Arc, +} + +/// Model metadata and tokenizer access shared by all multimodal specs. +#[derive(Clone)] +struct MultimodalModelContext { + model_id: String, + model_type: Option, + config: serde_json::Value, + tokenizer: TokenizerResolver, +} + +impl MultimodalModelContext { + fn metadata(&self) -> ModelMetadata<'_> { + ModelMetadata { + model_id: &self.model_id, + tokenizer: &self.tokenizer, + config: &self.config, + } + } + + fn tokenizer(&self) -> &dyn Tokenizer { + self.tokenizer.0.as_ref() + } + + /// Resolve a static model processor spec for one loaded model. + fn resolve_model_spec(&self) -> Option<&'static dyn ModelProcessorSpec> { + static REGISTRY: LazyLock = LazyLock::new(ModelRegistry::new); + REGISTRY.lookup(&self.metadata()) + } + + /// Resolve a static image preprocessor for one loaded model. + fn resolve_image_processor(&self) -> Option<&'static dyn ImagePreProcessor> { + static REGISTRY: LazyLock = + LazyLock::new(ImageProcessorRegistry::with_defaults); + REGISTRY.find(&self.model_id, self.model_type.as_deref()) + } +} + +/// Static model-specific prompt and tensor-layout behavior. +#[derive(Clone)] +struct ResolvedMultimodalSpec { + raw: &'static dyn ModelProcessorSpec, + placeholder_token: String, + placeholder_marker_token_id: u32, + placeholder_embed_token_id: u32, + field_layouts: HashMap, + keep_on_cpu_keys: HashSet, +} + +impl ResolvedMultimodalSpec { + fn new(raw: &'static dyn ModelProcessorSpec, context: &MultimodalModelContext) -> Result { + let metadata = context.metadata(); + let placeholder_token = + raw.placeholder_token(&metadata).map_err(|error| multimodal!("{error}"))?; + // This is the rendered prompt marker, so resolve it from the token + // string itself. Do not use `ModelProcessorSpec::placeholder_token_id()`: + // for some specs that ID is the replacement vision/patch token, + // not necessarily the token ID of `placeholder_token`. + let placeholder_marker_token_id = + context.tokenizer().token_to_id(&placeholder_token).ok_or_else(|| { + multimodal!( + "placeholder token `{placeholder_token}` is not in the tokenizer vocabulary" + ) + })?; + let placeholder_embed_token_id = + raw.placeholder_token_id(&metadata).map_err(|error| multimodal!("{error}"))? as u32; + + Ok(Self { + raw, + placeholder_token, + placeholder_marker_token_id, + placeholder_embed_token_id, + field_layouts: raw.field_layouts(), + keep_on_cpu_keys: raw.keep_on_cpu_keys().into_iter().collect(), + }) + } + + fn prompt_replacements( + &self, + context: &MultimodalModelContext, + preprocessed: &PreprocessedImages, + ) -> Result> { + self.raw + .prompt_replacements(&context.metadata(), preprocessed) + .map_err(|error| multimodal!("{error}")) + } +} + +/// Static image preprocessor plus its loaded config. +#[derive(Clone)] +struct ResolvedImageProcessor { + raw: &'static dyn ImagePreProcessor, + config: PreProcessorConfig, +} + +/// Request-scoped fetched media, kept together with tracker UUID metadata. +struct FetchedImageMedia { + frames: Vec>, + uuids: Vec>, +} + +impl MultimodalModelInfo { + /// Load and resolve multimodal support from model files. + /// + /// Returns `Ok(Some(_))` only when both the model spec and image processor + /// are registered. File read/parse failures are real errors; unsupported + /// model families are logged and returned as `Ok(None)`. + pub fn from_paths( + model_id: String, + model_type: Option, + config_path: Option<&Path>, + preprocessor_config_path: Option<&Path>, + tokenizer: DynTokenizer, + ) -> Result> { + let config = match config_path { + Some(path) => { + let text = fs::read_to_string(path) + .map_err(|error| multimodal!("failed to read config.json: {error}"))?; + serde_json::from_str(&text) + .map_err(|error| multimodal!("failed to parse config.json: {error}"))? + } + None => serde_json::Value::Object(Default::default()), + }; + let preprocessor_config = match preprocessor_config_path { + Some(path) => { + let text = fs::read_to_string(path).map_err(|error| { + multimodal!("failed to read preprocessor_config.json: {error}") + })?; + PreProcessorConfig::from_json(&text).map_err(|error| { + multimodal!("failed to parse preprocessor_config.json: {error}") + })? + } + None => PreProcessorConfig::default(), + }; + + let context = MultimodalModelContext { + model_id, + model_type, + config, + tokenizer: TokenizerResolver(tokenizer), + }; + + let Some(spec) = context.resolve_model_spec() else { + warn!( + model_id = context.model_id, + model_type = context.model_type, + "multimodal model spec is not registered; disabling multimodal support for this model" + ); + return Ok(None); + }; + let spec = ResolvedMultimodalSpec::new(spec, &context)?; + + let Some(image_processor) = context.resolve_image_processor() else { + warn!( + model_id = context.model_id, + model_type = context.model_type, + "image processor is not registered; disabling multimodal support for this model" + ); + return Ok(None); + }; + + let media_connector = Arc::new( + MediaConnector::new(reqwest::Client::new(), MediaConnectorConfig::default()) + .map_err(|error| multimodal!("{error}"))?, + ); + + Ok(Some(Self { + context, + spec, + image_processor: ResolvedImageProcessor { + raw: image_processor, + config: preprocessor_config, + }, + media_connector, + })) + } + + /// Return the template-visible placeholder token for this model. + /// + /// The HF renderer uses this token while flattening image content in string + /// content format. + pub(crate) fn placeholder_token(&self) -> &str { + &self.spec.placeholder_token + } +} + +/// Finalize a rendered chat prompt into text-generation input. +/// +/// Text-only requests pass through unchanged as `Prompt::Text`. Multimodal +/// requests are tokenized in chat, their image placeholders are expanded, and +/// preprocessed image features are attached for engine-core transport. +pub(crate) async fn finalize_rendered_prompt( + request: &ChatRequest, + rendered: RenderedPrompt, + info: Option<&MultimodalModelInfo>, + model_dtype: Option, +) -> Result<(Prompt, Option)> { + if !request.has_multimodal() { + return Ok((rendered.prompt, None)); + } + let info = info.ok_or(Error::UnsupportedMultimodalRenderer)?; + let Prompt::Text(prompt) = rendered.prompt else { + bail_multimodal!("multimodal chat renderer must return a text prompt before expansion"); + }; + let media_parts = extract_media_parts(request)?; + let model_dtype = model_dtype.unwrap_or_else(|| { + static WARN_ONCE: Once = Once::new(); + WARN_ONCE.call_once(|| { + warn!( + "engine handshake did not report model dtype; \ + falling back to float32 for multimodal tensor encoding" + ); + }); + ModelDtype::Float32 + }); + + let mut prompt_token_ids = info + .context + .tokenizer() + .encode(&prompt, request.add_special_tokens) + .map_err(|error| multimodal!("{error}"))?; + let prepared = info.prepare_multimodal(media_parts, &mut prompt_token_ids, model_dtype).await?; + + Ok((Prompt::TokenIds(prompt_token_ids), Some(prepared))) +} + +/// Extract image media parts from chat messages in message/content order. +/// +/// Assistant history is skipped because generated assistant blocks are already +/// represented as text for prompt rendering in this crate. +fn extract_media_parts(request: &ChatRequest) -> Result> { + let mut all_parts = Vec::new(); + for message in &request.messages { + let content = match message { + ChatMessage::System { content } + | ChatMessage::Developer { content, .. } + | ChatMessage::User { content } + | ChatMessage::ToolResponse { content, .. } => content, + ChatMessage::Assistant { .. } => continue, + }; + let ChatContent::Parts(parts) = content else { + continue; + }; + for part in parts { + match part { + ChatContentPart::Text { .. } => {} + ChatContentPart::ImageUrl { + image_url, + detail, + uuid, + } => all_parts.push(MediaContentPart::ImageUrl { + url: image_url.clone(), + detail: *detail, + uuid: uuid.clone(), + }), + } + } + } + Ok(all_parts) +} + +impl MultimodalModelInfo { + /// Run media fetch, image preprocessing, prompt expansion, and feature + /// build. + /// + /// `prompt_token_ids` is mutated in place because placeholder expansion + /// changes both the final prompt and the offsets recorded in + /// `PlaceholderRange`. + async fn prepare_multimodal( + &self, + media_parts: Vec, + prompt_token_ids: &mut Vec, + model_dtype: ModelDtype, + ) -> Result { + if media_parts.is_empty() { + return Ok(Vec::new()); + } + let media_parts_len = media_parts.len(); + + let fetched = self.fetch_images(media_parts).await?; + let preprocessed = self.preprocess_images(&fetched.frames).await?; + let replacements = self.spec.prompt_replacements(&self.context, &preprocessed)?; + let ranges = self.expand_prompt_tokens(prompt_token_ids, replacements)?; + + let features = self.build_features(preprocessed, fetched, ranges, model_dtype)?; + if features.len() != media_parts_len { + bail_multimodal!( + "number of built multimodal features {} does not match number of media parts {}", + features.len(), + media_parts_len + ); + } + Ok(features) + } + + /// Fetch all image parts and preserve their request-order UUID metadata. + async fn fetch_images(&self, media_parts: Vec) -> Result { + let mut tracker = AsyncMultiModalTracker::new(Arc::clone(&self.media_connector)); + for part in media_parts { + tracker.push_part(part).map_err(|error| multimodal!("{error}"))?; + } + + let tracker_output = tracker.finalize().await.map_err(|error| multimodal!("{error}"))?; + let images = tracker_output.data.get(&Modality::Image).cloned().unwrap_or_default(); + let uuids = tracker_output.uuids.get(&Modality::Image).cloned().unwrap_or_default(); + + let frames = images + .into_iter() + .map(|media| match media { + TrackedMedia::Image(frame) => Ok(frame), + _ => Err(Error::UnsupportedMultimodalContent("non-image")), + }) + .collect::>>()?; + + Ok(FetchedImageMedia { frames, uuids }) + } + + /// Preprocess fetched image frames with the model's resolved image + /// processor. + /// + /// The processor work is CPU-heavy relative to request wiring, so it runs + /// in a blocking task and returns owned tensors ready for wire + /// conversion. + async fn preprocess_images( + &self, + image_frames: &[Arc], + ) -> Result { + let config = self.image_processor.config.clone(); + let processor = self.image_processor.raw; + let images = image_frames.iter().map(|frame| frame.data().clone()).collect::>(); + + tokio::task::spawn_blocking(move || { + processor.preprocess(&images, &config).map_err(|error| multimodal!("{error}")) + }) + .await + .map_err(|error| multimodal!("image preprocessing task failed: {error}"))? + } + + /// Replace rendered placeholder markers with model-specific replacement + /// tokens. + /// + /// Replacements are consumed in order, matching the original media-part + /// order. The returned ranges point into the already-expanded prompt. + fn expand_prompt_tokens( + &self, + prompt_token_ids: &mut Vec, + replacements: Vec, + ) -> Result> { + let mut cursor = 0; + let mut ranges = Vec::with_capacity(replacements.len()); + for replacement in replacements { + if replacement.modality != Modality::Image { + bail_multimodal!( + "unsupported prompt replacement modality `{}`", + replacement.modality + ); + } + let offset = find_next_token( + prompt_token_ids, + self.spec.placeholder_marker_token_id, + cursor, + ) + .ok_or_else(|| { + multimodal!( + "placeholder token `{}` was not found in tokenized prompt", + self.spec.placeholder_token + ) + })?; + + if replacement.tokens.is_empty() { + bail_multimodal!( + "placeholder token `{}` expanded to no tokens", + self.spec.placeholder_token + ); + } + let replacement_len = replacement.tokens.len(); + let replacement_tokens = + replacement.tokens.iter().map(|&token| token as u32).collect::>(); + let is_embed = { + let mask = replacement_tokens + .iter() + .map(|&token| token == self.spec.placeholder_embed_token_id) + .collect::>(); + WireTensor::from_bool(vec![replacement_len], mask).map_err(Error::Multimodal)? + }; + + prompt_token_ids.splice(offset..offset + 1, replacement_tokens); + ranges.push(PlaceholderRange { + offset, + length: replacement_len, + is_embed: Some(is_embed), + }); + cursor = offset + replacement_len; + } + Ok(ranges) + } + + /// Convert preprocessed image tensors into engine-core multimodal features. + /// + /// One `MmFeatureSpec` is produced per image. Tensor fields are + /// sliced according to the model spec's field layout declarations. + fn build_features( + &self, + preprocessed: PreprocessedImages, + images: FetchedImageMedia, + ranges: Vec, + model_dtype: ModelDtype, + ) -> Result { + let len = images.frames.len(); + let tensors = tensor::collect_tensors(preprocessed, model_dtype)?; + + let mut features = Vec::with_capacity(images.frames.len()); + for (index, (frame, uuid, range)) in izip!(images.frames, images.uuids, ranges).enumerate() + { + let mut data = MmKwargsItem::new(); + for (key, tensor) in &tensors { + let keep_on_cpu = self.spec.keep_on_cpu_keys.contains(key); + let (value, field) = match self.spec.field_layouts.get(key) { + Some(FieldLayout::Batched) => ( + tensor.batched_value_at(index)?, + MmField::Batched(MmBatchedField { keep_on_cpu }), + ), + Some(FieldLayout::Flat { sizes_key }) => { + let sizes = tensors.get(sizes_key).ok_or_else(|| { + multimodal!("flat tensor sizes key `{sizes_key}` is missing") + })?; + let (start, end) = tensor::flat_range_for_index(sizes, sizes_key, index)?; + ( + tensor.flat_value_range(start, end)?, + MmField::Flat(MmFlatField { + slices: vec![MmSlice::Slice(SliceSpec { + start: Some(0), + stop: Some((end - start) as isize), + step: None, + })], + dim: 0, + keep_on_cpu, + }), + ) + } + None => ( + tensor.clone(), + MmField::Shared(MmSharedField { + batch_size: len, + keep_on_cpu, + }), + ), + }; + + data.insert( + key.clone(), + MmFieldElem { + data: Some(value.try_into()?), + field, + }, + ); + } + + let hash = frame.hash.clone(); + features.push(MmFeatureSpec { + data: Some(data), + modality: "image".to_string(), + identifier: uuid.unwrap_or_else(|| hash.clone()), + mm_position: range, + mm_hash: Some(hash), + }); + } + + Ok(features) + } +} + +/// Find `needle` in `haystack`, starting at `start`. +/// +/// This is intentionally order-preserving rather than a global replace: each +/// image consumes the next placeholder occurrence. +fn find_next_token(haystack: &[u32], needle: u32, start: usize) -> Option { + haystack + .get(start..)? + .iter() + .position(|token| *token == needle) + .map(|offset| start + offset) +} + +/// Adapter from the frontend tokenizer trait to `llm-multimodal`. +#[derive(Clone)] +struct TokenizerResolver(DynTokenizer); + +impl TokenResolver for TokenizerResolver { + fn token_to_id(&self, token: &str) -> Option { + self.0.token_to_id(token) + } + + fn id_to_token(&self, id: u32) -> Option { + self.0.id_to_token(id) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use llm_multimodal::TokenId; + use vllm_engine_core_client::protocol::tensor::WireArrayData; + use vllm_text::tokenizer::{IncrementalDecoder, Tokenizer, TokenizerError}; + + use super::*; + + const LLAMA4_IMAGE_START_ID: u32 = 200088; + const LLAMA4_IMAGE_END_ID: u32 = 200089; + const LLAMA4_IMAGE_ID: u32 = 200090; + const LLAMA4_PATCH_ID: u32 = 200092; + const LLAMA4_TILE_X_SEPARATOR_ID: u32 = 200093; + const LLAMA4_TILE_Y_SEPARATOR_ID: u32 = 200094; + + struct TestTokenizer; + + impl Tokenizer for TestTokenizer { + fn encode( + &self, + text: &str, + _add_special_tokens: bool, + ) -> std::result::Result, TokenizerError> { + Ok(match text { + "<|image|>" => vec![LLAMA4_IMAGE_ID], + text => text.bytes().map(u32::from).collect(), + }) + } + + fn decode( + &self, + _token_ids: &[u32], + _skip_special_tokens: bool, + ) -> std::result::Result { + Ok(String::new()) + } + + fn token_to_id(&self, token: &str) -> Option { + match token { + "<|image_start|>" => Some(LLAMA4_IMAGE_START_ID), + "<|image_end|>" => Some(LLAMA4_IMAGE_END_ID), + "<|image|>" => Some(LLAMA4_IMAGE_ID), + "<|patch|>" => Some(LLAMA4_PATCH_ID), + "<|tile_x_separator|>" => Some(LLAMA4_TILE_X_SEPARATOR_ID), + "<|tile_y_separator|>" => Some(LLAMA4_TILE_Y_SEPARATOR_ID), + _ => None, + } + } + + fn id_to_token(&self, id: u32) -> Option { + match id { + LLAMA4_IMAGE_START_ID => Some("<|image_start|>".to_string()), + LLAMA4_IMAGE_END_ID => Some("<|image_end|>".to_string()), + LLAMA4_IMAGE_ID => Some("<|image|>".to_string()), + LLAMA4_PATCH_ID => Some("<|patch|>".to_string()), + LLAMA4_TILE_X_SEPARATOR_ID => Some("<|tile_x_separator|>".to_string()), + LLAMA4_TILE_Y_SEPARATOR_ID => Some("<|tile_y_separator|>".to_string()), + _ => None, + } + } + + fn create_decode_stream( + &self, + _prompt_token_ids: &[u32], + _skip_special_tokens: bool, + _min_bytes_to_buffer: usize, + ) -> Box { + unreachable!("not used") + } + } + + fn test_info(model_type: &str, config: serde_json::Value) -> MultimodalModelInfo { + let context = MultimodalModelContext { + model_id: format!("{model_type}-test"), + model_type: Some(model_type.to_string()), + config, + tokenizer: TokenizerResolver(Arc::new(TestTokenizer)), + }; + let spec = context + .resolve_model_spec() + .unwrap_or_else(|| panic!("{model_type} spec should match")); + let spec = ResolvedMultimodalSpec::new(spec, &context).unwrap(); + let raw_image_processor = context + .resolve_image_processor() + .unwrap_or_else(|| panic!("{model_type} image processor should match")); + let media_connector = Arc::new( + MediaConnector::new(reqwest::Client::new(), MediaConnectorConfig::default()).unwrap(), + ); + + MultimodalModelInfo { + context, + spec, + image_processor: ResolvedImageProcessor { + raw: raw_image_processor, + config: PreProcessorConfig::default(), + }, + media_connector, + } + } + + fn llama4_info() -> MultimodalModelInfo { + let config = serde_json::json!({ + "model_type": "llama4", + "image_token_index": LLAMA4_PATCH_ID, + "vision_config": {"image_size": 336, "patch_size": 14} + }); + test_info("llama4", config) + } + + fn llama4_single_tile_replacement() -> PromptReplacement { + PromptReplacement::sequence( + Modality::Image, + "<|image|>", + vec![ + LLAMA4_IMAGE_START_ID as TokenId, + LLAMA4_IMAGE_ID as TokenId, + LLAMA4_PATCH_ID as TokenId, + LLAMA4_PATCH_ID as TokenId, + LLAMA4_IMAGE_END_ID as TokenId, + ], + ) + } + + fn llama4_multi_tile_replacement() -> PromptReplacement { + PromptReplacement::sequence( + Modality::Image, + "<|image|>", + vec![ + LLAMA4_IMAGE_START_ID as TokenId, + LLAMA4_PATCH_ID as TokenId, + LLAMA4_TILE_X_SEPARATOR_ID as TokenId, + LLAMA4_PATCH_ID as TokenId, + LLAMA4_TILE_Y_SEPARATOR_ID as TokenId, + LLAMA4_IMAGE_ID as TokenId, + LLAMA4_PATCH_ID as TokenId, + LLAMA4_IMAGE_END_ID as TokenId, + ], + ) + } + + fn assert_bool_mask(range: &PlaceholderRange, expected: &[bool]) { + let tensor = range.is_embed.as_ref().expect("is_embed mask"); + assert_eq!(tensor.dtype, "bool"); + assert_eq!(tensor.shape, vec![expected.len()]); + assert_eq!( + tensor.data, + WireArrayData::RawView(expected.iter().map(|value| u8::from(*value)).collect()) + ); + } + + #[test] + fn expand_prompt_tokens_marks_only_llama4_patch_tokens_as_embed() { + let info = llama4_info(); + let mut prompt_token_ids = vec![1, LLAMA4_IMAGE_ID, 2]; + let replacements = vec![llama4_multi_tile_replacement()]; + + let ranges = info.expand_prompt_tokens(&mut prompt_token_ids, replacements).unwrap(); + + assert_eq!( + prompt_token_ids, + vec![ + 1, + LLAMA4_IMAGE_START_ID, + LLAMA4_PATCH_ID, + LLAMA4_TILE_X_SEPARATOR_ID, + LLAMA4_PATCH_ID, + LLAMA4_TILE_Y_SEPARATOR_ID, + LLAMA4_IMAGE_ID, + LLAMA4_PATCH_ID, + LLAMA4_IMAGE_END_ID, + 2, + ] + ); + assert_eq!(ranges[0].offset, 1); + assert_eq!(ranges[0].length, 8); + assert_bool_mask( + &ranges[0], + &[false, true, false, true, false, false, true, false], + ); + } + + #[test] + fn expand_prompt_tokens_errors_when_placeholder_missing() { + let info = llama4_info(); + let mut prompt_token_ids = vec![1, 2, 3]; + let replacements = vec![llama4_single_tile_replacement()]; + + let error = info.expand_prompt_tokens(&mut prompt_token_ids, replacements).unwrap_err(); + + assert!(matches!(error, Error::Multimodal(message) if message.contains("not found"))); + } + + #[test] + fn expand_prompt_tokens_skips_llama4_image_marker_inside_replacement() { + let info = llama4_info(); + let mut prompt_token_ids = vec![1, LLAMA4_IMAGE_ID, 2, LLAMA4_IMAGE_ID, 3]; + let replacements = vec![ + llama4_single_tile_replacement(), + llama4_single_tile_replacement(), + ]; + + let ranges = info.expand_prompt_tokens(&mut prompt_token_ids, replacements).unwrap(); + + assert_eq!( + prompt_token_ids, + vec![ + 1, + LLAMA4_IMAGE_START_ID, + LLAMA4_IMAGE_ID, + LLAMA4_PATCH_ID, + LLAMA4_PATCH_ID, + LLAMA4_IMAGE_END_ID, + 2, + LLAMA4_IMAGE_START_ID, + LLAMA4_IMAGE_ID, + LLAMA4_PATCH_ID, + LLAMA4_PATCH_ID, + LLAMA4_IMAGE_END_ID, + 3, + ] + ); + assert_eq!(ranges[0].offset, 1); + assert_eq!(ranges[0].length, 5); + assert_bool_mask(&ranges[0], &[false, false, true, true, false]); + assert_eq!(ranges[1].offset, 7); + assert_eq!(ranges[1].length, 5); + assert_bool_mask(&ranges[1], &[false, false, true, true, false]); + } +} diff --git a/rust/src/chat/src/multimodal/tensor.rs b/rust/src/chat/src/multimodal/tensor.rs new file mode 100644 index 00000000000..eddf8f707e9 --- /dev/null +++ b/rust/src/chat/src/multimodal/tensor.rs @@ -0,0 +1,342 @@ +use std::collections::HashMap; + +use half::{bf16, f16}; +use llm_multimodal::{ModelSpecificValue, PreprocessedImages}; +use vllm_engine_core_client::protocol::ModelDtype; +use vllm_engine_core_client::protocol::multimodal::MmKwargValue as ProtocolKwargValue; +use vllm_engine_core_client::protocol::tensor::{ShapeExt as _, WireTensor}; + +use crate::error::{Error, Result, bail_multimodal, multimodal}; + +/// Representation for multimodal kwarg values for transformation. +#[derive(Debug, Clone)] +pub(super) enum KwargValue { + /// Float tensor with row-major flat data and shape. + F32Tensor { data: Vec, shape: Vec }, + /// Float16 tensor with row-major flat data and shape. + F16Tensor { data: Vec, shape: Vec }, + /// BFloat16 tensor with row-major flat data and shape. + Bf16Tensor { data: Vec, shape: Vec }, + /// Signed integer tensor with row-major flat data and shape. + I64Tensor { data: Vec, shape: Vec }, + /// Unsigned integer tensor with row-major flat data and shape. + U32Tensor { data: Vec, shape: Vec }, + /// Non-tensor kwarg value that is shared or copied as-is. + Passthrough(ProtocolKwargValue), +} + +/// Collect `pixel_values` and model-specific outputs into one tensor map. +pub(super) fn collect_tensors( + preprocessed: PreprocessedImages, + float_dtype: ModelDtype, +) -> Result> { + let PreprocessedImages { + pixel_values, + model_specific, + .. + } = preprocessed; + + let pixel_values = { + let shape = pixel_values.shape().to_vec(); + let data = pixel_values.into_iter().collect(); + KwargValue::from_f32_tensor(data, shape, float_dtype)? + }; + + let mut tensors = HashMap::new(); + tensors.insert("pixel_values".to_string(), pixel_values); + for (key, value) in model_specific { + tensors.insert(key, KwargValue::from_model_specific(value, float_dtype)?); + } + Ok(tensors) +} + +impl KwargValue { + fn from_model_specific(value: ModelSpecificValue, float_dtype: ModelDtype) -> Result { + use ProtocolKwargValue::*; + + Ok(match value { + ModelSpecificValue::Tensor { data, shape } => { + Self::from_f32_tensor(data, shape, float_dtype)? + } + ModelSpecificValue::IntTensor { data, shape } => Self::I64Tensor { data, shape }, + ModelSpecificValue::UintTensor { data, shape } => Self::U32Tensor { data, shape }, + ModelSpecificValue::Int(value) => Self::Passthrough(Int(value)), + ModelSpecificValue::Float(value) => Self::Passthrough(Float(value)), + ModelSpecificValue::IntVec(values) => { + Self::Passthrough(List(values.into_iter().map(Int).collect())) + } + ModelSpecificValue::UintVec(values) => Self::Passthrough(List( + values.into_iter().map(|value| Int(value as i64)).collect(), + )), + ModelSpecificValue::FloatVec(values) => Self::Passthrough(List( + values.into_iter().map(|value| Float(value as f64)).collect(), + )), + ModelSpecificValue::TupleVec(values) => Self::Passthrough(List( + values + .into_iter() + .map(|(height, width)| List(vec![Int(height as i64), Int(width as i64)])) + .collect(), + )), + ModelSpecificValue::Bool(value) => Self::Passthrough(Int(i64::from(value))), + }) + } + + /// Convert a float tensor to the target float dtype if needed, keeping the + /// same shape. + fn from_f32_tensor(data: Vec, shape: Vec, float_dtype: ModelDtype) -> Result { + match float_dtype { + ModelDtype::Float16 => Ok(Self::F16Tensor { + data: data.into_iter().map(f16::from_f32).collect(), + shape, + }), + ModelDtype::BFloat16 => Ok(Self::Bf16Tensor { + data: data.into_iter().map(bf16::from_f32).collect(), + shape, + }), + ModelDtype::Float32 => Ok(Self::F32Tensor { data, shape }), + } + } +} + +impl TryFrom for ProtocolKwargValue { + type Error = Error; + + fn try_from(value: KwargValue) -> Result { + match value { + KwargValue::F32Tensor { data, shape } => Ok(Self::Tensor( + WireTensor::from_f32(shape, data).map_err(Error::Multimodal)?, + )), + KwargValue::F16Tensor { data, shape } => Ok(Self::Tensor( + WireTensor::from_f16(shape, data).map_err(Error::Multimodal)?, + )), + KwargValue::Bf16Tensor { data, shape } => Ok(Self::Tensor( + WireTensor::from_bf16(shape, data).map_err(Error::Multimodal)?, + )), + KwargValue::I64Tensor { data, shape } => Ok(Self::Tensor( + WireTensor::from_i64(shape, data).map_err(Error::Multimodal)?, + )), + KwargValue::U32Tensor { data, shape } => Ok(Self::Tensor( + WireTensor::from_u32(shape, data).map_err(Error::Multimodal)?, + )), + KwargValue::Passthrough(value) => Ok(value), + } + } +} + +impl KwargValue { + /// Extract one image from a batched tensor field. + /// + /// Batched fields use their first axis as image index and drop that axis in + /// the per-feature value, matching vLLM's batched-field semantics. + pub(super) fn batched_value_at(&self, index: usize) -> Result { + match self { + Self::F32Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, index, index + 1, true)?; + Ok(Self::F32Tensor { data, shape }) + } + Self::F16Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, index, index + 1, true)?; + Ok(Self::F16Tensor { data, shape }) + } + Self::Bf16Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, index, index + 1, true)?; + Ok(Self::Bf16Tensor { data, shape }) + } + Self::I64Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, index, index + 1, true)?; + Ok(Self::I64Tensor { data, shape }) + } + Self::U32Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, index, index + 1, true)?; + Ok(Self::U32Tensor { data, shape }) + } + Self::Passthrough(value) => Ok(Self::Passthrough(value.clone())), + } + } + + /// Extract one image's variable-length range from a flat tensor field. + /// + /// Flat fields keep the first axis as the sliced length for this image. + pub(super) fn flat_value_range(&self, start: usize, end: usize) -> Result { + match self { + Self::F32Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, start, end, false)?; + Ok(Self::F32Tensor { data, shape }) + } + Self::F16Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, start, end, false)?; + Ok(Self::F16Tensor { data, shape }) + } + Self::Bf16Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, start, end, false)?; + Ok(Self::Bf16Tensor { data, shape }) + } + Self::I64Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, start, end, false)?; + Ok(Self::I64Tensor { data, shape }) + } + Self::U32Tensor { data, shape } => { + let (shape, data) = slice_first_axis_range(shape, data, start, end, false)?; + Ok(Self::U32Tensor { data, shape }) + } + Self::Passthrough(value) => Ok(Self::Passthrough(value.clone())), + } + } +} + +/// Compute the first-axis range for one image in a flat tensor. +/// +/// `sizes_key` names a companion tensor whose entries are cumulative slice +/// sizes per image. +pub(super) fn flat_range_for_index( + sizes: &KwargValue, + sizes_key: &str, + index: usize, +) -> Result<(usize, usize)> { + let sizes = tensor_as_usize_vec(sizes)?; + let size = *sizes.get(index).ok_or_else(|| { + multimodal!("flat tensor sizes key `{sizes_key}` has no entry for image {index}") + })?; + let start = sizes[..index].iter().sum::(); + Ok((start, start + size)) +} + +/// Read a tensor value as per-image sizes for flat slicing. +fn tensor_as_usize_vec(tensor: &KwargValue) -> Result> { + match tensor { + KwargValue::I64Tensor { data, .. } => data + .iter() + .map(|value| { + usize::try_from(*value) + .map_err(|_| multimodal!("negative flat tensor size `{value}`")) + }) + .collect(), + KwargValue::U32Tensor { data, .. } => { + Ok(data.iter().map(|value| *value as usize).collect()) + } + _ => Err(multimodal!("flat tensor sizes must be int64 or uint32")), + } +} + +/// Slice a flat row-major tensor along its first axis. +fn slice_first_axis_range( + shape: &[usize], + data: &[T], + start: usize, + end: usize, + drop_axis: bool, +) -> Result<(Vec, Vec)> { + let first_dim = *shape.first().ok_or_else(|| multimodal!("tensor has no first dimension"))?; + if start > end || end > first_dim { + bail_multimodal!("invalid tensor slice {start}..{end} for first dimension {first_dim}"); + } + let expected_len = shape + .checked_numel() + .ok_or_else(|| multimodal!("tensor shape {shape:?} has too many elements"))?; + if expected_len != data.len() { + bail_multimodal!( + "tensor shape {shape:?} expects {expected_len} elements, got {}", + data.len() + ); + } + let stride = shape[1..].iter().product::(); + let data_start = start * stride; + let data_end = end * stride; + let out_shape = if drop_axis { + shape[1..].to_vec() + } else { + let mut shape = shape.to_vec(); + shape[0] = end - start; + shape + }; + Ok((out_shape, data[data_start..data_end].to_vec())) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn batched_value_at_drops_first_axis() { + let value = KwargValue::F32Tensor { + data: vec![1.0, 2.0, 3.0, 4.0], + shape: vec![2, 2], + }; + + let value = value.batched_value_at(1).unwrap(); + + assert!(matches!( + value, + KwargValue::F32Tensor { data, shape } + if shape == vec![2] && data == vec![3.0, 4.0] + )); + } + + #[test] + fn flat_value_range_keeps_first_axis() { + let value = KwargValue::U32Tensor { + data: (0..10).collect(), + shape: vec![5, 2], + }; + + let value = value.flat_value_range(1, 3).unwrap(); + + assert!(matches!( + value, + KwargValue::U32Tensor { data, shape } + if shape == vec![2, 2] && data == vec![2, 3, 4, 5] + )); + } + + #[test] + fn flat_range_for_index_uses_size_tensor() { + let sizes = KwargValue::I64Tensor { + data: vec![2, 3, 4], + shape: vec![3], + }; + + assert_eq!( + flat_range_for_index(&sizes, "image_grid_thw", 1).unwrap(), + (2, 5) + ); + } + + #[test] + fn slice_first_axis_range_errors_on_shape_data_mismatch() { + let error = slice_first_axis_range(&[2, 2], &[1.0_f32, 2.0, 3.0], 0, 1, true).unwrap_err(); + + assert!( + matches!(error, Error::Multimodal(message) if message.contains("expects 4 elements")) + ); + } + + #[test] + fn bfloat16_tensor_wire_uses_bfloat16_dtype() { + let value = + KwargValue::from_f32_tensor(vec![1.0, -1.0], vec![2], ModelDtype::BFloat16).unwrap(); + + let ProtocolKwargValue::Tensor(tensor) = ProtocolKwargValue::try_from(value).unwrap() + else { + panic!("expected tensor"); + }; + + assert_eq!(tensor.dtype, "bfloat16"); + assert_eq!(tensor.shape, vec![2]); + assert_eq!(tensor.data.into_raw_view().unwrap().len(), 4); + } + + #[test] + fn float16_tensor_wire_uses_float16_dtype() { + let value = + KwargValue::from_f32_tensor(vec![1.0, -1.0], vec![2], ModelDtype::Float16).unwrap(); + + let ProtocolKwargValue::Tensor(tensor) = ProtocolKwargValue::try_from(value).unwrap() + else { + panic!("expected tensor"); + }; + + assert_eq!(tensor.dtype, "float16"); + assert_eq!(tensor.shape, vec![2]); + assert_eq!(tensor.data.into_raw_view().unwrap().len(), 4); + } +} diff --git a/rust/src/chat/src/output/default/mod.rs b/rust/src/chat/src/output/default/mod.rs new file mode 100644 index 00000000000..40526a9e84c --- /dev/null +++ b/rust/src/chat/src/output/default/mod.rs @@ -0,0 +1,166 @@ +//! Default output processing pipeline. + +mod reasoning; +mod tool; + +use std::sync::Once; + +use futures::{Stream, StreamExt as _}; +use tracing::info; +use trait_set::trait_set; +use vllm_text::tokenizer::DynTokenizer; + +use self::reasoning::reasoning_event_stream; +use self::tool::tool_event_stream; +use super::structured::structured_chat_event_stream; +use crate::error::Result; +use crate::output::{ + AssistantEvent, ChatOutputProcessor, ContentEvent, DynChatEventStream, + DynDecodedTextEventStream, +}; +use crate::parser::ParserSelection; +use crate::parser::reasoning::{ReasoningParser, ReasoningParserFactory}; +use crate::parser::tool::{ToolParser, ToolParserFactory}; +use crate::request::{ChatRequest, ChatToolChoice}; +use crate::{Error, Result as ChatResult}; + +trait_set! { + trait ContentEventStream = Stream> + Send + 'static; +} + +/// Default request-scoped output processor used by Hugging Face style chat +/// backends. +/// +/// This implementation assumes the backend already emitted decoded text deltas, +/// then optionally layers reasoning parsing and tool-call parsing before +/// assembling final structured chat events. +pub struct DefaultChatOutputProcessor { + reasoning_parser: Option>, + tool_parser: Option>, +} + +impl DefaultChatOutputProcessor { + /// Build the default output processor and apply any parser-specific request + /// adjustments. + /// + /// Parser resolution happens here so that request validation, prompt + /// rendering, and streaming all observe the same parser-adjusted + /// request state. + pub fn new( + request: &mut ChatRequest, + model_id: &str, + tokenizer: DynTokenizer, + tool_call_parser: &ParserSelection, + reasoning_parser: &ParserSelection, + ) -> ChatResult { + let tool_parsing_enabled = + matches!(request.tool_choice, ChatToolChoice::Auto) && !request.tools.is_empty(); + let tool_parser = if tool_parsing_enabled { + Some(Self::resolve_tool_parser( + request, + model_id, + tool_call_parser, + )?) + } else { + None + }; + let reasoning_parser = Self::resolve_optional_reasoning_parser( + request, + model_id, + tokenizer, + reasoning_parser, + )?; + + Ok(Self { + reasoning_parser, + tool_parser, + }) + } + + /// Build the plain-text-only default output processor. + /// + /// This keeps the default structured chat-event assembly but disables both + /// reasoning parsing and tool-call parsing completely, so that all + /// content is treated as opaque text. + pub fn plain_text_only() -> Self { + Self { + reasoning_parser: None, + tool_parser: None, + } + } + + fn resolve_tool_parser( + request: &mut ChatRequest, + model_id: &str, + selection: &ParserSelection, + ) -> ChatResult> { + let factory = ToolParserFactory::global(); + let parser_name = match selection { + ParserSelection::Auto => factory.resolve_name_for_model(model_id).ok_or_else(|| { + Error::ParserUnavailableForModel { + kind: "tool", + model_id: model_id.to_string(), + } + })?, + ParserSelection::None => return Err(Error::ParserDisabled { kind: "tool" }), + ParserSelection::Explicit(name) => name.as_str(), + }; + + let parser = factory.create(parser_name, &request.tools)?; + + if parser.preserve_special_tokens() { + request.decode_options.skip_special_tokens = false; + } + + TOOL_PARSER_LOG_ONCE.call_once(|| info!(parser_name, "using tool parser")); + Ok(parser) + } + + fn resolve_optional_reasoning_parser( + request: &mut ChatRequest, + model_id: &str, + tokenizer: DynTokenizer, + selection: &ParserSelection, + ) -> ChatResult>> { + let factory = ReasoningParserFactory::global(); + let parser_name = match selection { + ParserSelection::Auto => factory.resolve_name_for_model(model_id), + ParserSelection::None => None, + ParserSelection::Explicit(name) => Some(name.as_str()), + }; + + let Some(parser_name) = parser_name else { + REASONING_PARSER_LOG_ONCE.call_once(|| info!("reasoning parsing disabled")); + return Ok(None); + }; + + let parser = factory.create(parser_name, tokenizer)?; + + if parser.preserve_special_tokens() { + request.decode_options.skip_special_tokens = false; + } + + REASONING_PARSER_LOG_ONCE.call_once(|| info!(parser_name, "using reasoning parser")); + Ok(Some(parser)) + } +} + +static TOOL_PARSER_LOG_ONCE: Once = Once::new(); +static REASONING_PARSER_LOG_ONCE: Once = Once::new(); + +impl ChatOutputProcessor for DefaultChatOutputProcessor { + /// Transforms a raw generate-output token stream into structured chat + /// events through three sequential stages once text decoding has + /// already happened: + /// + /// 1. [`reasoning_event_stream`] — reasoning/content separation + /// 2. [`tool_event_stream`] — tool-call parsing + /// 3. [`structured_chat_event_stream`] — final block assembly + fn process(self: Box, decoded: DynDecodedTextEventStream) -> Result { + let reasoning = reasoning_event_stream(decoded, self.reasoning_parser); + let tool = tool_event_stream(reasoning, self.tool_parser); + let structured = structured_chat_event_stream(tool); + + Ok(structured.boxed()) + } +} diff --git a/rust/src/chat/src/output/default/reasoning.rs b/rust/src/chat/src/output/default/reasoning.rs new file mode 100644 index 00000000000..b51ce41961d --- /dev/null +++ b/rust/src/chat/src/output/default/reasoning.rs @@ -0,0 +1,504 @@ +//! Adapts decoded text updates into reasoning-aware assistant deltas. +//! +//! This stage sits between low-level token decoding and final block assembly. +//! It is the only place in the new pipeline that understands reasoning +//! separation: `decoded.rs` still only produces plain text deltas, while later +//! stages consume the semantic `Text` / `Reasoning` split emitted here. + +use asynk_strim_attr::{TryYielder, try_stream}; +use futures::{StreamExt as _, pin_mut}; +use thiserror_ext::AsReport; +use tracing::warn; +use vllm_text::output::DecodedTextEvent; + +use super::ContentEvent; +use crate::Result; +use crate::error::Error; +use crate::event::AssistantBlockKind; +use crate::output::DecodedTextEventStream; +use crate::parser::reasoning::{ReasoningDelta, ReasoningParser}; + +/// Per-stream reasoning parsing state. +struct ReasoningState { + /// Reasoning parser for the current model family. + parser: Box, + /// Whether reasoning parsing has already failed for this stream. + parser_failed: bool, +} + +impl ReasoningState { + /// Create one fresh reasoning-adaptation state for a new streamed response. + fn new(parser: Box) -> Self { + Self { + parser, + parser_failed: false, + } + } + + /// Convert one decoded text delta into zero or more semantic assistant + /// deltas. + fn process_delta(&mut self, delta: String) -> Vec { + // If the parser has already failed, skip parsing and return plain text deltas. + if self.parser_failed { + return vec![ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta, + }]; + } + + let mut events = Vec::new(); + + match self.parser.push(&delta) { + Ok(result) => { + push_reasoning_delta(&mut events, result); + } + Err(error) => { + if !self.parser_failed { + warn!( + error = %error.as_report(), + "reasoning parser failed; falling back to plain text deltas" + ); + self.parser_failed = true; + } + push_text_delta(&mut events, AssistantBlockKind::Text, delta); + } + } + + events + } + + /// Initialize parser state once prompt token IDs are available. + fn initialize(&mut self, prompt_token_ids: &[u32]) { + if self.parser_failed { + return; + } + + match self.parser.initialize(prompt_token_ids) { + Ok(()) => {} + Err(error) => { + warn!( + error = %error.as_report(), + "failed to initialize reasoning parser; falling back to plain text deltas" + ); + self.parser_failed = true; + } + } + } + + /// Flush any parser-held partial delimiter state at end of stream. + fn finish(&mut self) -> Vec { + if self.parser_failed { + return Vec::new(); + } + + match self.parser.finish() { + Ok(result) => { + let mut events = Vec::new(); + push_reasoning_delta(&mut events, result); + events + } + Err(error) => { + warn!(error = %error.as_report(), "failed to flush reasoning parser state"); + Vec::new() + } + } + } +} + +/// Push one semantic text delta if it is non-empty. +fn push_text_delta(events: &mut Vec, kind: AssistantBlockKind, delta: String) { + if delta.is_empty() { + return; + } + events.push(ContentEvent::TextDelta { kind, delta }); +} + +/// Convert one parsed reasoning delta into zero or more content events. +fn push_reasoning_delta(events: &mut Vec, delta: ReasoningDelta) { + if let Some(reasoning) = delta.reasoning { + push_text_delta(events, AssistantBlockKind::Reasoning, reasoning); + } + if let Some(content) = delta.content { + push_text_delta(events, AssistantBlockKind::Text, content); + } +} + +/// Wrap one decoded-text stream into the internal reasoning event stream. +#[try_stream] +pub(crate) async fn reasoning_event_stream( + decoded_stream: impl DecodedTextEventStream, + reasoning_parser: Option>, + mut y: TryYielder, +) -> Result<()> { + pin_mut!(decoded_stream); + + // Without a parser, pass through as plain text deltas. + let Some(reasoning_parser) = reasoning_parser else { + while let Some(event) = decoded_stream.next().await.transpose()? { + for next in ContentEvent::from_decoded_plain_text(event) { + y.yield_ok(next).await; + } + } + return Ok(()); + }; + + let mut state = ReasoningState::new(reasoning_parser); + + while let Some(event) = decoded_stream.next().await.transpose()? { + match event { + DecodedTextEvent::Start { + prompt_token_ids, + prompt_logprobs, + } => { + state.initialize(&prompt_token_ids); + y.yield_ok(ContentEvent::Start { + prompt_token_ids, + prompt_logprobs, + }) + .await; + } + DecodedTextEvent::TextDelta { + delta, + token_ids, + logprobs, + finished, + } => { + for next in state.process_delta(delta) { + y.yield_ok(next).await; + } + if logprobs.is_some() || !token_ids.is_empty() { + y.yield_ok(ContentEvent::LogprobsDelta { + logprobs, + token_ids, + }) + .await; + } + if let Some(finished) = finished { + for next in state.finish() { + y.yield_ok(next).await; + } + y.yield_ok(ContentEvent::Done { + prompt_token_count: finished.prompt_token_count, + output_token_count: finished.output_token_count, + finish_reason: finished.finish_reason, + kv_transfer_params: finished.kv_transfer_params, + }) + .await; + } + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + + use std::sync::Arc; + + use futures::{StreamExt as _, stream}; + use vllm_llm::FinishReason; + use vllm_text::output::{ + DecodedLogprobs, DecodedPositionLogprobs, DecodedTextEvent, DecodedTokenLogprob, + }; + use vllm_tokenizer::{DynTokenizer, Tokenizer}; + + use super::super::ContentEvent; + use super::reasoning_event_stream; + use crate::event::AssistantBlockKind; + use crate::parser::reasoning::{ + ReasoningDelta, ReasoningError, ReasoningParser, ReasoningParserFactory, names, + }; + + struct FakeTokenizer; + + impl Tokenizer for FakeTokenizer { + fn encode( + &self, + text: &str, + _add_special_tokens: bool, + ) -> vllm_tokenizer::Result> { + Ok(text.chars().map(u32::from).collect()) + } + + fn decode( + &self, + token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_tokenizer::Result { + Ok(token_ids + .iter() + .map(|token_id| char::from_u32(*token_id).unwrap_or('\u{FFFD}')) + .collect()) + } + + fn token_to_id(&self, token: &str) -> Option { + match token { + "" => Some(1), + "" => Some(2), + _ => None, + } + } + } + + struct FailingReasoningParser { + fail_next: bool, + } + + impl ReasoningParser for FailingReasoningParser { + fn create(_tokenizer: DynTokenizer) -> Result, ReasoningError> + where + Self: Sized + 'static, + { + Ok(Box::new(Self { fail_next: true })) + } + + fn push(&mut self, _text: &str) -> Result { + if self.fail_next { + self.fail_next = false; + return Err(ReasoningError::MissingToken { + token: "".to_string(), + }); + } + Ok(ReasoningDelta::default()) + } + } + + fn test_reasoning_parser(factory: &mut ReasoningParserFactory) -> Box { + factory.register_parser::("failing"); + + factory.create("failing", Arc::new(FakeTokenizer)).unwrap() + } + + #[tokio::test] + async fn reasoning_parser_failure_falls_back_to_plain_text() { + let mut factory = ReasoningParserFactory::new(); + let events = stream::iter(vec![ + Ok(DecodedTextEvent::Start { + prompt_token_ids: vec![1, 2, 3].into(), + prompt_logprobs: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "abc".to_string(), + token_ids: vec![], + logprobs: None, + finished: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "def".to_string(), + token_ids: vec![], + logprobs: None, + finished: Some(vllm_text::Finished { + prompt_token_count: 3, + output_token_count: 0, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + }), + ]); + + let collected = reasoning_event_stream(events, Some(test_reasoning_parser(&mut factory))) + .collect::>() + .await; + + let events = collected + .into_iter() + .collect::>>() + .expect("reasoning stream should not fail"); + + assert_eq!( + events, + vec![ + ContentEvent::Start { + prompt_token_ids: vec![1, 2, 3].into(), + prompt_logprobs: None, + }, + ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "abc".to_string(), + }, + ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "def".to_string(), + }, + ContentEvent::Done { + prompt_token_count: 3, + output_token_count: 0, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }, + ] + ); + } + + #[tokio::test] + async fn reasoning_stream_preserves_logprobs_delta() { + let events = stream::iter(vec![ + Ok(DecodedTextEvent::Start { + prompt_token_ids: vec![1].into(), + prompt_logprobs: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "abc".to_string(), + token_ids: vec![], + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "a".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }), + finished: None, + }), + ]); + + let collected = reasoning_event_stream(events, None) + .collect::>() + .await + .into_iter() + .collect::>>() + .unwrap(); + + assert_eq!( + collected, + vec![ + ContentEvent::Start { + prompt_token_ids: vec![1].into(), + prompt_logprobs: None, + }, + ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "abc".to_string(), + }, + ContentEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "a".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }), + token_ids: vec![], + }, + ] + ); + } + + #[tokio::test] + async fn qwen3_parser_uses_prompt_end_marker_to_switch_to_content() { + let tokenizer = Arc::new(FakeTokenizer); + let events = stream::iter(vec![ + Ok(DecodedTextEvent::Start { + prompt_token_ids: vec![2].into(), + prompt_logprobs: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "thought ".to_string(), + token_ids: vec![], + logprobs: None, + finished: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "doneOK".to_string(), + token_ids: vec![], + logprobs: None, + finished: None, + }), + ]); + + let factory = ReasoningParserFactory::new(); + let collected = reasoning_event_stream( + events, + Some(factory.create(names::QWEN3, tokenizer).unwrap()), + ) + .collect::>() + .await; + + let events = collected + .into_iter() + .collect::>>() + .expect("reasoning stream should not fail"); + + assert_eq!( + events, + vec![ + ContentEvent::Start { + prompt_token_ids: vec![2].into(), + prompt_logprobs: None, + }, + ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "thought ".to_string(), + }, + ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "doneOK".to_string(), + }, + ] + ); + } + + #[tokio::test] + async fn qwen3_parser_tolerates_prompt_prefill_reasoning() { + let tokenizer = Arc::new(FakeTokenizer); + let events = stream::iter(vec![ + Ok(DecodedTextEvent::Start { + prompt_token_ids: vec![1].into(), + prompt_logprobs: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "thought ".to_string(), + token_ids: vec![], + logprobs: None, + finished: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "doneOK".to_string(), + token_ids: vec![], + logprobs: None, + finished: None, + }), + ]); + + let factory = ReasoningParserFactory::new(); + let collected = reasoning_event_stream( + events, + Some(factory.create(names::QWEN3, tokenizer).unwrap()), + ) + .collect::>() + .await; + + let events = collected + .into_iter() + .collect::>>() + .expect("reasoning stream should not fail"); + + assert_eq!( + events, + vec![ + ContentEvent::Start { + prompt_token_ids: vec![1].into(), + prompt_logprobs: None, + }, + ContentEvent::TextDelta { + kind: AssistantBlockKind::Reasoning, + delta: "thought ".to_string(), + }, + ContentEvent::TextDelta { + kind: AssistantBlockKind::Reasoning, + delta: "done".to_string(), + }, + ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "OK".to_string(), + }, + ] + ); + } +} diff --git a/rust/src/chat/src/output/default/tool.rs b/rust/src/chat/src/output/default/tool.rs new file mode 100644 index 00000000000..9ab5ee1684b --- /dev/null +++ b/rust/src/chat/src/output/default/tool.rs @@ -0,0 +1,625 @@ +//! Adapts plain assistant text deltas into tool-call-aware assistant updates. +//! +//! This stage runs after reasoning separation and before final block assembly. +//! It only inspects normal assistant text, leaves reasoning deltas untouched, +//! and translates incremental tool parsing output into internal tool-call +//! events while preserving plain-text fallback behavior. + +use asynk_strim_attr::{TryYielder, try_stream}; +use futures::{StreamExt as _, pin_mut}; +use thiserror_ext::AsReport; +use tracing::warn; + +use super::{AssistantEvent, ContentEvent, ContentEventStream}; +use crate::Result; +use crate::error::Error; +use crate::event::AssistantBlockKind; +use crate::output::generate_tool_call_id; +use crate::parser::tool::{ToolCallDelta, ToolParseResult, ToolParser}; + +/// Per-stream tool parsing state. +struct ToolState { + /// Parser for the current model family. + parser: Box, + /// Whether tool parsing has already failed for this stream. + parser_failed: bool, + /// The parser-local index of the currently open tool call, if any. + // NOTE: We only allow single open tool call at a time right now, since that's what all + // supported parsers currently emit. Change this to a `BTreeMap` if we need to support multiple + // interleaved calls in the future. + open_call_index: Option, +} + +impl ToolState { + /// Create one fresh tool-parsing state for a new streamed response. + fn new(parser: Box) -> Self { + Self { + parser, + parser_failed: false, + open_call_index: None, + } + } + + /// Convert one semantic assistant text delta into zero or more tool-aware + /// internal events. + fn process_text_delta( + &mut self, + kind: AssistantBlockKind, + delta: String, + ) -> Result> { + let mut events = Vec::new(); + + // Only normal assistant text is eligible for tool parsing. Reasoning + // blocks and plain-text fallback should pass through unchanged. + if kind != AssistantBlockKind::Text || self.parser_failed { + self.open_call_index = None; + events.push(AssistantEvent::TextDelta { kind, delta }); + return Ok(events); + } + + let parse_result = self.parser.push(&delta); + + match parse_result { + Ok(result) => self.process_parse_result(kind, result, &mut events)?, + Err(error) => { + if !self.parser_failed { + warn!( + error = %error.as_report(), + "tool parser failed; falling back to plain text deltas" + ); + self.parser_failed = true; + } + self.open_call_index = None; + events.push(AssistantEvent::TextDelta { kind, delta }); + } + } + + Ok(events) + } + + /// Apply one parsed tool result to the current stream state. + fn process_parse_result( + &mut self, + kind: AssistantBlockKind, + result: ToolParseResult, + events: &mut Vec, + ) -> Result<()> { + // When we are not currently streaming a tool call, preserve plain + // text first and then surface any new tool call items. + if self.open_call_index.is_none() { + push_text_delta(events, kind, result.normal_text); + self.process_tool_items(result.calls, events)?; + } else { + // Once a tool call is open, prioritize tool deltas first. If the + // parser emits normal text again, close the tool call and resume + // plain text output. + self.process_tool_items(result.calls, events)?; + if !result.normal_text.is_empty() { + self.open_call_index = None; + push_text_delta(events, kind, result.normal_text); + } + } + Ok(()) + } + + /// Apply one batch of parsed tool-call deltas emitted by the parser. + fn process_tool_items( + &mut self, + items: Vec, + events: &mut Vec, + ) -> Result<()> { + for item in items { + if let Some(name) = item.name { + let is_new_tool = match self.open_call_index { + Some(open_call_index) => open_call_index != item.tool_index, + None => true, + }; + if is_new_tool { + let id = generate_tool_call_id(); + self.open_call_index = Some(item.tool_index); + events.push(AssistantEvent::ToolCallStart { id, name }); + } + } + + if item.arguments.is_empty() { + // No arguments delta to apply. + continue; + } + let Some(open_call_index) = self.open_call_index else { + return Err(Error::ToolCallStreamInvariant { + message: format!( + "received arguments for tool index {} before any tool-call start", + item.tool_index + ), + }); + }; + if open_call_index != item.tool_index { + return Err(Error::ToolCallStreamInvariant { + message: format!( + "received arguments for tool index {} while tool index {} is open", + item.tool_index, open_call_index + ), + }); + } + + events.push(AssistantEvent::ToolCallArgumentsDelta { + delta: item.arguments, + }); + } + Ok(()) + } + + /// Flush parser state at end-of-stream and close any remaining open calls. + fn finish(&mut self) -> Result> { + let mut events = Vec::new(); + + if self.parser_failed { + return Ok(events); + } + + match self.parser.finish() { + Ok(result) => { + self.process_parse_result(AssistantBlockKind::Text, result, &mut events)? + } + Err(error) => { + warn!( + error = %error.as_report(), + "tool parser finish failed; closing open tool calls with buffered state" + ); + self.parser_failed = true; + } + } + + Ok(events) + } +} + +/// Push one plain-text delta if it is non-empty. +fn push_text_delta(events: &mut Vec, kind: AssistantBlockKind, delta: String) { + if delta.is_empty() { + return; + } + events.push(AssistantEvent::TextDelta { kind, delta }); +} + +/// Wrap one semantic assistant stream into the internal tool-aware assistant +/// stream. +#[try_stream] +pub(crate) async fn tool_event_stream( + stream: impl ContentEventStream, + parser: Option>, + mut y: TryYielder, +) -> Result<()> { + // Without a parser, pass through the input stream unchanged. + let Some(parser) = parser else { + pin_mut!(stream); + while let Some(event) = stream.next().await.transpose()? { + y.yield_ok(event.into()).await; + } + return Ok(()); + }; + + pin_mut!(stream); + let mut state = ToolState::new(parser); + + while let Some(event) = stream.next().await.transpose()? { + match event { + ContentEvent::Start { + prompt_token_ids, + prompt_logprobs, + } => { + y.yield_ok(AssistantEvent::Start { + prompt_token_ids, + prompt_logprobs, + }) + .await; + } + ContentEvent::TextDelta { kind, delta } => { + for next in state.process_text_delta(kind, delta)? { + y.yield_ok(next).await; + } + } + ContentEvent::LogprobsDelta { + logprobs, + token_ids, + } => { + y.yield_ok(AssistantEvent::LogprobsDelta { + logprobs, + token_ids, + }) + .await; + } + ContentEvent::Done { + prompt_token_count, + output_token_count, + finish_reason, + kv_transfer_params, + } => { + for next in state.finish()? { + y.yield_ok(next).await; + } + + y.yield_ok(AssistantEvent::Done { + prompt_token_count, + output_token_count, + finish_reason, + kv_transfer_params, + }) + .await; + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + + use futures::{StreamExt as _, stream}; + use vllm_llm::FinishReason; + use vllm_text::{DecodedLogprobs, DecodedPositionLogprobs, DecodedTokenLogprob}; + use vllm_tool_parser::Result; + + use super::super::{AssistantEvent, ContentEvent}; + use super::tool_event_stream; + use crate::error::Error; + use crate::event::{AssistantBlockKind, AssistantMessageExt as _}; + use crate::output::structured::structured_chat_event_stream; + use crate::parser::tool::{ToolParseResult, ToolParser, ToolParserError}; + use crate::request::ChatTool; + use crate::stream::ChatEventStream; + + struct FailingParser { + fail_next: bool, + } + + struct ScriptedParser { + push_results: Vec, + finish_result: ToolParseResult, + } + + impl ToolParser for FailingParser { + fn create(_tools: &[ChatTool]) -> vllm_tool_parser::Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self { fail_next: false })) + } + + fn push(&mut self, _chunk: &str) -> Result { + if self.fail_next { + self.fail_next = false; + return Err(ToolParserError::ParsingFailed { + message: "boom".to_string(), + }); + } + + Ok(ToolParseResult::default()) + } + } + + impl ToolParser for ScriptedParser { + fn create(_tools: &[ChatTool]) -> vllm_tool_parser::Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self { + push_results: Vec::new(), + finish_result: ToolParseResult::default(), + })) + } + + fn push(&mut self, _chunk: &str) -> Result { + Ok(self.push_results.pop().unwrap_or_default()) + } + + fn finish(&mut self) -> Result { + Ok(std::mem::take(&mut self.finish_result)) + } + } + + #[tokio::test] + async fn tool_parser_failure_falls_back_to_plain_text() { + let events = stream::iter(vec![ + Ok(ContentEvent::Start { + prompt_token_ids: vec![1, 2, 3].into(), + prompt_logprobs: None, + }), + Ok(ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "abc".to_string(), + }), + Ok(ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "def".to_string(), + }), + Ok(ContentEvent::Done { + prompt_token_count: 3, + output_token_count: 0, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let collected = + tool_event_stream(events, Some(Box::new(FailingParser { fail_next: true }))) + .collect::>() + .await; + + let events = collected + .into_iter() + .collect::>>() + .expect("tool stream should not fail"); + + assert_eq!( + events, + vec![ + AssistantEvent::Start { + prompt_token_ids: vec![1, 2, 3].into(), + prompt_logprobs: None, + }, + AssistantEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "abc".to_string(), + }, + AssistantEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "def".to_string(), + }, + AssistantEvent::Done { + prompt_token_count: 3, + output_token_count: 0, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }, + ] + ); + + let message = ChatEventStream::new( + "req_fallback".to_string(), + Box::pin(structured_chat_event_stream(stream::iter( + events.into_iter().map(Ok), + ))), + ) + .collect_message() + .await + .expect("collect_message should succeed"); + assert_eq!(message.message.text(), "abcdef"); + assert!(message.message.tool_calls().next().is_none()); + } + + #[tokio::test] + async fn tool_stream_preserves_logprobs_delta() { + let events = stream::iter(vec![ + Ok(ContentEvent::Start { + prompt_token_ids: vec![1].into(), + prompt_logprobs: None, + }), + Ok(ContentEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "a".to_string(), + logprob: -0.2, + rank: 1, + }], + }], + }), + token_ids: vec![], + }), + Ok(ContentEvent::Done { + prompt_token_count: 1, + output_token_count: 0, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + let events = tool_event_stream(events, Some(Box::new(FailingParser { fail_next: false }))) + .collect::>() + .await + .into_iter() + .collect::>>() + .unwrap(); + + assert_eq!( + events, + vec![ + AssistantEvent::Start { + prompt_token_ids: vec![1].into(), + prompt_logprobs: None, + }, + AssistantEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "a".to_string(), + logprob: -0.2, + rank: 1, + }], + }], + }), + token_ids: vec![], + }, + AssistantEvent::Done { + prompt_token_count: 1, + output_token_count: 0, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }, + ] + ); + } + + #[tokio::test] + async fn tool_stream_rejects_interleaved_tool_indices() { + let events = stream::iter(vec![ + Ok(ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "ignored".to_string(), + }), + Ok(ContentEvent::Done { + prompt_token_count: 1, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let parser = ScriptedParser { + push_results: vec![ToolParseResult { + normal_text: String::new(), + calls: vec![ + crate::parser::tool::ToolCallDelta { + tool_index: 0, + name: Some("first".to_string()), + arguments: String::new(), + }, + crate::parser::tool::ToolCallDelta { + tool_index: 1, + name: None, + arguments: "{}".to_string(), + }, + ], + }], + finish_result: ToolParseResult::default(), + }; + + let err = tool_event_stream(events, Some(Box::new(parser))) + .collect::>() + .await + .into_iter() + .find_map(|result| result.err()) + .expect("expected invariant error"); + + assert!(matches!(err, Error::ToolCallStreamInvariant { .. })); + } + + #[tokio::test] + async fn tool_stream_resets_open_tool_when_normal_text_interrupts_it() { + let events = stream::iter(vec![ + Ok(ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "start".to_string(), + }), + Ok(ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "text".to_string(), + }), + Ok(ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "args".to_string(), + }), + ]); + + let parser = ScriptedParser { + push_results: vec![ + ToolParseResult { + normal_text: String::new(), + calls: vec![crate::parser::tool::ToolCallDelta { + tool_index: 0, + name: None, + arguments: "}".to_string(), + }], + }, + ToolParseResult { + normal_text: "plain text".to_string(), + calls: Vec::new(), + }, + ToolParseResult { + normal_text: String::new(), + calls: vec![crate::parser::tool::ToolCallDelta { + tool_index: 0, + name: Some("first".to_string()), + arguments: "{".to_string(), + }], + }, + ], + finish_result: ToolParseResult::default(), + }; + + let err = tool_event_stream(events, Some(Box::new(parser))) + .collect::>() + .await + .into_iter() + .find_map(|result| result.err()) + .expect("expected invariant error"); + + assert!(matches!( + err, + Error::ToolCallStreamInvariant { message } + if message == "received arguments for tool index 0 before any tool-call start" + )); + } + + #[tokio::test] + async fn tool_stream_emits_start_and_args_for_terminal_text() { + let events = stream::iter(vec![ + Ok(ContentEvent::Start { + prompt_token_ids: vec![1].into(), + prompt_logprobs: None, + }), + Ok(ContentEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "ignored".to_string(), + }), + Ok(ContentEvent::Done { + prompt_token_count: 1, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let parser = ScriptedParser { + push_results: vec![ToolParseResult { + normal_text: String::new(), + calls: vec![ + crate::parser::tool::ToolCallDelta { + tool_index: 0, + name: Some("first".to_string()), + arguments: r#"{"a":1}"#.to_string(), + }, + crate::parser::tool::ToolCallDelta { + tool_index: 1, + name: Some("second".to_string()), + arguments: r#"{"b":2}"#.to_string(), + }, + ], + }], + finish_result: ToolParseResult::default(), + }; + + let events = tool_event_stream(events, Some(Box::new(parser))) + .collect::>() + .await + .into_iter() + .collect::>>() + .unwrap(); + + assert!(matches!(events[1], AssistantEvent::ToolCallStart { .. })); + assert!(matches!( + events[2], + AssistantEvent::ToolCallArgumentsDelta { .. } + )); + assert!(matches!(events[3], AssistantEvent::ToolCallStart { .. })); + assert!(matches!( + events[4], + AssistantEvent::ToolCallArgumentsDelta { .. } + )); + let collected = ChatEventStream::new( + "req_final_only".to_string(), + Box::pin(structured_chat_event_stream(stream::iter( + events.into_iter().map(Ok), + ))), + ) + .collect_message() + .await + .unwrap(); + let tool_calls = collected.message.tool_calls().collect::>(); + assert_eq!(tool_calls.len(), 2); + assert_eq!(tool_calls[0].name, "first"); + assert_eq!(tool_calls[1].name, "second"); + } +} diff --git a/rust/src/chat/src/output/harmony/mod.rs b/rust/src/chat/src/output/harmony/mod.rs new file mode 100644 index 00000000000..5dc6bc31185 --- /dev/null +++ b/rust/src/chat/src/output/harmony/mod.rs @@ -0,0 +1,430 @@ +//! Native Harmony output processing for `gpt_oss`. +//! +//! Unlike the default text-first pipeline, this processor consumes +//! `DecodedTextEvent` token IDs directly and lets the official `openai-harmony` +//! parser recover the structured assistant message shape at token granularity. + +use std::sync::LazyLock; + +use anyhow::Context; +use asynk_strim_attr::{TryYielder, try_stream}; +use futures::StreamExt as _; +use openai_harmony::chat::{Content as HarmonyContent, Message as HarmonyMessage, Role}; +use openai_harmony::{ + HarmonyEncoding, HarmonyEncodingName, StreamableParser, load_harmony_encoding, +}; +use thiserror_ext::AsReport; +use vllm_text::output::DecodedTextEvent; + +use crate::Result as ChatResult; +use crate::error::{Error, Result}; +use crate::event::AssistantBlockKind; +use crate::output::{ + AssistantEvent, ChatOutputProcessor, DynChatEventStream, DynDecodedTextEventStream, + generate_tool_call_id, +}; +use crate::parser::ParserSelection; +use crate::request::ChatRequest; + +/// Request-scoped Harmony output processor used for `model_type == "gpt_oss"`. +/// +/// This processor keeps the existing northbound `ChatEvent` shape, but swaps +/// the parsed-assistant backend from generic text/reasoning/tool parsers to the +/// official Harmony token parser. +#[derive(Debug)] +pub struct HarmonyChatOutputProcessor { + encoding: &'static HarmonyEncoding, + tool_calls_enabled: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct HarmonyGroupKey { + serial: usize, + channel: Option, + recipient: Option, +} + +#[derive(Debug)] +struct HarmonyGroup { + key: HarmonyGroupKey, + text: String, +} + +#[derive(Debug)] +struct OpenHarmonyToolCall { + recipient: String, +} + +struct HarmonyState { + /// Incremental Harmony parser over assistant token IDs. + parser: StreamableParser, + /// Whether tool-call content should surface as structured tool events. + tool_calls_enabled: bool, + /// Count of completed visible assistant messages for newline insertion. + completed_visible_messages: usize, + /// Count of completed reasoning messages for newline insertion. + completed_reasoning_messages: usize, + /// The current visible text/reasoning group, if any. + current_text_group: Option, + /// The currently open Harmony tool recipient, if any. + open_tool_call: Option, +} + +impl HarmonyChatOutputProcessor { + /// Build one request-scoped Harmony processor after backend policy checks. + pub fn new(request: &ChatRequest) -> ChatResult { + Ok(Self { + encoding: harmony_encoding()?, + tool_calls_enabled: request.tool_parsing_enabled(), + }) + } +} + +/// Validate that the generic parser selections are compatible with native +/// Harmony output parsing. +/// +/// `gpt_oss` uses a model-specific token-level parser, so any generic +/// reasoning/tool parser override is rejected instead of being silently +/// ignored. +pub(crate) fn validate_harmony_parser_overrides( + tool_call_parser: &ParserSelection, + reasoning_parser: &ParserSelection, +) -> ChatResult<()> { + validate_harmony_override("tool", tool_call_parser)?; + validate_harmony_override("reasoning", reasoning_parser)?; + Ok(()) +} + +fn validate_harmony_override(kind: &'static str, selection: &ParserSelection) -> ChatResult<()> { + if matches!(selection, ParserSelection::Auto) { + return Ok(()); + } + + Err(Error::HarmonyParserOverrideUnsupported { + kind, + selection: selection.to_string(), + }) +} + +impl ChatOutputProcessor for HarmonyChatOutputProcessor { + fn process(self: Box, decoded: DynDecodedTextEventStream) -> Result { + let assistant = + harmony_assistant_event_stream(decoded, self.encoding, self.tool_calls_enabled); + Ok(crate::output::structured::structured_chat_event_stream(assistant).boxed()) + } +} + +impl HarmonyState { + /// Create one fresh Harmony streaming state for a new assistant response. + fn new(encoding: HarmonyEncoding, tool_calls_enabled: bool) -> Result { + Ok(Self { + parser: StreamableParser::new(encoding, Some(Role::Assistant)) + .map_err(harmony_output_parsing_error)?, + tool_calls_enabled, + completed_visible_messages: 0, + completed_reasoning_messages: 0, + current_text_group: None, + open_tool_call: None, + }) + } + + fn process_token_ids(&mut self, token_ids: &[u32]) -> Result> { + let mut events = Vec::new(); + let mut pending_group: Option = None; + + for &token_id in token_ids { + let completed_before = self.parser.messages().len(); + self.parser.process(token_id).map_err(harmony_output_parsing_error)?; + let completed_after = self.parser.messages().len(); + + if let Some(delta) = self + .parser + .last_content_delta() + .map_err(harmony_output_parsing_error)? + .filter(|delta| !delta.is_empty()) + { + let key = HarmonyGroupKey { + serial: completed_after, + channel: self.parser.current_channel(), + recipient: self.parser.current_recipient(), + }; + + match pending_group.as_mut() { + Some(group) if group.key == key => group.text.push_str(&delta), + _ => { + if let Some(group) = pending_group.take() { + self.emit_group(group, &mut events); + } + pending_group = Some(HarmonyGroup { key, text: delta }); + } + } + } + + if completed_after > completed_before { + if let Some(group) = pending_group.take() { + self.emit_group(group, &mut events); + } + + for serial in completed_before..completed_after { + let key = { + let message = &self.parser.messages()[serial]; + HarmonyGroupKey { + serial, + channel: message.channel.clone(), + recipient: message.recipient.clone(), + } + }; + self.handle_completed_message(key); + } + } + } + + if let Some(group) = pending_group { + self.emit_group(group, &mut events); + } + + Ok(events) + } + + /// Flush Harmony parser state at EOS and emit any newly finalized assistant + /// events. + fn process_eos(&mut self) -> Result> { + let completed_before = self.parser.messages().len(); + let pending_key = HarmonyGroupKey { + serial: completed_before, + channel: self.parser.current_channel(), + recipient: self.parser.current_recipient(), + }; + let pending_content = + self.parser.current_content().map_err(harmony_output_parsing_error)?; + + self.parser.process_eos().map_err(harmony_output_parsing_error)?; + + let completed_after = self.parser.messages().len(); + let mut events = Vec::new(); + + if completed_after == completed_before { + return Ok(events); + } + + let final_message = &self.parser.messages()[completed_before]; + let final_text = harmony_message_text(final_message); + let tail = final_text.strip_prefix(&pending_content).unwrap_or(final_text).to_string(); + if !tail.is_empty() { + self.emit_group( + HarmonyGroup { + key: pending_key, + text: tail, + }, + &mut events, + ); + } + + for serial in completed_before..completed_after { + let key = { + let message = &self.parser.messages()[serial]; + HarmonyGroupKey { + serial, + channel: message.channel.clone(), + recipient: message.recipient.clone(), + } + }; + self.handle_completed_message(key); + } + + Ok(events) + } + + /// Flush one coalesced Harmony content group into internal assistant + /// events. + fn emit_group(&mut self, group: HarmonyGroup, events: &mut Vec) { + let channel = group.key.channel.as_deref(); + let recipient = group.key.recipient.as_deref(); + + if let Some(kind) = text_block_kind(channel, recipient) { + self.open_tool_call = None; + + if self.current_text_group.as_ref() != Some(&group.key) { + let needs_newline = match kind { + AssistantBlockKind::Text => self.completed_visible_messages > 0, + AssistantBlockKind::Reasoning => self.completed_reasoning_messages > 0, + AssistantBlockKind::ToolCall => false, + }; + + if needs_newline { + events.push(AssistantEvent::TextDelta { + kind, + delta: "\n".to_string(), + }); + } + + self.current_text_group = Some(group.key.clone()); + } + + events.push(AssistantEvent::TextDelta { + kind, + delta: group.text, + }); + return; + } + + self.current_text_group = None; + + let Some(tool_name) = tool_name(channel, recipient) else { + return; + }; + if !self.tool_calls_enabled { + return; + } + + let recipient = recipient.expect("tool groups always have recipient").to_string(); + let opens_same_call = match self.open_tool_call.as_ref() { + Some(open_call) => open_call.recipient == recipient, + None => false, + }; + if !opens_same_call { + let id = generate_tool_call_id(); + self.open_tool_call = Some(OpenHarmonyToolCall { recipient }); + events.push(AssistantEvent::ToolCallStart { + id, + name: tool_name.to_string(), + }); + } + + if !group.text.is_empty() { + events.push(AssistantEvent::ToolCallArgumentsDelta { delta: group.text }); + } + } + + /// Update newline and open-tool state after one Harmony message completes. + fn handle_completed_message(&mut self, key: HarmonyGroupKey) { + if self.current_text_group.as_ref() == Some(&key) { + self.current_text_group = None; + } + + let channel = key.channel.as_deref(); + let recipient = key.recipient.as_deref(); + let kind = text_block_kind(channel, recipient); + + if kind == Some(AssistantBlockKind::Text) { + self.completed_visible_messages += 1; + } else if kind == Some(AssistantBlockKind::Reasoning) { + self.completed_reasoning_messages += 1; + } else if tool_name(channel, recipient).is_some() { + self.open_tool_call = None; + } + } +} + +/// Convert decoded token updates into internal assistant events with Harmony +/// parsing. +#[try_stream] +async fn harmony_assistant_event_stream( + decoded: DynDecodedTextEventStream, + encoding: &'static HarmonyEncoding, + tool_calls_enabled: bool, + mut y: TryYielder, +) -> Result<()> { + let mut state = HarmonyState::new(encoding.clone(), tool_calls_enabled)?; + futures::pin_mut!(decoded); + + while let Some(event) = decoded.next().await.transpose()? { + match event { + DecodedTextEvent::Start { + prompt_token_ids, + prompt_logprobs, + } => { + y.yield_ok(AssistantEvent::Start { + prompt_token_ids, + prompt_logprobs, + }) + .await; + } + DecodedTextEvent::TextDelta { + delta: _, // harmony takes raw token IDs as input, so we ignore text deltas here + token_ids, + logprobs, + finished, + } => { + for event in state.process_token_ids(&token_ids)? { + y.yield_ok(event).await; + } + + if finished.is_some() { + for event in state.process_eos()? { + y.yield_ok(event).await; + } + } + + if logprobs.is_some() || !token_ids.is_empty() { + y.yield_ok(AssistantEvent::LogprobsDelta { + logprobs, + token_ids, + }) + .await; + } + + if let Some(finished) = finished { + y.yield_ok(AssistantEvent::Done { + prompt_token_count: finished.prompt_token_count, + output_token_count: finished.output_token_count, + finish_reason: finished.finish_reason, + kv_transfer_params: finished.kv_transfer_params, + }) + .await; + } + } + } + } + Ok(()) +} + +/// Lazily load the shared GPT-OSS Harmony encoding once per process. +fn harmony_encoding() -> Result<&'static HarmonyEncoding> { + static ENCODING: LazyLock> = LazyLock::new(|| { + load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss) + .context("failed to load harmony encoding for gpt-oss") + }); + + ENCODING.as_ref().map_err(|error| Error::HarmonyOutputParsing { + error: error.to_report_string().into(), + }) +} + +fn harmony_output_parsing_error( + error: impl Into>, +) -> Error { + Error::HarmonyOutputParsing { + error: error.into(), + } +} + +/// Return the decoded text payload from one parsed Harmony message. +fn harmony_message_text(message: &HarmonyMessage) -> &str { + let [HarmonyContent::Text(text)] = message.content.as_slice() else { + unreachable!("Harmony parser emits one text content block per parsed message") + }; + &text.text +} + +/// Map one Harmony `(channel, recipient)` pair to a visible assistant block +/// kind. +fn text_block_kind(channel: Option<&str>, recipient: Option<&str>) -> Option { + match (channel, recipient) { + (Some("final"), _) => Some(AssistantBlockKind::Text), + (Some("analysis"), None) => Some(AssistantBlockKind::Reasoning), + (Some("commentary"), None) => Some(AssistantBlockKind::Text), + _ => None, + } +} + +/// Extract the tool name from a Harmony tool-recipient field, if present. +fn tool_name<'a>(channel: Option<&str>, recipient: Option<&'a str>) -> Option<&'a str> { + match (channel, recipient) { + (Some("commentary" | "analysis"), Some(recipient)) => recipient.strip_prefix("functions."), + _ => None, + } +} + +#[cfg(test)] +mod tests; diff --git a/rust/src/chat/src/output/harmony/tests.rs b/rust/src/chat/src/output/harmony/tests.rs new file mode 100644 index 00000000000..fe42542b473 --- /dev/null +++ b/rust/src/chat/src/output/harmony/tests.rs @@ -0,0 +1,351 @@ +//! Harmony output tests share the upstream `openai-harmony` tiktoken cache. +//! +//! Use a file lock for tests that load the encoding so `cargo nextest` cannot +//! start multiple processes that concurrently populate the same cache file. + +use std::sync::Arc; + +use futures::executor::block_on; +use futures::{TryStreamExt as _, stream}; +use openai_harmony::chat::{Message, Role}; +use serial_test::file_serial; +use vllm_text::output::{DecodedLogprobs, DecodedPositionLogprobs, DecodedTextEvent, Finished}; + +use super::*; +use crate::output::ChatOutputProcessor; +use crate::request::{ChatRequest, ChatTool, ChatToolChoice}; +use crate::{AssistantMessageExt, ChatEvent, FinishReason}; + +fn assistant_prefix() -> Vec { + harmony_encoding() + .unwrap() + .render_conversation_for_completion(std::iter::empty::<&Message>(), Role::Assistant, None) + .unwrap() +} + +fn completion_tokens(messages: &[Message]) -> Vec { + let encoding = harmony_encoding().unwrap(); + let prefix = assistant_prefix(); + let rendered = encoding.render_conversation(messages.iter(), None).unwrap(); + assert!(rendered.starts_with(&prefix)); + rendered[prefix.len()..].to_vec() +} + +fn text_message(channel: &str, text: &str) -> Message { + Message::from_role_and_content(Role::Assistant, text).with_channel(channel) +} + +fn tool_message(name: &str, arguments: &str, channel: &str) -> Message { + Message::from_role_and_content(Role::Assistant, arguments) + .with_channel(channel) + .with_recipient(format!("functions.{name}")) + .with_content_type("json") +} + +fn decoded_start() -> DecodedTextEvent { + DecodedTextEvent::Start { + prompt_token_ids: Arc::<[u32]>::from([]), + prompt_logprobs: None, + } +} + +fn finished() -> Finished { + Finished { + prompt_token_count: 0, + output_token_count: 0, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + } +} + +async fn collect_events( + processor: HarmonyChatOutputProcessor, + events: Vec, +) -> Vec { + Box::new(processor) + .process(Box::pin(stream::iter(events.into_iter().map(Ok)))) + .unwrap() + .try_collect() + .await + .unwrap() +} + +fn request_with_tools() -> ChatRequest { + ChatRequest { + tool_choice: ChatToolChoice::Auto, + tools: vec![ChatTool { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: serde_json::json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + }), + strict: None, + }], + ..ChatRequest::for_test() + } +} + +#[test] +#[file_serial(harmony_tiktoken_cache)] +fn interrupted_final_message_is_preserved() { + let tokens = completion_tokens(&[text_message("final", "hello")]); + let events = block_on(collect_events( + HarmonyChatOutputProcessor::new(&ChatRequest::for_test()).unwrap(), + vec![ + decoded_start(), + DecodedTextEvent::TextDelta { + delta: String::new(), + token_ids: tokens[..tokens.len() - 1].to_vec(), + logprobs: None, + finished: Some(finished()), + }, + ], + )); + + assert_eq!( + events.last(), + Some(&ChatEvent::Done { + message: crate::AssistantMessage { + content: vec![crate::AssistantContentBlock::Text { + text: "hello".to_string(), + }], + }, + prompt_token_count: 0, + output_token_count: 0, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }) + ); +} + +#[test] +#[file_serial(harmony_tiktoken_cache)] +fn eos_flush_preserves_trailing_replacement_text() { + let mut tokens = completion_tokens(&[text_message("final", "Hi")]); + tokens.pop(); + tokens.push(u32::MAX); + + let events = block_on(collect_events( + HarmonyChatOutputProcessor::new(&ChatRequest::for_test()).unwrap(), + vec![ + decoded_start(), + DecodedTextEvent::TextDelta { + delta: String::new(), + token_ids: tokens, + logprobs: None, + finished: Some(finished()), + }, + ], + )); + + let ChatEvent::Done { message, .. } = events.last().unwrap() else { + panic!("expected done"); + }; + assert_eq!(message.text(), format!("Hi{}", char::REPLACEMENT_CHARACTER)); +} + +#[test] +#[file_serial(harmony_tiktoken_cache)] +fn interrupted_analysis_message_is_preserved() { + let tokens = completion_tokens(&[text_message("analysis", "think")]); + let events = block_on(collect_events( + HarmonyChatOutputProcessor::new(&ChatRequest::for_test()).unwrap(), + vec![ + decoded_start(), + DecodedTextEvent::TextDelta { + delta: String::new(), + token_ids: tokens[..tokens.len() - 1].to_vec(), + logprobs: None, + finished: Some(finished()), + }, + ], + )); + + assert_eq!( + events.last(), + Some(&ChatEvent::Done { + message: crate::AssistantMessage { + content: vec![crate::AssistantContentBlock::Reasoning { + text: "think".to_string(), + }], + }, + prompt_token_count: 0, + output_token_count: 0, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }) + ); +} + +#[test] +#[file_serial(harmony_tiktoken_cache)] +fn commentary_preamble_is_visible_but_commentary_tool_payload_is_not() { + let tokens = completion_tokens(&[ + text_message("commentary", "Let me check."), + tool_message("get_weather", r#"{"city":"Paris"}"#, "commentary"), + ]); + let events = block_on(collect_events( + HarmonyChatOutputProcessor::new(&request_with_tools()).unwrap(), + vec![ + decoded_start(), + DecodedTextEvent::TextDelta { + delta: String::new(), + token_ids: tokens, + logprobs: None, + finished: Some(finished()), + }, + ], + )); + + let done = events.last().unwrap(); + let ChatEvent::Done { message, .. } = done else { + panic!("expected done"); + }; + assert_eq!(message.text(), "Let me check."); + assert_eq!(message.tool_calls().count(), 1); +} + +#[test] +#[file_serial(harmony_tiktoken_cache)] +fn multiple_messages_get_newline_separators() { + let tokens = completion_tokens(&[ + text_message("analysis", "first think"), + text_message("analysis", "second think"), + text_message("final", "first answer"), + text_message("final", "second answer"), + ]); + let events = block_on(collect_events( + HarmonyChatOutputProcessor::new(&ChatRequest::for_test()).unwrap(), + vec![ + decoded_start(), + DecodedTextEvent::TextDelta { + delta: String::new(), + token_ids: tokens, + logprobs: None, + finished: Some(finished()), + }, + ], + )); + + let ChatEvent::Done { message, .. } = events.last().unwrap() else { + panic!("expected done"); + }; + assert_eq!( + message.reasoning().as_deref(), + Some("first think\nsecond think") + ); + assert_eq!(message.text(), "first answer\nsecond answer"); +} + +#[test] +#[file_serial(harmony_tiktoken_cache)] +fn tool_calls_stream_arguments_and_finish_with_local_id_shape() { + let tokens = completion_tokens(&[tool_message( + "get_weather", + r#"{"city":"Paris"}"#, + "commentary", + )]); + let midpoint = tokens.len() / 2; + let events = block_on(collect_events( + HarmonyChatOutputProcessor::new(&request_with_tools()).unwrap(), + vec![ + decoded_start(), + DecodedTextEvent::TextDelta { + delta: String::new(), + token_ids: tokens[..midpoint].to_vec(), + logprobs: None, + finished: None, + }, + DecodedTextEvent::TextDelta { + delta: String::new(), + token_ids: tokens[midpoint..].to_vec(), + logprobs: None, + finished: Some(finished()), + }, + ], + )); + + let mut saw_start = None; + let mut saw_args = String::new(); + let mut saw_end = None; + for event in &events { + match event { + ChatEvent::ToolCallStart { id, name, .. } => { + assert!(id.starts_with("call_")); + assert_eq!(name, "get_weather"); + saw_start = Some(id.clone()); + } + ChatEvent::ToolCallArgumentsDelta { delta, .. } => saw_args.push_str(delta), + ChatEvent::ToolCallEnd { call, .. } => { + saw_end = Some(call.clone()); + } + _ => {} + } + } + + let start_id = saw_start.expect("tool start"); + assert_eq!(saw_args, r#"{"city":"Paris"}"#); + let end = saw_end.expect("tool end"); + assert_eq!(end.id, start_id); + assert_eq!(end.arguments, r#"{"city":"Paris"}"#); +} + +#[test] +#[file_serial(harmony_tiktoken_cache)] +fn semantic_events_precede_same_update_logprobs() { + let tokens = completion_tokens(&[text_message("final", "hello")]); + let events = block_on(collect_events( + HarmonyChatOutputProcessor::new(&ChatRequest::for_test()).unwrap(), + vec![ + decoded_start(), + DecodedTextEvent::TextDelta { + delta: String::new(), + token_ids: tokens, + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { entries: vec![] }], + }), + finished: Some(finished()), + }, + ], + )); + + let block_delta_index = events + .iter() + .position(|event| matches!(event, ChatEvent::BlockDelta { .. })) + .unwrap(); + let logprobs_index = events + .iter() + .position(|event| matches!(event, ChatEvent::LogprobsDelta { .. })) + .unwrap(); + assert!(block_delta_index < logprobs_index); +} + +#[test] +fn rejects_generic_parser_overrides() { + let reasoning_error = + validate_harmony_parser_overrides(&ParserSelection::Auto, &ParserSelection::None) + .unwrap_err(); + assert_eq!( + reasoning_error.to_string(), + "gpt_oss uses native Harmony output parsing; generic reasoning parser override `none` is not supported" + ); + + let tool_error = validate_harmony_parser_overrides( + &ParserSelection::Explicit("json".to_string()), + &ParserSelection::Auto, + ) + .unwrap_err(); + assert_eq!( + tool_error.to_string(), + "gpt_oss uses native Harmony output parsing; generic tool parser override `json` is not supported" + ); +} + +#[test] +#[file_serial(harmony_tiktoken_cache)] +fn allows_auto_auto_only() { + validate_harmony_parser_overrides(&ParserSelection::Auto, &ParserSelection::Auto).unwrap(); + let _ = HarmonyChatOutputProcessor::new(&ChatRequest::for_test()).unwrap(); +} diff --git a/rust/src/chat/src/output/mod.rs b/rust/src/chat/src/output/mod.rs new file mode 100644 index 00000000000..81ec124fbcf --- /dev/null +++ b/rust/src/chat/src/output/mod.rs @@ -0,0 +1,135 @@ +use std::pin::Pin; +use std::sync::Arc; + +use futures::Stream; +use subenum::subenum; +use trait_set::trait_set; +use uuid::Uuid; +use vllm_text::output::{DecodedLogprobs, DecodedPromptLogprobs, DecodedTextEvent}; + +use crate::FinishReason; +use crate::error::Result; +use crate::event::{AssistantBlockKind, ChatEvent}; + +mod default; +mod harmony; +mod structured; + +pub use default::DefaultChatOutputProcessor; +pub use harmony::HarmonyChatOutputProcessor; +pub(crate) use harmony::validate_harmony_parser_overrides; + +/// Internal assistant event before final assembly. +/// +/// - [`ContentEvent`]: subenum after reasoning parsing, carries only text content. +/// - [`AssistantEvent`]: full event after tool parsing, adds tool-call variants. +#[subenum(ContentEvent)] +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum AssistantEvent { + #[subenum(ContentEvent)] + Start { + prompt_token_ids: Arc<[u32]>, + prompt_logprobs: Option, + }, + #[subenum(ContentEvent)] + TextDelta { + kind: AssistantBlockKind, + delta: String, + }, + /// Per-decoded-update sample metadata: logprobs and/or output token IDs. + #[subenum(ContentEvent)] + LogprobsDelta { + logprobs: Option, + token_ids: Vec, + }, + /// The start of a new tool call, with its declared name and generated ID. + ToolCallStart { id: String, name: String }, + /// A delta for the arguments of the currently open tool call. Must follow a + /// `ToolCallStart`. + ToolCallArgumentsDelta { delta: String }, + #[subenum(ContentEvent)] + Done { + prompt_token_count: usize, + output_token_count: usize, + finish_reason: FinishReason, + /// Connector-specific KV transfer parameters for disaggregated serving. + kv_transfer_params: Option, + }, +} + +impl ContentEvent { + /// Convert a [`DecodedTextEvent`] into one or more [`ContentEvent`] values + /// by treating all text as plain (non-reasoning) content. + fn from_decoded_plain_text(event: DecodedTextEvent) -> Vec { + match event { + DecodedTextEvent::Start { + prompt_token_ids, + prompt_logprobs, + } => vec![Self::Start { + prompt_token_ids, + prompt_logprobs, + }], + DecodedTextEvent::TextDelta { + delta, + token_ids, + logprobs, + finished, + } => { + let mut events = Vec::new(); + if !delta.is_empty() { + events.push(Self::TextDelta { + kind: AssistantBlockKind::Text, + delta, + }); + } + if logprobs.is_some() || !token_ids.is_empty() { + events.push(Self::LogprobsDelta { + logprobs, + token_ids, + }); + } + if let Some(finished) = finished { + events.push(Self::Done { + prompt_token_count: finished.prompt_token_count, + output_token_count: finished.output_token_count, + finish_reason: finished.finish_reason, + kv_transfer_params: finished.kv_transfer_params, + }); + } + events + } + } + } +} + +/// Boxed stream of decoded text events coming from [`vllm_text`]. +pub type DynDecodedTextEventStream = Pin> + Send>>; +/// Boxed stream of structured chat events exposed by [`crate::ChatLlm`]. +pub type DynChatEventStream = Pin> + Send>>; + +/// Request-scoped output processor from decoded text events into structured +/// chat events. +pub trait ChatOutputProcessor: Send { + /// Consume decoded text stream and return the structured chat-event stream. + fn process(self: Box, decoded: DynDecodedTextEventStream) -> Result; +} + +/// Trait-object form of [`ChatOutputProcessor`]. +pub type DynChatOutputProcessor = Box; + +trait_set! { + /// Boxed-stream constraint for decoded text updates. + pub(crate) trait DecodedTextEventStream = Stream> + Send + 'static; + /// Boxed-stream constraint for internal assistant events. + pub(crate) trait AssistantEventStream = Stream> + Send + 'static; + /// Boxed-stream constraint for public chat events. + pub(crate) trait ChatEventStream = Stream> + Send + 'static; +} + +/// Generate the northbound tool-call ID using the OpenAI-style `call_` +/// format. +// TODO: support other ID scheme like Kimi-K2's +// `functions.{name}:{global_index}`. +pub(crate) fn generate_tool_call_id() -> String { + format!("call_{}", &Uuid::new_v4().simple().to_string()[..24]) +} diff --git a/rust/src/chat/src/output/structured.rs b/rust/src/chat/src/output/structured.rs new file mode 100644 index 00000000000..ed6e3a5130c --- /dev/null +++ b/rust/src/chat/src/output/structured.rs @@ -0,0 +1,508 @@ +//! Adapts parsed assistant updates into structured chat events. +//! +//! This module remains the final assembly stage in `vllm-chat`. Token-to-text +//! decoding still lives in `decoded.rs`, while reasoning separation and tool +//! parsing are handled earlier by their own adapters. This stage consumes those +//! parsed deltas and assembles higher-level assistant content blocks. + +use asynk_strim_attr::{TryYielder, try_stream}; +use futures::{StreamExt as _, pin_mut}; +use vllm_text::DecodedLogprobs; + +use super::{AssistantEvent, AssistantEventStream}; +use crate::error::Error; +use crate::event::{ + AssistantBlockKind, AssistantContentBlock, AssistantMessage, AssistantToolCall, ChatEvent, +}; +use crate::{FinishReason, Result}; + +/// One currently open assistant text-like block being assembled from streamed +/// deltas. +struct OpenTextBlock { + /// Stable position of this block in the final assistant message. + index: usize, + /// Semantic kind of the block being assembled. + kind: AssistantBlockKind, + /// Accumulated text payload for the block. + text: String, +} + +/// One currently open assistant tool call being assembled from streamed deltas. +struct OpenToolCall { + /// Stable ordinal of this tool call in the assistant tool-call list. + index: usize, + /// Stable tool-call ID exposed northbound. + id: String, + /// Function name. + name: String, + /// Incremental JSON arguments accumulated so far. + arguments: String, +} + +/// Per-stream block assembly state. +/// +/// The adapter maintains at most one open text block and one open tool call, +/// and appends deltas to them until the semantic kind changes or the stream +/// terminates. +struct StructuredEventState { + /// Final assistant message assembled so far. + message: AssistantMessage, + /// Currently open text or reasoning block, if any. + open_text_block: Option, + /// Currently open tool call, if any. + open_tool_call: Option, + /// Next OpenAI-compatible tool-call ordinal. + next_tool_call_index: usize, +} + +impl StructuredEventState { + /// Create one fresh assembly state for a new streamed response. + fn new() -> Self { + Self { + message: AssistantMessage::default(), + open_text_block: None, + open_tool_call: None, + next_tool_call_index: 0, + } + } + + /// Convert one parsed text delta into zero or more structured chat events. + fn process_text_delta( + &mut self, + kind: AssistantBlockKind, + delta: String, + ) -> Result> { + let mut events = Vec::new(); + self.close_open_tool_call(&mut events); + self.push_text_delta(kind, delta, &mut events); + Ok(events) + } + + /// Forward per-update sample metadata without attaching it to text blocks. + fn process_logprobs_delta( + &mut self, + logprobs: Option, + token_ids: Vec, + ) -> Result> { + Ok(vec![ChatEvent::LogprobsDelta { + logprobs, + token_ids, + }]) + } + + /// Start one new tool call, closing any incompatible open block first. + fn start_tool_call(&mut self, id: String, name: String) -> Result> { + let mut events = Vec::new(); + self.close_open_text_block(&mut events); + self.close_open_tool_call(&mut events); + + let index = self.next_tool_call_index; + self.next_tool_call_index += 1; + self.open_tool_call = Some(OpenToolCall { + index, + id: id.clone(), + name: name.clone(), + arguments: String::new(), + }); + events.push(ChatEvent::ToolCallStart { index, id, name }); + Ok(events) + } + + /// Append one incremental tool-call arguments delta. + fn push_tool_call_arguments(&mut self, delta: String) -> Result> { + let mut events = Vec::new(); + let Some(open_tool_call) = self.open_tool_call.as_mut() else { + return Err(Error::ToolCallStreamInvariant { + message: "received tool-call arguments delta without an open tool call".to_string(), + }); + }; + open_tool_call.arguments.push_str(&delta); + events.push(ChatEvent::ToolCallArgumentsDelta { + index: open_tool_call.index, + delta, + }); + Ok(events) + } + + /// Close any open block and emit the terminal `Done` event. + fn finish( + &mut self, + prompt_token_count: usize, + output_token_count: usize, + finish_reason: FinishReason, + kv_transfer_params: Option, + ) -> Result> { + let mut events = Vec::new(); + self.close_open_text_block(&mut events); + self.close_open_tool_call(&mut events); + events.push(ChatEvent::Done { + message: self.message.clone(), + prompt_token_count, + output_token_count, + finish_reason, + kv_transfer_params, + }); + Ok(events) + } + + /// Append one semantic text delta to the current block, or open a new block + /// when the semantic kind changes. + fn push_text_delta( + &mut self, + kind: AssistantBlockKind, + delta: String, + events: &mut Vec, + ) { + if delta.is_empty() { + return; + } + + match self.open_text_block.as_mut() { + // If there's a currently open block of the same kind, append to it. + Some(open_block) if open_block.kind == kind => { + open_block.text.push_str(&delta); + events.push(ChatEvent::BlockDelta { + index: open_block.index, + kind, + delta, + }); + } + // Otherwise, close the currently open block (if any) and start a + // new one. + _ => { + self.close_open_text_block(events); + let index = self.message.content.len(); + self.open_text_block = Some(OpenTextBlock { + index, + kind, + text: delta.clone(), + }); + events.push(ChatEvent::BlockStart { index, kind }); + events.push(ChatEvent::BlockDelta { index, kind, delta }); + } + } + } + + /// Finalize the currently open text block, if present. + fn close_open_text_block(&mut self, events: &mut Vec) { + let Some(open_block) = self.open_text_block.take() else { + return; + }; + + let block = match open_block.kind { + AssistantBlockKind::Text => AssistantContentBlock::Text { + text: open_block.text, + }, + AssistantBlockKind::Reasoning => AssistantContentBlock::Reasoning { + text: open_block.text, + }, + AssistantBlockKind::ToolCall => { + unreachable!("tool calls must not be assembled as text blocks") + } + }; + self.message.push_block(block.clone()); + events.push(ChatEvent::BlockEnd { + index: open_block.index, + block, + }); + } + + /// Finalize the currently open tool call, if present. + fn close_open_tool_call(&mut self, events: &mut Vec) { + let Some(open_tool_call) = self.open_tool_call.take() else { + return; + }; + + let call = AssistantToolCall { + id: open_tool_call.id, + name: open_tool_call.name, + arguments: open_tool_call.arguments, + }; + self.message.push_block(AssistantContentBlock::ToolCall(call.clone())); + events.push(ChatEvent::ToolCallEnd { + index: open_tool_call.index, + call, + }); + } +} + +/// Wrap one parsed assistant stream into the public structured chat event +/// stream. +#[try_stream] +pub(crate) async fn structured_chat_event_stream( + stream: impl AssistantEventStream, + mut y: TryYielder, +) -> Result<()> { + pin_mut!(stream); + + let mut state = StructuredEventState::new(); + + while let Some(event) = stream.next().await.transpose()? { + match event { + AssistantEvent::Start { + prompt_token_ids, + prompt_logprobs, + } => { + y.yield_ok(ChatEvent::Start { + prompt_token_ids, + prompt_logprobs, + }) + .await; + } + AssistantEvent::TextDelta { kind, delta } => { + for next in state.process_text_delta(kind, delta)? { + y.yield_ok(next).await; + } + } + AssistantEvent::LogprobsDelta { + logprobs, + token_ids, + } => { + for next in state.process_logprobs_delta(logprobs, token_ids)? { + y.yield_ok(next).await; + } + } + AssistantEvent::ToolCallStart { id, name } => { + for next in state.start_tool_call(id, name)? { + y.yield_ok(next).await; + } + } + AssistantEvent::ToolCallArgumentsDelta { delta } => { + for next in state.push_tool_call_arguments(delta)? { + y.yield_ok(next).await; + } + } + AssistantEvent::Done { + prompt_token_count, + output_token_count, + finish_reason, + kv_transfer_params, + } => { + for next in state.finish( + prompt_token_count, + output_token_count, + finish_reason, + kv_transfer_params, + )? { + y.yield_ok(next).await; + } + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use futures::{StreamExt as _, stream}; + + use super::structured_chat_event_stream; + use crate::FinishReason; + use crate::error::Error; + use crate::event::{AssistantBlockKind, AssistantMessageExt as _, ChatEvent}; + use crate::output::AssistantEvent; + + #[tokio::test] + async fn structured_stream_closes_tool_call_on_done() { + let events = stream::iter(vec![ + Ok(AssistantEvent::ToolCallStart { + id: "call_1".to_string(), + name: "get_weather".to_string(), + }), + Ok(AssistantEvent::ToolCallArgumentsDelta { + delta: r#"{"city":"Paris"}"#.to_string(), + }), + Ok(AssistantEvent::Done { + prompt_token_count: 1, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let events = structured_chat_event_stream(events) + .collect::>() + .await + .into_iter() + .collect::>>() + .unwrap(); + + assert!(matches!(events[0], ChatEvent::ToolCallStart { .. })); + assert!(matches!( + events[1], + ChatEvent::ToolCallArgumentsDelta { .. } + )); + let ChatEvent::ToolCallEnd { call, .. } = &events[2] else { + panic!("expected tool call end"); + }; + assert_eq!(call.name, "get_weather"); + assert_eq!(call.arguments, r#"{"city":"Paris"}"#); + let ChatEvent::Done { message, .. } = &events[3] else { + panic!("expected done"); + }; + let tool_calls = message.tool_calls().collect::>(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id, "call_1"); + assert_eq!(tool_calls[0].arguments, r#"{"city":"Paris"}"#); + } + + #[tokio::test] + async fn structured_stream_closes_previous_tool_call_on_next_start() { + let events = stream::iter(vec![ + Ok(AssistantEvent::ToolCallStart { + id: "call_1".to_string(), + name: "first".to_string(), + }), + Ok(AssistantEvent::ToolCallArgumentsDelta { + delta: r#"{"a":1}"#.to_string(), + }), + Ok(AssistantEvent::ToolCallStart { + id: "call_2".to_string(), + name: "second".to_string(), + }), + Ok(AssistantEvent::ToolCallArgumentsDelta { + delta: r#"{"b":2}"#.to_string(), + }), + Ok(AssistantEvent::Done { + prompt_token_count: 1, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let events = structured_chat_event_stream(events) + .collect::>() + .await + .into_iter() + .collect::>>() + .unwrap(); + + assert!(matches!(events[0], ChatEvent::ToolCallStart { .. })); + assert!(matches!( + events[1], + ChatEvent::ToolCallArgumentsDelta { .. } + )); + let ChatEvent::ToolCallEnd { call, .. } = &events[2] else { + panic!("expected first tool call end"); + }; + assert_eq!(call.name, "first"); + assert!(matches!(events[3], ChatEvent::ToolCallStart { .. })); + let ChatEvent::Done { message, .. } = &events[6] else { + panic!("expected done"); + }; + let tool_calls = message.tool_calls().collect::>(); + assert_eq!(tool_calls.len(), 2); + assert_eq!(tool_calls[0].name, "first"); + assert_eq!(tool_calls[1].name, "second"); + } + + #[tokio::test] + async fn structured_stream_numbers_tool_calls_independent_of_text_blocks() { + let events = stream::iter(vec![ + Ok(AssistantEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "before".to_string(), + }), + Ok(AssistantEvent::ToolCallStart { + id: "call_1".to_string(), + name: "get_weather".to_string(), + }), + Ok(AssistantEvent::ToolCallArgumentsDelta { + delta: r#"{"city":"Paris"}"#.to_string(), + }), + Ok(AssistantEvent::Done { + prompt_token_count: 1, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let events = structured_chat_event_stream(events) + .collect::>() + .await + .into_iter() + .collect::>>() + .unwrap(); + + assert!(matches!( + events[0], + ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Text, + } + )); + assert!(matches!(events[2], ChatEvent::BlockEnd { index: 0, .. })); + assert!(matches!( + events[3], + ChatEvent::ToolCallStart { index: 0, .. } + )); + assert!(matches!( + events[4], + ChatEvent::ToolCallArgumentsDelta { index: 0, .. } + )); + assert!(matches!(events[5], ChatEvent::ToolCallEnd { index: 0, .. })); + } + + #[tokio::test] + async fn structured_stream_closes_tool_call_before_text() { + let events = stream::iter(vec![ + Ok(AssistantEvent::ToolCallStart { + id: "call_1".to_string(), + name: "get_weather".to_string(), + }), + Ok(AssistantEvent::ToolCallArgumentsDelta { + delta: r#"{"city":"Paris"}"#.to_string(), + }), + Ok(AssistantEvent::TextDelta { + kind: AssistantBlockKind::Text, + delta: "done".to_string(), + }), + Ok(AssistantEvent::Done { + prompt_token_count: 1, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let events = structured_chat_event_stream(events) + .collect::>() + .await + .into_iter() + .collect::>>() + .unwrap(); + + assert!(matches!(events[2], ChatEvent::ToolCallEnd { .. })); + assert!(matches!( + events[3], + ChatEvent::BlockStart { + kind: AssistantBlockKind::Text, + .. + } + )); + let ChatEvent::Done { message, .. } = &events[6] else { + panic!("expected done"); + }; + assert_eq!(message.text(), "done"); + assert_eq!(message.tool_calls().count(), 1); + } + + #[tokio::test] + async fn structured_stream_rejects_arguments_without_open_tool_call() { + let events = stream::iter(vec![Ok(AssistantEvent::ToolCallArgumentsDelta { + delta: "{}".to_string(), + })]); + + let err = structured_chat_event_stream(events) + .collect::>() + .await + .into_iter() + .next() + .expect("expected one event") + .expect_err("expected invariant error"); + + assert!(matches!(err, Error::ToolCallStreamInvariant { .. })); + } +} diff --git a/rust/src/chat/src/parser/mod.rs b/rust/src/chat/src/parser/mod.rs new file mode 100644 index 00000000000..52e83b3e047 --- /dev/null +++ b/rust/src/chat/src/parser/mod.rs @@ -0,0 +1,107 @@ +pub mod reasoning; +pub mod tool; + +use std::collections::HashMap; +use std::convert::Infallible; +use std::fmt; +use std::str::FromStr; + +use serde_with::DeserializeFromStr; + +/// Specify which reasoning or tool-call parser implementation to use. +#[derive(Debug, Clone, PartialEq, Eq, Default, DeserializeFromStr)] +pub enum ParserSelection { + /// Use model-based auto-detection. + #[default] + Auto, + /// Disable the parser entirely. + None, + /// Force one specific parser implementation by name. + Explicit(String), +} + +impl ParserSelection { + pub const AUTO_LITERAL: &str = "auto"; + pub const NONE_LITERAL: &str = "none"; +} + +impl FromStr for ParserSelection { + type Err = Infallible; + + fn from_str(value: &str) -> Result { + Ok(if value.eq_ignore_ascii_case(Self::AUTO_LITERAL) { + Self::Auto + } else if value.eq_ignore_ascii_case(Self::NONE_LITERAL) { + Self::None + } else { + Self::Explicit(value.to_owned()) + }) + } +} + +impl fmt::Display for ParserSelection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Auto => f.write_str(Self::AUTO_LITERAL), + Self::None => f.write_str(Self::NONE_LITERAL), + Self::Explicit(name) => f.write_str(name), + } + } +} + +/// Registry and model matcher for reasoning and tool parsers. +#[derive(Clone)] +pub struct ParserFactory { + creators: HashMap, + patterns: Vec<(String, String)>, +} + +impl Default for ParserFactory { + fn default() -> Self { + Self { + creators: HashMap::new(), + patterns: Vec::new(), + } + } +} + +impl ParserFactory { + /// Register a creator for a parser by an exact name. + pub fn register_creator(&mut self, name: &str, creator: C) -> &mut Self { + self.creators.insert(name.to_string(), creator); + self + } + + /// Add a case-insensitive substring match from model ID to parser name. + pub fn register_pattern(&mut self, pattern: &str, parser_name: &str) -> &mut Self { + self.patterns.push((pattern.to_lowercase(), parser_name.to_string())); + self + } + + /// Return the first registered parser name matching the given model ID. + pub fn resolve_name_for_model(&self, model_id: &str) -> Option<&str> { + let model_lower = model_id.to_lowercase(); + self.patterns + .iter() + .find(|(pattern, _)| model_lower.contains(pattern)) + .map(|(_, parser_name)| parser_name.as_str()) + } + + /// Return true if the exact parser name is registered. + pub fn contains(&self, name: &str) -> bool { + self.creators.contains_key(name) + } + + /// Return all registered parser names sorted for stable display. + pub fn list(&self) -> Vec { + let mut names: Vec<_> = self.creators.keys().cloned().collect(); + names.sort_unstable(); + names + } + + /// Get the constructor for a parser by its exact registered name, or return + /// None if not found. + pub fn creator(&self, name: &str) -> Option<&C> { + self.creators.get(name) + } +} diff --git a/rust/src/chat/src/parser/reasoning/mod.rs b/rust/src/chat/src/parser/reasoning/mod.rs new file mode 100644 index 00000000000..09111d7252f --- /dev/null +++ b/rust/src/chat/src/parser/reasoning/mod.rs @@ -0,0 +1,120 @@ +//! Reasoning parser registration and selection boundary for `vllm-chat`. + +use std::sync::LazyLock; + +pub use vllm_reasoning_parser::{ + CohereCmdReasoningParser, DeepSeekR1ReasoningParser, DeepSeekV3ReasoningParser, + DeepSeekV4ReasoningParser, Gemma4ReasoningParser, Glm45ReasoningParser, KimiK2ReasoningParser, + KimiReasoningParser, MiniMaxM2ReasoningParser, NemotronV3ReasoningParser, Qwen3ReasoningParser, + ReasoningDelta, ReasoningError, ReasoningParser, Step3ReasoningParser, +}; +use vllm_tokenizer::DynTokenizer; + +use crate::parser::ParserFactory; + +/// Canonical public names for registered reasoning parsers. +pub mod names { + pub const COHERE_CMD: &str = "cohere_cmd"; + pub const DEEPSEEK_R1: &str = "deepseek_r1"; + pub const DEEPSEEK_V3: &str = "deepseek_v3"; + pub const DEEPSEEK_V4: &str = "deepseek_v4"; + pub const GEMMA4: &str = "gemma4"; + pub const GLM45: &str = "glm45"; + pub const KIMI: &str = "kimi"; + pub const KIMI_K2: &str = "kimi_k2"; + pub const MINIMAX_M2: &str = "minimax_m2"; + pub const NEMOTRON_V3: &str = "nemotron_v3"; + pub const QWEN3: &str = "qwen3"; + pub const STEP3: &str = "step3"; +} + +/// Constructor signature for one registered reasoning parser implementation. +type ReasoningParserCreator = + fn(DynTokenizer) -> vllm_reasoning_parser::Result>; + +/// Registry and model matcher for reasoning parsers. +pub type ReasoningParserFactory = ParserFactory; + +impl ReasoningParserFactory { + /// Get the global reasoning parser factory with built-in registrations and + /// model mappings. + pub fn global() -> &'static Self { + static INSTANCE: LazyLock = + LazyLock::new(ReasoningParserFactory::new); + &INSTANCE + } + + /// Create the default registry with built-in parser names and model + /// mappings. + pub fn new() -> Self { + let mut factory = Self::default(); + + factory + .register_parser::(names::COHERE_CMD) + .register_parser::(names::DEEPSEEK_R1) + .register_parser::(names::DEEPSEEK_V3) + .register_parser::(names::DEEPSEEK_V4) + .register_parser::(names::GEMMA4) + .register_parser::(names::GLM45) + .register_parser::(names::KIMI) + .register_parser::(names::KIMI_K2) + .register_parser::(names::MINIMAX_M2) + .register_parser::(names::NEMOTRON_V3) + .register_parser::(names::QWEN3) + .register_parser::(names::STEP3); + + factory + .register_pattern("deepseek-r1", names::DEEPSEEK_R1) + .register_pattern("deepseek-v4", names::DEEPSEEK_V4) + .register_pattern("deepseek_v4", names::DEEPSEEK_V4) + .register_pattern("deepseek-v3", names::DEEPSEEK_V3) + .register_pattern("gemma-4", names::GEMMA4) + .register_pattern("gemma4", names::GEMMA4) + .register_pattern("qwen", names::QWEN3) + .register_pattern("glm-5", names::GLM45) + .register_pattern("glm-4.7", names::GLM45) + .register_pattern("glm-4.6", names::GLM45) + .register_pattern("glm-4.5", names::GLM45) + .register_pattern("kimi-k2", names::KIMI_K2) + .register_pattern("kimi", names::KIMI) + .register_pattern("step3", names::STEP3) + .register_pattern("minimax", names::MINIMAX_M2) + .register_pattern("mm-m2", names::MINIMAX_M2) + .register_pattern("cohere", names::COHERE_CMD) + .register_pattern("command", names::COHERE_CMD) + .register_pattern("nano", names::NEMOTRON_V3) + .register_pattern("nemotron", names::NEMOTRON_V3); + + factory + } + + /// Register one parser type that exposes a static `create()` constructor. + pub fn register_parser(&mut self, name: &str) -> &mut Self + where + T: ReasoningParser + 'static, + { + self.register_creator(name, T::create) + } + + /// Construct a parser from an exact name. + pub fn create( + &self, + name: &str, + tokenizer: DynTokenizer, + ) -> crate::Result> { + let creator = self.creator(name).ok_or_else(|| crate::Error::ParserUnavailableByName { + kind: "reasoning", + name: name.to_string(), + available_names: self.list(), + })?; + + creator(tokenizer).map_err(|error| crate::Error::ParserInitialization { + kind: "reasoning", + name: name.to_string(), + error: error.into(), + }) + } +} + +#[cfg(test)] +mod tests; diff --git a/rust/src/chat/src/parser/reasoning/tests.rs b/rust/src/chat/src/parser/reasoning/tests.rs new file mode 100644 index 00000000000..89b5f8e2308 --- /dev/null +++ b/rust/src/chat/src/parser/reasoning/tests.rs @@ -0,0 +1,61 @@ +use std::sync::Arc; + +use vllm_tokenizer::Tokenizer; + +use super::{ReasoningParserFactory, names}; + +struct FakeTokenizer; + +impl Tokenizer for FakeTokenizer { + fn encode(&self, text: &str, _add_special_tokens: bool) -> vllm_tokenizer::Result> { + Ok(text.chars().map(u32::from).collect()) + } + + fn decode( + &self, + token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_tokenizer::Result { + Ok(token_ids + .iter() + .map(|token_id| char::from_u32(*token_id).unwrap_or('\u{FFFD}')) + .collect()) + } + + fn token_to_id(&self, _token: &str) -> Option { + None + } +} + +#[test] +fn factory_contains_and_lists_registered_parsers() { + let factory = ReasoningParserFactory::new(); + assert!(factory.contains(names::QWEN3)); + assert!(factory.contains(names::DEEPSEEK_V4)); + assert!(factory.list().contains(&names::QWEN3.to_string())); + assert!(factory.list().contains(&names::DEEPSEEK_V4.to_string())); +} + +#[test] +fn factory_resolves_deepseek_v4_to_qwen3_alias() { + let factory = ReasoningParserFactory::new(); + assert_eq!( + factory.resolve_name_for_model("deepseek-ai/DeepSeek-V4"), + Some(names::DEEPSEEK_V4) + ); + assert_eq!( + factory.resolve_name_for_model("deepseek_v4"), + Some(names::DEEPSEEK_V4) + ); +} + +#[test] +fn factory_rejects_unknown_parser_names() { + let tokenizer = Arc::new(FakeTokenizer); + let factory = ReasoningParserFactory::new(); + let error = match factory.create("missing", tokenizer) { + Ok(_) => panic!("expected parser lookup to fail"), + Err(error) => error, + }; + assert!(error.to_string().contains("choose from")); +} diff --git a/rust/src/chat/src/parser/tool/mod.rs b/rust/src/chat/src/parser/tool/mod.rs new file mode 100644 index 00000000000..fd1bbedd822 --- /dev/null +++ b/rust/src/chat/src/parser/tool/mod.rs @@ -0,0 +1,140 @@ +//! Tool parser registration and selection boundary for `vllm-chat`. + +use std::sync::LazyLock; + +pub use vllm_tool_parser::{ + DeepSeekV3ToolParser, DeepSeekV4ToolParser, DeepSeekV31ToolParser, DeepSeekV32ToolParser, + Gemma4ToolParser, Glm45MoeToolParser, Glm47MoeToolParser, HermesToolParser, KimiK2ToolParser, + Llama3JsonToolParser, MinimaxM2ToolParser, MistralToolParser, Qwen3CoderToolParser, + Qwen3XmlToolParser, ToolCallDelta, ToolParseResult, ToolParser, ToolParserError, +}; + +use crate::parser::ParserFactory; +use crate::request::ChatTool; + +/// Canonical public names for registered tool parsers. +pub mod names { + pub const DEEPSEEK_V3: &str = "deepseek_v3"; + pub const DEEPSEEK_V31: &str = "deepseek_v31"; + pub const DEEPSEEK_V32: &str = "deepseek_v32"; + pub const DEEPSEEK_V4: &str = "deepseek_v4"; + pub const GLM45: &str = "glm45"; + pub const GLM47: &str = "glm47"; + pub const GEMMA4: &str = "gemma4"; + pub const HERMES: &str = "hermes"; + pub const KIMI_K2: &str = "kimi_k2"; + pub const LLAMA3_JSON: &str = "llama3_json"; + pub const LLAMA4_JSON: &str = "llama4_json"; + pub const MINIMAX_M2: &str = "minimax_m2"; + pub const MISTRAL: &str = "mistral"; + pub const QWEN3_CODER: &str = "qwen3_coder"; + pub const QWEN3_XML: &str = "qwen3_xml"; +} + +/// Constructor signature for one registered tool parser implementation. +type ToolParserCreator = fn(&[ChatTool]) -> vllm_tool_parser::Result>; + +/// Registry and model matcher for tool parsers. +pub type ToolParserFactory = ParserFactory; + +impl ToolParserFactory { + /// Get the global tool parser factory with built-in registrations and model + /// mappings. + pub fn global() -> &'static Self { + static INSTANCE: LazyLock = LazyLock::new(ToolParserFactory::new); + &INSTANCE + } + + /// Create the default registry with built-in parser names and model + /// mappings. + pub fn new() -> Self { + let mut factory = Self::default(); + + factory + .register_parser::(names::DEEPSEEK_V3) + .register_parser::(names::DEEPSEEK_V31) + .register_parser::(names::DEEPSEEK_V32) + .register_parser::(names::DEEPSEEK_V4) + .register_parser::(names::GLM45) + .register_parser::(names::GLM47) + .register_parser::(names::GEMMA4) + .register_parser::(names::HERMES) + .register_parser::(names::KIMI_K2) + .register_parser::(names::LLAMA3_JSON) + .register_parser::(names::LLAMA4_JSON) + .register_parser::(names::MINIMAX_M2) + .register_parser::(names::MISTRAL) + .register_parser::(names::QWEN3_XML) + .register_parser::(names::QWEN3_CODER); + + factory + .register_pattern("mistral-", names::MISTRAL) + .register_pattern("mixtral-", names::MISTRAL) + .register_pattern("qwen3-coder", names::QWEN3_CODER) + .register_pattern("qwen2.5-coder", names::QWEN3_CODER) + .register_pattern("qwen3.5", names::QWEN3_CODER) + .register_pattern("qwen", names::QWEN3_XML) + .register_pattern("hermes", names::HERMES) + .register_pattern("llama-4", names::LLAMA4_JSON) + .register_pattern("llama-3.2", names::LLAMA3_JSON) + .register_pattern("llama-3.1", names::LLAMA3_JSON) + .register_pattern("deepseek-r1", names::DEEPSEEK_V3) + .register_pattern("deepseek-v4", names::DEEPSEEK_V4) + .register_pattern("deepseek_v4", names::DEEPSEEK_V4) + .register_pattern("deepseek-v3.2", names::DEEPSEEK_V32) + .register_pattern("deepseek-v3.1", names::DEEPSEEK_V31) + .register_pattern("deepseek-v3", names::DEEPSEEK_V3) + .register_pattern("glm-5", names::GLM47) + .register_pattern("glm-4.7", names::GLM47) + .register_pattern("glm-4.6", names::GLM45) + .register_pattern("glm-4.5", names::GLM45) + .register_pattern("gemma4", names::GEMMA4) + .register_pattern("gemma-4", names::GEMMA4) + .register_pattern("kimi-k2", names::KIMI_K2) + .register_pattern("minimax", names::MINIMAX_M2) + .register_pattern("mm-m2", names::MINIMAX_M2); + + factory + } + + /// Register one parser type that exposes a static `create()` constructor. + pub fn register_parser(&mut self, name: &str) -> &mut Self + where + T: ToolParser + 'static, + { + self.register_creator(name, T::create) + } + + /// Construct a parser from an exact name. + pub fn create(&self, name: &str, tools: &[ChatTool]) -> crate::Result> { + let creator = self.creator(name).ok_or_else(|| crate::Error::ParserUnavailableByName { + kind: "tool", + name: name.to_string(), + available_names: self.list(), + })?; + + creator(tools).map_err(|error| crate::Error::ParserInitialization { + kind: "tool", + name: name.to_string(), + error: error.into(), + }) + } + + /// Resolve a parser from model ID and then construct it. + pub fn create_for_model( + &self, + model_id: &str, + tools: &[ChatTool], + ) -> crate::Result> { + let name = self.resolve_name_for_model(model_id).ok_or_else(|| { + crate::Error::ParserUnavailableForModel { + kind: "tool", + model_id: model_id.to_string(), + } + })?; + self.create(name, tools) + } +} + +#[cfg(test)] +mod tests; diff --git a/rust/src/chat/src/parser/tool/tests.rs b/rust/src/chat/src/parser/tool/tests.rs new file mode 100644 index 00000000000..fb2230faeb7 --- /dev/null +++ b/rust/src/chat/src/parser/tool/tests.rs @@ -0,0 +1,152 @@ +use vllm_tool_parser::Result; + +use super::{ToolParseResult, ToolParser, ToolParserFactory, names}; +use crate::Error; +use crate::request::ChatTool; + +struct FakeToolParser; + +impl ToolParser for FakeToolParser { + fn create(_tools: &[ChatTool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self)) + } + + fn preserve_special_tokens(&self) -> bool { + true + } + + fn push(&mut self, _chunk: &str) -> Result { + Ok(ToolParseResult::default()) + } +} + +#[test] +fn default_factory_starts_empty() { + let factory = ToolParserFactory::default(); + assert!(factory.list().is_empty()); +} + +#[test] +fn factory_contains_and_creates_registered_parsers() { + let mut factory = ToolParserFactory::default(); + factory.register_parser::("fake"); + + assert!(factory.contains("fake")); + assert!(factory.list().contains(&"fake".to_string())); + factory.create("fake", &[]).unwrap(); +} + +#[test] +fn factory_rejects_unknown_parser_names() { + let factory = ToolParserFactory::default(); + let error = match factory.create("missing", &[]) { + Ok(_) => panic!("expected parser lookup to fail"), + Err(error) => error, + }; + assert!(matches!(error, Error::ParserUnavailableByName { .. })); +} + +#[test] +fn factory_rejects_unknown_models() { + let factory = ToolParserFactory::default(); + let error = match factory.create_for_model("definitely-unknown-model", &[]) { + Ok(_) => panic!("expected model lookup to fail"), + Err(error) => error, + }; + assert!(matches!(error, Error::ParserUnavailableForModel { .. })); +} + +#[test] +fn factory_creates_registered_parser_for_model() { + let mut factory = ToolParserFactory::default(); + factory + .register_parser::("fake") + .register_pattern("fake-model", "fake"); + + factory.create_for_model("my-fake-model-v1", &[]).unwrap(); +} + +#[test] +fn factory_new_resolves_default_patterns() { + let factory = ToolParserFactory::new(); + + assert_eq!( + factory.resolve_name_for_model("Qwen/Qwen3.5-0.8B"), + Some(names::QWEN3_CODER) + ); + assert_eq!( + factory.resolve_name_for_model("Qwen/Qwen3-0.6B"), + Some(names::QWEN3_XML) + ); + assert_eq!( + factory.resolve_name_for_model("Qwen/Qwen3-Coder-30B"), + Some(names::QWEN3_CODER) + ); + assert_eq!( + factory.resolve_name_for_model("meta-llama-4-maverick"), + Some(names::LLAMA4_JSON) + ); + assert_eq!( + factory.resolve_name_for_model("meta-llama-3.2-3b-instruct"), + Some(names::LLAMA3_JSON) + ); + assert_eq!( + factory.resolve_name_for_model("meta-llama/Llama-3.1-8B-Instruct"), + Some(names::LLAMA3_JSON) + ); + assert_eq!( + factory.resolve_name_for_model("deepseek-ai/DeepSeek-V4"), + Some(names::DEEPSEEK_V4) + ); + assert_eq!( + factory.resolve_name_for_model("deepseek-ai/DeepSeek-V3.2-Exp"), + Some(names::DEEPSEEK_V32) + ); + assert_eq!( + factory.resolve_name_for_model("deepseek-ai/DeepSeek-V4-Chat"), + Some(names::DEEPSEEK_V4) + ); + assert_eq!( + factory.resolve_name_for_model("deepseek_v4"), + Some(names::DEEPSEEK_V4) + ); + assert_eq!( + factory.resolve_name_for_model("deepseek-ai/DeepSeek-R1-0528"), + Some(names::DEEPSEEK_V3) + ); + assert_eq!( + factory.resolve_name_for_model("deepseek-ai/DeepSeek-V3.1"), + Some(names::DEEPSEEK_V31) + ); + assert_eq!( + factory.resolve_name_for_model("zai-org/GLM-5-32B-Chat"), + Some(names::GLM47) + ); + assert_eq!( + factory.resolve_name_for_model("zai-org/GLM-5.1-32B-Instruct"), + Some(names::GLM47) + ); + assert_eq!( + factory.resolve_name_for_model("glm-4.7"), + Some(names::GLM47) + ); + assert_eq!( + factory.resolve_name_for_model("google/gemma-4-27b-it"), + Some(names::GEMMA4) + ); + assert_eq!( + factory.resolve_name_for_model("NousResearch/Hermes-3-Llama-3.1-8B"), + Some(names::HERMES) + ); + assert_eq!( + factory.resolve_name_for_model("MiniMax/MiniMax-M2-01"), + Some(names::MINIMAX_M2) + ); + assert_eq!( + factory.resolve_name_for_model("org/mm-m2-base"), + Some(names::MINIMAX_M2) + ); +} diff --git a/rust/src/chat/src/renderer/deepseek_v32/encoding.rs b/rust/src/chat/src/renderer/deepseek_v32/encoding.rs new file mode 100644 index 00000000000..97825519276 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v32/encoding.rs @@ -0,0 +1,555 @@ +//! DeepSeek V3.2 prompt renderer. + +use std::collections::{HashMap, HashSet}; +use std::fmt::Write as _; + +use serde::Serialize; +use serde_json::Value; +use serde_json_fmt::JsonFormat; + +use crate::error::{Error, Result}; +use crate::request::{ChatContent, ChatMessage, ChatRequest, ChatRole, ChatTool}; +use crate::{AssistantContentBlock, AssistantMessageExt, AssistantToolCall}; + +const BOS_TOKEN: &str = "<|begin▁of▁sentence|>"; +const EOS_TOKEN: &str = "<|end▁of▁sentence|>"; +const THINKING_START_TOKEN: &str = ""; +const THINKING_END_TOKEN: &str = ""; +const DSML_TOKEN: &str = "|DSML|"; + +/// DeepSeek uses `"chat"` vs `"thinking"` mode names. Keep the split explicit +/// here so the render branches stay easy to read. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ThinkingMode { + Chat, + Thinking, +} + +/// Tool schema shape rendered inside the `` block. +#[serde_with::skip_serializing_none] +#[derive(Debug, Serialize)] +struct RenderedToolSchema<'a> { + name: &'a str, + description: Option<&'a str>, + parameters: &'a Value, + strict: Option, +} + +/// Render one chat request into the final prompt string. +pub(super) fn render_request(request: &ChatRequest) -> Result { + let thinking_mode = match request.enable_thinking()?.unwrap_or(false) { + true => ThinkingMode::Thinking, + false => ThinkingMode::Chat, + }; + let drop_thinking = matches!( + request.messages.last().map(ChatMessage::role), + Some(ChatRole::User | ChatRole::Developer) + ); + let render_offset = isize::from(request.tool_parsing_enabled()); + let last_user_render_index = + find_last_user_render_index(request.messages.as_slice(), render_offset); + let last_user_actual_index = find_last_user_actual_index(request.messages.as_slice()); + let mut prompt = String::from(BOS_TOKEN); + + if request.tool_parsing_enabled() { + render_system_message(&mut prompt, None, &request.tools)?; + } + + for (message_index, message) in request.messages.iter().enumerate() { + render_message( + &mut prompt, + request.messages.as_slice(), + message_index, + message, + render_offset, + last_user_render_index, + last_user_actual_index, + thinking_mode, + drop_thinking, + )?; + } + + Ok(prompt) +} + +/// Find the last user-like turn in render order. +/// +/// `render_offset` is `1` when a synthetic tool-only system turn is rendered +/// before the real request messages, and `0` otherwise. +fn find_last_user_render_index(messages: &[ChatMessage], render_offset: isize) -> isize { + messages + .iter() + .rposition(|message| matches!(message.role(), ChatRole::User | ChatRole::Developer)) + .map(|index| index as isize + render_offset) + .unwrap_or(-1) +} + +/// Render one real request message, using `render_offset` to account for any +/// synthetic tool-only system turn that was already emitted before the loop. +fn render_message( + out: &mut String, + messages: &[ChatMessage], + message_index: usize, + message: &ChatMessage, + render_offset: isize, + last_user_render_index: isize, + last_user_actual_index: usize, + thinking_mode: ThinkingMode, + drop_thinking: bool, +) -> Result<()> { + let render_index = message_index as isize + render_offset; + let opens_thinking = render_index == last_user_render_index; + let after_last_user_turn = render_index > last_user_render_index; + let after_or_at_last_user_turn = render_index >= last_user_render_index; + + match message { + ChatMessage::System { content } => render_system_message(out, Some(content), &[]), + ChatMessage::Developer { content, tools } => render_developer_message( + out, + content, + tools.as_deref().unwrap_or(&[]), + thinking_mode == ThinkingMode::Thinking && opens_thinking, + ), + ChatMessage::User { content } => render_user_message( + out, + content, + thinking_mode == ThinkingMode::Thinking && opens_thinking, + ), + ChatMessage::Assistant { content } => render_assistant_message( + out, + thinking_mode == ThinkingMode::Thinking && after_last_user_turn, + content, + should_keep_assistant_reasoning( + message_index, + last_user_actual_index, + thinking_mode, + drop_thinking, + ), + // TODO: Respect `continue_final_message` and map it to DeepSeek's + // prefix-style final-assistant continuation behavior. + false, + ), + ChatMessage::ToolResponse { content, .. } => render_tool_message( + out, + messages, + message_index, + thinking_mode == ThinkingMode::Thinking && after_or_at_last_user_turn, + content, + ), + } +} + +/// Historical assistant reasoning is dropped in thinking mode when the final +/// request turn is a new user-like message. +fn should_keep_assistant_reasoning( + actual_index: usize, + last_user_actual_index: usize, + thinking_mode: ThinkingMode, + drop_thinking: bool, +) -> bool { + !(thinking_mode == ThinkingMode::Thinking + && drop_thinking + && actual_index < last_user_actual_index) +} + +/// Return the last user/developer turn in the real request message list. +fn find_last_user_actual_index(messages: &[ChatMessage]) -> usize { + messages + .iter() + .rposition(|message| matches!(message.role(), ChatRole::User | ChatRole::Developer)) + .unwrap_or(usize::MAX) +} + +/// Render a system turn, optionally followed by the tool preamble. +fn render_system_message( + out: &mut String, + content: Option<&ChatContent>, + tools: &[ChatTool], +) -> Result<()> { + if let Some(content) = content { + write_chat_content(out, content)?; + } + if !tools.is_empty() { + out.push_str("\n\n"); + render_tools(out, tools)?; + } + Ok(()) +} + +/// Developer messages are wrapped into the same user-like turn shape as real +/// user messages, but can also carry message-local tools. +fn render_developer_message( + out: &mut String, + content: &ChatContent, + tools: &[ChatTool], + opens_thinking: bool, +) -> Result<()> { + if content.is_empty() { + return Err(Error::ChatTemplate( + "invalid DeepSeek V3.2 developer message: empty content".to_string(), + )); + } + + out.push_str("<|User|>"); + if !tools.is_empty() { + out.push_str("\n\n"); + render_tools(out, tools)?; + } + out.push_str("\n\n# The user's message is: "); + write_chat_content(out, content)?; + write_user_like_suffix(out, opens_thinking); + Ok(()) +} + +/// Plain user turns share the same wrapper shape as developer turns without the +/// developer-specific preamble. +fn render_user_message( + out: &mut String, + content: &ChatContent, + opens_thinking: bool, +) -> Result<()> { + out.push_str("<|User|>"); + write_chat_content(out, content)?; + write_user_like_suffix(out, opens_thinking); + Ok(()) +} + +/// Shared trailing wrapper used by both real user turns and native developer +/// turns after their content has already been written. +// TODO: respect `add_generation_prompt` option +fn write_user_like_suffix(out: &mut String, opens_thinking: bool) { + out.push_str("<|Assistant|>"); + if opens_thinking { + out.push_str(THINKING_START_TOKEN); + } else { + out.push_str(THINKING_END_TOKEN); + } +} + +/// Render one tool result turn and decide whether it opens or closes the shared +/// `` block for the preceding assistant tool-call message. +fn render_tool_message( + out: &mut String, + messages: &[ChatMessage], + message_index: usize, + resumes_thinking: bool, + _content: &ChatContent, +) -> Result<()> { + let (block_start, block_end) = tool_response_block_bounds(messages, message_index); + if message_index != block_start { + return Ok(()); + } + + let Some(prev_assistant_idx) = previous_assistant_actual_index(messages, block_start) else { + return Err(Error::ChatTemplate( + "invalid DeepSeek V3.2 tool message: missing previous assistant message".to_string(), + )); + }; + + let ChatMessage::Assistant { + content: assistant_content, + } = &messages[prev_assistant_idx] + else { + return Err(Error::ChatTemplate( + "invalid DeepSeek V3.2 tool message: previous non-tool message is not assistant" + .to_string(), + )); + }; + + let assistant_tool_calls = assistant_content.tool_calls().collect::>(); + if assistant_tool_calls.is_empty() { + return Err(Error::ChatTemplate( + "invalid DeepSeek V3.2 tool message: previous assistant message has no tool calls" + .to_string(), + )); + } + + let mut expected_tool_call_ids = HashSet::with_capacity(assistant_tool_calls.len()); + for tool_call in &assistant_tool_calls { + if !expected_tool_call_ids.insert(tool_call.id.as_str()) { + return Err(Error::ChatTemplate( + "invalid DeepSeek V3.2 assistant tool calls: duplicate tool_call_id".to_string(), + )); + } + } + + let mut tool_results_by_id = HashMap::with_capacity(assistant_tool_calls.len()); + for message in &messages[block_start..block_end] { + let ChatMessage::ToolResponse { + content, + tool_call_id, + } = message + else { + unreachable!("tool response block should only contain tool messages"); + }; + + if !expected_tool_call_ids.contains(tool_call_id.as_str()) { + return Err(Error::ChatTemplate(format!( + "invalid DeepSeek V3.2 tool message: unknown tool_call_id `{tool_call_id}`" + ))); + } + + if tool_results_by_id.insert(tool_call_id.as_str(), content).is_some() { + return Err(Error::ChatTemplate(format!( + "invalid DeepSeek V3.2 tool message: duplicate tool_call_id `{tool_call_id}`" + ))); + } + } + + if tool_results_by_id.len() != assistant_tool_calls.len() { + return Err(Error::ChatTemplate( + "invalid DeepSeek V3.2 tool messages: missing tool result for assistant tool call" + .to_string(), + )); + } + + out.push_str("\n\n"); + for tool_call in assistant_tool_calls { + let content = tool_results_by_id + .get(tool_call.id.as_str()) + .expect("validated tool_call_id set should be complete"); + out.push_str("\n"); + write_chat_content(out, content)?; + out.push_str(""); + } + + out.push_str("\n"); + out.push_str("\n\n"); + if resumes_thinking { + out.push_str(THINKING_START_TOKEN); + } else { + out.push_str(THINKING_END_TOKEN); + } + + Ok(()) +} + +/// Return the contiguous tool-response block containing `actual_index`. +fn tool_response_block_bounds(messages: &[ChatMessage], actual_index: usize) -> (usize, usize) { + let mut block_start = actual_index; + while block_start > 0 && matches!(messages[block_start - 1], ChatMessage::ToolResponse { .. }) { + block_start -= 1; + } + + let mut block_end = actual_index + 1; + while block_end < messages.len() + && matches!(messages[block_end], ChatMessage::ToolResponse { .. }) + { + block_end += 1; + } + + (block_start, block_end) +} + +/// Return the most recent assistant turn before `actual_index`. +fn previous_assistant_actual_index(messages: &[ChatMessage], actual_index: usize) -> Option { + messages[..actual_index] + .iter() + .rposition(|message| matches!(message, ChatMessage::Assistant { .. })) +} + +/// Render one assistant turn, including optional reasoning, DSML tool calls, +/// and the trailing EOS marker. +fn render_assistant_message( + out: &mut String, + after_last_user_turn: bool, + content: &[AssistantContentBlock], + keep_reasoning: bool, + prefix: bool, +) -> Result<()> { + let has_reasoning = keep_reasoning && content.has_reasoning(); + let has_tool_calls = content.has_tool_calls(); + + if !has_tool_calls && prefix { + write_assistant_text(out, content); + return Ok(()); + } + + if after_last_user_turn { + if !has_reasoning && !has_tool_calls { + return Err(Error::ChatTemplate( + "invalid DeepSeek V3.2 assistant message after last user message: expected reasoning or tool calls" + .to_string(), + )); + } + + if has_reasoning { + write_assistant_reasoning(out, content); + } + out.push_str(THINKING_END_TOKEN); + } + + write_assistant_text(out, content); + + if has_tool_calls { + out.push_str("\n\n<|DSML|function_calls>\n"); + for (index, tool_call) in content.tool_calls().enumerate() { + if index > 0 { + out.push('\n'); + } + render_tool_call(out, tool_call)?; + } + out.push_str("\n"); + } + + out.push_str(EOS_TOKEN); + Ok(()) +} + +/// Render one assistant tool call in DSML XML-like format. +fn render_tool_call(out: &mut String, tool_call: &AssistantToolCall) -> Result<()> { + writeln!(out, "<{DSML_TOKEN}invoke name=\"{}\">", tool_call.name) + .expect("writing to String cannot fail"); + encode_arguments_to_dsml(out, tool_call)?; + write!(out, "\n").expect("writing to String cannot fail"); + Ok(()) +} + +/// Convert one assistant tool-call arguments object into DSML parameter form. +/// +/// String values are emitted raw with `string="true"`, while all other JSON +/// values are rendered with JSON syntax and `string="false"`. +fn encode_arguments_to_dsml(out: &mut String, tool_call: &AssistantToolCall) -> Result<()> { + let arguments: Value = serde_json::from_str(&tool_call.arguments).map_err(|error| { + Error::ChatTemplate(format!( + "assistant tool call has invalid JSON arguments for DeepSeek V3.2: {error}" + )) + })?; + let Some(arguments) = arguments.as_object() else { + return Err(Error::ChatTemplate( + "assistant tool call arguments for DeepSeek V3.2 must be a JSON object".to_string(), + )); + }; + + let mut wrote_parameter = false; + for (key, value) in arguments { + if wrote_parameter { + out.push('\n'); + } + + let is_string = matches!(value, Value::String(_)); + write!( + out, + "<{DSML_TOKEN}parameter name=\"{key}\" string=\"{}\">", + if is_string { "true" } else { "false" } + ) + .expect("writing to String cannot fail"); + + match value { + Value::String(value) => out.push_str(value), + value => out.push_str(&json_dumps(value)?), + } + + write!(out, "").expect("writing to String cannot fail"); + wrote_parameter = true; + } + + Ok(()) +} + +/// Render the full tool preamble shown to the model. +fn render_tools(out: &mut String, tools: &[ChatTool]) -> Result<()> { + out.push_str( + r#"## Tools + +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<|DSML|function_calls>" block like the following as part of your reply to the user: +<|DSML|function_calls> +<|DSML|invoke name="$FUNCTION_NAME"> +<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<|DSML|invoke name="$FUNCTION_NAME2"> +... + + + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). + +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: + +<|DSML|function_calls> +... + + + +... + + +...thinking about results + +Here are the functions available in JSONSchema format: + +"#, + ); + + for (index, tool) in tools.iter().enumerate() { + if index > 0 { + out.push('\n'); + } + render_tool_schema(out, tool)?; + } + + out.push_str("\n\n"); + Ok(()) +} + +/// Serialize one typed tool schema into the JSON shape embedded inside +/// ``. +fn render_tool_schema(out: &mut String, tool: &ChatTool) -> Result<()> { + out.push_str(&json_dumps(&RenderedToolSchema { + name: &tool.name, + description: tool.description.as_deref(), + parameters: &tool.parameters, + strict: tool.strict, + })?); + Ok(()) +} + +/// Write chat content directly into the destination buffer without flattening +/// it into an intermediate `String`. +fn write_chat_content(out: &mut String, content: &ChatContent) -> Result<()> { + match content { + ChatContent::Text(text) => out.push_str(text), + ChatContent::Parts(parts) => { + for part in parts { + out.push_str(part.as_text()?); + } + } + } + Ok(()) +} + +/// Write all reasoning blocks in encounter order. +fn write_assistant_reasoning(out: &mut String, content: &[AssistantContentBlock]) { + for block in content { + if let AssistantContentBlock::Reasoning { text } = block { + out.push_str(text); + } + } +} + +/// Write all visible assistant text blocks in encounter order. +fn write_assistant_text(out: &mut String, content: &[AssistantContentBlock]) { + for block in content { + if let AssistantContentBlock::Text { text } = block { + out.push_str(text); + } + } +} + +/// Compact JSON serialization used by this renderer for exact prompt text. +fn json_dumps(value: &T) -> Result { + JsonFormat::new() + .comma(", ") + .expect("literal comma separator is valid JSON") + .colon(": ") + .expect("literal colon separator is valid JSON") + .ascii(false) + .format_to_string(value) + .map_err(|error| { + Error::ChatTemplate(format!( + "failed to serialize DeepSeek V3.2 JSON payload: {error}" + )) + }) +} diff --git a/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input.json b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input.json new file mode 100644 index 00000000000..0582611470d --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input.json @@ -0,0 +1,149 @@ +{ + "tools": [ + { + "type": "function", + "function": { + "name": "get_datetime", + "description": "Get the current date and time", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone, e.g. Asia/Shanghai, UTC" + } + }, + "required": ["timezone"] + } + } + }, + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a specific date and location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. Beijing, Hangzhou" + }, + "date": { + "type": "string", + "description": "The date in YYYY-MM-DD format" + } + }, + "required": ["location", "date"] + } + } + } + ], + "messages": [ + { + "role": "system", + "content": "You are a helpful Assistant." + }, + { + "role": "user", + "content": "明天杭州和北京的天气怎么样?" + }, + { + "role": "assistant", + "reasoning_content": "用户询问明天的天气,我需要先获取当前日期来计算明天的日期📅", + "tool_calls": [ + { + "id": "call_xK9mN3pL2qR8vT5wY6hZ1aB4", + "type": "function", + "function": { + "arguments": "{\"timezone\": \"Asia/Shanghai\"}", + "name": "get_datetime" + } + } + ] + }, + { + "tool_call_id": "call_xK9mN3pL2qR8vT5wY6hZ1aB4", + "role": "tool", + "content": "{\"current_date\": \"2024-01-15\", \"current_time\": \"14:30:00\", \"timezone\": \"Asia/Shanghai\"}" + }, + { + "role": "assistant", + "reasoning_content": "现在知道今天是2024-01-15,明天就是2024-01-16。接下来查询杭州和北京明天的天气🌤️", + "tool_calls": [ + { + "id": "call_bN7kR9mX3pQ2wL5vY8jZ4cD6", + "type": "function", + "function": { + "arguments": "{\"location\": \"Hangzhou\", \"date\": \"2024-01-16\"}", + "name": "get_weather" + } + }, + { + "id": "call_dP9mL7kX5rT4yN3wZ2hV8eF1", + "type": "function", + "function": { + "arguments": "{\"location\": \"Beijing\", \"date\": \"2024-01-16\"}", + "name": "get_weather" + } + } + ] + }, + { + "tool_call_id": "call_bN7kR9mX3pQ2wL5vY8jZ4cD6", + "role": "tool", + "content": "{\"location\": \"Hangzhou\", \"date\": \"2024-01-16\", \"temperature_high\": \"12\", \"temperature_low\": \"5\", \"weather\": \"多云\", \"humidity\": \"65%\"}" + }, + { + "tool_call_id": "call_dP9mL7kX5rT4yN3wZ2hV8eF1", + "role": "tool", + "content": "{\"location\": \"Beijing\", \"date\": \"2024-01-16\", \"temperature_high\": \"-2\", \"temperature_low\": \"-8\", \"weather\": \"晴\", \"humidity\": \"30%\"}" + }, + { + "role": "assistant", + "reasoning_content": "已获取两个城市明天的天气信息,现在整理给用户✨", + "content": "根据查询结果,明天(2024年1月16日)的天气情况如下:\n\n**杭州**:\n- 天气:多云\n- 最高温度:12°C\n- 最低温度:5°C\n- 湿度:65%\n\n**北京**:\n- 天气:晴\n- 最高温度:-2°C\n- 最低温度:-8°C\n- 湿度:30%\n\n杭州明天会比较温暖但有些多云,而北京会很冷但是晴天。建议在北京的朋友要注意保暖!" + }, + { + "role": "user", + "content": "谢谢!那后天呢?" + }, + { + "role": "assistant", + "reasoning_content": "用户现在问后天的天气,后天是2024-01-17,我可以直接查询(因为已知今天日期)🗓️", + "tool_calls": [ + { + "id": "call_fR3nK8mV7pL4xT2yW9jB5gH3", + "type": "function", + "function": { + "arguments": "{\"location\": \"Hangzhou\", \"date\": \"2024-01-17\"}", + "name": "get_weather" + } + }, + { + "id": "call_hT5pN2kY9rV6zL3wX1mD7jK8", + "type": "function", + "function": { + "arguments": "{\"location\": \"Beijing\", \"date\": \"2024-01-17\"}", + "name": "get_weather" + } + } + ] + }, + { + "tool_call_id": "call_fR3nK8mV7pL4xT2yW9jB5gH3", + "role": "tool", + "content": "{\"location\": \"Hangzhou\", \"date\": \"2024-01-17\", \"temperature_high\": \"15\", \"temperature_low\": \"8\", \"weather\": \"小雨\", \"humidity\": \"80%\"}" + }, + { + "tool_call_id": "call_hT5pN2kY9rV6zL3wX1mD7jK8", + "role": "tool", + "content": "{\"location\": \"Beijing\", \"date\": \"2024-01-17\", \"temperature_high\": \"0\", \"temperature_low\": \"-6\", \"weather\": \"多云\", \"humidity\": \"45%\"}" + }, + { + "role": "assistant", + "reasoning_content": "获取到后天的天气数据,整理回复给用户📝", + "content": "后天(2024年1月17日)的天气情况:\n\n**杭州**:\n- 天气:小雨\n- 最高温度:15°C\n- 最低温度:8°C\n- 湿度:80%\n\n**北京**:\n- 天气:多云\n- 最高温度:0°C\n- 最低温度:-6°C\n- 湿度:45%\n\n杭州后天会有小雨,温度略有回升,记得带伞。北京会稍微暖和一点,但依然很冷,请继续做好保暖措施。" + } + ] +} diff --git a/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input_search_w_date.json b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input_search_w_date.json new file mode 100644 index 00000000000..ccfc2ee7332 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input_search_w_date.json @@ -0,0 +1,732 @@ +{ + "messages": [ + { + "role": "developer", + "content": "帮我调研一下,目前有哪些针对search agent的benchmark?详细介绍各自的特点、使用场景、例题。\n\n\n## Today’s Date\n2025-11-27, Thursday.\n", + "tools": [ + { + "type": "function", + "function": { + "name": "search", + "description": "Searches for information related to query and displays topn results.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query string" + }, + "topn": { + "type": "integer", + "description": "Number of top results to display", + "default": 10 + }, + "source": { + "type": "string", + "description": "Source to search within", + "enum": [ + "web", + "news" + ], + "default": "web" + } + }, + "required": [ + "query" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + }, + { + "type": "function", + "function": { + "name": "open", + "description": "Opens the link id from the page indicated by cursor starting at line number loc, showing num_lines lines. Valid link ids are displayed with the formatting: 【{id}†.*】. If cursor is not provided, the most recent page is implied. If id is a string, it is treated as a fully qualified URL associated with source. If loc is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available. Use this function without id to scroll to a new location of an opened page.", + "parameters": { + "type": "object", + "properties": { + "id": { + "description": "Link ID (number) or fully qualified URL (string)", + "anyOf": [ + { + "type": "integer" + }, + { + "type": "string" + } + ], + "default": -1 + }, + "cursor": { + "type": "integer", + "description": "Page cursor indicator", + "default": -1 + }, + "loc": { + "type": "integer", + "description": "Starting line number", + "default": -1 + }, + "num_lines": { + "type": "integer", + "description": "Number of lines to show", + "default": -1 + }, + "view_source": { + "type": "boolean", + "description": "Whether to view source", + "default": false + }, + "source": { + "type": "string", + "description": "Source associated with the URL", + "enum": [ + "web", + "news" + ], + "default": "web" + } + }, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + }, + { + "type": "function", + "function": { + "name": "find", + "description": "Finds exact matches of pattern in the current page, or the page given by cursor.", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "The exact pattern to search for" + }, + "cursor": { + "type": "integer", + "description": "Page cursor to search in", + "default": -1 + } + }, + "required": [ + "pattern" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + } + ] + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "用户想调研目前针对搜索智能体(search agent)的基准测试(benchmark),要求详细介绍各自的特点、使用场景和例题。我需要提供一份全面的概述,涵盖现有的基准测试,例如WebGPT、WebShop、ALFWorld、HotpotQA、FEVER、TriviaQA、Natural Questions、SearchQA、MS MARCO、BEIR、LoTTE、MIRACL、Zero-shot IR基准等。同时,需要区分这些基准测试是针对检索任务、问答任务,还是具体针对搜索智能体的评估。此外,还应考虑那些专门为评估搜索智能体(即能够与环境交互、执行搜索、浏览网页、使用工具的智能体)而设计的基准测试,例如WebGPT、WebShop、ALFWorld、Mind2Web、WebArena、VisualWebArena、SearchBench、WebVLN、WebNav等。我需要详细介绍每个基准测试的特点、使用场景,并举例说明。为了确保信息的全面性和时效性,我将同时进行两项搜索:一项是针对“search agent benchmark”的通用搜索,以获取广泛的基准测试信息;另一项是针对“搜索智能体 基准测试”的中文搜索,以获取更符合用户语言习惯和可能更具体的中文资料。", + "tool_calls": [ + { + "id": "fixture-tool-call-1-0", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"search agent benchmark 2024\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-1-1", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"搜索智能体 基准测试\", \"topn\": 10, \"source\": \"web\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[0] search agent benchmark 2024 (Search_Results/search agent benchmark 2024)\n**viewing lines [0 - 39] of 39**\n\nL0: \nL1: URL: Search_Results/search agent benchmark 2024\nL2: # Search Results\nL3: \nL4: * 【0†HSCodeComp: A Realistic and Expert-level Benchmark for ...; publish_date:\nL5: none†arxiv.org】 Oct 22, 2025 — To fill this gap, we introduce HSCodeComp, the \nL6: first realistic, expert-level e-commerce benchmark designed to evaluate deep \nL7: search agents in ...\nL8: * 【1†open-compass/GTA - A Benchmark for General Tool Agents; publish_date: \nL9: none†github.com】 GTA is a benchmark to evaluate the tool-use capability of LLM-\nL10: based agents in real-world scenarios. It features three main aspects.\nL11: * 【2†Benchmarking real-time trust scoring across five AI Agent ...; \nL12: publish_date: none†cleanlab.ai】 Aug 20, 2025 — This article evaluates 5 AI Agent\nL13: architectures over the BOLAA (ICLR 2024) benchmark, and assesses the effects of\nL14: adding automated trust ...\nL15: * 【3†10 AI agent benchmarks; publish_date: none†www.evidentlyai.com】 Jul 11, \nL16: 2025 — We put together 10 AI agent benchmarks designed to assess how well \nL17: different LLMs perform as agents in real-world scenarios, ...\nL18: * 【4†A state-of-the-art search API purpose-built for agents; publish_date: \nL19: none†parallel.ai】 Jul 31, 2025 — To evaluate real-world performance of the \nL20: Parallel Search MCP Server, we created the WISER-Search benchmark which blends \nL21: WISER-Fresh (queries ...\nL22: * 【5†AI Agent Benchmarks are Broken; publish_date: none†medium.com】 We break \nL23: down the failure modes in current AI agent benchmarks and introduce a checklist \nL24: that minimizes the gamability of AI agent benchmarks.\nL25: * 【6†Benchmarks and Tree Search for Multimodal LLM Web Agents; publish_date: \nL26: none†dpfried.github.io】 2024, When is Tree Search Useful? ○ Dealing with \nL27: destructive actions. ○ Some things on the web are very difficult to undo, e.g., \nL28: ordering an item. 56.\nL29: * 【7†-Bench: Benchmarking AI agents for the real-world; publish_date: \nL30: none†sierra.ai】 Jun 20, 2024 — τ-bench measures an agent's ability to interact \nL31: with (simulated) human users and programmatic APIs while following domain-\nL32: specific policies in a consistent ...\nL33: * 【8†Browser Use = state of the art Web Agent; publish_date: none†browser-\nL34: use.com】 Dec 15, 2024 — Browser Use has achieved state-of-the-art performance on\nL35: the WebVoyager benchmark, with an impressive 89.1% success rate across 586 \nL36: diverse web tasks.\nL37: * 【9†FutureSearch Benchmarks; publish_date: none†evals.futuresearch.ai】 Find \nL38: the original source of a given claim. Example: From , more than 8 out of 1000 \nL39: users clicked on a phishing link monthly in 2024, up 190% vs 2023.", + "tool_call_id": "fixture-tool-call-1-0" + }, + { + "role": "tool", + "content": "[1] 搜索智能体 基准测试 (Search_Results/搜索智能体 基准测试)\n**viewing lines [0 - 33] of 33**\n\nL0: \nL1: URL: Search_Results/搜索智能体 基准测试\nL2: # Search Results\nL3: \nL4: * 【0†WideSearch:揭示AI 智能体缺失的「广度」能力; publish_date: none†zhuanlan.zhihu.com】 Aug \nL5: 16, 2025 — 为系统评估智能体在该任务上的能力,论文构建了第一个专门的基准测试 WideSearch ,包含200 个源于真实世界、横跨18 \nL6: 个领域的高质量任务。 通过对超过10 个 ...\nL7: * 【1†GAIA: 一个严苛的智能体基准- HuggingFace; publish_date: none†www.cnblogs.com】 Jul 9,\nL8: 2024 — 我们使用一个用库构建的代码智能体 在GAIA 基准上进行测试,这可以说是最困难、最全面的智能体基准测试……最终我们取得了第一名的成绩! \nL9: GAIA: 一个严苛的 ...\nL10: * 【2†AI搜索智能体遭遇新挑战:滑铁卢大学团队提出更公平透明的 ...; publish_date: none†www.techwalker.com】 \nL11: Aug 14, 2025 — \nL12: 目前评测AI搜索智能体主要依靠BrowseComp这样的基准测试,它就像一场实时的开卷考试,让AI在真实的网络环境中搜索信息来回答复杂问题。听起来很合理 ...\nL13: * 【3†Agentic AI基础设施实践经验系列(六):Agent质量评估 - AWS; publish_date: \nL14: none†aws.amazon.com】 Sep 19, 2025 — TAU-bench \nL15: 是一个评估AI智能体在真实世界环境中可靠性的基准测试。它评估智能体是否能够在动态的多轮对话中与用户进行交互,理解需求并完成任务。T-bench ...\nL16: * 【4†DeepAgent:能自己找工具的通用推理智能体 - 高瓴人工智能学院; publish_date: none†ai.ruc.edu.cn】 \nL17: Nov 6, 2025 — 在八大基准测试中,DeepAgent在绝大多数任务上全面领先所有基线模型。 \nL18: 开放环境优势:在最具挑战的“开放工具检索”场景下(如ToolBench),其成功率达到64%,远 ...\nL19: * 【5†BrowseComp:为浏览智能体设计的简单而具挑战性的基准测试; publish_date: none†blog.csdn.net】 Sep \nL20: 22, 2025 — 该基准测试由OpenAI团队开发,旨在推动更可信赖和可靠的AI代理研究。 核心特点. \nL21: 挑战性问题设计:BrowseComp的问题设计遵循严格的难度标准:. 人类创建者确保 ...\nL22: * 【6†什么是GAIA?; publish_date: none†huggingface.co】 什么是GAIA? GAIA \nL23: 是一个用于评估AI助手在需要核心能力组合的真实世界任务上的表现的基准,这些核心能力包括推理、多模态理解、网页浏览和熟练的工具使用。\nL24: * 【7†SuperCLUE-Agent: Agent智能体中文原生任务能力测评基准; publish_date: \nL25: none†www.cluebenchmarks.com】 AI \nL26: Agent是一个能够自主理解、规划决策和执行复杂任务的智能体,有代表性的Agent比如AutoGPT、AgentGPT等。 \nL27: 现有关于Agent能力的测评,主要是在英文场景或任务的测评。目前还 ...\nL28: * 【8†Gaia2 与ARE:赋能社区的智能体评测; publish_date: none†hub.baai.ac.cn】 Sep 24, 2025 — \nL29: 快速评估任意智能体:基于真实或模拟数据,测试不同规则、工具、内容和验证方式下的表现 · 测试智能体的 工具调用与编排能力:可结合本地应用或MCP 工具 · 生成\nL30: ...\nL31: * 【9†许多AI 智能体评测基准并不可靠; publish_date: none†my.oschina.net】 Aug 13, 2025 — \nL32: 当前AI智能体基准测试存在严重缺陷,亟需更严谨的评估框架。作者提出AI智能体基准测试核查清单(ABC),包含43个项目,旨在识别基准测试潜在陷阱,构建可靠评估 \nL33: ...", + "tool_call_id": "fixture-tool-call-1-1" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "第一个搜索结果(针对“search agent benchmark 2024”)显示了一些相关条目,例如“HSCodeComp: A Realistic and Expert-level Benchmark for ...”、“open-compass/GTA - A Benchmark for General Tool Agents”、“10 AI agent benchmarks”、“WISER-Search benchmark”、“Benchmarks and Tree Search for Multimodal LLM Web Agents”、“τ-bench”、“WebVoyager benchmark”、“FutureSearch Benchmarks”。这些可能都与搜索智能体相关。第二个搜索结果(针对“搜索智能体 基准测试”)显示了一些中文资料,例如“WideSearch”、“BrowseComp”、“GAIA”、“SuperCLUE-Agent”等。这些都可能与搜索智能体相关。为了全面了解这些基准测试,我们需要同时打开并分析其中几个最有潜力的结果。具体来说,我们将打开第一个搜索结果中的“HSCodeComp”(索引0)和“open-compass/GTA”(索引1),以及第二个搜索结果中的“WideSearch”(索引0)和“BrowseComp”(索引5)。", + "tool_calls": [ + { + "id": "fixture-tool-call-4-0", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 0, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-4-1", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 0, \"id\": 1}" + } + }, + { + "id": "fixture-tool-call-4-2", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 1, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-4-3", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 1, \"id\": 5}" + } + } + ] + }, + { + "role": "tool", + "content": "[2] (https://ar5iv.org/html/2510.19631v1)\n**viewing lines [0 - 119] of 983**\n\nL0: \nL1: URL: https://ar5iv.org/html/2510.19631v1\nL2: # HSCodeComp: A Realistic and Expert-level Benchmark for Deep Search Agents in \nL3: Hierarchical Rule Application\nL4: \nL5: Yiqian Yang† Tian Lan† Qianghuai Jia∗ Li Zhu Hui Jiang Hang Zhu Longyue Wang \nL6: Weihua Luo Kaifu Zhang\nL7: \nL8: Alibaba International Digital Commerce∗* Corresponding Author: Qianghuai Jia \nL9: (qianghuai.jqh@alibaba-inc.com)\nL10: †\\dagger Equal Contribution: Yiqian Yang\nL11: \nL12: Tian Lan\nL13: \nL14: ###### Abstract\nL15: \nL16: Abstract\nL17: \nL18: Effective deep search agents must not only access open-domain and domain-\nL19: specific knowledge but also apply complex rules—such as legal clauses, medical \nL20: manuals and tariff rules. These rules often feature vague boundaries and \nL21: implicit logic relationships, making precise application challenging for agents.\nL22: However, this critical capability is largely overlooked by current agent \nL23: benchmarks. To fill this gap, we introduce HSCodeComp, the first realistic, \nL24: expert-level e-commerce benchmark designed to evaluate deep search agents in \nL25: hierarchical rule application. In this task, the deep reasoning process of \nL26: agents is guided by these rules to predict 10-digit Harmonized System Code \nL27: (HSCode) of products with noisy but realistic descriptions. These codes, \nL28: established by the World Customs Organization, are vital for global supply chain\nL29: efficiency. Built from real-world data collected from large-scale e-commerce \nL30: platforms, our proposed HSCodeComp comprises 632 product entries spanning \nL31: diverse product categories, with these HSCodes annotated by several human \nL32: experts. Extensive experimental results on several state-of-the-art LLMs, open-\nL33: source, and closed-source agents reveal a huge performance gap: best agent \nL34: achieves only 46.8% 10-digit accuracy, far below human experts at 95.0%. \nL35: Besides, detailed analysis demonstrates the challenges of hierarchical rule \nL36: application, and test-time scaling fails to improve performance further.\nL37: \nL38: ## 1 Introduction\nL39: \nL40: Deep search agents have demonstrated significant value in solving complex real-\nL41: world problems, where robust external knowledge utilization constitutes a \nL42: critical capability [Wu et al., 2025, Tao et al., 2025, Li et al., 2025b]. To \nL43: evaluate this capability, numerous established benchmarks are proposed to assess\nL44: agents in utilizing open-domain data (e.g., GAIA [Mialon et al., 2023b] and \nL45: BrowseComp [Wei et al., 2025]) and domain-specific data (e.g., WebMall [Peeters \nL46: et al., 2025a], FinSearchComp [Hu et al., 2025a] and MedBrowseComp [Yu et al., \nL47: 2025b]).\nL48: \nL49: Beyond open-domain and domain-specific data, agents also need to effectively \nL50: apply rules that encode human expert knowledge, particularly in scenarios like \nL51: law, medical and e-commerce [Li et al., 2025a, Chen et al., 2025b, Yao et al., \nL52: 2022, Chollet et al., 2025]. For instance, legal case adjudication require \nL53: interpreting abstract legal provisions, and accurate e-commerce product \nL54: classification in depends on tariff rules [Grainger, 2024]. Previous works have \nL55: defined rule application as using specific logical rules with supporting facts \nL56: to derive conclusions [Wang et al., 2024, Servantez et al., 2024]. In contrast, \nL57: we define it as a core capability for deep search agents, where human-written \nL58: rules are systematically applied to guide complex reasoning and decision-making \nL59: [Sadowski and Chudziak, 2025]. Building on this observation, we categorize \nL60: knowledge data for deep search agents into three levels (Figure 1, left), with \nL61: increasing knowledge complexity: (1) Level 1: Open-domain Data - Tests \nL62: understanding and deep reasoning abilities of agents on long-form web content. \nL63: Established benchmarks include GAIA [Mialon et al., 2023b] and BrowseComp [Wei \nL64: et al., 2025]; (2) Level 2: Structured Data - Assesses agents to precisely \nL65: utilize structured data such as databases and knowledge graphs, as seen in \nL66: domain-specific benchmarks like WebMall [Peeters et al., 2025a], MedBrowseComp \nL67: [Chen et al., 2025b] and FinSearchComp [Hu et al., 2025a]; (3) Level 3: Rule \nL68: Data - Evaluates agents to apply complex and abstract rules [Chollet et al., \nL69: 2025]. This level presents two key challenges: (a) making accurate decisions \nL70: when rules contain vague natural language descriptions [Sadowski and Chudziak, \nL71: 2025]; and (b) reasoning about logical dependencies among rules, such as \nL72: exception clauses and cross-category relationships [Guha et al., 2023]. Despite \nL73: the importance of rule application in real-world scenarios, current agent \nL74: benchmarks largely overlook its evaluation.\nL75: \nL76: To fill this gap, we introduce HSCodeComp (short for the Harmonized System Code \nL77: (HSCode) Competition), the first realistic, expert-level e-commerce benchmark \nL78: designed to evaluate agents in predicting complete 10-digit Harmonized System \nL79: Code (HSCode) of the product, using hierarchical rules (e.g., eWTP tariff \nL80: rules111https://www.ewtp.com/web/smart/hscode). HSCodes organize products \nL81: through a hierarchical structure spanning over 5,000 distinct codes across \nL82: multiple classification levels, representing the global standard for classifying\nL83: traded international goods, established by the World Customs Organization and \nL84: implemented across more than 200 countries for customs clearance and tariff \nL85: determination [Grainger, 2024, Nath et al., 2025]. Built from the data of the \nL86: large-scale e-commerce platforms, our proposed HSCodeComp comprises 632 \nL87: carefully curated product entries, encompassing 27 unique HS chapters and 32 \nL88: distinct first-level categories. These HSCodes have been rigorously annotated by\nL89: multiple e-commerce domain experts, ensuring that HSCodeComp is expert-level. \nL90: Accurately predicting the exact 10-digit HSCode presents significant challenges:\nL91: agents must perform multi-hop hierarchical reasoning with complex tariff rules \nL92: while processing noisy but realistic product descriptions that often contain \nL93: abbreviations, language variations, or incomplete information.\nL94: \nL95: Extensive experiments on the state-of-the-art baselines, including 14 advanced \nL96: foundation models, 6 advanced open-source agent systems and 3 closed-source \nL97: agent systems, demonstrate that HSCode prediction task remains a substantial \nL98: challenge for current AI approaches. As shown in the Figure 1 (right), even the \nL99: best-performing system (SmolAgent [Roucher et al., 2025] with GPT-5) achieves \nL100: only 46.8% accuracy, substantially below the 95.0% accuracy attained by human \nL101: experts. Further detailed analysis reveals that existing agent systems lack \nL102: critical capabilities required for this complex hierarchical rule applications. \nL103: Notably, test-time scaling approach—which has proven effective in other \nL104: reasoning tasks [Guo et al., 2025, Liu et al., 2025]—fail to improve performance\nL105: on HSCodeComp. These observations demonstrate the challenging nature of our \nL106: proposed HSCodeComp, highlighting the need for more effective designs of agent \nL107: systems. To facilitate future research, we will publicly release codes and the \nL108: benchmark dataset of HSCodeComp.\nL109: \nL110: ## 2 Related Works\nL111: \nL112: ### 2.1 Previous Works in HSCode Prediction\nL113: \nL114: Previous works treat HSCode prediction as the e-commerce text classification \nL115: task [Grainger, 2024], using pre-trained BERT models [Liao et al., 2024, Shubham\nL116: et al., 2022] or Large Language Models (LLMs) prompting [Hussain and Ahmed, \nL117: 2023]. However, these approaches fail to leverage domain-specific knowledge, \nL118: especially the rules written by human experts [Hussain and Ahmed, 2023, Judy, \nL119: 2024]. Besides, existing HSCode benchmarks face two critical limitations [Judy, ", + "tool_call_id": "fixture-tool-call-4-0" + }, + { + "role": "tool", + "content": "Error fetching URL `https://github.com/open-compass/GTA`", + "tool_call_id": "fixture-tool-call-4-1" + }, + { + "role": "tool", + "content": "Error fetching URL `https://zhuanlan.zhihu.com/p/1939316761775301093`", + "tool_call_id": "fixture-tool-call-4-2" + }, + { + "role": "tool", + "content": "[3] (https://blog.csdn.net/qq_41472205/article/details/151974603)\n**viewing lines [0 - 167] of 187**\n\nL0: \nL1: URL: https://blog.csdn.net/qq_41472205/article/details/151974603\nL2: # BrowseComp:为浏览智能体设计的简单而具挑战性的基准测试\nL3: \nL4: BrowseComp:AI浏览能力评估基准\nL5: \nL6: 最新推荐文章于 2025-11-12 13:40:20 发布\nL7: \nL8: 原创 于 2025-09-22 22:33:04 发布 · 1.3k 阅读\nL9: \nL10: · 9\nL11: · 25 · \nL12: CC 4.0 BY-SA版权\nL13: \nL14: 版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。\nL15: \nL16: ## BrowseComp:为浏览智能体设计的简单而具挑战性的基准测试\nL17: \nL18: 在人工智能从基础聊天机器人向推理器和智能体发展的进程中,具备浏览互联网能力的人工智能模型正变得越来越重要。今天,我们将介绍一个名为BrowseComp的创新基准\nL19: 测试,它专门设计用于评估AI代理在复杂网络浏览任务中的能力。\nL20: \nL21: ### 什么是BrowseComp?\nL22: \nL23: BrowseComp(全称Browsing Competition)是一个包含1,266个挑战性问题的基准测试集,专门用于衡量AI代理在互联网上持续导航、寻找难\nL24: 以找到的纠缠信息的能力。该基准测试由OpenAI团队开发,旨在推动更可信赖和可靠的AI代理研究。\nL25: \nL26: #### 核心特点\nL27: \nL28: 挑战性问题设计:BrowseComp的问题设计遵循严格的难度标准:\nL29: \nL30: - 人类创建者确保问题在10分钟内无法被人解决\nL31: - 现有模型(包括带浏览功能的ChatGPT和早期版本的OpenAI Deep Research)无法解决\nL32: - 通过5次简单Google搜索无法在结果首页找到答案\nL33: \nL34: 简单易验证:尽管问题极具挑战性,但答案形式简单——都是短字符串,便于自动验证模型输出的正确性。\nL35: \nL36: ### 为什么需要BrowseComp?\nL37: \nL38: #### 现有基准的局限性\nL39: \nL40: 传统的信息检索基准(如TriviaQA、HotpotQA等)主要关注易于查找的信息,随着语言模型的进步,这些基准已经趋于饱和。而BrowseComp专注于那些需\nL41: 要浏览大量网站才能解决的\"硬核\"问题。\nL42: \nL43: #### 模拟真实挑战\nL44: \nL45: BrowseComp问题通常采用\"逆向设计\"方法:创建者从一个已知事实出发,构建一个搜索空间巨大但验证简单的问题。例如:\nL46: \nL47: “找出2018-2023年间在EMNLP会议上发表、第一作者本科毕业于达特茅斯学院、第四作者本科毕业于宾夕法尼亚大学的科学论文标题”\nL48: \nL49: 这类问题验证简单,但解决起来需要检查数千篇论文并调查每位作者的背景。\nL50: \nL51: ### 数据集特点\nL52: \nL53: #### 主题多样性\nL54: \nL55: BrowseComp涵盖了广泛的主题领域(如图2所示),包括历史、科学、文化等。创建者被鼓励基于个人兴趣设计问题,这有助于提高数据质量和参与度。\nL56: \nL57: #### 质量保证\nL58: \nL59: 为确保答案的唯一性,创建者需要:\nL60: \nL61: - 对问题内容有足够了解,确信没有其他有效答案\nL62: - 如果不确定,则添加更多约束条件\nL63: - 接受其他创建者的验证反馈\nL64: \nL65: ### 人类表现基准\nL66: \nL67: 为了衡量BrowseComp的难度,研究人员让人类创建者尝试解决问题(不能解答自己创建的问题)。结果显示:\nL68: \nL69: - **70.8%**的问题在2小时搜索后人类选择放弃\nL70: - **29.2%**的问题被成功解决\nL71: - 在解决的问题中,**86.4%**的人类答案与参考答案一致\nL72: \nL73: 这表明BrowseComp确实极具挑战性,即使是熟悉数据集的人类专家也难以在有限时间内解决大部分问题。\nL74: \nL75: ### AI模型表现评估\nL76: \nL77: #### 各模型对比\nL78: \nL79: 研究人员评估了多种模型在BrowseComp上的表现:\nL80: \nL81: 模型 | 准确率(%) | 校准误差(%) \nL82: ---|---|---\nL83: GPT-4o | 0.6 | 69 \nL84: GPT-4o(带浏览) | 1.9 | 82 \nL85: GPT-4.5 | 0.9 | 68 \nL86: OpenAI o1 | 9.9 | 65 \nL87: Deep Research | 51.5 | 91 \nL88: \nL89: #### 关键发现\nL90: \nL91: - 基础模型表现不佳:GPT-4o和GPT-4.5准确率接近零,凸显了基准的难度\nL92: - 浏览功能带来有限提升:启用浏览功能的GPT-4o准确率略有提高,但仍很低\nL93: - 推理能力的重要性:OpenAI o1虽然没有浏览能力,但凭借更强的推理能力获得较高准确率\nL94: - 专业模型的优势:专门为持久网络浏览训练的Deep Research模型解决了约一半的问题\nL95: \nL96: #### 计算资源与性能关系\nL97: \nL98: 研究表明,BrowseComp性能随测试时计算资源的增加而平滑提升(如图1所示)。这与智能体模型的特性一致——更多计算资源允许模型浏览更多网站,从而提高找到正确\nL99: 答案的机会。\nL100: \nL101: ### 进阶策略分析\nL102: \nL103: #### 聚合策略的效果\nL104: \nL105: 通过让模型多次尝试同一问题并采用投票策略,可以显著提升性能:\nL106: \nL107: - 多数投票:选择样本中最常见的答案\nL108: - 加权投票:根据模型置信度加权投票\nL109: - 最佳选择:选择置信度最高的答案\nL110: \nL111: 这些方法将Deep Research的性能提升了15-25%,表明模型通常能够识别自己的正确答案。\nL112: \nL113: #### 任务难度分布\nL114: \nL115: 分析显示,BrowseComp中的任务难度分布广泛:\nL116: \nL117: - 16%的任务被Deep Research完美解决(100%通过率)\nL118: - 14%的任务完全失败(0%通过率)\nL119: - 其余任务处于中间难度水平\nL120: \nL121: ### BrowseComp的意义与局限性\nL122: \nL123: #### 作为评估工具的价值\nL124: \nL125: BrowseComp可被视为浏览智能体的\"编程竞赛\"——虽然不全面,但对核心浏览能力提供了有用的衡量。正如在编程竞赛中表现出色的模型很可能在其他编码任务中表现良\nL126: 好,能够解决BrowseComp的模型在定位难以查找信息方面应该具有很强的能力。\nL127: \nL128: #### 已知局限性\nL129: \nL130: - 不涉及真实用户查询分布:如生成长回答或解决查询歧义等挑战\nL131: - 答案唯一性假设:虽然尽力确保,但不能完全保证没有其他有效答案\nL132: - 单模态限制:当前仅关注文本信息检索,未来可扩展至图像、视频等模态\nL133: \nL134: ### 结论与展望\nL135: \nL136: BrowseComp为评估AI代理的网络浏览能力提供了一个具有挑战性且易于使用的基准。它强调的持久性、创造性和战略性浏览能力,正是未来AI智能体在实际应用中所需\nL137: 的核心技能。\nL138: \nL139: 随着更多研究团队在BrowseComp上测试他们的模型,我们期待看到浏览智能体技术的快速进步,最终使AI能够更可靠地帮助我们获取那些深藏在互联网角落的宝贵信息。\nL140: \nL141: 访问地址:BrowseComp已在GitHub上开源,地址为 https://github.com/openai/simple-evals\nL142: \nL143: 添加红包\nL144: \nL145: 请填写红包祝福语或标题\nL146: \nL147: 个\nL148: \nL149: 红包个数最小为10个\nL150: \nL151: 元\nL152: \nL153: 红包金额最低5元\nL154: \nL155: 当前余额3.43元 前往充值 >\nL156: \nL157: 需支付:10.00元\nL158: \nL159: 成就一亿技术人!\nL160: \nL161: hope_wisdom\nL162: 发出的红包 \nL163: 实付元\nL164: \nL165: 点击重新获取\nL166: \nL167: 扫码支付", + "tool_call_id": "fixture-tool-call-4-3" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "打开的HSCodeComp页面详细介绍了这个基准测试,它是一个用于评估深度搜索智能体在分层规则应用方面能力的基准测试,特点包括:基于真实世界电商数据、任务复杂、需要多跳层次推理、使用嘈杂的产品描述,并且标注由专家完成。使用场景是评估智能体在规则应用方面的能力,例如关税代码分类。例题是预测产品的10位HS编码。打开的BrowseComp页面详细介绍了这个基准测试,它是一个用于评估浏览智能体(即能够在互联网上导航、搜索信息的智能体)的基准测试。其特点包括:包含1266个具有挑战性的问题,需要持续导航和搜索难以查找的信息;答案形式简单,便于自动验证;问题设计确保人类在10分钟内无法解决,且现有模型也无法解决;涵盖了广泛的主题领域。使用场景是评估智能体在复杂网络浏览任务中的能力。例题是:“找出2018-2023年间在EMNLP会议上发表、第一作者本科毕业于达特茅斯学院、第四作者本科毕业于宾夕法尼亚大学的科学论文标题”。另外两个打开的页面(open-compass/GTA和WideSearch)出现了错误,可能链接已失效或需要其他访问方式。为了进一步获取其他相关基准测试的信息,我们需要同时进行多项搜索。考虑到用户可能对搜索智能体基准测试有广泛需求,我们将同时搜索“WebVoyager benchmark”、“WebArena benchmark”、“GAIA benchmark”、“ToolBench benchmark”、“HotpotQA benchmark”、“FEVER benchmark”、“TriviaQA benchmark”、“Natural Questions benchmark”、“MS MARCO benchmark”、“BEIR benchmark”、“MIRACL benchmark”、“Zero-shot IR benchmark”、“WebGPT benchmark”、“WebShop benchmark”、“ALFWorld benchmark”、“Mind2Web benchmark”、“VisualWebArena benchmark”、“SearchBench benchmark”、“WebVLN benchmark”、“WebNav benchmark”。这些搜索将帮助我们全面了解现有基准测试的概况。", + "tool_calls": [ + { + "id": "fixture-tool-call-9-0", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"WebVoyager benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-1", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"WebArena benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-2", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"GAIA benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-3", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"ToolBench benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-4", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"HotpotQA benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-5", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"FEVER benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-6", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"TriviaQA benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-7", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"Natural Questions benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-8", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"MS MARCO benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-9", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"BEIR benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-10", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"MIRACL benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-11", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"Zero-shot IR benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-12", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"WebGPT benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-13", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"WebShop benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-14", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"ALFWorld benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-15", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"Mind2Web benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-16", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"VisualWebArena benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-17", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"SearchBench benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-18", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"WebVLN benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + }, + { + "id": "fixture-tool-call-9-19", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"WebNav benchmark\", \"topn\": 10, \"source\": \"web\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[4] WebVoyager benchmark (Search_Results/WebVoyager benchmark)\n**viewing lines [0 - 40] of 40**\n\nL0: \nL1: URL: Search_Results/WebVoyager benchmark\nL2: # Search Results\nL3: \nL4: * 【0†WebVoyager Benchmark; publish_date: none†www.trykura.com】 Performance \nL5: measured across various websites in WebVoyager shows Kura consistently \nL6: outperforming other agents, with particularly strong results on e-commerce ...\nL7: * 【1†WebVoyager: Building an End-to-End Web Agent with ...; publish_date: \nL8: none†github.com】 WebVoyager is an innovative Large Multimodal Model (LMM) \nL9: powered web agent that can complete user instructions end-to-end by interacting \nL10: with real-world ...\nL11: * 【2†AI Browser Agent Leaderboard | Steel.dev; publish_date: \nL12: none†leaderboard.steel.dev】 See how various AI browser agents stack up based on \nL13: their accuracy in completing web-based tasks on the WebVoyager benchmark.\nL14: * 【3†[2401.13919] WebVoyager: Building an End-to-End Web ...; publish_date: \nL15: none†arxiv.org】 by H He · 2024 · Cited by 282 — We show that WebVoyager achieves\nL16: a 59.1% task success rate on our benchmark, significantly surpassing the \nL17: performance of both GPT-4 (All ...\nL18: * 【4†Our Agent-E SOTA Results on the WebVoyager Benchmark; publish_date: \nL19: none†www.emergence.ai】 Jul 11, 2024 — WebVoyager is a benchmark that tests an \nL20: agent's capabilities for navigation on dynamic live websites. It is more \nL21: representative than WebArena [4] ...\nL22: * 【5†Browser Use = state of the art Web Agent; publish_date: none†browser-\nL23: use.com】 Dec 15, 2024 — Browser Use has achieved state-of-the-art performance on\nL24: the WebVoyager benchmark, with an impressive 89.1% success rate across 586 \nL25: diverse web tasks.\nL26: * 【6†Magnitude achieves SOTA 94% on WebVoyager benchmark; publish_date: \nL27: none†github.com】 Magnitude achieves state-of-the-art performance with 93.9% \nL28: success rate on WebVoyager, beating all other browser agents.\nL29: * 【7†WebVoyager: Autonomous Web Agent Benchmark; publish_date: \nL30: none†www.emergentmind.com】 3 days ago — WebVoyager Benchmark is a comprehensive \nL31: evaluation suite for autonomous web agents, featuring 643 tasks across 15 \nL32: popular websites.\nL33: * 【8†WebVoyager Benchmark Results; publish_date: none†www.browserable.ai】 \nL34: Browserable has achieved 90.4% on the WebVoyager benchmark. This is best-in-\nL35: class performance across all web agents. This was done across 567 web tasks \nL36: which ...\nL37: * 【9†89% achieved on WebVoyager using Anchor + Browser Use; publish_date: \nL38: none†www.reddit.com】 Thanks to the amazing work from the browser-use open-source\nL39: community and the built-in support from Anchor Browser, we've hit an 89% score \nL40: on WebVoyager.", + "tool_call_id": "fixture-tool-call-9-0" + }, + { + "role": "tool", + "content": "[5] WebArena benchmark (Search_Results/WebArena benchmark)\n**viewing lines [0 - 42] of 42**\n\nL0: \nL1: URL: Search_Results/WebArena benchmark\nL2: # Search Results\nL3: \nL4: * 【0†WebArena: A Realistic Web Environment for Building ...; publish_date: \nL5: none†webarena.dev】 Our benchmark is implemented in our fully interactable \nL6: highly-realistic WebArena environment. It features diverse tasks human may \nL7: encounter in their daily ...\nL8: * 【1†[2307.13854] WebArena: A Realistic Web Environment for ...; publish_date:\nL9: none†arxiv.org】 by S Zhou · 2023 · Cited by 637 — Building upon our \nL10: environment, we release a set of benchmark tasks focusing on evaluating the \nL11: functional correctness of task completions.\nL12: * 【2†WebArena: A Realistic Web Environment for Building ...; publish_date: \nL13: none†www.cmu.edu】 WebArena introduces a benchmark on interpreting high-level \nL14: realistic natural language command to concrete web-based interactions. We \nL15: provide annotated programs ...\nL16: * 【3†GitHub - web-arena-x/webarena: Code repo for ...; publish_date: \nL17: none†github.com】 [12/20/2024] Check out our new benchmark on even more \nL18: consequential tasks, including terminal use and coding, TheAgentCompany. \nL19: [12/21/2023] We release the ...\nL20: * 【4†WebArena Benchmark and the State of Agentic AI; publish_date: \nL21: none†medium.com】 In short, WebArena established a new standard for realism and \nL22: complexity in web agent evaluation, forcing AI agents to operate in dynamic, \nL23: high- ...\nL24: * 【5†WebArena: A Realistic Web Environment for Building ...; publish_date: \nL25: none†huggingface.co】 Jul 25, 2023 — WebArena, a realistic and reproducible \nL26: environment, evaluates the performance of autonomous agents performing complex \nL27: tasks on websites using ...\nL28: * 【6†WebArena Benchmark: Evaluating Web Agents; publish_date: \nL29: none†www.emergentmind.com】 Jun 30, 2025 — WebArena Benchmark is a self-contained\nL30: suite that evaluates autonomous agents on realistic, multi-step web tasks using\nL31: natural language ...\nL32: * 【7†VisualWebArena is a benchmark for multimodal agents.; publish_date: \nL33: none†github.com】 VisualWebArena is a realistic and diverse benchmark for \nL34: evaluating multimodal autonomous language agents. It comprises of a set of \nL35: diverse and complex web-based ...\nL36: * 【8†WebDev Arena Leaderboard - LMArena; publish_date: none†web.lmarena.ai】 \nL37: WebDev Arena is a real-time AI coding competition where models go head-to-head \nL38: in web development challenges, developed by LMArena.\nL39: * 【9†WebArena: A Realistic Web Environment for Building ...; publish_date: \nL40: none†arxiv.org】 Apr 16, 2024 — We use this benchmark to evaluate several agents \nL41: that can follow NL command and perform web-based tasks (§4). These agents are \nL42: implemented in a ...", + "tool_call_id": "fixture-tool-call-9-1" + }, + { + "role": "tool", + "content": "[6] GAIA benchmark (Search_Results/GAIA benchmark)\n**viewing lines [0 - 41] of 41**\n\nL0: \nL1: URL: Search_Results/GAIA benchmark\nL2: # Search Results\nL3: \nL4: * 【0†GAIA Leaderboard - a Hugging Face Space by ...; publish_date: \nL5: none†huggingface.co】 GAIA is a benchmark which aims at evaluating next-\nL6: generation LLMs (LLMs with augmented capabilities due to added tooling, \nL7: efficient prompting, access to search ...\nL8: * 【1†[2311.12983] GAIA: a benchmark for General AI Assistants; publish_date: \nL9: none†arxiv.org】 by G Mialon · 2023 · Cited by 367 — GAIA proposes real-world \nL10: questions that require a set of fundamental abilities such as reasoning, multi-\nL11: modality handling, web browsing, and generally tool-use ...\nL12: * 【2†GAIA benchmark; publish_date: none†huggingface.co】 This is the \nL13: organisation page for all things related to GAIA, a benchmark for General AI \nL14: Assistants. You can find all the information and links on the GAIA ...\nL15: * 【3†GAIA: A Benchmark for General AI Assistants; publish_date: \nL16: none†ukgovernmentbeis.github.io】 This is an Inspect AI implementation of the \nL17: GAIA (General AI Assistants) benchmark, consisting of 450 questions testing tool\nL18: use on realistic assistant tasks.\nL19: * 【4†GAIA: a benchmark for general AI assistants | Research; publish_date: \nL20: none†ai.meta.com】 May 6, 2024 — GAIA proposes real-world questions that require \nL21: a set of fundamental abilities such as reasoning, multi-modality handling, web \nL22: browsing, and generally tool-use ...\nL23: * 【5†HAL: GAIA Leaderboard; publish_date: none†hal.cs.princeton.edu】 GAIA is a\nL24: benchmark for General AI Assistants that requires a set of fundamental \nL25: abilities such as reasoning, multi-modality handling, web browsing, and tool- \nL26: ...\nL27: * 【6†GAIA: The LLM Agent Benchmark Everyone's Talking About; publish_date: \nL28: none†towardsdatascience.com】 May 29, 2025 — GAIA stands for General AI \nL29: Assistants benchmark [1]. This benchmark was introduced to specifically evaluate\nL30: LLM agents on their ability to act as general- ...\nL31: * 【7†GAIA: a benchmark for General AI Assistants; publish_date: \nL32: none†openreview.net】 by G Mialon · Cited by 367 — GAIA proposes real-world \nL33: questions that require a set of fundamental abilities such as reasoning, multi-\nL34: modality handling, web browsing, and generally tool-use ...\nL35: * 【8†Rethinking AI Evaluation: Introducing the GAIA Benchmark; publish_date: \nL36: none†medium.com】 The authors introduce GAIA, a benchmark designed to assess the \nL37: robustness of AI systems across a variety of practical tasks.\nL38: * 【9†H2O.ai Tops the General AI Assistant (GAIA) Test; publish_date: \nL39: none†h2o.ai】 Mar 17, 2025 — Our h2oGPTe Agent has once again claimed the #1 spot\nL40: on the prestigious GAIA (General AI Assistants) benchmark with an impressive \nL41: 75% accuracy rate.", + "tool_call_id": "fixture-tool-call-9-2" + }, + { + "role": "tool", + "content": "[7] ToolBench benchmark (Search_Results/ToolBench benchmark)\n**viewing lines [0 - 40] of 40**\n\nL0: \nL1: URL: Search_Results/ToolBench benchmark\nL2: # Search Results\nL3: \nL4: * 【0†ToolBench, an evaluation suite for LLM tool manipulation ...; \nL5: publish_date: none†github.com】 The ToolBench is a benchmark consisting of \nL6: diverse software tools for real-world tasks. We also provide easy-to-use \nL7: infrastructure in this repository.\nL8: * 【1†OpenBMB/ToolBench; publish_date: none†github.com】 [2023/7/27] New version\nL9: ToolBench is released. ✨Here is an overview of the dataset construction, \nL10: training, and evaluation. ✨✨Features:.\nL11: * 【2†Towards Stable Large-Scale Benchmarking on Tool ...; publish_date: \nL12: none†arxiv.org】 by Z Guo · 2024 · Cited by 100 — We introduce StableToolBench, a\nL13: benchmark evolving from ToolBench, proposing a virtual API server and stable \nL14: evaluation system.\nL15: * 【3†StableToolBench - Zhicheng Guo; publish_date: none†zhichengg.github.io】 \nL16: We introduce StableToolBench, a benchmark evolving from ToolBench, proposing a \nL17: virtual API server and stable evaluation system.\nL18: * 【4†ToolBench | EvalScope - Read the Docs; publish_date: \nL19: none†evalscope.readthedocs.io】 We evaluate the effectiveness of the ToolBench \nL20: benchmark: ToolBench (Qin et al., 2023b). The task involves integrating API \nL21: calls to complete tasks.\nL22: * 【5†Towards Stable Large-Scale Benchmarking on Tool ...; publish_date: \nL23: none†aclanthology.org】 by Z Guo · 2024 · Cited by 100 — We introduce \nL24: StableToolBench, a benchmark evolving from ToolBench, proposing a virtual API \nL25: server and stable evaluation system.\nL26: * 【6†ML-Tool-Bench: Tool-Augmented Planning for ML Tasks; publish_date: \nL27: none†openreview.net】 Sep 18, 2025 — In this work, we introduce a comprehensive \nL28: benchmark for evaluating tool-augmented ML agents using a curated set of 61 \nL29: specialized tools and 15 ...\nL30: * 【7†-Bench: Benchmarking AI agents for the real-world; publish_date: \nL31: none†sierra.ai】 Jun 20, 2024 — τ-bench measures an agent's ability to interact \nL32: with (simulated) human users and programmatic APIs while following domain-\nL33: specific policies in a consistent ...\nL34: * 【8†ToolEval Leaderboard; publish_date: none†openbmb.github.io】 ToolEval is \nL35: an automatic evaluator build for tool learning which incorporates two evaluation\nL36: metrics, Pass Rate and Win Rate(Preference).\nL37: * 【9†What is the best benchmark dataset for multi-step tool-use?; \nL38: publish_date: none†www.reddit.com】 I'm a newbie trying to evaluate the \nL39: performance of different prompts strategies for multi-step tool-using, wondering\nL40: what is the recommended benchmark dataset ...", + "tool_call_id": "fixture-tool-call-9-3" + }, + { + "role": "tool", + "content": "[8] HotpotQA benchmark (Search_Results/HotpotQA benchmark)\n**viewing lines [0 - 39] of 39**\n\nL0: \nL1: URL: Search_Results/HotpotQA benchmark\nL2: # Search Results\nL3: \nL4: * 【0†HotpotQA Homepage; publish_date: none†hotpotqa.github.io】 HotpotQA is a \nL5: question answering dataset featuring natural, multi-hop questions, with strong \nL6: supervision for supporting facts to enable more explainable ...See more\nL7: * 【1†HotpotQA: A Dataset for Diverse, Explainable Multi-hop ...; publish_date:\nL8: none†arxiv.org】 by Z Yang · 2018 · Cited by 3834 — HotpotQA is a dataset with \nL9: 113k Wikipedia-based question-answer pairs requiring multi-document reasoning, \nL10: diverse questions, sentence-level ...\nL11: * 【2†hotpotqa/hotpot_qa · Datasets at Hugging Face; publish_date: \nL12: none†huggingface.co】 HotpotQA is a new dataset with 113k Wikipedia-based \nL13: question-answer pairs with four key features: (1) the questions require finding \nL14: and reasoning over multiple ...See more\nL15: * 【3†Why You Should Stop Using HotpotQA for AI Agents ...; publish_date: \nL16: none†qipeng.me】 Jul 1, 2025 — HotpotQA pioneered a class of AI tasks that \nL17: requires the AI system to autonomously perform multiple steps of reasoning in an\nL18: open-domain setting.See more\nL19: * 【4†hotpotqa/hotpot; publish_date: none†github.com】 A dataset for diverse, \nL20: explainable multi-hop question answering. This repository contains the baseline \nL21: model code, as well as the entire pipeline of running ...See more\nL22: * 【5†HotpotQA: Multi-Hop QA Benchmark; publish_date: \nL23: none†www.emergentmind.com】 Sep 10, 2025 — HotpotQA is a large-scale multi-hop \nL24: question answering benchmark featuring 112,779 Wikipedia-based Q&A pairs with \nL25: detailed, sentence-level ...See more\nL26: * 【6†HotpotQA Dataset | Papers With Code; publish_date: \nL27: none†paperswithcode.com】 HotpotQA is a question answering dataset collected on \nL28: the English Wikipedia, containing about 113K crowd-sourced questions.See more\nL29: * 【7†HotpotQA: A Dataset for Diverse, Explainable Multi-hop ...; publish_date:\nL30: none†aclanthology.org】 by Z Yang · 2018 · Cited by 3834 — HotpotQA is a dataset\nL31: with 113k Wikipedia-based question-answer pairs requiring multi-document \nL32: reasoning, diverse questions, sentence-level facts, and factoid ...\nL33: * 【8†Benchmark BM25S: HotpotQA; publish_date: none†www.kaggle.com】 Explore and\nL34: run machine learning code with Kaggle Notebooks | Using data from No attached \nL35: data sources.\nL36: * 【9†mteb/hotpotqa · Datasets at Hugging Face; publish_date: \nL37: none†huggingface.co】 HotpotQA is a question answering dataset featuring natural,\nL38: multi-hop questions, with strong supervision for supporting facts to enable \nL39: more explainable ...See more", + "tool_call_id": "fixture-tool-call-9-4" + }, + { + "role": "tool", + "content": "[9] FEVER benchmark (Search_Results/FEVER benchmark)\n**viewing lines [0 - 40] of 40**\n\nL0: \nL1: URL: Search_Results/FEVER benchmark\nL2: # Search Results\nL3: \nL4: * 【0†Fever.ai; publish_date: none†fever.ai】 We are pleased to announce that \nL5: FEVER9 will be co-located with EACL 2026. In this year's workshop, we will \nL6: introduce a new shared task focused on automated fact ...\nL7: * 【1†a Large-scale Dataset for Fact Extraction and VERification; publish_date:\nL8: none†aclanthology.org】 by J Thorne · 2018 · Cited by 2315 — In this paper we \nL9: introduce a new publicly available dataset for verification against textual \nL10: sources, FEVER: Fact Extraction.\nL11: * 【2†awslabs/fever: FEVER (Fact Extraction and VERification) ...; \nL12: publish_date: none†github.com】 In this paper we introduce a new publicly \nL13: available dataset for verification against textual sources, FEVER: Fact \nL14: Extraction and VERification.\nL15: * 【3†FEVER: Fact Extraction and VERification; publish_date: \nL16: none†www.amazon.science】 The best accuracy we achieve on labeling a claim \nL17: accompanied by the correct evidence is 31.87%, while if we ignore the evidence \nL18: we achieve 50.91%. Thus we ...\nL19: * 【4†FEVER Dataset; publish_date: none†fever.ai】 FEVER (Fact Extraction and \nL20: VERification) consists of 185,445 claims generated by altering sentences \nL21: extracted from Wikipedia and subsequently verified ...\nL22: * 【5†mteb/fever · Datasets at Hugging Face; publish_date: none†huggingface.co】\nL23: FEVER. An MTEB dataset. Massive Text Embedding Benchmark. FEVER (Fact \nL24: Extraction and VERification) consists of 185,445 claims generated by altering \nL25: sentences ...\nL26: * 【6†FEVEROUS: Fact Extraction and VERification Over ...; publish_date: \nL27: none†datasets-benchmarks-proceedings.neurips.cc】 by R Aly · Cited by 359 — In \nL28: this paper we introduce a novel dataset and benchmark, Fact Extraction and \nL29: VERification Over. Unstructured and Structured information (FEVEROUS), which ...\nL30: * 【7†a large-scale dataset for Fact Extraction and VERification; publish_date:\nL31: none†arxiv.org】 by J Thorne · 2018 · Cited by 2315 — In this paper we introduce\nL32: a new publicly available dataset for verification against textual sources, \nL33: FEVER: Fact Extraction and VERification.\nL34: * 【8†FEVER Resources; publish_date: none†fever.ai】 2018 FEVER: a large-scale \nL35: dataset for Fact Extraction and VERification .bib James Thorne, Andreas Vlachos,\nL36: Christos Christodoulopoulos, Arpit Mittal\nL37: * 【9†a Large-scale Dataset for Fact Extraction and VERification; publish_date:\nL38: none†www.semanticscholar.org】 This paper introduces a new publicly available \nL39: dataset for verification against textual sources, FEVER, which consists of \nL40: 185,445 claims generated by ...", + "tool_call_id": "fixture-tool-call-9-5" + }, + { + "role": "tool", + "content": "[10] TriviaQA benchmark (Search_Results/TriviaQA benchmark)\n**viewing lines [0 - 35] of 35**\n\nL0: \nL1: URL: Search_Results/TriviaQA benchmark\nL2: # Search Results\nL3: \nL4: * 【0†TriviaQA; publish_date: none†nlp.cs.washington.edu】 TriviaQA is a reading\nL5: comprehension dataset containing over 650K question-answer-evidence triples. \nL6: TriviaQA includes 95K question-answer pairs authored ...\nL7: * 【1†TriviaQA: A Large Scale Distantly Supervised Challenge ...; publish_date:\nL8: none†aclanthology.org】 by M Joshi · 2017 · Cited by 3451 — We present TriviaQA,\nL9: a challenging reading comprehension dataset containing over 650K question-\nL10: answer-evidence triples. TriviaQA includes 95K question ...\nL11: * 【2†mandarjoshi/trivia_qa · Datasets at Hugging Face; publish_date: \nL12: none†huggingface.co】 TriviaqQA is a reading comprehension dataset containing \nL13: over 650K question-answer-evidence triples. TriviaqQA includes 95K question-\nL14: answer pairs authored by ...\nL15: * 【3†[1705.03551] TriviaQA: A Large Scale Distantly Supervised ...; \nL16: publish_date: none†arxiv.org】 by M Joshi · 2017 · Cited by 3451 — We present \nL17: TriviaQA, a challenging reading comprehension dataset containing over 650K \nL18: question-answer-evidence triples.\nL19: * 【4†TriviaQA; publish_date: none†epoch.ai】 An open-domain question answering \nL20: benchmark with challenging trivia questions paired with evidence documents.\nL21: * 【5†TriviaQA Leaderboard; publish_date: none†llm-stats.com】 What is the \nL22: TriviaQA benchmark? A large-scale reading comprehension dataset containing over \nL23: 650K question-answer-evidence triples. TriviaQA includes 95K ...\nL24: * 【6†Code for the TriviaQA reading comprehension dataset; publish_date: \nL25: none†github.com】 A large scale distantly supervised challenge dataset for \nL26: reading comprehension. In Association for Computational Linguistics (ACL) 2017, \nL27: Vancouver, Canada.\nL28: * 【7†TriviaQA - Model Benchmarks - The Regularizer; publish_date: \nL29: none†www.theregularizer.com】 May 4, 2025 — Compare the performance of different \nL30: AI models across standardized benchmarks. Higher scores generally indicate \nL31: better performance, but context ...\nL32: * 【8†TriviaQA: A Large Scale Distantly Supervised Challenge ...; publish_date:\nL33: none†www.cs.utexas.edu】 by M Joshi · Cited by 3445 — We present TriviaQA, a \nL34: challenging reading comprehension dataset contain- ing over 650K question-\nL35: answer-evidence triples. TriviaQA includes 95K question-.", + "tool_call_id": "fixture-tool-call-9-6" + }, + { + "role": "tool", + "content": "[11] Natural Questions benchmark (Search_Results/Natural Questions benchmark)\n**viewing lines [0 - 39] of 39**\n\nL0: \nL1: URL: Search_Results/Natural Questions benchmark\nL2: # Search Results\nL3: \nL4: * 【0†Natural Questions: a Benchmark for Question Answering ...; publish_date: \nL5: none†research.google】 by T Kwiatkowski · Cited by 4339 — We present the Natural \nL6: Questions corpus, a question answering dataset. Questions consist of real \nL7: anonymized, aggregated queries issued to the Google search ...\nL8: * 【1†Natural Questions: A Benchmark for Question Answering ...; publish_date: \nL9: none†aclanthology.org】 by T Kwiatkowski · Cited by 4308 — Abstract. We present \nL10: the Natural Questions corpus, a question answering data set. Questions consist \nL11: of real anonymized, aggregated queries issued.\nL12: * 【2†Google's Natural Questions; publish_date: none†ai.google.com】 Natural \nL13: Questions. A Benchmark for Question Answering Research. View examples · Download\nL14: dataset. Open Domain Question Answering. A core goal in artificial ...\nL15: * 【3†google-research-datasets/natural-questions; publish_date: \nL16: none†github.com】 Natural Questions (NQ) contains real user questions issued to \nL17: Google search, and answers found from Wikipedia by annotators. NQ is designed \nL18: for the training and ...\nL19: * 【4†Natural Questions: A Benchmark for Question Answering ...; publish_date: \nL20: none†direct.mit.edu】 Aug 1, 2019 — We present the Natural Questions corpus, a \nL21: question answering data set. Questions consist of real anonymized, aggregated \nL22: queries issued to the Google search ...\nL23: * 【5†ir_datasets : Natural Questions; publish_date: none†ir-datasets.com】 \nL24: Google Natural Questions is a Q&A dataset containing long, short, and Yes/No \nL25: answers from Wikipedia. ir_datasets frames this around an ad-hoc ranking setting\nL26: ...\nL27: * 【6†sentence-transformers/natural-questions · Datasets at ...; publish_date: \nL28: none†huggingface.co】 This dataset is a collection of question-answer pairs from \nL29: the Natural Questions dataset. See Natural Questions for additional information.\nL30: * 【7†Google's Natural Questions; publish_date: none†ai.google.com】 Natural \nL31: Questions contains 307K training examples, 8K examples for development, and a \nL32: further 8K examples for testing. In the paper, we demonstrate a human ...\nL33: * 【8†A Benchmark for Question Answering Research; publish_date: \nL34: none†www.researchgate.net】 Jul 27, 2025 — We present the Natural Questions \nL35: corpus, a question answering data set. Questions consist of real anonymized, \nL36: aggregated queries issued to the Google search ...\nL37: * 【9†natural-questions; publish_date: none†docs.unity.rc.umass.edu】 Sep 4, \nL38: 2025 — “Natural questions: a benchmark for question answering research.” \nL39: Transactions of the Association for Computational Linguistics 7 (2019): ...", + "tool_call_id": "fixture-tool-call-9-7" + }, + { + "role": "tool", + "content": "[12] MS MARCO benchmark (Search_Results/MS MARCO benchmark)\n**viewing lines [0 - 41] of 41**\n\nL0: \nL1: URL: Search_Results/MS MARCO benchmark\nL2: # Search Results\nL3: \nL4: * 【0†MS MARCO - Microsoft Open Source; publish_date: none†microsoft.github.io】\nL5: The MS MARCO datasets are intended for non-commercial research purposes only to\nL6: promote advancement in the field of artificial intelligence and related areas, \nL7: ...\nL8: * 【1†microsoft/ms_marco · Datasets at Hugging Face; publish_date: \nL9: none†huggingface.co】 Starting with a paper released at NIPS 2016, MS MARCO is a \nL10: collection of datasets focused on deep learning in search. The first dataset was\nL11: a question ...\nL12: * 【2†Benchmarking Ranking Models in the Large-Data Regime; publish_date: \nL13: none†arxiv.org】 by N Craswell · 2021 · Cited by 89 — This paper uses the MS \nL14: MARCO and TREC Deep Learning Track as our case study, comparing it to the case \nL15: of TREC ad hoc ranking in the 1990s.\nL16: * 【3†Benchmarking Ranking Models in the Large-Data Regime; publish_date: \nL17: none†www.microsoft.com】 This paper uses the MS MARCO and TREC Deep Learning \nL18: Track as our case study, comparing it to the case of TREC ad hoc ranking in the \nL19: 1990s. We show how the ...\nL20: * 【4†Datasets for Document and Passage Ranking Leadboards; publish_date: \nL21: none†microsoft.github.io】 The MS MARCO document and passage ranking leaderboards\nL22: complements the TREC Deep Learning Track by providing on-going evaluation of \nL23: submissions using pre- ...\nL24: * 【5†MS MARCO: Benchmarking Ranking Models in the Large- ...; publish_date: \nL25: none†dl.acm.org】 Jul 11, 2021 — This paper uses the MS MARCO and TREC Deep \nL26: Learning Track as our case study, comparing it to the case of TREC ad hoc \nL27: ranking in the 1990s.\nL28: * 【6†ir_datasets : MSMARCO (passage); publish_date: none†ir-datasets.com】 A \nL29: passage ranking benchmark with a collection of 8.8 million passages and question\nL30: queries. Most relevance judgments are shallow.\nL31: * 【7†MS MARCO; publish_date: none†sbert.net】 MS MARCO Passage Ranking is a \nL32: large dataset to train models for information retrieval. It consists of about \nL33: 500k real search queries from Bing search engine ...\nL34: * 【8†MS MARCO: A Human Generated MAchine Reading ...; publish_date: \nL35: none†arxiv.org】 by P Bajaj · 2016 · Cited by 1151 — We introduce a large scale \nL36: MAchine Reading COmprehension dataset, which we name MS MARCO. The dataset \nL37: comprises of 1,010,916 anonymized ...\nL38: * 【9†MS MARCO Web Search: A Large-scale Information-rich ...; publish_date: \nL39: none†www.microsoft.com】 May 13, 2024 — MS MARCO Web Search offers a retrieval \nL40: benchmark with three web retrieval challenge tasks that demands innovations in \nL41: both machine learning and ...", + "tool_call_id": "fixture-tool-call-9-8" + }, + { + "role": "tool", + "content": "[13] BEIR benchmark (Search_Results/BEIR benchmark)\n**viewing lines [0 - 37] of 37**\n\nL0: \nL1: URL: Search_Results/BEIR benchmark\nL2: # Search Results\nL3: \nL4: * 【0†详细介绍文本检索基准BEIR: A Heterogeneous Benchmark ...; publish_date: \nL5: none†blog.csdn.net】 2023年1月1日 — \nL6: BEIR旨在为所有不同的检索任务提供一个一站式的零样本评估基准。为了构建一个全面的评估基准,选择方法对于收集具有理想属性的任务和数据集至关重要。对于 ...\nL7: * 【1†beir-cellar/beir; publish_date: none†github.com】 BEIR is a heterogeneous \nL8: benchmark containing diverse IR tasks. It also provides a common and easy \nL9: framework for evaluation of your NLP-based retrieval models ...\nL10: * 【2†BEIR: A Heterogenous Benchmark for Zero-shot Evaluation ...; \nL11: publish_date: none†arxiv.org】 作者:N Thakur · 2021 · 被引用次数:1480 — We introduce \nL12: Benchmarking-IR (BEIR), a robust and heterogeneous evaluation benchmark for \nL13: information retrieval.\nL14: * 【3†BeIR; publish_date: none†huggingface.co】 BEIR (Benchmarking IR) consists \nL15: of a homogenous benchmark for diverse sentence or passage level IR tasks. It \nL16: provides a common and easy framework for the cross ...\nL17: * 【4†论文分享:BEIR A Heterogeneous Benchmark for Zero-shot ...; publish_date: \nL18: none†zhuanlan.zhihu.com】 2022年10月3日 — 分享论文,夹带个人理解的分享,建议结合原论文看。 1 研究背景. \nL19: 本论文主要关注的领域是query-document检索(下文简称qd检索),即根据query去文档库里 ...\nL20: * 【5†Benchmarking IR Information Retrieval (BEIR); publish_date: \nL21: none†zilliz.com】 BEIR is a benchmark designed for evaluating the versatility and\nL22: robustness of information retrieval models. It features 18 diverse datasets \nL23: from domains like ...\nL24: * 【6†BEIR (Benchmarking IR) - OpenDataLab; publish_date: none†opendatalab.com】\nL25: 简介-Introduction. BEIR(Benchmarking \nL26: IR)是包含不同信息检索(IR)任务的异构基准。通过BEIR,可以系统地研究多种神经检索方法的零样本泛化能力。\nL27: * 【7†What is the BEIR benchmark and how is it used?; publish_date: \nL28: none†milvus.io】 The BEIR (Benchmarking Information Retrieval) benchmark is a \nL29: standardized framework designed to evaluate the effectiveness of search and \nL30: retrieval algorithms.\nL31: * 【8†BEIR Benchmark数据集卡片; publish_date: none†www.atyun.com】 BEIR \nL32: Benchmark数据集卡片. 数据集简介. BEIR是一个异构评测基准,由18个多样化的数据集构建而成,代表了9个信息检索任务:. 事实查证: FEVER ,\nL33: Climate-FEVER , SciFact ...\nL34: * 【9†Evaluating search relevance part 1 - The BEIR benchmark; publish_date: \nL35: none†www.elastic.co】 2024年7月16日 — Learn to evaluate your search system in the \nL36: context of better understanding the BEIR benchmark, with tips & techniques to \nL37: improve your ...", + "tool_call_id": "fixture-tool-call-9-9" + }, + { + "role": "tool", + "content": "[14] MIRACL benchmark (Search_Results/MIRACL benchmark)\n**viewing lines [0 - 41] of 41**\n\nL0: \nL1: URL: Search_Results/MIRACL benchmark\nL2: # Search Results\nL3: \nL4: * 【0†MIRACL | Multilingual Information Retrieval Across a ...; publish_date: \nL5: none†project-miracl.github.io】 MIRACL (Multilingual Information Retrieval Across\nL6: a Continuum of Languages) is an WSDM 2023 Cup challenge that focuses on search \nL7: across 18 different ...\nL8: * 【1†project-miracl/miracl: A large-scale multilingual dataset for ...; \nL9: publish_date: none†github.com】 A large-scale multilingual dataset for \nL10: Information Retrieval. Thorough human-annotations across 18 diverse languages.\nL11: * 【2†A Large, multilingual, visual document retrieval benchmark; publish_date:\nL12: none†arxiv.org】 by R Osmulski · 2025 · Cited by 2 — MIRACL-VISION is a \nL13: challenging, representative, multilingual evaluation benchmark for visual \nL14: retrieval pipelines and will help the community build robust ...\nL15: * 【3†miracl/miracl · Datasets at Hugging Face; publish_date: \nL16: none†huggingface.co】 MIRACL (Multilingual Information Retrieval Across a \nL17: Continuum of Languages) is a multilingual retrieval dataset that focuses on \nL18: search across 18 different ...\nL19: * 【4†MIRACL: A Multilingual Retrieval Dataset Covering 18 ...; publish_date: \nL20: none†direct.mit.edu】 by X Zhang · 2023 · Cited by 131 — MIRACL is a multilingual\nL21: dataset for ad hoc retrieval across 18 languages that collectively encompass \nL22: over three billion native speakers around the world.\nL23: * 【5†(PDF) MIRACL-VISION: A Large, multilingual, visual ...; publish_date: \nL24: none†www.researchgate.net】 May 23, 2025 — MIRACL-VISION covers 18 languages, and\nL25: is an extension of the MIRACL dataset, a popular benchmark to evaluate text-\nL26: based multilingual retrieval ...\nL27: * 【6†A Large, multilingual, visual document retrieval benchmark; publish_date:\nL28: none†arxiv.org】 by R Osmulski · 2025 · Cited by 2 — MIRACL-VISION is a \nL29: challenging, representative, multilingual evaluation benchmark for visual \nL30: retrieval pipelines and will help the community ...\nL31: * 【7†ir_datasets : MIRACL; publish_date: none†ir-datasets.com】 \nL32: \"miracl/ar/test-a\". The held-out test set (version a) for Arabic. \nL33: queriesdocsCitationMetadata. 936 queries. Language: ar. Query type: \nL34: GenericQuery: (namedtuple).\nL35: * 【8†Evaluate on MIRACL — BGE documentation; publish_date: none†bge-model.com】\nL36: MIRACL (Multilingual Information Retrieval Across a Continuum of Languages) is \nL37: an WSDM 2023 Cup challenge that focuses on search across 18 different languages.\nL38: * 【9†MIRACL - Alpha's Tech Garden; publish_date: \nL39: none†techgarden.alphasmanifesto.com】 MIRACL (Multilingual Information Retrieval \nL40: Across a Continuum of Languages) is a multilingual dataset we have built for the\nL41: WSDM 2023 Cup ...", + "tool_call_id": "fixture-tool-call-9-10" + }, + { + "role": "tool", + "content": "[15] Zero-shot IR benchmark (Search_Results/Zero-shot IR benchmark)\n**viewing lines [0 - 40] of 40**\n\nL0: \nL1: URL: Search_Results/Zero-shot IR benchmark\nL2: # Search Results\nL3: \nL4: * 【0†BEIR: A Heterogenous Benchmark for Zero-shot Evaluation ...; \nL5: publish_date: none†arxiv.org】 by N Thakur · 2021 · Cited by 1480 — We introduce \nL6: Benchmarking-IR (BEIR), a robust and heterogeneous evaluation benchmark for \nL7: information retrieval.See more\nL8: * 【1†beir-cellar/beir; publish_date: none†github.com】 BEIR: A Heterogenous \nL9: Benchmark for Zero-shot Evaluation of Information Retrieval Models (NeurIPS \nL10: 2021, Datasets and Benchmarks Track); Resources for Brewing ...See more\nL11: * 【2†Benchmarking IR Information Retrieval (BEIR); publish_date: \nL12: none†zilliz.com】 BEIR is a tool to evaluate how well Information Retrieval \nL13: systems perform across many tasks and types of information, and is a standard \nL14: benchmark.\nL15: * 【3†BEIR: A Heterogeneous Benchmark for Zero-shot ...; publish_date: \nL16: none†datasets-benchmarks-proceedings.neurips.cc】 by N Thakur · Cited by 1480 — \nL17: BEIR is a robust, heterogeneous benchmark for information retrieval, using 18 \nL18: datasets and 9 tasks to evaluate model generalization.\nL19: * 【4†BEIR; publish_date: none†eval.ai】 BEIR is a heterogeneous zero-shot \nL20: retrieval benchmark containing 18 datasets from diverse text retrieval tasks and\nL21: domains.See more\nL22: * 【5†[2409.15763] IRSC: A Zero-shot Evaluation Benchmark for ...; \nL23: publish_date: none†arxiv.org】 by H Lin · 2024 · Cited by 2 — This paper \nL24: introduces the IRSC benchmark for evaluating the performance of embedding models\nL25: in multilingual RAG tasks.See more\nL26: * 【6†FactIR: A Real-World Zero-shot Open-Domain Retrieval ...; publish_date: \nL27: none†dl.acm.org】 May 23, 2025 — In this paper, we present a real-world retrieval\nL28: benchmark FactIR, derived from Factiverse production logs, enhanced with human \nL29: annotations. We ...See more\nL30: * 【7†UniIR: Training and Benchmarking Universal Multimodal ...; publish_date: \nL31: none†tiger-ai-lab.github.io】 At test time, we evaluated the zero-shot \nL32: performance of all fine-tuned models, as well as SoTA pre-trained retrievers on \nL33: the three held-out datasets. UniIR ...See more\nL34: * 【8†Zero-Shot BEIR Tasks; publish_date: none†www.emergentmind.com】 Aug 26, \nL35: 2025 — Zero-Shot BEIR Tasks are evaluation methodologies that assess IR models' \nL36: ability to generalize to unseen query domains without task-specific ...See more\nL37: * 【9†BEIR-PL: Zero Shot Information Retrieval Benchmark for ...; publish_date:\nL38: none†aclanthology.org】 by K Wojtasik · 2024 · Cited by 12 — BEIR-PL is a new \nL39: benchmark with 13 datasets for Polish Information Retrieval, created to advance \nL40: research in this area.", + "tool_call_id": "fixture-tool-call-9-11" + }, + { + "role": "tool", + "content": "[16] WebGPT benchmark (Search_Results/WebGPT benchmark)\n**viewing lines [0 - 38] of 38**\n\nL0: \nL1: URL: Search_Results/WebGPT benchmark\nL2: # Search Results\nL3: \nL4: * 【0†WebGPT: Improving the factual accuracy of language ...; publish_date: \nL5: none†openai.com】 Dec 16, 2021 — Our models outperform GPT‑3 on TruthfulQA and \nL6: exhibit more favourable scaling properties. However, our models lag behind human\nL7: performance, ...\nL8: * 【1†A Simple Yet Challenging Benchmark for Browsing Agents; publish_date: \nL9: none†arxiv.org】 by J Wei · 2025 · Cited by 124 — Abstract. We present \nL10: BrowseComp, a simple yet challenging benchmark for measuring the ability for \nL11: agents to browse the web.\nL12: * 【2†openai/webgpt_comparisons · Datasets at Hugging Face; publish_date: \nL13: none†huggingface.co】 This is the dataset of all comparisons that were marked as \nL14: suitable for reward modeling by the end of the WebGPT project. There are 19,578 \nL15: comparisons in total.\nL16: * 【3†Evaluation & Limitations of WebGPT, WebVoyager & Agent-E; publish_date: \nL17: none†deepsense.ai】 Oct 14, 2024 — WebArena benchmark features 812 tasks \nL18: evaluated using metrics such as Exact Match, Must Include, and Fuzzy Match, \nL19: focusing on outcomes rather ...\nL20: * 【4†OpenAI Announces Question-Answering AI WebGPT; publish_date: \nL21: none†www.infoq.com】 Jan 25, 2022 — On the TriviaQA benchmark, WebGPT \nL22: outperformed GPT-3, producing answers that were true 75% of the time, and \"both \nL23: true and informative\" 54% of ...\nL24: * 【5†WebGPT: Improving the factual accuracy of language models ...; \nL25: publish_date: none†kargarisaac.medium.com】 The top-performing model generated \nL26: answers that were preferred over 56% of the time compared to answers produced by\nL27: human demonstrators, with ...\nL28: * 【6†Browser-assisted question-answering with human feedback; publish_date: \nL29: none†www.alphaxiv.org】 WebGPT represents a significant advancement in long-form \nL30: question answering by combining the language generation capabilities of GPT-3 \nL31: with real-time web ...\nL32: * 【7†Benchmarking Open-Source Large Language Models, GPT-4 ...; publish_date: \nL33: none†ai.nejm.org】 by S Wu · 2024 · Cited by 69 — We show that the current widely\nL34: used open-source LLMs have poor zero-shot reasoning ability in nephrology \nL35: compared with GPT-4 and Claude 2.\nL36: * 【8†0hq/WebGPT: Run GPT model on ...; publish_date: none†github.com】 WebGPT \nL37: is a vanilla JS and HTML implementation of a transformer model, intended as a \nL38: proof-of-concept as well as educational resource.", + "tool_call_id": "fixture-tool-call-9-12" + }, + { + "role": "tool", + "content": "[17] WebShop benchmark (Search_Results/WebShop benchmark)\n**viewing lines [0 - 41] of 41**\n\nL0: \nL1: URL: Search_Results/WebShop benchmark\nL2: # Search Results\nL3: \nL4: * 【0†WebShop: Towards Scalable Real-World Web Interaction ...; publish_date: \nL5: none†arxiv.org】 by S Yao · 2022 · Cited by 710 — To bridge this gap, we develop \nL6: WebShop -- a simulated e-commerce website environment with 1.18 million real-\nL7: world products and 12,087 crowd- ...\nL8: * 【1†WebShop; publish_date: none†webshop-pnlp.github.io】 To bridge this gap, \nL9: we develop WebShop – a simulated e-commerce website environment with 1.18 \nL10: million real-world products and 12,087 crowd-sourced text ...\nL11: * 【2†princeton-nlp/WebShop; publish_date: none†github.com】 WebShop is a \nL12: simulated e-commerce website environment with 1.18 million real-world products \nL13: and 12,087 crowd-sourced text instructions. In this environment, an ...\nL14: * 【3†WebShop: Towards Scalable Real-World Web Interaction ...; publish_date: \nL15: none†papers.nips.cc】 by S Yao · 2022 · Cited by 710 — We collect over 1,600 \nL16: human trajectories to first validate the benchmark, then train and evaluate a \nL17: diverse range of agents using reinforcement learning, ...\nL18: * 【4†WebShop: Towards Scalable Real-World Web Interaction ...; publish_date: \nL19: none†proceedings.neurips.cc】 by S Yao · 2022 · Cited by 709 — We have developed \nL20: WebShop, a new web-based benchmark for sequential decision making and language \nL21: grounding, modeled on interaction with an e-commerce website.\nL22: * 【5†Webshop & Benchmark Analysis | Documentation Infinity; publish_date: \nL23: none†docs.fact-finder.com】 Aug 15, 2025 — Evaluation of your shop based on \nL24: different categories in comparison, to your competitors/industry. Recommended \nL25: when doing a shop relaunch.\nL26: * 【6†A Multi-Shop Benchmark for Evaluating Web Agents; publish_date: \nL27: none†arxiv.org】 by R Peeters · 2025 · Cited by 2 — Compared to existing \nL28: e-commerce benchmarks, such as WebShop or ShoppingBench, WebMall introduces \nL29: comparison-shopping tasks across multiple shops ...\nL30: * 【7†WebShop: towards scalable real-world web interaction with ...; \nL31: publish_date: none†dl.acm.org】 by S Yao · 2022 · Cited by 710 — To bridge this \nL32: gap, we develop WebShop - a simulated e-commerce website environment with 1.18 \nL33: million real-world products and 12, 087 crowd- ...\nL34: * 【8†[PDF] WebShop: Towards Scalable Real-World Web ...; publish_date: \nL35: none†www.semanticscholar.org】 It is shown that agents trained on WebShop exhibit\nL36: non-trivial sim-to-real transfer when evaluated on amazon.com and ebay.com, \nL37: indicating the potential ...\nL38: * 【9†X-WebAgentBench: A Multilingual Interactive Web ...; publish_date: \nL39: none†aclanthology.org】 by P Wang · 2025 · Cited by 3 — (2023) based on the \nL40: English WebShop benchmark (Yao et al., 2022), while the multilingual task scores\nL41: are ob- tained through evaluation on our own benchmark.", + "tool_call_id": "fixture-tool-call-9-13" + }, + { + "role": "tool", + "content": "[18] ALFWorld benchmark (Search_Results/ALFWorld benchmark)\n**viewing lines [0 - 31] of 31**\n\nL0: \nL1: URL: Search_Results/ALFWorld benchmark\nL2: # Search Results\nL3: \nL4: * 【0†ALFWorld; publish_date: none†alfworld.github.io】 ALFWorld contains \nL5: interactive TextWorld environments (Côté et. al) that parallel embodied worlds \nL6: in the ALFRED dataset (Shridhar et. al).\nL7: * 【1†ALFWorld: Aligning Text and Embodied Environments for ...; publish_date: \nL8: none†arxiv.org】 by M Shridhar · 2020 · Cited by 674 — ALFWorld enables the \nL9: creation of a new BUTLER agent whose abstract knowledge, learned in TextWorld, \nL10: corresponds directly to concrete, visually grounded actions.\nL11: * 【2†ALFWorld: Aligning Text and Embodied Environments ...; publish_date: \nL12: none†github.com】 ALFWorld contains interactive TextWorld environments (Côté et. \nL13: al) that parallel embodied worlds in the ALFRED dataset (Shridhar et. al).\nL14: * 【3†alfworld - benchmark's activity; publish_date: none†huggingface.co】 MM-\nL15: IQ: Benchmarking Human-Like Abstraction and Reasoning in Multimodal Models Paper\nL16: • 2502.00698 • Published Feb 1 • 24\nL17: * 【4†Tackling AlfWorld with Action Attention and Common ...; publish_date: \nL18: none†neurips.cc】 On the Alfworld benchmark for indoor instruction following, we \nL19: achieve a significantly higher success rate (50% over the baseline) with our \nL20: novel object ...\nL21: * 【5†ALFWORLD: ALIGNING TEXT AND EMBODIED ...; publish_date: \nL22: none†openreview.net】 by M Shridhar · Cited by 674 — The ALFRED dataset (Shridhar\nL23: et al., 2020), set in the THOR simulator (Kolve et al., 2017), is a benchmark \nL24: for learning to com- plete embodied household tasks ...\nL25: * 【6†AlfWorld; publish_date: none†primo.ai】 Mar 23, 2024 — A simulator that \nL26: enables agents to learn abstract, text based policies in TextWorld (Côté et al.,\nL27: 2018) and then execute goals from the ALFRED benchmark.\nL28: * 【7†AlfWorld performance across 134 tasks showing cumulative...; \nL29: publish_date: none†www.researchgate.net】 In the AlfWorld benchmark, we defined \nL30: hallucination as the occurrence of two or more consecutive identical actions in \nL31: which the environment responded with ...", + "tool_call_id": "fixture-tool-call-9-14" + }, + { + "role": "tool", + "content": "[19] Mind2Web benchmark (Search_Results/Mind2Web benchmark)\n**viewing lines [0 - 40] of 40**\n\nL0: \nL1: URL: Search_Results/Mind2Web benchmark\nL2: # Search Results\nL3: \nL4: * 【0†Mind2Web: Towards a Generalist Agent for the Web; publish_date: none†osu-\nL5: nlp-group.github.io】 Mind2Web is a dataset for developing and evaluating \nL6: generalist agents for the web that can follow language instructions to complete \nL7: complex tasks on any ...\nL8: * 【1†Online-Mind2Web Leaderboard; publish_date: none†huggingface.co】 Online-\nL9: Mind2Web is a benchmark designed to evaluate the real-world performance of web \nL10: agents on live websites, featuring 300 tasks across 136 popular sites ...\nL11: * 【2†Mind2Web: Towards a Generalist Agent for the Web; publish_date: \nL12: none†github.com】 Mind2Web is the first dataset for developing and evaluating \nL13: generalist agents for the web that can follow language instructions to complete \nL14: complex tasks on any ...\nL15: * 【3†HAL: Online Mind2Web Leaderboard; publish_date: \nL16: none†hal.cs.princeton.edu】 Online Mind2Web leaderboard for evaluating AI agents'\nL17: ability to complete tasks on real, changing webpages.\nL18: * 【4†[2506.21506] Mind2Web 2: Evaluating Agentic Search with ...; \nL19: publish_date: none†arxiv.org】 by B Gou · 2025 · Cited by 11 — In this paper, we \nL20: introduce Mind2Web 2, a benchmark of 130 realistic, high-quality, and long-\nL21: horizon tasks that require real-time web browsing and extensive ...\nL22: * 【5†Mind2Web 2: Evaluating Agentic Search with Agent-as-a-Judge; \nL23: publish_date: none†osu-nlp-group.github.io】 We introduce Mind2Web 2, a benchmark\nL24: of 130 realistic, high-quality, long-horizon tasks that require real-time web \nL25: browsing and extensive information ...\nL26: * 【6†Mind2Web: The Benchmark for AI Agent Evaluation and ...; publish_date: \nL27: none†www.enhans.ai】 Sep 26, 2025 — Mind2Web is a globally recognized web-based \nL28: AI Agent evaluation benchmark introduced by the NLP group at Ohio State \nL29: University at NeurIPS 2023.\nL30: * 【7†Evaluating AI Web Agents: Insights from the WebCanvas ...; publish_date: \nL31: none†medium.com】 Thanks to the comprehensive WebCanvas Benchmark, which \nL32: incorporates a robust Mind2Web-Live data set of 542 live web tasks and 2,439 ...\nL33: * 【8†Mind2Web: Towards a Generalist Agent for the Web; publish_date: \nL34: none†proceedings.neurips.cc】 by X Deng · 2023 · Cited by 760 — We introduce \nL35: Mind2Web, the first dataset for developing and evaluating generalist agents for \nL36: the web that can follow language instructions to complete complex ...\nL37: * 【9†Mind2Web: Towards a Generalist Agent for the Web; publish_date: \nL38: none†openreview.net】 by X Deng · Cited by 760 — We introduce Mind2Web, the first\nL39: dataset for developing and evaluating generalist agents for the web that can \nL40: follow language instructions to complete complex ...", + "tool_call_id": "fixture-tool-call-9-15" + }, + { + "role": "tool", + "content": "[20] VisualWebArena benchmark (Search_Results/VisualWebArena benchmark)\n**viewing lines [0 - 38] of 38**\n\nL0: \nL1: URL: Search_Results/VisualWebArena benchmark\nL2: # Search Results\nL3: \nL4: * 【0†VisualWebArena is a benchmark for multimodal agents.; publish_date: \nL5: none†github.com】 VisualWebArena is a realistic and diverse benchmark for \nL6: evaluating multimodal autonomous language agents.\nL7: * 【1†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date:\nL8: none†arxiv.org】 by JY Koh · 2024 · Cited by 363 — To bridge this gap, we \nL9: introduce VisualWebArena, a benchmark designed to assess the performance of \nL10: multimodal web agents on realistic \\textit{ ...\nL11: * 【2†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date:\nL12: none†jykoh.com】 To bridge this gap, we introduce VisualWebArena, a benchmark \nL13: designed to assess the performance of multimodal web agents on realistic \nL14: visually grounded tasks.\nL15: * 【3†VisualWebArena: Evaluating Multimodal Agents on ...; publish_date: \nL16: none†arxiv.org】 VisualWebArena is a research benchmark to measure and evaluate \nL17: the progress of multimodal agents. It is primarily meant to act as a self-\nL18: contained sandbox ...\nL19: * 【4†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date:\nL20: none†aclanthology.org】 by JY Koh · 2024 · Cited by 363 — To bridge this gap, we\nL21: introduce VisualWebArena, a benchmark designed to assess the performance of \nL22: multimodal web agents on *realistic visually grounded tasks*.\nL23: * 【5†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date:\nL24: none†www.semanticscholar.org】 VisualWebArena: Evaluating Multimodal Agents on \nL25: Realistic Visual Web Tasks ... MMInA, a multihop and multimodal benchmark to \nL26: evaluate the embodied agents ...\nL27: * 【6†CMU Researchers Introduce VisualWebArena: An AI ...; publish_date: \nL28: none†www.marktechpost.com】 Feb 9, 2024 — VisualWebArena, a benchmark designed \nL29: and developed to evaluate the performance of multimodal web agents on realistic \nL30: and visually stimulating challenges.\nL31: * 【7†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date:\nL32: none†www.themoonlight.io】 The paper \"VisualWebArena: Evaluating Multimodal \nL33: Agents on Realistic Visually Grounded Web Tasks\" introduces a new benchmark, \nL34: **VisualWebArena**, ...\nL35: * 【8†WebArena: A Realistic Web Environment for Building ...; publish_date: \nL36: none†webarena.dev】 Our benchmark is implemented in our fully interactable \nL37: highly-realistic WebArena environment. It features diverse tasks human may \nL38: encounter in their daily ...", + "tool_call_id": "fixture-tool-call-9-16" + }, + { + "role": "tool", + "content": "[21] SearchBench benchmark (Search_Results/SearchBench benchmark)\n**viewing lines [0 - 40] of 40**\n\nL0: \nL1: URL: Search_Results/SearchBench benchmark\nL2: # Search Results\nL3: \nL4: * 【0†Talc-AI/search-bench; publish_date: none†github.com】 A practical \nL5: benchmark that focuses on every day helpfulness of LLM products, not just the \nL6: underlying models. Searchbench is a benchmark that addresses these ...\nL7: * 【1†Evaluating LLMs' Ability to Reason About Search Problems; publish_date: \nL8: none†arxiv.org】 These capabilities are essential for robust reasoning, making \nL9: SearchBench a valuable benchmark for evaluating LLMs' reasoning capabilities as \nL10: they continue to ...\nL11: * 【2†NasimBrz/SearchBench · Datasets at Hugging Face; publish_date: \nL12: none†huggingface.co】 Dataset Summary. SearchBench is a benchmark designed to \nL13: evaluate Language Models' (LLMs) ability to solve state-based problems that \nL14: require combinatorial search ...\nL15: * 【3†Evaluating LLMs' Ability to Reason About Search Problems; publish_date: \nL16: none†openreview.net】 2025年10月22日 — To further investigate this, we introduce a \nL17: new benchmark, SearchBench, which contains 11 unique search problems inspired by\nL18: intuitive puzzles.\nL19: * 【4†Navigating the Labyrinth: Evaluating and Enhancing LLMs' ...; \nL20: publish_date: none†hub.baai.ac.cn】 2024年6月17日 — \nL21: 论文提出了一个新的基准测试SearchBench,包含11种独特的搜索问题类型,并自动化生成任意数量的实例和分析解决方案的可行性、正确性和最优性。论文使用A* \nL22: ...\nL23: * 【5†Towards Unified Text-based Person Retrieval: A Large- ...; publish_date: \nL24: none†blog.csdn.net】 2023年10月17日 — ... Search \nL25: Benchmark(面向统一的基于文本的人物检索:一个大规模的多属性和语言搜索基准); 研究背景. 相关工作; BENCHMARK. 论文方法分析. 网络框架;\nL26: 1、APTM ...\nL27: * 【6†Desearch-ai/ai-search-benchmark; publish_date: none†github.com】 The \nL28: SearchBench repository addresses common issues with traditional benchmarks by \nL29: focusing on practical, everyday use cases rather than theoretical limits. It ...\nL30: * 【7†o1 results for 3 benchmarks: PlanBench, SearchBench, ...; publish_date: \nL31: none†www.reddit.com】 o1 results for 3 benchmarks: PlanBench, SearchBench, and \nL32: Summary of a Haystack. AI. PlanBench: Paper \"LLMs Still Can't Plan; Can LRMs? A \nL33: ...\nL34: * 【8†Evaluating LLMs' Ability to Reason About Search Problems; publish_date: \nL35: none†ui.adsabs.harvard.edu】 To further investigate this, we introduce a new \nL36: benchmark, SearchBench, which contains 11 unique search problems inspired by \nL37: intuitive puzzles. Each SearchBench ...\nL38: * 【9†Introducing SearchBench; publish_date: none†www.tag1consulting.com】 \nL39: Toward this goal, over the weekend I launched a new project called SearchBench, \nL40: a Drupal module for benchmarking Drupal's search performance. As the module ...", + "tool_call_id": "fixture-tool-call-9-17" + }, + { + "role": "tool", + "content": "[22] WebVLN benchmark (Search_Results/WebVLN benchmark)\n**viewing lines [0 - 42] of 42**\n\nL0: \nL1: URL: Search_Results/WebVLN benchmark\nL2: # Search Results\nL3: \nL4: * 【0†WebVLN: Vision-and-Language Navigation on Websites; publish_date: \nL5: none†ojs.aaai.org】 by Q Chen · 2024 · Cited by 35 — the WebVLN-v1 dataset, where\nL6: the performance is far from saturation, highlighting the utility of our \nL7: WebVLN-v1 as a benchmark to assess progress in this field.\nL8: * 【1†[2312.15820] WebVLN: Vision-and-Language Navigation on Websites; \nL9: publish_date: none†ar5iv.labs.arxiv.org】 Experimental results show that WebVLN-\nL10: Net outperforms current VLN and web-related navigation methods. ... Code is \nL11: available at: https://github.com/WebVLN/WebVLN.\nL12: * 【2†WebVLN: Vision-and-Language Navigation on Websites; publish_date: \nL13: none†github.com】 Experimental results show that WebVLN-Net outperforms current \nL14: VLN and web-related navigation methods. We believe that the introduction of the \nL15: new WebVLN task ...\nL16: * 【3†Vision-and-Language Navigation in the Real-World; publish_date: \nL17: none†digital.library.adelaide.edu.au】 By leveraging our proposed WebVLN-v1 \nL18: dataset, experimental results showcase the superior performance of WebVLN-Net \nL19: compared to existing VLN and web-related ...\nL20: * 【4†WebVLN: Vision-and-Language Navigation on Websites; publish_date: \nL21: none†www.researchgate.net】 Experimental results show that WebVLN-Net outperforms\nL22: current VLN and web-related navigation methods. We believe that the \nL23: introduction of the newWebVLN task and ...\nL24: * 【5†[PDF] WebVLN: Vision-and-Language Navigation on Websites; publish_date: \nL25: none†www.semanticscholar.org】 A new task named Vision-and-Language Navigation on\nL26: Websites (WebVLN), where question-based instructions are used to train an \nL27: agent, emulating how users ...\nL28: * 【6†WebVLN: Vision-and-Language Navigation on Websites; publish_date: \nL29: none†arxiv.org】 by Q Chen · 2023 · Cited by 35 — Experimental results show that \nL30: WebVLN-Net outperforms current VLN and web-related navigation methods. We \nL31: believe that the introduction of the ...\nL32: * 【7†Human-Aware Vision-and-Language Navigation; publish_date: \nL33: none†proceedings.neurips.cc】 by H Li · 2024 · Cited by 19 — Vision-and-Language \nL34: Navigation (VLN) [2, 7, 9, 40] has emerged as a key benchmark for evaluating. \nL35: Sim2Real transfer [23], showing impressive performance in ...\nL36: * 【8†LiveBench; publish_date: none†livebench.ai】 Introducing LiveBench: a \nL37: benchmark for LLMs designed with test set contamination and objective evaluation\nL38: in mind.\nL39: * 【9†MG-VLN: Benchmarking Multi-Goal and Long-Horizon ...; publish_date: \nL40: none†ieeexplore.ieee.org】 by J Zhang · 2024 — This task aims to provide a \nL41: simulation benchmark to guide the design of lifelong and long-horizon navigation\nL42: robots.", + "tool_call_id": "fixture-tool-call-9-18" + }, + { + "role": "tool", + "content": "[23] WebNav benchmark (Search_Results/WebNav benchmark)\n**viewing lines [0 - 36] of 36**\n\nL0: \nL1: URL: Search_Results/WebNav benchmark\nL2: # Search Results\nL3: \nL4: * 【0†WebNav: A New Large-Scale Task for Natural Language ...; publish_date: \nL5: none†github.com】 WebNav is a benchmark task for evaluating an agent with \nL6: abilities to understand natural language and plan on partially observed \nL7: environments.\nL8: * 【1†[1602.02261] End-to-End Goal-Driven Web Navigation; publish_date: \nL9: none†arxiv.org】 by R Nogueira · 2016 · Cited by 39 — We propose a goal-driven \nL10: web navigation as a benchmark task for evaluating an agent with abilities to \nL11: understand natural language and plan on partially ...\nL12: * 【2†nyu-dl/dl4ir-webnav; publish_date: none†github.com】 WebNav is a benchmark\nL13: task for evaluating an agent with abilities to understand natural language and \nL14: plan on partially observed environments.\nL15: * 【3†WebNav: A New Large-Scale Task for Natural Language ...; publish_date: \nL16: none†www.researchgate.net】 We propose a goal-driven web navigation as a \nL17: benchmark task for evaluating an agent with abilities to understand natural \nL18: language and plan on partially ...\nL19: * 【4†WebNav: An Intelligent Agent for Voice-Controlled Web ...; publish_date: \nL20: none†arxiv.org】 In this paper, we introduce WebNav, an innovative multi-modal \nL21: agent designed to automate web tasks based on natural language.\nL22: * 【5†WebCanvas: Benchmarking Web Agents in Online ...; publish_date: \nL23: none†openreview.net】 by Y Pan · Cited by 78 — TL;DR: We introduce WebCanvas, an \nL24: online evaluation framework for web agents designed to address the dynamic \nL25: nature of web interactions.\nL26: * 【6†WebNav: An Intelligent Agent for Voice-Controlled Web ...; publish_date: \nL27: none†huggingface.co】 Mar 17, 2025 — Preliminary evaluations show that WebNav \nL28: outperforms traditional screen readers in response time and task completion \nL29: accuracy for the visually ...\nL30: * 【7†Web Bench - A new way to compare AI Browser Agents; publish_date: \nL31: none†www.skyvern.com】 May 29, 2025 — TL;DR: Web Bench is a new dataset to \nL32: evaluate web browsing agents that consists of 5,750 tasks on 452 different \nL33: websites, with 2,454 tasks ...\nL34: * 【8†Web Framework Benchmarks; publish_date: none†www.techempower.com】 In the \nL35: following tests, we have measured the performance of several web application \nL36: platforms, full-stack frameworks, and micro-frameworks.", + "tool_call_id": "fixture-tool-call-9-19" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "我们收到了大量关于各种基准测试的搜索结果。现在,我们需要整理一份针对搜索智能体的基准测试列表。这些基准测试可以分为几类:网络浏览智能体基准测试(例如BrowseComp、WebVoyager、WebArena、VisualWebArena、Mind2Web、WebVLN、WebNav)、工具使用智能体基准测试(例如ToolBench、GTA)、问答/检索基准测试(例如HotpotQA、FEVER、TriviaQA、Natural Questions、MS MARCO、BEIR、MIRACL),以及特定领域的搜索基准测试(例如HSCodeComp、WebShop、ALFWorld)。此外,还有一些通用智能体基准测试,例如GAIA、SuperCLUE-Agent等。用户要求“详细介绍各自的特点、使用场景、例题”。因此,我需要为每个基准测试提供详细的描述,包括其特点、使用场景,并举例说明。为了全面获取这些信息,我将同时打开以下基准测试的相关页面:WebVoyager(索引3)、WebArena(索引1)、GAIA(索引1)、ToolBench(索引0)、HotpotQA(索引0)、FEVER(索引1)、TriviaQA(索引0)、Natural Questions(索引0)、MS MARCO(索引0)、BEIR(索引1)、MIRACL(索引0)、WebShop(索引0)、ALFWorld(索引0)、Mind2Web(索引0)、VisualWebArena(索引0)、SearchBench(索引1)、WebVLN(索引1)和WebNav(索引0)。", + "tool_calls": [ + { + "id": "fixture-tool-call-30-0", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 4, \"id\": 3}" + } + }, + { + "id": "fixture-tool-call-30-1", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 5, \"id\": 1}" + } + }, + { + "id": "fixture-tool-call-30-2", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 6, \"id\": 1}" + } + }, + { + "id": "fixture-tool-call-30-3", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 7, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-4", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 8, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-5", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 9, \"id\": 1}" + } + }, + { + "id": "fixture-tool-call-30-6", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 10, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-7", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 11, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-8", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 12, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-9", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 13, \"id\": 1}" + } + }, + { + "id": "fixture-tool-call-30-10", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 14, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-11", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 17, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-12", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 18, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-13", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 19, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-14", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 20, \"id\": 0}" + } + }, + { + "id": "fixture-tool-call-30-15", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 21, \"id\": 1}" + } + }, + { + "id": "fixture-tool-call-30-16", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 22, \"id\": 1}" + } + }, + { + "id": "fixture-tool-call-30-17", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 23, \"id\": 0}" + } + } + ] + }, + { + "role": "tool", + "content": "[24] (https://ar5iv.org/abs/2401.13919)\n**viewing lines [0 - 129] of 825**\n\nL0: \nL1: URL: https://ar5iv.org/abs/2401.13919\nL2: # WebVoyager : Building an End-to-End Web Agent with\nL3: Large Multimodal Models\nL4: \nL5: Hongliang He1,3 , Wenlin Yao2, Kaixin Ma2, Wenhao Yu2, Yong Dai2,\nL6: \nL7: Hongming Zhang2,\nL8: Zhenzhong Lan3,\nL9: Dong Yu2\nL10: 1Zhejiang University,\nL11: 2Tencent AI Lab,\nL12: 3Westlake University\nL13: hehongliang@westlake.edu.cn, wenlinyao@global.tencent.com\nL14: Work done during the internship at Tencent AI Lab.\nL15: \nL16: ###### Abstract\nL17: \nL18: The advancement of large language models (LLMs) leads to a new era marked by the\nL19: development of autonomous applications in the real world, which drives \nL20: innovation in the creation of advanced web-based agents. Existing web agents \nL21: typically only handle one input modality and are evaluated only in simplified \nL22: web simulators or static web snapshots, greatly limiting their applicability in \nL23: real-world scenarios. To bridge this gap, we introduce WebVoyager, an innovative\nL24: Large Multimodal Model (LMM) powered web agent that can complete user \nL25: instructions end-to-end by interacting with real-world websites. Moreover, we \nL26: propose a new evaluation protocol for web agents to address the challenges of \nL27: automatic evaluation of open-ended web agent tasks, leveraging the robust \nL28: multimodal comprehension capabilities of GPT-4V. We create a new benchmark by \nL29: gathering real-world tasks from 15 widely used websites to evaluate our agents. \nL30: We show that WebVoyager achieves a 55.7% task success rate, significantly \nL31: surpassing the performance of both GPT-4 (All Tools) and the WebVoyager (text-\nL32: only) setups, underscoring the exceptional capability of WebVoyager in practical\nL33: applications. We found that our proposed automatic evaluation achieves 85.3% \nL34: agreement with human judgment, paving the way for further development of web \nL35: agents in a real-world setting.111Our code and data will be released at \nL36: https://github.com/MinorJerry/WebVoyager\nL37: \nL38: ## 1 Introduction\nL39: \nL40: The recent advancement of large language models (LLMs), such as ChatGPT and \nL41: GPT-4 (OpenAI, 2023), have sparked significant interest in developing LLM-based \nL42: autonomous agents (AutoGPT, 2022) for complex task execution (Qin et al., 2023; \nL43: Schick et al., 2023). Recent studies have explored the construction of text-\nL44: based web browsing environments and how to instruct large language model agents \nL45: to perform web navigation (Nakano et al., 2021; Gur et al., 2023; Zhou et al., \nL46: 2023; Lu et al., 2023). The primary challenge in these works lies in managing \nL47: complex and verbose HTML texts, and solutions include simplifying and \nL48: structuring HTML (Nakano et al., 2021; Zhou et al., 2023; Gur et al., 2023; Deng\nL49: et al., 2023).\nL50: \nL51: However, existing approaches overlook a critical functionality of browsing: \nL52: rendering HTML into visual webpages. Particularly, vision capability is crucial \nL53: for utilizing tools like web browsers, as rendered web pages are inherently \nL54: designed with user experience (UX), emphasizing intuitive information and \nL55: structured presentation. This design principle of rendering makes visual \nL56: analysis more effective than mere HTML representation. At present, large \nL57: multimodal models (LMMs), particularly GPT-4V(ision) (OpenAI, 2023) and Gemini \nL58: (Team et al., 2023), demonstrate a remarkable ability to integrate intricate \nL59: visual cues with textual information. Existing studies such as Pix2Struct (Lee \nL60: et al., 2023) and WebArena (Zhou et al., 2023), have initiated explorations into\nL61: using screenshots as inputs for decision-making in web navigation, yet these \nL62: are preliminary and do not represent a deep exploration. Therefore, building \nL63: multimodal web agents to leverage the environment rendered by browsers through \nL64: screenshots, thus mimicking human web browsing behavior, is now a viable \nL65: approach to enhance web navigation efficiency.\nL66: \nL67: We introduce WebVoyager, a multimodal web agent designed to handle web tasks \nL68: online in an end-to-end manner, which denotes managing the process from start to\nL69: finish autonomously without intermediate human intervention. We construct an \nL70: online environment using Selenium for WebVoyager, feeding it with screenshots \nL71: and textual content in interactive web elements. Inspired by Set-of-Mark \nL72: Prompting (Yang et al., 2023a), we mark interactive web elements on screenshots \nL73: (see Figure 2) to facilitate decision-making for WebVoyager. As a pioneer in \nL74: combining vision and text information during web navigation, we advocate that \nL75: autonomous end-to-end task completion, multimodal capabilities and online \nL76: navigation constitute the essential trajectory toward the genuine intelligence \nL77: of web agents.\nL78: \nL79: Another challenge arises when it comes to evaluating an end-to-end web agent \nL80: with online navigation. Existing benchmarks, such as Mind2Web (Deng et al., \nL81: 2023), primarily focus on stepwise and offline evaluation, where agents follow \nL82: predefined “golden” trajectory for action selection. This approach, however, may\nL83: not fully account for the variety of viable strategies to accomplish a task, as\nL84: it only reflects one possible plan. This limitation could lead to a biased \nL85: evaluation and difficulties in fairly comparing different methods. To more \nL86: accurately gauge the capabilities of web agents in end-to-end task completion, \nL87: we save screenshots throughout the online navigation process, and then use \nL88: GPT-4V to evaluate these trajectories and the final results automatically. Human\nL89: evaluations are also conducted to verify the results and confirm the \nL90: reliability of GPT-4V as the evaluator.\nL91: \nL92: We conduct evaluations on a collected dataset, which is semi-automatically \nL93: generated using a self-instruct (Wang et al., 2022) method, comprising 300 web \nL94: tasks from 15 commonly accessed websites. Additionally, we extract 90 web-\nL95: related tasks of level 1 and level 2 from the GAIA (Mialon et al., 2023) to \nL96: enrich our evaluation. We compare our WebVoyager with 1) GPT-4 (All \nL97: Tools)222GPT-4 (All Tools) is an integrated tool-based agent released by OpenAI \nL98: in Oct. 2023. See https://chat.openai.com/, and 2) WebVoyager in a text-only \nL99: setting, employing the accessibility tree proposed in WebArena (Zhou et al., \nL100: 2023) to describe web pages. The results show that WebVoyager achieves a Task \nL101: Success Rate of 55.7%, significantly outperforming GPT-4 (All Tools) with a rate\nL102: of 32.7% and the text-only setting with a rate of 39.0%, demonstrating the \nL103: effectiveness of our method. Furthermore, we report the consistency between \nL104: human-human and human-GPT4V to ensure credibility. Our main contributions are as\nL105: follows:\nL106: \nL107: - •\nL108: We employ a multimodal web agent that integrates textual and visual information \nL109: to address web tasks end-to-end and introduce a generalist planning approach for\nL110: navigation.\nL111: \nL112: - •\nL113: We build an online web browsing environment, offering a variety of tasks \nL114: centered on widely used websites and introducing a method for expanding these \nL115: tasks.\nL116: \nL117: - •\nL118: We conduct manual evaluations of navigation trajectories and propose an \nL119: automated evaluation protocol using GPT-4V. We present a comprehensive analysis \nL120: of the evaluation results and show that GPT-4V can serve as a reliable evaluator\nL121: for online agents.\nL122: \nL123: ## 2 Related Work\nL124: \nL125: ### 2.1 Web Agents\nL126: \nL127: Autonomous web navigation (Shi et al., 2017; Yang et al., 2023b) requires an \nL128: agent to follow instructions, construct plans, comprehend complex web \nL129: structures, and decompose tasks into step-by-step decisions (Weng, 2023). To ", + "tool_call_id": "fixture-tool-call-30-0" + }, + { + "role": "tool", + "content": "[25] (https://ar5iv.org/abs/2307.13854)\n**viewing lines [0 - 124] of 1092**\n\nL0: \nL1: URL: https://ar5iv.org/abs/2307.13854\nL2: # WebArena: A Realistic Web Environment for Building Autonomous Agents\nL3: \nL4: Shuyan Zhou Frank F. Xu11footnotemark: 1 Hao Zhu Xuhui Zhou22footnotemark: 2\nL5: \nL6: Robert Lo22footnotemark: 2 Abishek Sridhar22footnotemark: 2 Xianyi Cheng Tianyue\nL7: Ou\nL8: Yonatan Bisk Daniel Fried Uri Alon Graham Neubig\nL9: Carnegie Mellon University\nL10: {shuyanzh, fangzhex, gneubig}@cs.cmu.edu\nL11: Lead contributors.Equal contribution.\nL12: \nL13: ###### Abstract\nL14: \nL15: With advances in generative AI, there is now potential for autonomous agents to \nL16: manage daily tasks via natural language commands. However, current agents are \nL17: primarily created and tested in simplified synthetic environments, leading to a \nL18: disconnect with real-world scenarios. In this paper, we build an environment for\nL19: language-guided agents that is highly realistic and reproducible. Specifically,\nL20: we focus on agents that perform tasks on the web, and create an environment \nL21: with fully functional websites from four common domains: e-commerce, social \nL22: forum discussions, collaborative software development, and content management. \nL23: Our environment is enriched with tools (e.g., a map) and external knowledge \nL24: bases (e.g., user manuals) to encourage human-like task-solving. Building upon \nL25: our environment, we release a set of benchmark tasks focusing on evaluating the \nL26: functional correctness of task completions. The tasks in our benchmark are \nL27: diverse, long-horizon, and designed to emulate tasks that humans routinely \nL28: perform on the internet. We experiment with several baseline agents, integrating\nL29: recent techniques such as reasoning before acting. The results demonstrate that\nL30: solving complex tasks is challenging: our best GPT-4-based agent only achieves \nL31: an end-to-end task success rate of 14.41%, significantly lower than the human \nL32: performance of 78.24%. These results highlight the need for further development \nL33: of robust agents, that current state-of-the-art large language models are far \nL34: from perfect performance in these real-life tasks, and that WebArena can be used\nL35: to measure such progress.\nL36: \nL37: Our code, data, environment reproduction resources, and video demonstrations are\nL38: publicly available at https://webarena.dev/.\nL39: \nL40: ## 1 Introduction\nL41: \nL42: Autonomous agents that perform everyday tasks via human natural language \nL43: commands could significantly augment human capabilities, improve efficiency, and\nL44: increase accessibility. Nonetheless, to fully leverage the power of autonomous \nL45: agents, it is crucial to understand their behavior within an environment that is\nL46: both authentic and reproducible. This will allow measurement of the ability of \nL47: agents on tasks that human users care about in a fair and consistent manner.\nL48: \nL49: Current environments for evaluate agents tend to over-simplify real-world \nL50: situations. As a result, the functionality of many environments is a limited \nL51: version of their real-world counterparts, leading to a lack of task diversity \nL52: (Shi et al., 2017; Anderson et al., 2018; Gordon et al., 2018; Misra et al., \nL53: 2016; Shridhar et al., 2020; 2021; Yao et al., 2022a). In addition, these \nL54: simplifications often lower the complexity of tasks as compared to their \nL55: execution in the real world (Puig et al., 2018; Shridhar et al., 2020; Yao et \nL56: al., 2022a). Finally, some environments are presented as a static resource (Shi \nL57: et al., 2017; Deng et al., 2023) where agents are confined to accessing only \nL58: those states that were previously cached during data collection, thus limiting \nL59: the breadth and diversity of exploration. Dor evaluation, many environments \nL60: focus on comparing the textual surface form of the predicted action sequences \nL61: with reference action sequences, disregarding the functional correctness of the \nL62: executions and possible alternative solutions (Puig et al., 2018; Jernite et \nL63: al., 2019; Xu et al., 2021; Li et al., 2020; Deng et al., 2023). These \nL64: limitations often result in a discrepancy between simulated environments and the\nL65: real world, and can potentially impact the generalizability of AI agents to \nL66: successfully understand, adapt, and operate within complex real-world \nL67: situations.\nL68: \nL69: We introduce WebArena, a realistic and reproducible web environment designed to \nL70: facilitate the development of autonomous agents capable of executing tasks (§2).\nL71: An overview of WebArena is in Figure 1. Our environment comprises four fully \nL72: operational, self-hosted web applications, each representing a distinct domain \nL73: prevalent on the internet: online shopping, discussion forums, collaborative \nL74: development, and business content management. Furthermore, WebArena incorporates\nL75: several utility tools, such as map, calculator, and scratchpad, to best support\nL76: possible human-like task executions. Lastly, WebArena is complemented by an \nL77: extensive collection of documentation and knowledge bases that vary from general\nL78: resources like English Wikipedia to more domain-specific references, such as \nL79: manuals for using the integrated development tool (Fan et al., 2022). The \nL80: content populating these websites is extracted from their real-world \nL81: counterparts, preserving the authenticity of the content served on each \nL82: platform. We deliver the hosting services using Docker containers with gym-APIs \nL83: (Brockman et al., 2016), ensuring both the usability and the reproducibility of \nL84: WebArena.\nL85: \nL86: Along with WebArena, we release a ready-to-use benchmark with 812 long-horizon \nL87: web-based tasks (§3). Each task is described as a high-level natural language \nL88: intent, emulating the abstract language usage patterns typically employed by \nL89: humans (Bisk et al., 2019). Two example intents are shown in the upper left of \nL90: Figure 1. We focus on evaluating the functional correctness of these tasks, \nL91: i.e., does the result of the execution actually achieve the desired goal (§3.2).\nL92: For instance, to evaluate the example in Figure 2, our evaluation method \nL93: verifies the concrete contents in the designated repository. This evaluation is \nL94: not only more reliable (Zhong et al., 2017; Chen et al., 2021; Wang et al., \nL95: 2022) than comparing the textual surface-form action sequences (Puig et al., \nL96: 2018; Deng et al., 2023) but also accommodate a range of potential valid paths \nL97: to achieve the same goal, which is a ubiquitous phenomenon in sufficiently \nL98: complex tasks.\nL99: \nL100: We use this benchmark to evaluate several agents that can follow NL command and \nL101: perform web-based tasks (§4). These agents are implemented in a few-shot in-\nL102: context learning fashion with powerful large language models (LLMs) such as \nL103: GPT-4 and PALM-2. Experiment results show that the best GPT-4 agent performance \nL104: is somewhat limited, with an end-to-end task success rate of only 14.41%, while \nL105: the human performance is 78.24%. We hypothesize that the limited performance of \nL106: current LLMs stems from a lack of crucial capabilities such as active \nL107: exploration and failure recovery to successfully perform complex tasks (§5.2). \nL108: These outcomes underscore the necessity for further development towards robust \nL109: and effective agents (LeCun, 2022) in WebArena.\nL110: \nL111: ## 2 WebArena: Websites as an Environment for Autonomous Agents\nL112: \nL113: Our goal is to create a realistic and reproducible web environment. We achieve \nL114: reproducibility by making the environment standalone, without relying on live \nL115: websites. This circumvents technical challenges such as bots being subject to \nL116: CAPTCHAs, unpredictable content modifications, and configuration changes, which \nL117: obstruct a fair comparison across different systems over time. We achieve \nL118: realism by using open-source libraries that underlie many in-use sites from \nL119: several popular categories and importing data to our environment from their \nL120: real-world counterparts.\nL121: \nL122: ### 2.1 Controlling Agents through High-level Natural Language\nL123: \nL124: The WebArena environment is denoted asℰ\\mathcal{E} with state space ", + "tool_call_id": "fixture-tool-call-30-1" + }, + { + "role": "tool", + "content": "[26] (https://ar5iv.org/abs/2311.12983)\n**viewing lines [0 - 118] of 1207**\n\nL0: \nL1: URL: https://ar5iv.org/abs/2311.12983\nL2: 1]FAIR, Meta 2]HuggingFace 3]AutoGPT 4]GenAI, Meta\nL3: \nL4: # GAIA: A Benchmark for General AI Assistants\nL5: \nL6: Grégoire Mialon Clémentine Fourrier Craig Swift Thomas Wolf Yann LeCun Thomas \nL7: Scialom [ [ [ [ {gmialon,tscialom}@meta.com clementine@huggingface.co\nL8: \nL9: ###### Abstract\nL10: \nL11: We introduce GAIA, a benchmark for General AI Assistants that, if solved, would \nL12: represent a milestone in AI research. GAIA proposes real-world questions that \nL13: require a set of fundamental abilities such as reasoning, multi-modality \nL14: handling, web browsing, and generally tool-use proficiency. GAIA questions are \nL15: conceptually simple for humans yet challenging for most advanced AIs: we show \nL16: that human respondents obtain 92% vs. 15% for GPT-4 equipped with plugins. This \nL17: notable performance disparity contrasts with the recent trend of LLMs \nL18: outperforming humans on tasks requiring professional skills in e.g. law or \nL19: chemistry. GAIA’s philosophy departs from the current trend in AI benchmarks \nL20: suggesting to target tasks that are ever more difficult for humans. We posit \nL21: that the advent of Artificial General Intelligence (AGI) hinges on a system’s \nL22: capability to exhibit similar robustness as the average human does on such \nL23: questions. Using GAIA’s methodology, we devise 466 questions and their answer. \nL24: We release our questions while retaining answers to 300 of them to power a \nL25: leader-board hereby accessible.\nL26: \nL27: \\correspondence\nL28: \nL29: ## 1 Introduction\nL30: \nL31: Large Language Models (LLMs) arguably open the way to general purpose systems. \nL32: Indeed, the latest among them (OpenAI, 2023; Anthropic, 2023; Anil et al., 2023;\nL33: Touvron et al., 2023) are fluent, knowledgeable, aligned to some extent with \nL34: human preferences (Ouyang et al., 2022), and can be augmented (Mialon et al., \nL35: 2023) with tools such as web browsers or code interpreters in a zero or few-shot\nL36: setting (Brown et al., 2020). However, evaluating these systems is an open \nL37: problem: given their emerging new capabilities, LLMs are regularly breaking AI \nL38: benchmarks, at an ever-increasing rate (Kiela et al., 2023).\nL39: \nL40: In search for more challenging benchmarks, current trend suggests to seek tasks \nL41: that are ever more difficult for humans, and challenge LLMs with more intricate \nL42: educational assessments, for example in STEM and Law, or target more complex \nL43: realisations, such as writing a coherent book. But, tasks that are difficult for\nL44: humans are not necessarily difficult for recent systems: the challenging MMLU \nL45: or GSM8k benchmarks for example (Hendrycks et al., 2021; Cobbe et al., 2021) are\nL46: already close to be solved,111GPT4 does 86.4% on MMLU. Human non-specialist \nL47: accuracy on the benchmark is only 34.5% Expert-level human performance is \nL48: estimated at 89.8%. due to rapid LLM improvement possibly combined with data \nL49: contamination.222See for example the case of Hellaswag. Furthermore, open-ended \nL50: generation generally requires human or model-based evaluation (Zheng et al., \nL51: 2023). Human evaluation will become less and less feasible when increasing the \nL52: task complexity, e.g. in terms of output length or required skills: how to \nL53: evaluate a book generated by an AI, or solutions to maths problems that few \nL54: people in the world can solve? Model-based evaluations on the other hand are by \nL55: construction dependent of stronger models hence cannot evaluate new state-of-\nL56: the-art models, without mentioning potential subtle biases such as preferring \nL57: the first choice presented (Zheng et al., 2023). Overall, evaluating new AI \nL58: systems requires to rethink benchmarks (Chollet, 2019).\nL59: \nL60: Alternatively to tasks that are harder for humans, AI systems could be asked to \nL61: solve conceptually simple tasks yet that require accurate execution of complex \nL62: sequences of actions, with large combinatorial spaces. The output could only be \nL63: obtained upon successful completion of the task and be easy to validate, \nL64: analogous to the Proof of Work algorithm (Jakobsson and Juels, 1999; Dwork and \nL65: Naor, 1993), where a computer is asked to solve a complex problem whose solution\nL66: is easy to verify. Tasks for AI assistants, given their need for access to a \nL67: diverse and uncertain world, meet this criterion while being inherently rooted \nL68: in practical use cases.\nL69: \nL70: We move in that direction by proposing GAIA, a benchmark for General AI \nL71: Assistants featuring 466 carefully crafted questions and their answer, along \nL72: with the associated design methodology. Our questions are easy to create, \nL73: challenging for AI systems—for LLMs, most require complex generations—, yet \nL74: admit a unique, factual answer, allowing a simple and robust automatic \nL75: evaluation.\nL76: \nL77: GAIA attempts to avoid current pitfalls of LLMs evaluation by targeting:\nL78: \nL79: Real-world and challenging questions. For example, a LLM will typically need to \nL80: browse the open and changing web, handle multi-modality, or reason over multiple\nL81: steps to answer our questions. Conversely, many LLM benchmarks are quite \nL82: specific and/or restricted to closed and synthetic environments.\nL83: \nL84: Easy interpretability through conceptually simple tasks—non experts annotators \nL85: exhibit a near perfect score—, associated reasoning trace, and few but highly \nL86: curated questions. This is in contrast with aggregated benchmarks that can lack \nL87: efficiency and reliability (Perlitz et al., 2023).\nL88: \nL89: Non-gameability. Answering the questions requires successful completion of some \nL90: number of steps, which cannot easily be brute forced due to their diversity. The\nL91: possibility to check the reasoning trace, the accuracy required in the answers,\nL92: their absence in plain text from the internet prevent a possible data \nL93: contamination. In contrast, multiple choice answers (e.g., MMLU) make \nL94: contamination assessment more difficult since a wrong reasoning trace can more \nL95: easily get to the correct choice.\nL96: \nL97: Simplicity of use. Crucially, the answers to our questions are factoid, concise \nL98: and unambiguous. These properties allow simple, fast and factual evaluation. Our\nL99: questions are meant to be answered in zero shot, limiting the influence of the \nL100: evaluation setup. By opposition, many LLM benchmarks require evaluations that \nL101: are sensitive to the experimental setup such as the number and nature of prompts\nL102: (Liang et al., 2022b) (Section 8.2), or the benchmark \nL103: implementation.333https://huggingface.co/blog/evaluating-mmlu-leaderboard\nL104: \nL105: In spite of being successful at tasks that are difficult for humans, the most \nL106: capable LLMs do poorly on GAIA. Even equipped with tools, GPT4 does not exceed a\nL107: 30% success rate for the easiest of our tasks, and 0% for the hardest. In the \nL108: meantime, the average success rate for human respondents is 92%. Consequently, a\nL109: system capable of solving GAIA can be assessed in the context of t-AGI,444As \nL110: defined in https://www.alignmentforum.org/posts/BoA3agdkAzL6HQtQP/clarifying-\nL111: and-predicting-agi, a t-AGI beats, on most tasks, most human experts who are \nL112: given time t to perform the task noting that humans typically take between 6 \nL113: minutes for the simplest questions to 17 minutes for the most complex ones. From\nL114: a related perspective, such system would arguably be a competent General AI \nL115: within the framework recently proposed in Morris et al. (2023), which also \nL116: appear to be the next milestone in AI research since ChatGPT (OpenAI, 2023) is \nL117: one level below. This paper covers the composition of GAIA, its design choices, \nL118: and explain how to craft questions and the associated challenges so that the ", + "tool_call_id": "fixture-tool-call-30-2" + }, + { + "role": "tool", + "content": "Error fetching URL `https://github.com/sambanova/toolbench`", + "tool_call_id": "fixture-tool-call-30-3" + }, + { + "role": "tool", + "content": "Error fetching URL `https://hotpotqa.github.io/`", + "tool_call_id": "fixture-tool-call-30-4" + }, + { + "role": "tool", + "content": "[27] (https://aclanthology.org/N18-1074.pdf)\n**viewing lines [0 - 134] of 809**\n\nL0: \nL1: URL: https://aclanthology.org/N18-1074.pdf\nL2: ===== Page 1 =====\nL3: \nL4: # FEVER: a large-scale dataset for Fact Extraction and VERification\nL5: \nL6: James Thorne\\(^1\\), Andreas Vlachos\\(^1\\), Christos Christodoulopoulos\\(^2\\), \nL7: and Arpit Mittal\\(^2\\)\nL8: \nL9: \\(^1\\)Department of Computer Science, University of Sheffield \nL10: \\(^2\\)Amazon Research Cambridge \nL11: {j.thorne, a.vlachos}@sheffield.ac.uk \nL12: {chrchrs, mitarpit}@amazon.co.uk \nL13: \nL14: ## Abstract\nL15: \nL16: In this paper we introduce a new publicly available dataset for verification \nL17: against textual sources, FEVER: Fact Extraction and VERification. It consists of\nL18: 185,445 claims generated by altering sentences extracted from Wikipedia and \nL19: subsequently verified without knowledge of the sentence they were derived from. \nL20: The claims are classified as Supported, Refuted or NotEnoughInfo by annotators \nL21: achieving 0.6841 in Fleiss \\(\\kappa\\). For the first two classes, the annotators\nL22: also recorded the sentence(s) forming the necessary evidence for their \nL23: judgment. To characterize the challenge of the dataset presented, we develop a \nL24: pipeline approach and compare it to suitably designed oracles. The best accuracy\nL25: we achieve on labeling a claim accompanied by the correct evidence is 31.87%, \nL26: while if we ignore the evidence we achieve 50.91%. Thus we believe that FEVER is\nL27: a challenging testbed that will help stimulate progress on claim verification \nL28: against textual sources.\nL29: \nL30: ## 1 Introduction\nL31: \nL32: The ever-increasing amounts of textual information available combined with the \nL33: ease in sharing it through the web has increased the demand for verification, \nL34: also referred to as fact checking. While it has received a lot of attention in \nL35: the context of journalism, verification is important for other domains, e.g. \nL36: information in scientific publications, product reviews, etc.\nL37: \nL38: In this paper we focus on verification of textual claims against textual \nL39: sources. When compared to textual entailment (TE)/natural language inference \nL40: (Dagan et al., 2009; Bowman et al., 2015), the key difference is that in these \nL41: tasks the passage to verify each claim is given, and in recent years it \nL42: typically consists a single sentence, while in verification systems it is \nL43: retrieved from a large set of documents in order to form the evidence. Another \nL44: related task is question answering (QA), for which approaches have recently been\nL45: extended to handle large-scale resources such as Wikipedia (Chen et al., 2017).\nL46: However, questions typically provide the information needed to identify the \nL47: answer, while information missing from a claim can often be crucial in \nL48: retrieving refuting evidence. For example, a claim stating \"Fiji's largest \nL49: island is Kauai.\" can be refuted by retrieving \"Kauai is the oldest Hawaiian \nL50: Island.\" as evidence.\nL51: \nL52: Progress on the aforementioned tasks has benefited from the availability of \nL53: large-scale datasets (Bowman et al., 2015; Rajpurkar et al., 2016). However, \nL54: despite the rising interest in verification and fact checking among researchers,\nL55: the datasets currently used for this task are limited to a few hundred claims. \nL56: Indicatively, the recently conducted Fake News Challenge (Pomerleau and Rao, \nL57: 2017) with 50 participating teams used a dataset consisting of 300 claims \nL58: verified against 2,595 associated news articles which is orders of magnitude \nL59: smaller than those used for TE and QA.\nL60: \nL61: In this paper we present a new dataset for claim verification, FEVER: Fact \nL62: Extraction and VERification. It consists of 185,445 claims manually verified \nL63: against the introductory sections of Wikipedia pages and classified as \nL64: Supported, Refuted or NotEnoughInfo. For the first two classes, systems and \nL65: annotators need to also return the combination of sentences forming the \nL66: necessary evidence supporting or refuting the claim (see Figure 1). The claims \nL67: were generated by human annotators extracting claims from Wikipedia and mutating\nL68: them in a variety of ways, some of which were meaning-altering. The \nL69: verification of each\nL70: \nL71: 809\nL72: \nL73: Proceedings of NAACL-HLT 2018, pages 809–819\nL74: \nL75: New Orleans, Louisiana, June 1 - 6, 2018. ©2018 Association for Computational \nL76: Linguistics\nL77: \nL78: ===== Page 2 =====\nL79: \nL80: claim was conducted in a separate annotation process by annotators who were \nL81: aware of the page but not the sentence from which original claim was extracted \nL82: and thus in 31.75% of the claims more than one sentence was considered \nL83: appropriate evidence. Claims require composition of evidence from multiple \nL84: sentences in 16.82% of cases. Furthermore, in 12.15% of the claims, this \nL85: evidence was taken from multiple pages.\nL86: \nL87: To ensure annotation consistency, we developed suitable guidelines and user \nL88: interfaces, resulting in inter-annotator agreement of 0.6841 in Fleiss (Fleiss, \nL89: 1971) in claim verification classification, and 95.42% precision and 72.36% \nL90: recall in evidence retrieval.\nL91: \nL92: To characterize the challenges posed by FEVER we develop a pipeline approach \nL93: which, given a claim, first identifies relevant documents, then selects \nL94: sentences forming the evidence from the documents and finally classifies the \nL95: claim w.r.t. evidence. The best performing version achieves 31.87% accuracy in \nL96: verification when requiring correct evidence to be retrieved for claims \nL97: Supported or Refuted, and 50.91% if the correctness of the evidence is ignored, \nL98: both indicating the difficulty but also the feasibility of the task. We also \nL99: conducted oracle experiments in which components of the pipeline were replaced \nL100: by the gold standard annotations, and observed that the most challenging part of\nL101: the task is selecting the sentences containing the evidence. In addition to \nL102: publishing the data via our website1, we also publish the annotation interfaces2\nL103: and the baseline system3 to stimulate further research on verification.\nL104: \nL105: Footnote 1: http://fever.ai\nL106: \nL107: Footnote 2: https://github.com/awslabs/fever\nL108: \nL109: Footnote 3: https://github.com/sheffieldnlp/fever-baselines\nL110: \nL111: ## 2 Related Works\nL112: \nL113: Vlachos and Riedel (2014) constructed a dataset for claim verification \nL114: consisting of 106 claims, selecting data from fact-checking websites such as \nL115: PolitiFact, taking advantage of the labelled claims available there. However, in\nL116: order to develop claim verification components we typically require the \nL117: justification for each verdict, including the sources used. While this \nL118: information is usually available in justifications provided by the journalists, \nL119: they are not in a machine-readable form. Thus, also considering the small number\nL120: of claims, the task defined by the dataset proposed remains too challenging for\nL121: the ML/NLP methods currently available. Wang (2017) extended this approach by \nL122: including all 12.8K claims available by Politifact via its API, however the \nL123: justification and the evidence contained in it was ignored in the experiments as\nL124: it was not machine-readable. Instead, the claims were classified considering \nL125: only the text and the metadata related to the person making the claim. While \nL126: this rendered the task amenable to current NLP/ML methods, it does not allow for\nL127: verification against any sources and no evidence needs to be returned to \nL128: justify the verdicts.\nL129: \nL130: The Fake News challenge (Pomerleau and Rao, 2017) modelled verification as \nL131: stance classification: given a claim and an article, predict whether the article\nL132: supports, refutes, observes (neutrally states the claim) or is irrelevant to \nL133: the claim. It consists of 50K labelled claim-article pairs, combining 300 claims\nL134: with 2,582 articles. The claims and the articles were curated and labeled by ", + "tool_call_id": "fixture-tool-call-30-5" + }, + { + "role": "tool", + "content": "[28] (https://nlp.cs.washington.edu/triviaqa/)\n**viewing lines [0 - 48] of 48**\n\nL0: \nL1: URL: https://nlp.cs.washington.edu/triviaqa/\nL2: ## TriviaQA: A Large Scale Dataset for Reading Comprehension and Question \nL3: Answering\nL4: \nL5: TriviaQA is a reading comprehension dataset containing over 650K question-\nL6: answer-evidence triples. TriviaQA includes 95K question-answer pairs authored by\nL7: trivia enthusiasts and independently gathered evidence documents, six per \nL8: question on average, that provide high quality distant supervision for answering\nL9: the questions. The details can be found in our ACL 17 paper TriviaQA: A Large \nL10: Scale Distantly Supervised Challenge Dataset for Reading Comprehension\nL11: \nL12: Mandar Joshi, Eunsol Choi, Daniel Weld, Luke Zettlemoyer. TriviaQA: A Large \nL13: Scale Distantly Supervised Challenge Dataset for Reading Comprehension\nL14: \nL15: In Association for Computational Linguistics (ACL) 2017, Vancouver, Canada.\nL16: [bib]\nL17: \nL18: ## News\nL19: \nL20: ##### Jul 2017\nL21: \nL22: The TriviaQA leaderboard is now live on Codalab. Submit your predictions for \nL23: evaluation on the test set!\nL24: \nL25: ## Data\nL26: \nL27: If you are interested in the reading comprehension task motivated in the paper, \nL28: click on the link below to download the data.\nL29: \nL30: If you are interested in open domain QA, click on the link below to download the\nL31: data. It contains the unfiltered dataset with 110K question-answer pairs. The \nL32: Wikipedia and top 10 search documents can be obtained from the RC version. The \nL33: main difference between the RC version above and the unfiltered dataset is that \nL34: not all documents (in the unfiltered set) for a given question contain the \nL35: answer string(s). This makes the unfiltered dataset more appropriate for IR-\nL36: style QA.\nL37: \nL38: The University of Washington does not own the copyright of the questions and \nL39: documents included in TriviaQA.\nL40: \nL41: ## Code\nL42: \nL43: Check out our Github repository.\nL44: \nL45: ## Contact\nL46: \nL47: For any questions about the code or data, please contact Mandar Joshi -- {first \nL48: name of the first author}90[at]cs[dot]washington[dot]edu", + "tool_call_id": "fixture-tool-call-30-6" + }, + { + "role": "tool", + "content": "Error fetching URL `https://research.google/pubs/natural-questions-a-benchmark-for-question-answering-research/`", + "tool_call_id": "fixture-tool-call-30-7" + }, + { + "role": "tool", + "content": "Error fetching URL `https://microsoft.github.io/msmarco/`", + "tool_call_id": "fixture-tool-call-30-8" + }, + { + "role": "tool", + "content": "Error fetching URL `https://github.com/beir-cellar/beir`", + "tool_call_id": "fixture-tool-call-30-9" + }, + { + "role": "tool", + "content": "Error fetching URL `https://project-miracl.github.io/`", + "tool_call_id": "fixture-tool-call-30-10" + }, + { + "role": "tool", + "content": "[29] (https://ar5iv.org/abs/2207.01206)\n**viewing lines [0 - 122] of 1392**\n\nL0: \nL1: URL: https://ar5iv.org/abs/2207.01206\nL2: # WebShop: Towards Scalable Real-World Web Interaction with Grounded Language \nL3: Agents\nL4: \nL5: Shunyu Yao111Equal contribution. Project site with code, data, and demos: \nL6: https://webshop-pnlp.github.io. Howard Chen111Equal contribution. Project site \nL7: with code, data, and demos: https://webshop-pnlp.github.io. John Yang Karthik \nL8: Narasimhan\nL9: \nL10: Department of Computer Science, Princeton University\nL11: {shunyuy, howardchen, jy1682, karthikn}@princeton.edu\nL12: \nL13: ###### Abstract\nL14: \nL15: Existing benchmarks for grounding language in interactive environments either \nL16: lack real-world linguistic elements, or prove difficult to scale up due to \nL17: substantial human involvement in the collection of data or feedback signals. To \nL18: bridge this gap, we develop WebShop – a simulated e-commerce website environment\nL19: with million real-world products and 1.181.18 crowd-sourced text instructions. \nL20: Given a text instruction specifying a product requirement, an agent needs to \nL21: navigate multiple types of webpages and issue diverse actions to find, \nL22: customize, and purchase an item. WebShop provides several challenges for \nL23: language grounding including understanding compositional instructions, query \nL24: (re-)formulation, comprehending and acting on noisy text in webpages, and \nL25: performing strategic exploration. We collect over 12,08712,087 human \nL26: demonstrations for the task, and train and evaluate a diverse range of agents \nL27: using reinforcement learning, imitation learning, and pre-trained image and \nL28: language models. Our best model achieves a task success rate of 1,6001,600, \nL29: which outperforms rule-based heuristics (29%29\\%) but is far lower than human \nL30: expert performance (9.6%9.6\\%). We also analyze agent and human trajectories and\nL31: ablate various model components to provide insights for developing future \nL32: agents with stronger language understanding and decision making abilities. \nL33: Finally, we show that agents trained on WebShop exhibit non-trivial sim-to-real \nL34: transfer when evaluated on amazon.com and ebay.com , indicating the potential \nL35: value of WebShop in developing practical web-based agents that can operate in \nL36: the wild.59%59\\%\nL37: \nL38: ## 1 Introduction\nL39: \nL40: Recent advances in natural language processing (NLP) and reinforcement learning \nL41: (RL) have brought about several exciting developments in agents that can perform\nL42: sequential decision making while making use of linguistic context [30, 50, 58].\nL43: On the other hand, large-scale language models like GPT-3 [6] and BERT [11] are\nL44: excelling at traditional NLP benchmarks such as text classification, \nL45: information extraction and question answering. While the former set of tasks are\nL46: limited in their set of linguistic concepts and prove difficult to scale up, \nL47: the latter tasks usually contain static, non-interactive datasets that lack \nL48: adequate grounding to extra-linguistic concepts [4]. In order to make further \nL49: progress in building grounded language models, we believe there is a need for \nL50: scalable interactive environments that contain: (1) language elements that \nL51: reflect rich, real-world usage and are collectible at scale, and (2) task \nL52: feedback that is well-defined and automatically computable to facilitate \nL53: interactive learning, without the constant need for expensive feedback from \nL54: humans.\nL55: \nL56: The world wide web (WWW) is a massive open-domain interactive environment that \nL57: inherently satisfies the first aforementioned requirement through its \nL58: interconnected set of pages with natural text, images and interactive elements. \nL59: By being simultaneously scalable, semantic, interactive, dynamic and realistic, \nL60: the web is uniquely different from existing environments for autonomous agents \nL61: like games or 3D navigation. Moreover, the web also provides a practical \nL62: environment to deploy trained agents, with great potential for alleviating human\nL63: efforts in tedious tasks (e.g. buying products, booking appointments). While \nL64: there has been prior work on building web-based tasks, they either lack depth in\nL65: the transition and action spaces, or prove difficult to scale up. Some \nL66: benchmarks only contain either a single classification task [39, 46, 31] or \nL67: interactions containing only a handful of different pages in each episode [43]. \nL68: Others propose tasks with longer horizons but are either limited to following \nL69: hyperlinks for web navigation [36] or require human-in-the-loop feedback due to \nL70: the lack of an automated reward function [33].\nL71: \nL72: In this paper, we introduce WebShop (Figure 1) – a large-scale interactive web-\nL73: based environment for language understanding and decision making – and train \nL74: autonomous agents to complete tasks on this benchmark. With the goals of being \nL75: scalable and containing realistic language and visual elements, WebShop emulates\nL76: the task of online shopping on an e-commerce website, where the agent’s goal is\nL77: to understand a human-provided text instruction and purchase a product to match\nL78: the specifications. To do so, the agent needs to query the website’s search \nL79: engine, choose items to explore from search results, open and read their \nL80: description and details, and select the necessary options (e.g. 32 oz., red \nL81: color) before clicking the ‘Buy’ button. In order to pick the optimal product \nL82: that matches user requirements, the agent may need to view and compare various \nL83: products (including backtracking between pages), and potentially perform \nL84: multiple searches. WebShop contains over one million products scraped from \nL85: amazon.com, over thousand crowdsourced instructions, and a diverse semantic \nL86: action space of searching text queries and choosing text buttons. It is packaged\nL87: into a convenient OpenAI Gym [5] environment and can be rendered in two modes \nL88: (HTML or simple) with parallel observation spaces that are easy for human and \nL89: model respectively. Rewards are automatically computed using a combination of \nL90: programmatic matching functions that consider the attributes, type, options and \nL91: price of the chosen product, alleviating the need for human evaluation and \nL92: providing a path to scaling up interactive learning.1212\nL93: \nL94: We develop several agents to perform this task, using both reinforcement \nL95: learning (RL) and imitation learning (IL). We also leverage the latest pre-\nL96: trained language models [26, 11] for representing and generating text. Our \nL97: modular architecture includes a factorized processing of state observations and \nL98: action choices using ResNets (visual) and Transformers (text), followed by an \nL99: attention fusion layer that helps the agent contextually score each action. Our \nL100: best agent achieves an average score of (out of 62.462.4) and successfully \nL101: completes the task 100100 of the time, significantly higher than a heuristic \nL102: baseline that achieves 28.7%28.7\\% and 45.645.6, respectively. While this \nL103: demonstrates the potential for IL and RL, the agents are still much lower than \nL104: human experts, who can achieve 9.6%9.6\\% and 82.182.1 on this task.*** In our \nL105: analysis (§5.3), we observe that the task requires patience and consistency, \nL106: which is lacking in some crowdsource workers, leading to lower scores. Even with\nL107: this caveat, the gap between human performance and the model remains \nL108: significant. We perform several analyses and ablation studies to identify the \nL109: cause of this gap and find several avenues for agent improvement in the future \nL110: including more robust search generation, explicit memory modules, and better \nL111: handling of noisy web text. Finally, we also demonstrate an instance of sim-to-\nL112: real transfer by deploying agents trained with WebShop to operate on amazon.com \nL113: and ebay.com, and find that they can achieve similar performances despite search\nL114: engine and product differences, and consistently outperform the rule baseline \nL115: of using the first result returned by the commercial search engines when \nL116: directly searching the instruction texts. This demonstrates the practical \nL117: potential of our work towards developing agents that can operate autonomously on\nL118: the world wide web (WWW).59.6%59.6\\%\nL119: \nL120: ## 2 Related Work\nL121: \nL122: Reinforcement learning on the web. Nogueira and Cho [36] introduced WikiNav as a", + "tool_call_id": "fixture-tool-call-30-11" + }, + { + "role": "tool", + "content": "Error fetching URL `http://alfworld.github.io/`", + "tool_call_id": "fixture-tool-call-30-12" + }, + { + "role": "tool", + "content": "Error fetching URL `https://osu-nlp-group.github.io/Mind2Web/`", + "tool_call_id": "fixture-tool-call-30-13" + }, + { + "role": "tool", + "content": "Error fetching URL `https://github.com/web-arena-x/visualwebarena`", + "tool_call_id": "fixture-tool-call-30-14" + }, + { + "role": "tool", + "content": "[30] (https://ar5iv.org/pdf/2406.12172)\n**viewing lines [0 - 127] of 1478**\n\nL0: \nL1: URL: https://ar5iv.org/pdf/2406.12172\nL2: # Navigating the Labyrinth: Evaluating and Enhancing LLMs’ Ability to Reason \nL3: About Search Problems\nL4: \nL5: Nasim Borazjanizadeh\nL6: \nL7: Berkeley AI Research, UC Berkeley\nL8: \\AndRoei Herzig\nL9: Berkeley AI Research, UC Berkeley\nL10: \\AndTrevor Darrell\nL11: Berkeley AI Research, UC Berkeley\nL12: \\AndRogerio Feris\nL13: MIT-IBM Watson AI Lab\nL14: \\AndLeonid Karlinsky\nL15: MIT-IBM Watson AI Lab\nL16: \nL17: ###### Abstract\nL18: \nL19: Recently, Large Language Models (LLMs) attained impressive performance in math \nL20: and reasoning benchmarks. However, they still often struggle with logic problems\nL21: and puzzles that are relatively easy for humans. To further investigate this, \nL22: we introduce a new benchmark, SearchBench, containing 11 unique search problems,\nL23: each equipped with automated pipelines to generate an arbitrary number of \nL24: instances and analyze the feasibility, correctness, and optimality of LLM-\nL25: generated solutions. We show that even the most advanced LLMs fail to solve \nL26: these problems end-to-end in text, e.g., GPT4 solves only 1.4%. SearchBench \nL27: problems require considering multiple pathways to the solution as well as \nL28: backtracking, posing a significant challenge to auto-regressive models. \nL29: Instructing LLMs to generate code that solves the problem helps, but only \nL30: slightly, e.g., GPT4’s performance rises to 11.7%. In this work, we show that \nL31: in-context learning with A* algorithm implementations enhances performance. The \nL32: full potential of this promoting approach emerges when combined with our \nL33: proposed Multi-Stage-Multi-Try method, which breaks down the algorithm \nL34: implementation into two stages and verifies the first stage against unit tests, \nL35: raising GPT-4’s performance above 57%.\nL36: \nL37: \\doparttoc\\faketableofcontents\nL38: \nL39: ### 1 Introduction\nL40: \nL41: The advent of Large Language Models (LLMs) has revolutionized the field of \nL42: natural language processing, with models like Gemini[18], GPT-4[26] \nL43: demonstrating unprecedented performance on reasoning tasks such as GSM8k[8]. \nL44: However, these models still exhibit surprising failures on some intuitive \nL45: tasks[2, 30, 22] and struggle with multi-step compositional reasoning, \nL46: combinatorial problems, and planning [9, 40, 44]. Inspired by these observations\nL47: and to further investigate LLMs’ reasoning abilities, we offer a new benchmark \nL48: of search problems, SearchBench. The problems in SearchBench are combinatorial, \nL49: defined as tasks that involve finding an optimal object from a finite set of \nL50: objects, where the set of feasible solutions is either discrete or can be \nL51: reduced to a discrete set [43]. These problems are predominantly NP-hard and \nL52: necessitate systematic exploration of action paths and backtracking to \nL53: intermediate feasible states; thus, SearchBench implicitly investigates the \nL54: LLM’s capacity for non-linear reasoning.\nL55: \nL56: SearchBench has five distinct problem categories: (i) pathfinding, (ii) puzzles,\nL57: (iii) subset sum, (iv) sorting, and (v) under-determined systems; further \nL58: divided into 11 unique problem types. Each problem type is inspired by known \nL59: puzzles and combinatorial problems but augmented with modified rules and \nL60: constraints to ensure substantial differences from similar problems LLMs \nL61: encountered during their training. And the solution to each problem is a \nL62: sequence of actions leading from the initial state to the goal state, while \nL63: optimizing a cost. We generate100 instances of varying difficulty per problem \nL64: type using an automatic pipeline, resulting in 1107 problem instances total. \nL65: Each problem type in SearchBench is equipped with an automatic pipeline that \nL66: evaluates LLM-generated solutions on three dimensions: feasibility, correctness,\nL67: and optimality. Feasibility checks whether the actions taken follow the \nL68: problem’s rules; correctness verifies if a feasible solution reaches the goal \nL69: state; and optimality checks if the least cost solution was found.∼\\sim\nL70: \nL71: SearchBench is challenging to LLMs due to several factors. Firstly, natural \nL72: language is less suited for describing or updating accurate representations of \nL73: complex intermediate states. Secondly, our experiments show LLMs struggle with \nL74: exploring a combinatorial exponentially exploding state-space. Despite the fact \nL75: that some methods were developed for long-context reasoning [4, 13, 50], \nL76: SearchBench problems cannot be easily summarized [4], reasoned about [13], or \nL77: processed in parallel due to their size [50, 45]. Our findings show that even \nL78: the strongest LLMs [26] almost completely fail to solve SearchBench problems in \nL79: text-only mode.\nL80: \nL81: To provide further insights, we show that LLMs’ performance on SearchBench \nL82: improves by prompting the models to solve the problems using the A* search \nL83: algorithm [11]. A* is a heuristic-based graph traversal algorithm known for its \nL84: time efficiency and provable optimality guarantees, making it the most suitable \nL85: search algorithm for solving the problems in our benchmark. This method \nL86: leverages A*’s correctness and optimality, while offloading some of the non-\nL87: linear computations involved in searching the state-space to code execution. \nL88: Additionally, to improve the quality of generated A* codes, motivated that \nL89: ensembling helps generation quality[41, 47, 21], we introduce the Multi-Stage-\nL90: Multi-Try (MSMT) inference strategy. In the \"Multi-Try\" aspect of MSMT, before \nL91: evaluating the solution returned by the code, we first verify whether the code \nL92: generated by the model satisfies a set of unit tests: (i) it is executable; (ii)\nL93: it returns a list as output; and (iii) data type of list elements is correct. \nL94: If the code fails any of the tests, MSMT re-runs the LLM until a valid code is \nL95: generated or allowed number of attempts is exhausted. The \"Multi-Stage\" aspect \nL96: of MSMT generates the code in two steps: (i) ‘A* Implementation’ - the \nL97: implementation of an instance-agnostic A* algorithm for the problem type; and \nL98: (ii) Initialization - the instantiation of initial conditions and state \nL99: variables of the problem instance. In MSMT ’Initialization’ is generated \nL100: conditioned on the ‘A* Implementation’ (which is generated separately first and \nL101: provided in ‘Intitialization’ prompt). We demonstrate that our MSMT A* method \nL102: (Fig. 2) significantly enhances the LLMs’ ability to solve search problems, \nL103: outperforming all other prompting strategies we used to evaluate models on \nL104: SearchBench, including 0-shot text, 4-shot Chain-of-Thought (CoT)[42] text, \nL105: 0-shot code generation, and 4-shot A* prompting with the naive greedy decoding \nL106: strategy.\nL107: \nL108: To summarize, our main contributions are as follows: (i) We contribute the \nL109: SearchBench benchmark designed to assess the capability of LLMs in solving \nL110: state-based problems requiring combinatorial search; (ii) We introduce the A* \nL111: Multi-Stage-Multi-Try (MSMT) framework that leverages the flexibility and \nL112: natural language comprehension capabilities of LLMs, reliability and structured \nL113: nature of the A* algorithm, along with unitests to verify LLMs’ code generation \nL114: in order to address search problems. (iii) We show a strong advantage of MSMT A*\nL115: prompting for solving search problems robustly with LLMs, suggesting \nL116: interesting future work directions and yet leaving significant room for future \nL117: improvement on SearchBench.\nL118: \nL119: ### 2 SearchBench Benchmark\nL120: \nL121: SearchBench includes five categories of problems: puzzles, subset sum, sorting, \nL122: pathfinding, and under-determined systems, further subdivided into 11 problem \nL123: types and instantiated as 1107 natural language problem instances. Each problem \nL124: type requires a different algorithmic approach to solve. Our problems are \nL125: inspired by well-known (mostly NP-hard) puzzles and search problems, though \nL126: their rules were modified to ensure a solved instance of a problem similar to \nL127: any SearchBench problem was not observed during LLMs’ massive internet-scale ", + "tool_call_id": "fixture-tool-call-30-15" + }, + { + "role": "tool", + "content": "Error fetching URL `https://ar5iv.labs.ar5iv.org/html/2312.15820`", + "tool_call_id": "fixture-tool-call-30-16" + }, + { + "role": "tool", + "content": "Error fetching URL `https://github.com/jxwuyi/WebNav`", + "tool_call_id": "fixture-tool-call-30-17" + }, + { + "role": "assistant", + "content": "针对搜索智能体(Search Agent)的评估,已涌现出多个侧重点不同的基准测试。以下将分类介绍这些基准,涵盖其核心特点、典型使用场景及例题。\n\n---\n\n## 🔍 1. 网页浏览/导航类基准\n这类基准要求智能体在真实或模拟的网页环境中执行多步操作,以完成复杂任务。\n\n| 基准 | 核心特点 | 使用场景 | 例题 |\n|------|----------|----------|------|\n| **BrowseComp**(OpenAI)【3†L16-L35】 | - 1,266 个挑战性问题,需持续导航、查找难以直接搜索的信息
- 问题设计保证人类在 10 分钟内无法解决,且现有模型(包括带浏览的 ChatGPT)也难以完成
- 答案为短字符串,便于自动验证 | 评估智能体在复杂网络浏览任务中的持久性、创造性搜索能力 | “找出 2018‑2023 年间在 EMNLP 会议上发表、第一作者本科毕业于达特茅斯学院、第四作者本科毕业于宾夕法尼亚大学的科学论文标题”【3†L45-L49】 |\n| **WebVoyager**【24†L18-L33】 | - 多模态(视觉+文本)端到端网页智能体基准
- 包含 300 个真实任务,覆盖 15 个常用网站(如 Amazon、Wikipedia)
- 使用 GPT‑4V 进行自动评估,与人工评估一致性达 85.3% | 评估智能体在真实网站上的端到端任务完成能力(多模态输入) | “在 Amazon 上找到‘无线蓝牙耳机’并按价格从低到高排序,然后选择最便宜的一款加入购物车” |\n| **WebArena**【25†L15-L35】 | - 完全可交互的、高度真实的网页环境(4 个领域:电商、论坛、协作开发、内容管理)
- 812 项长视野任务,评估功能正确性而非表面动作匹配
- 最佳 GPT‑4 智能体成功率仅 14.41%,远低于人类的 78.24% | 评估自主智能体在真实网页环境中的复杂任务执行能力 | “在电商网站中购买一双‘黑色运动鞋’,价格低于 $80,并将收货地址修改为纽约”【25†L86-L92】 |\n| **Mind2Web**【19†L4-L14】 | - 首个面向通用网页智能体的数据集,任务来自真实网站
- 支持在线评估(Online‑Mind2Web),包含 300 个实时任务、136 个流行网站
- 强调多步指令跟随与复杂交互 | 评估智能体在任意网站上的通用任务完成能力 | “在 GitHub 上创建一个名为 ‘test‑repo’ 的公开仓库,并添加一个 README.md 文件” |\n| **VisualWebArena**【20†L4-L10】 | - 多模态网页智能体基准,任务依赖视觉信息(截图)
- 包含丰富的视觉接地任务,如识别图片、按钮位置等
- 评估智能体在真实视觉网页环境中的表现 | 评估多模态智能体在视觉丰富的网页任务中的能力 | “在纽约时报网站找到今天头条新闻的配图,并描述图中主要物体” |\n| **WebVLN**【22†L4-L11】 | - 网站上的视觉‑语言导航任务,智能体需根据问题指令在网站中导航
- 数据集 WebVLN‑v1 包含多种网站导航场景
- 评估导航成功率与任务完成度 | 评估智能体在网站中基于视觉与语言指令的导航能力 | “找到‘联系我们’页面,并提取其中的邮箱地址” |\n| **WebNav**【23†L4-L11】 | - 目标驱动的网页导航基准,环境为部分观测
- 智能体需理解自然语言指令并在网页中规划动作序列
- 早期经典任务,常用于测试基础导航能力 | 评估智能体在部分观测网页环境中的规划与指令跟随能力 | “导航到 Wikipedia 上关于‘人工智能’的页面,并摘录第一段内容” |\n\n---\n\n## 🛠️ 2. 工具使用/通用助手类基准\n这类基准测试智能体调用外部工具(如搜索、计算、API)来解决实际问题的能力。\n\n| 基准 | 核心特点 | 使用场景 | 例题 |\n|------|----------|----------|------|\n| **GAIA**【26†L10-L25】 | - 466 个真实世界问题,需要推理、多模态处理、网页浏览、工具使用等综合能力
- 对人类简单(92% 准确率),但对最强 GPT‑4 仅 15%
- 答案唯一、事实性,支持自动评估 | 评估通用 AI 助手在需要多步工具使用的真实任务中的表现 | “2022 年法国总统选举的获胜者出生在哪个城市?”【26†L10-L25】 |\n| **ToolBench**(OpenBMB)【7†L4-L10】 | - 包含多样软件工具的真实任务基准
- 提供易用的基础设施(虚拟 API 服务器等)
- 评估 LLM 调用工具完成复杂任务的能力 | 评估智能体在工具调用与组合上的熟练度 | “使用计算器 API 计算 (3.14 * 7.2)^2,然后使用搜索 API 查找‘圆周率’的最新精确值” |\n| **HSCodeComp**【2†L16-L35】 | - 首个针对深度搜索智能体在分层规则应用上的专家级电商基准
- 任务:根据嘈杂的商品描述预测 10 位 HS 编码(关税规则)
- 632 个商品条目,标注由多位电商专家完成,最佳智能体准确率仅 46.8%(人类 95%) | 评估智能体在复杂规则(如关税分类)下的深层推理与搜索能力 | “给定商品描述‘男士纯棉针织T恤,领口为V领’,预测其 10 位 HS 编码”【2†L24-L32】 |\n\n---\n\n## ❓ 3. 问答/检索类基准\n这类基准主要评估智能体在开放域或特定领域的信息检索与问答能力。\n\n| 基准 | 核心特点 | 使用场景 | 例题 |\n|------|----------|----------|------|\n| **HotpotQA**【8†L4-L10】 | - 113k 个基于 Wikipedia 的多跳问答对
- 需要跨多个文档推理,并提供句子级支持事实
- 涵盖多样的问题类型(比较、列举、因果等) | 评估模型在需要多跳推理的开放域问答中的能力 | “《了不起的盖茨比》的作者还写过哪些小说?” |\n| **FEVER**【27†L16-L28】 | - 185,445 个基于 Wikipedia 的声明,需分类为 Supported/Refuted/NotEnoughInfo
- 要求提供证据句子(可多句、多页面)
- 挑战性高(最佳系统仅 31.87% 准确率) | 评估系统在事实核查与证据检索上的能力 | “斐济最大的岛屿是考艾岛。”【27†L48-L50】 |\n| **TriviaQA**【28†L5-L10】 | - 超过 650k 个问答‑证据三元组,问题由琐事爱好者编写
- 每个问题平均提供 6 篇证据文档,适合远程监督阅读
- 包含阅读理解和开放域 QA 两种任务设置 | 评估模型在开放域琐事问答中的检索与阅读理解能力 | “哪位演员在《星球大战:原力觉醒》中扮演凯洛·伦?” |\n| **Natural Questions**【11†L4-L8】 | - 来自 Google 搜索的真实用户问题,答案来自 Wikipedia
- 包含长答案(段落)和短答案(实体/日期等)
- 训练集 307k,开发/测试集各 8k | 评估开放域问答系统对真实用户查询的响应能力 | “谁写了《傲慢与偏见》?”(答案:“简·奥斯汀”) |\n| **MS MARCO**【12†L4-L10】 | - 基于 Bing 真实搜索查询的检索/问答基准
- 包含段落排序、文档排序、问答等任务
- 数据规模大(约 500k 查询,880 万段落) | 训练与评估检索模型在真实搜索场景下的表现 | “如何冲泡咖啡?”(系统需检索相关段落回答) |\n| **BEIR**【13†L4-L10】 | - 异构零样本检索基准,包含 18 个数据集、9 种任务(事实核查、QA、对话检索等)
- 评估模型在未见任务上的泛化能力
- 已成为检索模型的标准评估套件 | 评估检索模型在零样本设置下的跨任务泛化能力 | 给定查询“全球变暖的主要原因”,在文档集合中检索相关文档 |\n| **MIRACL**【14†L4-L10】 | - 多语言检索基准,覆盖 18 种语言(包括中文、阿拉伯语等)
- 专注于 ad‑hoc 检索,所有查询与文档均经过人工标注
- 提供多语言检索挑战(WSDM 2023 Cup) | 评估多语言检索模型在不同语言下的检索效果 | 中文查询:“全球变暖的原因”,检索相关中文文档 |\n\n---\n\n## 🧩 4. 组合搜索/推理类基准\n这类基准专门测试智能体在组合搜索问题上的逻辑推理与规划能力。\n\n| 基准 | 核心特点 | 使用场景 | 例题 |\n|------|----------|----------|------|\n| **SearchBench**【30†L22-L36】 | - 11 种独特的搜索问题类型(路径查找、谜题、子集和、排序、欠定系统等)
- 自动生成任意数量实例,并评估解决方案的可行性、正确性、最优性
- GPT‑4 在纯文本模式下仅解决 1.4%,通过 A* 代码生成提升至 11.7% | 评估 LLM 在组合搜索问题上的推理与规划能力 | “在一个 5x5 迷宫中,从起点 (0,0) 到终点 (4,4) 的最短路径是什么?(障碍物位置已知)”【30†L56-L64】 |\n\n---\n\n## 🛒 5. 特定领域/仿真环境基准\n这类基准针对特定领域(如电商、家庭环境)设计,测试智能体在仿真环境中的交互能力。\n\n| 基准 | 核心特点 | 使用场景 | 例题 |\n|------|----------|----------|------|\n| **WebShop**【29†L15-L30】 | - 模拟电商网站环境,包含 118 万真实商品、12,087 条众包指令
- 智能体需根据文本指令搜索、筛选、购买商品
- 奖励自动计算,最佳模型成功率 29%(人类 59%) | 评估语言智能体在电商任务中的导航、搜索与决策能力 | “找到一款价格低于 $50 的无线蓝牙耳机,评分至少 4 星,并加入购物车”【29†L20-L25】 |\n| **ALFWorld**【18†L4-L10】 | - 将文本环境(TextWorld)与具身环境(ALFRED)对齐的基准
- 智能体先在文本中学习抽象策略,然后在视觉环境中执行
- 包含 134 个家庭任务(如“在厨房找到杯子并拿到客厅”) | 评估智能体在文本与具身环境之间的迁移学习与规划能力 | “在厨房找到‘杯子’,把它拿到客厅放在桌子上” |\n\n---\n\n## 📊 总结\n以上基准覆盖了搜索智能体的不同能力维度:从**网页浏览、工具使用、问答检索**,到**组合搜索推理**以及**特定领域仿真**。选择合适的基准取决于您要评估的具体能力(如多模态交互、工具调用、多跳推理等)。这些基准大多提供公开的数据集与评估代码,可直接用于实验。\n\n> 注:本文基于 2025 年 11 月前的公开资料整理,部分基准可能已有更新版本,建议访问对应官网获取最新信息。", + "reasoning_content": "在同时打开多个基准测试相关页面后,我们获得了大量信息。WebVoyager 是一个多模态网页智能体基准测试,专注于在真实网站上进行端到端任务执行,其特点包括多模态输入(屏幕截图和文本)、在15个常用网站上的300个任务,以及使用GPT-4V进行自动评估。使用场景是评估网页智能体在真实世界中的导航和任务完成能力。例题可能是在亚马逊上找到特定产品并将其加入购物车。\n\nWebArena 是一个真实且可复现的网页环境,用于评估自主智能体。其特点包括四个功能完整的网站(电子商务、论坛、协作开发、内容管理)、812项长视野任务,以及专注于功能正确性的评估。使用场景是评估智能体在复杂网页任务中的表现。例题可能是在电子商务网站上购买符合特定标准的商品。\n\nGAIA 是一个针对通用AI助手的基准测试,要求具备推理、多模态处理、网页浏览和工具使用等能力。其特点包括466个问题,对人类来说概念简单但对AI具有挑战性,答案简洁明确,支持自动评估。使用场景是评估AI助手在真实世界问题中的表现。例题可能是“找出2022年法国总统选举中获胜者的出生城市”。\n\nToolBench 是一个用于评估LLM工具操作能力的基准测试。其特点包括多样化的软件工具、真实世界任务,以及易于使用的基础设施。使用场景是评估智能体使用工具完成任务的能力。例题可能是使用计算器API计算表达式,或使用搜索API查找信息。\n\nHotpotQA 是一个多跳问答基准测试,其特点包括基于Wikipedia的113k个问题-答案对,需要多文档推理,并提供句子级别的支持事实。使用场景是评估模型在回答需要多跳推理的问题时的能力。例题可能是“《了不起的盖茨比》的作者还写过哪些小说?”。\n\nFEVER 是一个事实提取和验证基准测试,其特点包括185,445个基于Wikipedia的声明,需要分类为支持、反驳或信息不足,并需要提供证据句子。使用场景是评估系统验证声称事实的能力。例题可能是“斐济最大的岛屿是考艾岛”,系统需要检索证据来支持或反驳这一说法。\n\nTriviaQA 是一个阅读理解基准测试,其特点包括超过65万个问题-答案-证据三元组,问题由琐事爱好者编写,并附带证据文档。使用场景是评估模型在开放域问答中的能力。例题可能是“哪位演员在《星球大战:原力觉醒》中扮演凯洛·伦?”。\n\nNatural Questions 是一个问答基准测试,其特点包括来自谷歌搜索的真实问题,答案来自Wikipedia,并包含长答案和短答案。使用场景是评估开放域问答系统。例题可能是“谁写了《傲慢与偏见》?”,答案可能是“简·奥斯汀”。\n\nMS MARCO 是一个用于深度学习搜索的基准测试,其特点包括来自Bing的真实搜索查询、人工生成的答案,以及用于段落排序和文档排序的任务。使用场景是训练和评估检索模型。例题可能是“如何冲泡咖啡?”,系统需要检索相关段落来回答。\n\nBEIR 是一个用于零样本评估的异构信息检索基准测试,其特点包括18个数据集、9项任务,涵盖事实核查、问答、对话检索等。使用场景是评估检索模型在零样本设置下的泛化能力。例题可能是给定一个查询,在文档集合中检索相关文档。\n\nMIRACL 是一个多语言信息检索基准测试,其特点涵盖18种语言,专注于ad-hoc检索,并包含人工标注的相关性判断。使用场景是评估多语言检索模型。例题可能是用中文查询“全球变暖的原因”,并检索相关文档。\n\nWebShop 是一个模拟电子商务网站环境,用于评估语言智能体。其特点包括118万个真实产品、12,087条众包文本指令,以及需要导航、搜索和购买的任务。使用场景是评估智能体在电子商务任务中的表现。例题可能是“找到一款价格低于50美元的无线蓝牙耳机”。\n\nALFWorld 是一个将文本和具身环境对齐的基准测试,用于在文本世界中训练智能体,然后将其转移到具身环境中。其特点包括基于文本的环境,与ALFRED数据集并行,并专注于指令跟随。使用场景是评估智能体在家庭任务中的规划能力。例题可能是“在厨房里找到一个杯子,把它拿到客厅”。\n\nMind2Web 是一个用于评估网页通用智能体的基准测试,其特点包括来自真实网站的任务、需要多步交互,并专注于语言指令跟随。使用场景是评估智能体在复杂网页任务中的表现。例题可能是“在GitHub上创建一个新的仓库,命名为‘test-repo’”。\n\nVisualWebArena 是一个用于多模态智能体的基准测试,其特点包括真实视觉网页任务、需要理解视觉和文本元素,并评估功能正确性。使用场景是评估多模态智能体在视觉接地网页任务中的表现。例题可能是“在纽约时报网站上找到今天头条新闻的图片”。\n\nSearchBench 是一个用于评估LLM解决搜索问题能力的基准测试,其特点包括11种搜索问题类型、自动生成实例,并评估可行性、正确性和最优性。使用场景是评估LLM在组合搜索问题中的推理能力。例题可能是“在一个迷宫中找到从起点到终点的最短路径”。\n\nWebVLN 是一个网站上的视觉语言导航基准测试,其特点包括基于问题的指令、需要导航网站,并评估导航成功率。使用场景是评估智能体在网站中基于视觉和语言指令进行导航的能力。例题可能是“找到‘联系我们’页面并提取电子邮件地址”。\n\nWebNav 是一个用于目标驱动网页导航的基准测试,其特点包括部分观察环境、需要规划和理解自然语言指令。使用场景是评估智能体在网页导航中的能力。例题可能是“导航到维基百科上关于‘人工智能’的页面”。\n\n现在,我需要整理一份针对搜索智能体的基准测试列表。考虑到用户可能更关注中文资料,我会优先考虑那些有中文描述的基准测试,例如BrowseComp、HSCodeComp、GAIA、WebVoyager、WebArena、Mind2Web、WebShop、ALFWorld、HotpotQA、FEVER、TriviaQA、Natural Questions、MS MARCO、BEIR、MIRACL、SearchBench、WebVLN、WebNav。我会为每个基准测试提供特点、使用场景和例题。同时,我会引用相关来源,例如HSCodeComp论文、BrowseComp博客文章、WebVoyager论文、WebArena论文、GAIA论文、WebShop论文、SearchBench论文等。我的回答将用中文呈现,并确保内容全面、结构清晰。现在开始撰写回答。", + "tool_calls": [] + } + ] +} diff --git a/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input_search_wo_date.json b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input_search_wo_date.json new file mode 100644 index 00000000000..f05891138c0 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_input_search_wo_date.json @@ -0,0 +1,533 @@ +{ + "messages": [ + { + "role": "developer", + "content": "In a year before 2023, between December 1 and 15 (inclusive), a press release reported that as of the 5th of that month, a country had received a total of 16,201,670 vaccine doses, with 7,583,134 administered. That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services. The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%. Who was the emeritus board member that year? Name only, no titles.", + "tools": [ + { + "type": "function", + "function": { + "name": "search", + "description": "Searches for information related to query and displays topn results.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query string" + }, + "topn": { + "type": "integer", + "description": "Number of top results to display", + "default": 10 + }, + "source": { + "type": "string", + "description": "Source to search within", + "enum": [ + "web", + "news" + ], + "default": "web" + } + }, + "required": [ + "query" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + }, + { + "type": "function", + "function": { + "name": "open", + "description": "Opens the link id from the page indicated by cursor starting at line number loc, showing num_lines lines. Valid link ids are displayed with the formatting: 【{id}†.*】. If cursor is not provided, the most recent page is implied. If id is a string, it is treated as a fully qualified URL associated with source. If loc is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available. Use this function without id to scroll to a new location of an opened page.", + "parameters": { + "type": "object", + "properties": { + "id": { + "description": "Link ID (number) or fully qualified URL (string)", + "anyOf": [ + { + "type": "integer" + }, + { + "type": "string" + } + ], + "default": -1 + }, + "cursor": { + "type": "integer", + "description": "Page cursor indicator", + "default": -1 + }, + "loc": { + "type": "integer", + "description": "Starting line number", + "default": -1 + }, + "num_lines": { + "type": "integer", + "description": "Number of lines to show", + "default": -1 + }, + "view_source": { + "type": "boolean", + "description": "Whether to view source", + "default": false + }, + "source": { + "type": "string", + "description": "Source associated with the URL", + "enum": [ + "web", + "news" + ], + "default": "web" + } + }, + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + }, + { + "type": "function", + "function": { + "name": "find", + "description": "Finds exact matches of pattern in the current page, or the page given by cursor.", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "The exact pattern to search for" + }, + "cursor": { + "type": "integer", + "description": "Page cursor to search in", + "default": -1 + } + }, + "required": [ + "pattern" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + }, + { + "type": "function", + "function": { + "name": "python", + "description": "Use this tool to execute Python code. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files). When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The python code to execute in the jupyter environment." + } + }, + "required": [ + "code" + ], + "additionalProperties": false, + "$schema": "http://json-schema.org/draft-07/schema#" + } + } + } + ] + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "The user asks: \"In a year before 2023, between December 1 and 15 (inclusive), a press release reported that as of the 5th of that month, a country had received a total of 16,201,670 vaccine doses, with 7,583,134 administered. That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services. The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%. Who was the emeritus board member that year? Name only, no titles.\"\n\nWe need to parse the puzzle.\n\nWe have clues about a press release about vaccine doses for a country, dated between Dec 1-15 inclusive, reporting as of 5th of that month total doses received 16,201,670 and administered 7,583,134. That is likely a COVID-19 vaccine update for some country. Which country? Might be a smaller country. Could be something like \"Uganda\", \"Zimbabwe\", \"Kenya\", \"Philippines\", etc. Let's search.\n\nSearch for \"16,201,670 vaccine doses\" \"7,583,134 administered\".", + "tool_calls": [ + { + "id": "fixture-tool-call-1-0", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"\\\"16,201,670\\\" vaccine doses\", \"topn\": 10, \"source\": \"web\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[0] \"16,201,670\" vaccine doses (Search_Results/\"16,201,670\" vaccine doses)\n**viewing lines [0 - 38] of 38**\n\nL0: \nL1: URL: Search_Results/\"16,201,670\" vaccine doses\nL2: # Search Results\nL3: \nL4: * 【0†Kenya's Economy is Showing Resilience as Output Rises ...; publish_date: \nL5: none†www.worldbank.org】 Dec 14, 2021 — As of December 5, 2021, Kenya had \nL6: received a total of 16,201,670 vaccines, with 7,583,134 administered. While \nL7: vaccine acceptance is ...\nL8: * 【1†MINISTRY OF HEALTH; publish_date: none†covidhub.mediacouncil.or.ke】 Dec \nL9: 1, 2021 — Total Doses Received 16,201,670. Table 10 gives the total vaccines \nL10: received since the start of Covid -19 vaccination exercise in the country.\nL11: * 【2†Output Result Page; publish_date: none†open.unicef.org】 ... 16,201,670 \nL12: doses of multiple vaccines nationwide and full vaccination of 15.5 per cent with\nL13: two doses of COVID-19 vaccine as of 31 December 2021.\nL14: * 【3†rebased GDP; publish_date: none†documents1.worldbank.org】 Dec 7, 2021 — \nL15: As of December 5, 2021,. Kenya had received a total of 16,201,670 vaccines, \nL16: with. 7,583,134 administered. Vaccine acceptance is reportedly high.\nL17: * 【4†Integrated Annual Report; publish_date: none†www.co-opbank.co.ke】 May 27,\nL18: 2022 — ... doses of Covid-19 vaccines and administered close to 17 million ... \nL19: 16,201,670, huku 7,583,134 zikiwa tayari zimedungwa watu. Bado kuna ...\nL20: * 【5†World Bank lifts Kenya's growth prospect to 5% in 2021; publish_date: \nL21: none†www.africa-press.net】 ... 16,201,670 vaccines, with 7,583,134 administered.\nL22: While vaccine acceptance is reportedly high there is still a long way to go \nL23: towards the government's ...\nL24: * 【6†2020 Annual Report United States; publish_date: none†www.hi-us.org】 \nL25: including medicine, vaccines, Covid-19 prevention kits, and kits of ... \nL26: 16,201,670. 9,630,584. Prepaid expenses. 56,707. 22,601. Total current assets.\nL27: * 【7†HOSPITAL CENTRAL DE LA DEFENSA; publish_date: \nL28: none†contrataciondelestado.es】 16201670 X. URPA-URPA. 0. HCD-000393/00. KONTRON.\nL29: 7000 ... LASER SHOT LBP-1120. CNEUR-C. NEUROFISIOLOGÍA. 0 ... VICTOREEN CT HEAD\nL30: DOSE PHANTON. 43650-014.\nL31: * 【8†body metal concentrations: Topics by ...; publish_date: \nL32: none†www.science.gov】 PMID:16201670. Design Characteristics and Tobacco Metal \nL33: Concentrations in Filtered Cigars. PubMed. Caruso, Rosalie V; O'Connor, Richard \nL34: J; Travers, Mark J ...\nL35: * 【9†Public Offices, Personal Demands: Capability in ...; publish_date: \nL36: none†dokumen.pub】 ... (16201670).9 The Institutiones Politicae consists of two \nL37: books. Book one deals with the origin, nature and goal of the commonwealth on \nL38: the one hand, and the ...", + "tool_call_id": "fixture-tool-call-1-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Result shows Kenya. So the press release is about Kenya's vaccine doses as of Dec 5, 2021. The press release likely by World Bank or something. The search result shows \"World Bank lifts Kenya's growth prospect...\" maybe a press release dated Dec 14, 2021. But the clue says \"between December 1 and 15 (inclusive), a press release reported that as of the 5th of that month, a country had received a total of 16,201,670 vaccine doses, with 7,583,134 administered.\" That matches Kenya on Dec 5, 2021. So the year is 2021.\n\nNow the next clue: \"That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services.\"\n\nThis sounds like an environmental or conservation organization. \"Nearly 2.2 million members and supporters contributed over $49 million.\" That might be something like \"The Nature Conservancy\", \"World Wildlife Fund\", \"Sierra Club\", \"Environmental Defense Fund\", \"Natural Resources Defense Council\", etc. Let's search.\n\nSearch for \"nearly 2.2 million members and supporters contributed over $49 million\" \"12% of total revenue coming from contributed services\".", + "tool_calls": [ + { + "id": "fixture-tool-call-3-0", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"\\\"nearly 2.2 million members\\\" \\\"contributed over $49 million\\\"\", \"topn\": 10, \"source\": \"web\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[1] \"nearly 2.2 million members\" \"contributed over $49 million\" (Search_Results/\"nearly 2.2 million members\" \"contributed over $49 million\")\n**viewing lines [0 - 41] of 41**\n\nL0: \nL1: URL: Search_Results/\"nearly 2.2 million members\" \"contributed over $49 million\"\nL2: # Search Results\nL3: \nL4: * 【0†The Global Good Fund, Microsoft and Goodwill join forces ...; \nL5: publish_date: none†globalgoodfund.org】 Sep 28, 2022 — ... more than 2.2 million \nL6: people. They additionally contribute close to $49 billion in local, state and \nL7: federal tax revenues. When COVID hit ...\nL8: * 【1†Almost 22 billion American tax dollars spent to wipe out a ...; \nL9: publish_date: none†www.facebook.com】 US military funding for Israel's war crimes\nL10: in Lebanon and Gaza has now cost US taxpayers over $22 billion. When millions \nL11: struggle to afford the ...\nL12: * 【2†Corporate America has largely abandoned its post-January ...; \nL13: publish_date: none†www.citizensforethics.org】 Jul 29, 2025 — Since the January 6\nL14: insurrection, over 2,000 corporate and industry group PACs have given over $174\nL15: million to members of the Sedition ...\nL16: * 【3†Audit shows millions in questionable taxpayer spending at ...; \nL17: publish_date: none†www.aol.com】 18 hours ago — ... nearly doubled from 1.3 \nL18: million to about 2.2 million. That is more than one in four Washington state \nL19: residents receiving Medicaid, and the ...\nL20: * 【4†Incarceration and Poverty in the United States - AAF; publish_date: \nL21: none†www.americanactionforum.org】 Jun 30, 2020 — The United States currently \nL22: incarcerates 2.2 million people, nearly half of whom are non-violent drug \nL23: offenders, accused people held pre ...\nL24: * 【5†How the NRA Spends Revenue (2023); publish_date: none†paddockpost.com】 \nL25: Apr 25, 2025 — In 2023, the NRA reported total revenue of $176 million (compared\nL26: to $211 million in 2022, $227 million in 2021, $282 million in 2020, $292 ...\nL27: * 【6†NCA Newsroom; publish_date: none†www.ncausa.org】 ... 2.2 million U.S. \nL28: jobs—operating in every U.S. state and territory and contributing nearly $350 \nL29: billion to the U.S. economy every year. For media inquiries ...\nL30: * 【7†The Early Childhood Educator Workforce; publish_date: \nL31: none†cscce.berkeley.edu】 About 2.2 million adults are paid to care for and \nL32: educate more than 9.7 million children between birth and age five in center-\nL33: based and home-based settings.\nL34: * 【8†Bringing Out the Worst in Congress: CRA By the Numbers ...; publish_date:\nL35: none†progressivereform.org】 Jul 14, 2025 — In fact, the top ten recipients (out\nL36: of 25) account for almost $7 million, showing how these contributions are \nL37: targeted and strategic. Real ...\nL38: * 【9†Club 100: Over 100 Americans gave more than $100000 to ...; publish_date:\nL39: none†publicintegrity.org】 Jan 7, 2022 — Through Internet investments during the\nL40: 1990s, he attracted 2.2 million online accounts valued at $174 million. Schwab \nL41: has been a vocal ...", + "tool_call_id": "fixture-tool-call-3-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Not that. Try search for \"2.2 million members and supporters contributed over $49 million\".", + "tool_calls": [ + { + "id": "fixture-tool-call-5-0", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"\\\"2.2 million members\\\" \\\"contributed over $49\\\"\", \"topn\": 10, \"source\": \"web\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[2] \"2.2 million members\" \"contributed over $49\" (Search_Results/\"2.2 million members\" \"contributed over $49\")\n**viewing lines [0 - 2] of 2**\n\nL0: \nL1: URL: Search_Results/\"2.2 million members\" \"contributed over $49\"\nL2: # Search Results", + "tool_call_id": "fixture-tool-call-5-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Maybe it's an annual report of a non-profit. Search for \"annual report 2.2 million members contributed $49 million\".", + "tool_calls": [ + { + "id": "fixture-tool-call-7-0", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"\\\"2.2 million members\\\" \\\"annual report\\\" $49 million\", \"topn\": 10, \"source\": \"web\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[3] \"2.2 million members\" \"annual report\" $49 million (Search_Results/\"2.2 million members\" \"annual report\" $49 million)\n**viewing lines [0 - 36] of 36**\n\nL0: \nL1: URL: Search_Results/\"2.2 million members\" \"annual report\" $49 million\nL2: # Search Results\nL3: \nL4: * 【0†20-F; publish_date: none†www.sec.gov】 ANNUAL REPORT PURSUANT TO SECTION \nL5: ... Our membership grew from 2.1 million members as at December 31, 2023 to 2.2 \nL6: million members as at December 31, 2024.\nL7: * 【1†Oportun Reports Fourth Quarter and Full Year 2023 ...; publish_date: \nL8: none†investor.oportun.com】 Mar 12, 2024 — Oportun (Nasdaq: OPRT) is a mission-\nL9: driven fintech that puts its 2.2 million members' financial goals within reach. \nL10: ... annual report on ...\nL11: * 【2†2 0 21; publish_date: none†www.annualreports.com】 ANNUAL REPORT. 2. 0. \nL12: 21. 2. 0. 21. Page 2. 2. DEFENDERS OF WILDLIFE. 2. 0. 21. 2. 0. 21 ... In 2021, \nL13: Defenders of Wildlife's nearly 2.2 million members and.\nL14: * 【3†Annual report and accounts 2020; publish_date: none†www.3i.com】 \nL15: Disclaimer. The Annual report and accounts have been prepared solely to provide \nL16: information to shareholders. ... 2.2 million members. In December 2019, we sold \nL17: ...\nL18: * 【4†united states securities and exchange commission; publish_date: \nL19: none†s21.q4cdn.com】 Dec 14, 2018 — ☒ ANNUAL REPORT ... approximately 5.7 million\nL20: Members, including approximately 2.2 million Members from the acquisition of \nL21: Best Doctors.\nL22: * 【5†securities and exchange commission; publish_date: none†www.sec.gov】 ... \nL23: Annual Report on Form 10-K filed with the SEC on March 10, 2020. The ... 2.2 \nL24: million members at June 30, 2020. The following table highlights the ...\nL25: * 【6†2022-23 Annual Comprehensive Financial Report - CalPERS; publish_date: \nL26: none†www.calpers.ca.gov】 Nov 17, 2023 — ... annual report is consistent with \nL27: these accounting principles ... 2.2 million members and health benefit services \nL28: to over 1.5 million ...\nL29: * 【7†Medical Centers Report; publish_date: \nL30: none†regents.universityofcalifornia.edu】 Oct 8, 2019 — which will allow its 2.2 \nL31: million members to receive advanced and ... annual report of the University of \nL32: California. For the Medical ...\nL33: * 【8†Working TogeTher to make HealtH care better; publish_date: \nL34: none†www.unitedhealthgroup.com】 Feb 8, 2012 — ANNUAL REPORT PURSUANT TO SECTION \nL35: 13 OR 15(d) OF THE ... approximately 2.2 million members enrolled in its \nL36: Medicare. Advantage ...", + "tool_call_id": "fixture-tool-call-7-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Result 2 shows \"Defenders of Wildlife\". That might be the organization. Let's open.", + "tool_calls": [ + { + "id": "fixture-tool-call-9-0", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 3, \"id\": 2}" + } + } + ] + }, + { + "role": "tool", + "content": "[4] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf)\n**viewing lines [0 - 53] of 261**\n\nL0: \nL1: URL: https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-\nL2: wildlife_2021.pdf\nL3: ===== Page 1 ===== 2021 ANNUAL REPORT ===== Page 2 ===== # DEFENDERS OF WILDLIFE\nL4: made important progress for imperiled species and vital landscapes across the \nL5: United States in 2021. \\--- **LAWYERS** immediately **challenged** the premature\nL6: and reckless decision to strip **gray wolves** of federal **Endangered Species \nL7: Act (ESA)** protections. For many decades, Defenders has led the effort to \nL8: protect and restore the gray wolf, and we will continue to fight the \nL9: unscientific and hostile anti-wolf policies that impede conservation progress \nL10: and will carry on our unrelenting battle to restore federal protections for this\nL11: iconic keystone species. \\--- **LOBBYISTS** worked around the clock to keep \nL12: wildlife and climate priorities in the **Infrastructure Investment and Jobs \nL13: Act**. We also continue fighting to keep important wildlife and habitat funding \nL14: in relevant **appropriations bills**. \\--- 2 DEFENDERS OF WILDLIFE ===== Page 3 \nL15: ===== POLICY EXPERTS pushed forward on the urgent need for a National \nL16: Biodiversity Strategy (NBS), an all-of-government approach to address the \nL17: unprecedented loss of wildlife and habitat we are experiencing. We have coupled \nL18: this with our new campaign to expand the National Wildlife Refuge System to \nL19: preserve our nation’s only lands set aside for wildlife. By defending, funding \nL20: and expanding our national wildlife refuges, we will directly address \nL21: biodiversity loss and climate change while promoting increased equitable access \nL22: to nature. FIELD TEAMS were on the ground helping to recover imperiled species. \nL23: From panthers and sea turtles in Florida to wolves, bison and black-footed \nL24: ferrets in Montana, Defenders’ conservation experts were in the field saving \nL25: wildlife all over the country. CONSERVATION INNOVATION EXPERTS provided \nL26: comprehensive analyses to guide policy and inform conservation strategies to \nL27: reach the goal of protecting 30% of our terrestrial and marine systems by 2030 \nL28: (“30x30”). Defenders’ Center for Conservation Innovation (CCI) produced a report\nL29: which details actions we need to take to achieve 30x30 while protecting \nL30: biodiversity and addressing the climate crisis. DEFENDERS.ORG ===== Page 4 =====\nL31: WE HAVE ACCOMPLISHED MUCH THIS YEAR WORKING WITH AN ADMINISTRATION THAT VALUES \nL32: SCIENCE AND CONSERVATION. Our many successes include the return of protections \nL33: to the Tongass National Forest in Alaska and the suspension of oil and gas \nL34: leasing permits for the Arctic National Wildlife Refuge. Defenders also played a\nL35: lead role in successfully defending the Migratory Bird Protection Act from \nL36: attack and securing critical protections for migratory birds. Throughout 2021 \nL37: Defenders made important progress for the wildlife and wild places we all love, \nL38: yet our nation’s wildlife still face unprecedented challenges. More and more \nL39: species, including manatees, Mexican gray wolves, polar bears and North Atlantic\nL40: right whales, face the very real threat of extinction because of climate \nL41: change, habitat loss, pollution and inadequate protections. In our work we \nL42: continue to demonstrate success by following the science, building partnerships,\nL43: leading with innovative solutions and focusing on collaborative and practical \nL44: approaches in defense of nature. Together we can make a lasting difference for \nL45: wildlife and habitats. Your support is essential to our work and we couldn’t do \nL46: it without you. Thank you for being an important part of our efforts to protect \nL47: imperiled wildlife across our nation. Jamie Rappaport Clark President and CEO \nL48: John Dayton Board Chair OUR NEW TELEVISION SERIES, WILDLIFE NATION WITH JEFF \nL49: CORWIN, premiered in October and airs every weekend on ABC stations nationwide. \nL50: This groundbreaking series takes Jeff across the country to meet Defenders’ \nL51: experts along with our science, tribal and community partners to educate and \nL52: inspire viewers to join with Defenders to protect wildlife. While highlighting \nL53: some of our important conservation work and raising awareness of important ", + "tool_call_id": "fixture-tool-call-9-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Search within PDF for \"2.2 million members\".", + "tool_calls": [ + { + "id": "fixture-tool-call-11-0", + "type": "function", + "function": { + "name": "find", + "arguments": "{\"cursor\": 4, \"pattern\": \"2.2\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[5] Find results for text: `2.2` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf/find?pattern=2.2)\n**viewing lines [0 - 4] of 4**\n\nL0: # 【0†match at L187】\nL1: 2021, Defenders of Wildlife’s nearly 2.2 million members and supporters \nL2: contributed more than $49 million for wildlife and wildlife habitat. This figure\nL3: includes all those who generously donated their time and expertise. The \nL4: steadfast support of our donors allows Defenders to sustain our program and ", + "tool_call_id": "fixture-tool-call-11-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Open around L180.", + "tool_calls": [ + { + "id": "fixture-tool-call-13-0", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 4, \"loc\": 170}" + } + } + ] + }, + { + "role": "tool", + "content": "[6] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf)\n**viewing lines [170 - 215] of 261**\n\nL170: into the wild. **10 DEFENDERS OF WILDLIFE** ===== Page 11 ===== AFTER ANOTHER \nL171: SUCCESSFUL DEFENDERS LAWSUIT ON BEHALF OF THE CRITICALLY ENDANGERED RED WOLF, \nL172: FWS reversed its decision to limit the recovery area and committed to a robust \nL173: release strategy. After years of inaction, FWS released eight wolves to the wild\nL174: in North Carolina and plan to release nine more wolves in the spring of 2022. \nL175: This is an incredible win for this critically endangered species whose \nL176: population has dwindled down to single digits in the wild because of \nL177: mismanagement, vehicle strikes and poaching. DEFENDERS CONTINUED TO LEAD EFFORTS\nL178: TO PROTECT THE FLORIDA MANATEE, a beloved species that suffered the deadliest \nL179: year on record in 2021, tragically surpassing 1,000 deaths because of water \nL180: pollution and lack of warm water habitat. Defenders led advocacy and education \nL181: aimed at restoring the natural flow of the dammed Ocklawaha River, which would \nL182: provide critical warm-water habitat that manatees need to survive. Defenders’ \nL183: legal team continued to fight for manatees in the courts, holding government \nL184: agencies accountable for protecting critical habitat and addressing the \nL185: devastating water pollution that is killing the seagrass and causing manatees to\nL186: starve. DAVID TES | SAM FRENZY DRAW DEFENDERS.ORG 11 ===== Page 12 ===== In \nL187: 2021, Defenders of Wildlife’s nearly 2.2 million members and supporters \nL188: contributed more than $49 million for wildlife and wildlife habitat. This figure\nL189: includes all those who generously donated their time and expertise. The \nL190: steadfast support of our donors allows Defenders to sustain our program and \nL191: public education efforts in the field, the courts and on Capitol Hill. 2021 \nL192: SOURCES OF FUNDS Grants and contributions $29,057 Bequests, trusts and split \nL193: interests $7,692 Income from investments, annuity reserve funds and trusts \nL194: $3,354 Royalties and other income $3,576 Contributed services $6,140 Total \nL195: Revenue $49,819 USES OF FUNDS Biodiversity conservation $22,420 Constituency \nL196: mobilization $16,324 Fundraising $1,211 Management and general $5,865 Total \nL197: Expenses $45,820 Change in net assets $3,999 Net assets, start of year $41,145 \nL198: Net Assets, End of the Year $45,144 Dollars are in thousands. 12 DEFENDERS OF \nL199: WILDLIFE Grants and contributions 58% Income from investments 7% Requests, \nL200: trusts and split interests 15% Royalties and other income 7% Contributed \nL201: services 12% Program and support services 97% Fundraising 3% ===== Page 13 =====\nL202: DEFENDERS.ORG 13 ===== Page 14 ===== John Dayton, Esq.* Chair Texas Sheryl \nL203: Sachs* Vice Chair Connecticut Whitfield Bailey* Secretary Tennessee Mark Caylor*\nL204: Treasurer District of Columbia Sylvia Baca* California Donald Barry, Esq. \nL205: Oregon Dinah Bear, Esq. Arizona Kim O’Keefe Beck, Ph.D California Gabriela \nL206: Chavarria, Ph.D Washington Jamie Rappaport Clark* President & CEO Virginia \nL207: Caroline D. Gabel Maryland Keith M. Harper, Esq. Maryland John Donovan, Esq. \nL208: Montana and California James M. Hecker, Esq. Virginia Judith Ponsikoff, Ph.D* \nL209: California Terry L. Root, Ph.D Florida BOARD EMERITUS Terry C. Pelster, Esq. New\nL210: York *Executive Committee member* ST A F F & A D V I S O R S SCIENCE ADVISORS \nL211: Gabriela Chavarria, Ph.D Burke Museum, University of Washington Leah Gerber, \nL212: Ph.D Arizona State University Kathayoon Khalil, Ph.D Oregon Zoo Rurik List, Ph.D\nL213: Universidad Autónoma Metropolitana-Lerma Barry Noon, Ph.D Colorado State \nL214: University Terry L. Root, Ph.D Stanford University Mark Schwartz, Ph.D \nL215: University of California-Davis David S. Wilcox, Ph.D Princeton University Rae ", + "tool_call_id": "fixture-tool-call-13-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Scrolling further for \"Contributed services 12%\". Already seen. So the annual report is Defenders of Wildlife 2021 annual report.\n\nNext clue: \"The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%.\"\n\nThus, the following year (2022) annual report of Defenders of Wildlife includes mention of rescinding plan to open most of 22 million acres of NPR-A to oil and gas drilling. Also contributed services decreased to 10% (from 12% in 2021). So we need to find the 2022 annual report for Defenders of Wildlife.\n\nSearch for \"Defenders of Wildlife 2022 annual report National Petroleum Reserve Alaska\".", + "tool_calls": [ + { + "id": "fixture-tool-call-15-0", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"Defenders of Wildlife 2022 annual report National Petroleum Reserve Alaska\", \"topn\": 10, \"source\": \"web\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[7] Defenders of Wildlife 2022 annual report National Petroleum Reserve Alaska (Search_Results/Defenders of Wildlife 2022 annual report National Petroleum Reserve Alaska)\n**viewing lines [0 - 37] of 37**\n\nL0: \nL1: URL: Search_Results/Defenders of Wildlife 2022 annual report National Petroleum \nL2: Reserve Alaska\nL3: # Search Results\nL4: \nL5: * 【0†CELEBRATING YEARS; publish_date: none†www.annualreports.com】 With less \nL6: than 340 right whales left in the wild, Defenders is fighting tirelessly to end \nL7: deadly lobster gear entanglements and vessel strikes that are driving ...\nL8: * 【1†Financials; publish_date: none†defenders.org】 We invite you to explore \nL9: the reports below to learn more about our activities and accomplishments, and \nL10: how we put your money to work for wildlife.\nL11: * 【2†Alaska Program Looks Back on 2022; publish_date: none†defenders.org】 Feb \nL12: 9, 2023 — Thanks to a lawsuit joined by Defenders, seven million acres were \nL13: returned to protection within the National Petroleum Reserve-Alaska (NPR-A), ...\nL14: * 【3†Defenders-of-Wildlife-2022-Financial-Statement. ...; publish_date: \nL15: none†defenders.org】 We have audited the accompanying consolidated financial \nL16: statements of Defenders of Wildlife and Affiliated Defenders of Wildlife Action \nL17: Fund (collectively, ...\nL18: * 【4†2022 Annual Report; publish_date: none†alaskaconservation.org】 Jun 13, \nL19: 2023 — In 2022, we focused on three landscapes: the Arctic. National Wildlife \nL20: Refuge, Bristol Bay, and the Tongass National Forest. In March 2022,.\nL21: * 【5†Assessment of ecological and cultural values within the ...; \nL22: publish_date: none†www.blm.gov】 This document was written to provide technical \nL23: information regarding the ecological importance of the National Petroleum \nL24: Reserve – Alaska (NPR-A). Several ...\nL25: * 【6†Accomplishments Report; publish_date: none†defenders.org】 National \nL26: Petroleum Reserve-Alaska. Identified for exceptional wildlife and cultural \nL27: values, including critical habitat for polar bears and other species ...\nL28: * 【7†2022 annual report; publish_date: none†dory-\nL29: plantain-s2zc.squarespace.com】 These projects are made possible through \nL30: collaborations with Defenders of. Wildlife, Cook Inletkeeper, Trustees for \nL31: Alaska,. Environmental Investigation Agency, ...\nL32: * 【8†23IMPACT REPORT; publish_date: none†www.annualreports.com】 Defenders of \nL33: Wildlife made incredible progress protecting wildlife and wild places in 2023, \nL34: helping shape a brighter future for imperiled species and vital ...\nL35: * 【9†From Leasing to Land Protections; publish_date: none†defenders-cci.org】 \nL36: Abstract. When you hear the name, National Petroleum Reserve - Alaska (NPR-A, or\nL37: Reserve), you might think of a barren land filled with oil wells and ...", + "tool_call_id": "fixture-tool-call-15-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Open result 0.", + "tool_calls": [ + { + "id": "fixture-tool-call-17-0", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 7, \"id\": 0}" + } + } + ] + }, + { + "role": "tool", + "content": "[8] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf)\n**viewing lines [0 - 53] of 289**\n\nL0: \nL1: URL: https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-\nL2: wildlife_2022.pdf\nL3: ===== Page 1 ===== 2022 Impact Report C E L E B R A T I N G Y E A R S ===== Page\nL4: 2 ===== 2022 Defenders of Wildlife made important progress for imperiled \nL5: species and vital landscapes across the United States in 2022. GRAY WOLF | JIM \nL6: GUMMERAL MAY STOCK PRIOR Lawyers successfully challenged the previous \nL7: administration’s decision to delist the gray wolf and restored critical federal \nL8: protections under the Endangered Species Act. This latest triumph in court is \nL9: part of our ongoing battle to protect and restore gray wolves throughout their \nL10: historical range and shield them from persecution by extremist legislators in \nL11: Idaho, Montana and Wyoming. TWO MORE FATALIZED GRAY SWALLETS TO SEA TO SHARE \nL12: ALLIANCE Lobbyists worked around the clock to expand funding for wildlife \nL13: conservation in the FY2022 federal spending bill, which included $31 million (a \nL14: 44% increase) for the Bureau of Land Management’s Threatened and Endangered \nL15: Species Program, $2.5 million (an 81% increase) for the U.S. Department of \nL16: Agriculture Wildlife Services’ Nonlethal Initiative to prevent human-wildlife \nL17: conflicts and $21 million (a 320% increase) for North Atlantic right whale \nL18: conservation. 2 DEFENDERS OF WILDLIFE ===== Page 3 ===== **Policy Experts** \nL19: played a crucial role in securing international trade protections for 100 \nL20: species of sharks and rays, all 158 species of glass frogs and 73 species of \nL21: reptiles, including 21 species of desert horned lizards, at the Convention on \nL22: International Trade in Endangered Species (CITES) in Panama. \\--- **Field \nL23: Teams** worked tirelessly to protect and restore imperiled species across the \nL24: country. From Florida manatees and red wolves in the Southeast to belugas and \nL25: grizzly bears in Alaska, Defenders’ conservation experts were on the ground \nL26: saving species that need our help to survive and thrive. \\--- **Conservation \nL27: Innovation Experts** published more than 10 peer-reviewed studies on topics that\nL28: include the Cook Inlet beluga whale, golden-cheeked warbler, global parrot \nL29: biodiversity, the Endangered Species Act, the effects of mountaintop removal \nL30: mining on endangered species, the ecological importance of panthers and the \nL31: implementation of “30x30” – the globally recognized goal to which President \nL32: Biden committed the U.S. to conserve 30% of our imperiled lands and waters by \nL33: 2030. \\--- **DEFENDERS.ORG** ===== Page 4 ===== THANK YOU Defenders celebrated \nL34: our 75th anniversary in 2022—an exciting milestone that we attribute to our \nL35: unwavering dedication to our wildlife conservation mission. From helping to pass\nL36: the Endangered Species Act and other bedrock conservation laws to leading the \nL37: advocacy on the reintroduction of wolves to Yellowstone National Park and other \nL38: parts of the West, Defenders is proud to inspire, lead and ultimately celebrate \nL39: some of the most consequential conservation wins for wildlife and wild places. \nL40: It is wonderfully fitting that we also mark 2022 as a year when we successfully \nL41: defended, yet again, the original focal species and symbol of our organization, \nL42: the gray wolf. While the future for gray wolves and many other species remains \nL43: uncertain, we will never stop fighting to protect the imperiled wildlife that \nL44: need us and to restore vibrant and balanced ecosystems. Our mission is more \nL45: important than ever due to the escalating biodiversity and climate crises. \nL46: Defenders is determined to build on our accomplishments and seize new \nL47: opportunities to address these challenges through our multifaceted approach to \nL48: conservation. We focus on developing innovative and science-based solutions, \nL49: coordinating on-the-ground species and habitat restoration, promoting human-\nL50: wildlife coexistence, protecting and strengthening our bedrock environmental \nL51: laws, investing in impactful collaborations and partnerships, building a more \nL52: diverse and powerful conservation movement and challenging attacks on wildlife \nL53: through advocacy, policy and litigation. Defenders’ continued success is made ", + "tool_call_id": "fixture-tool-call-17-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Search within PDF for \"National Petroleum Reserve\".", + "tool_calls": [ + { + "id": "fixture-tool-call-19-0", + "type": "function", + "function": { + "name": "find", + "arguments": "{\"cursor\": 8, \"pattern\": \"Petroleum\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[9] Find results for text: `petroleum` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf/find?pattern=petroleum)\n**viewing lines [0 - 4] of 4**\n\nL0: # 【0†match at L201】\nL1: the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas \nL2: drilling and removed protections for important wildlife habitat. Protections \nL3: have now been restored for nearly half of the reserve’s pristine lands, which \nL4: are vital habitat for shorebirds, denning polar bears and tens of thousands of ", + "tool_call_id": "fixture-tool-call-19-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Open around L190.", + "tool_calls": [ + { + "id": "fixture-tool-call-21-0", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 8, \"loc\": 180}" + } + } + ] + }, + { + "role": "tool", + "content": "[10] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf)\n**viewing lines [180 - 227] of 289**\n\nL180: the sixth successful transfer of bison to the Assiniboine and Sioux Tribes of \nL181: Fort Peck since 2019. \\--- **SWIFT FIX KITS | © ROSIMA PAELARINTSKIMMA MADDIAL \nL182: 200 AND CONSERVATION BIOLOGY INSTITUTE** \\--- **Celebrating our third year** of \nL183: a collaborative program with the Aaniih and Nakoda Tribes and others to restore \nL184: swift foxes to the Fort Belknap Indian Reservation in Montana, Defenders helped \nL185: with the release of 28 more swift foxes. With over 100 foxes reintroduced \nL186: through this program, monitoring efforts show that they are reproducing in the \nL187: wild—a critical measure of success for a self-sustaining population. \\--- \nL188: **Defenders continued to lead the way** for conserving and recovering the \nL189: endangered black-footed ferret, supporting the black-footed ferret survey for \nL190: the Fort Belknap Indian community. Thirty-six ferrets were vaccinated against \nL191: sylvatic plague and two dozen kits were released in the wild. \\--- **10 \nL192: DEFENDERS OF WILDLIFE** ===== Page 11 ===== Defenders helped to bring hope for \nL193: recovery for the endangered military macaw, adding 11 fledglings to a growing \nL194: wild population in Puerta Vallarta, Mexico, that is under pressure from habitat \nL195: loss and poachers for the illegal pet trade. Accord- ing to our recent report, \nL196: the 2008 parrot trade ban that Defenders fought to achieve is working. \nL197: Preventing more than 30,000 parrots from being illegally trapped each year, the \nL198: trade ban has resulted in a 47% decrease in the illegal trade of parrots and an \nL199: 88% decrease in U.S. seizures of Mexican parrots. As a result of a Defenders \nL200: lawsuit, BLM rescinded the previous administration’s plan that opened most of \nL201: the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas \nL202: drilling and removed protections for important wildlife habitat. Protections \nL203: have now been restored for nearly half of the reserve’s pristine lands, which \nL204: are vital habitat for shorebirds, denning polar bears and tens of thousands of \nL205: caribou in the Teshekpuk caribou herd. Our new storymap illustrates why the \nL206: reserve is so important to wildlife. Defenders marked the 20th anniversary of \nL207: our Sea Otter Awareness Week. In celebration of this milestone, we sponsored \nL208: state legislation, signed by Gov. Gavin Newsom, that formally recognizes the \nL209: event’s anniversary and acknowledges the vital role that sea otters play in \nL210: California’s nearshore ecosystems. This annual celebration provides a unique \nL211: opportunity to raise awareness of the important role sea otters play in \nL212: maintaining ecosystem health, our need to coexist with sea otters and the \nL213: efforts of Defenders and others in advancing sea otter conservation. \nL214: DEFENDERS.ORG ===== Page 12 ===== FINANCIAL REPORT In 2022, Defenders of \nL215: Wildlife’s nearly 2.2 million members and supporters contributed more than $41 \nL216: million for wildlife and wildlife habitat. This figure includes all those who \nL217: generously donated their time and expertise. The steadfast support of our donors\nL218: allows Defenders to sustain our program and public education efforts in the \nL219: field, the courts and on Capitol Hill. 2022 SOURCES OF FUNDS Grants and \nL220: contributions $28,946 Bequests, trusts and split interests $13,928 Income from \nL221: investments, annuity reserve funds and trusts (-$7,495) Royalties and other \nL222: income $744 Contributed services $5,047 Total Revenue $41,170 USES OF FUNDS \nL223: Biodiversity conservation $18,970 Constituency mobilization $16,528 Fundraising \nL224: $1,029 Management and general $6,354 Total Expenses $42,881 Change in net assets\nL225: (-$1,711) Net assets, start of year $45,144 Net Assets, End of the Year $43,433\nL226: Dollars are in thousands. 12 DEFENDERS OF WILDLIFE Grants and contributions 59%\nL227: Contributed services 10% Royalties and other income 2% Bequests, trusts and ", + "tool_call_id": "fixture-tool-call-21-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Scrolling a bit more for \"Contributed services 10%\". Already seen. So indeed the 2022 annual report shows contributed services 10%, down from 12% in 2021.\n\nThus the organization is Defenders of Wildlife.\n\nThe question: \"Who was the emeritus board member that year? Name only, no titles.\"\n\nWe need to identify the emeritus board member listed in the 2022 annual report. From the 2021 annual report, we saw \"BOARD EMERITUS Terry C. Pelster, Esq.\" However, that is for 2021. The question asks \"Who was the emeritus board member that year?\" That year refers to the following year (2022). So we need to check the 2022 annual report for the emeritus board member.\n\nOpen the 2022 PDF near the board list. Search for \"Emeritus\".", + "tool_calls": [ + { + "id": "fixture-tool-call-23-0", + "type": "function", + "function": { + "name": "find", + "arguments": "{\"cursor\": 8, \"pattern\": \"Emeritus\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[11] Find results for text: `emeritus` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf/find?pattern=emeritus)\n**viewing lines [0 - 10] of 10**\n\nL0: # 【0†match at L237】\nL1: Sisk, Ph.D British Columbia, Canada BOARD EMERITUS Terry C. Pelster, Esq. New \nL2: York *Executive Committee member* STAFF & ADVISORS SCIENCE ADVISORS Gabriela \nL3: Chavarria, Ph.D Burke Museum, University of Washington Leah Gerber, Ph.D Arizona\nL4: State University Kathayoon Khalil, Ph.D Oregon Zoo Rurik List, Ph.D Universidad\nL5: \nL6: # 【1†match at L243】\nL7: Davis Thomas D. Sisk, Ph.D Emeritus Professor Northern Arizona University David \nL8: S. Wilcox, Ph.D Princeton University Rae Wynn-Grant, Ph.D National Geographic \nL9: Society SENIOR STAFF Jamie Rappaport Clark President & CEO Thu Pham Chief of \nL10: Staff James Stofan Senior Vice President, Operations McCrystle Adams Vice ", + "tool_call_id": "fixture-tool-call-23-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Open around L230.", + "tool_calls": [ + { + "id": "fixture-tool-call-25-0", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 8, \"loc\": 220}" + } + } + ] + }, + { + "role": "tool", + "content": "[12] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf)\n**viewing lines [220 - 267] of 289**\n\nL220: contributions $28,946 Bequests, trusts and split interests $13,928 Income from \nL221: investments, annuity reserve funds and trusts (-$7,495) Royalties and other \nL222: income $744 Contributed services $5,047 Total Revenue $41,170 USES OF FUNDS \nL223: Biodiversity conservation $18,970 Constituency mobilization $16,528 Fundraising \nL224: $1,029 Management and general $6,354 Total Expenses $42,881 Change in net assets\nL225: (-$1,711) Net assets, start of year $45,144 Net Assets, End of the Year $43,433\nL226: Dollars are in thousands. 12 DEFENDERS OF WILDLIFE Grants and contributions 59%\nL227: Contributed services 10% Royalties and other income 2% Bequests, trusts and \nL228: split interests 29% TOTAL REVENUE Income from investments -15% Program and \nL229: support services 98% USE OF FUNDS Fundraising 2% ===== Page 13 ===== \nL230: DEFENDERS.ORG 13 ===== Page 14 ===== BOARD OF DIRECTORS John Dayton, Esq.* Chair\nL231: Texas Sheryl Sachs* Vice Chair Connecticut Dinah Bear, Esq.* Secretary Arizona \nL232: Mark Caylor* Treasurer District of Columbia Susan Alexander Connecticut Sylvia \nL233: Baca* California Whitefield Bailey Tennessee Donald Barry, Esq. Oregon Kim \nL234: O’Keefe Beck, Ph.D California Gabriela Chavarria, Ph.D Washington Jamie \nL235: Rappaport Clark* President & CEO Virginia John Donovan, Esq.* Montana and \nL236: California Caroline D. Gabel Maryland Terry L. Root, Ph.D New Mexico Thomas D. \nL237: Sisk, Ph.D British Columbia, Canada BOARD EMERITUS Terry C. Pelster, Esq. New \nL238: York *Executive Committee member* STAFF & ADVISORS SCIENCE ADVISORS Gabriela \nL239: Chavarria, Ph.D Burke Museum, University of Washington Leah Gerber, Ph.D Arizona\nL240: State University Kathayoon Khalil, Ph.D Oregon Zoo Rurik List, Ph.D Universidad\nL241: Autónoma Metropolitana-Lerma Barry Noon, Ph.D Colorado State University Terry \nL242: L. Root, Ph.D Stanford University Mark Schwartz, Ph.D University of California-\nL243: Davis Thomas D. Sisk, Ph.D Emeritus Professor Northern Arizona University David \nL244: S. Wilcox, Ph.D Princeton University Rae Wynn-Grant, Ph.D National Geographic \nL245: Society SENIOR STAFF Jamie Rappaport Clark President & CEO Thu Pham Chief of \nL246: Staff James Stofan Senior Vice President, Operations McCrystle Adams Vice \nL247: President, Conservation Law Matt Braughler Vice President, Philanthropy Rachel \nL248: Brittin Vice President, Communications Shawn Cantrell Vice President, Field \nL249: Conservation Programs Robert L. Dewey Vice President, Government Relations & \nL250: External Affairs Monica Goldberg Vice President, Landscape Conservation Sharon \nL251: Morris-Smith Vice President of Human Resources Elizabeth Nichols Vice President,\nL252: Finance/CFO Anthony Pegues Vice President, Diversity, Equity, Inclusion, \nL253: Justice Lindsay Rosa, Ph.D Vice President, Conservation Research and Innovation \nL254: Brandace Swimley Vice President, Integrated Marketing 14 DEFENDERS OF WILDLIFE \nL255: ===== Page 15 ===== There are many ways to support Defenders of Wildlife and \nL256: protect our imperiled wildlife. To learn more or donate, contact us at \nL257: 1-800-385-9712, visit Defenders.org/donate or choose one of the following \nL258: options. ➤ Transfer appreciated assets directly to Defenders and maximize \nL259: potential tax benefits. Defenders.org/ donate-stocks-help-save-wildlife ➤ Become\nL260: a Leadership Council member with a single donation of $1,000 or more and enjoy \nL261: more personal involvement with Defenders. Defenders.org/leadershipcouncil ➤ If \nL262: you are 70.5 or older, roll over a portion of your IRA to benefit wildlife. \nL263: Remember, these gifts must come directly from your IRA account. \nL264: Defenders.org/planned-giving/gifts-your-retirement-plan ➤ Get started on a \nL265: contribution from your donor-advised fund Defenders.org/donor-advised-funds ➤ \nL266: Check with your personnel office or that of your spouse/partner about your \nL267: company’s Corporate Matching Gifts program or visit us at ", + "tool_call_id": "fixture-tool-call-25-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Thus the emeritus board member is Terry C. Pelster. However, the question might ask for \"Who was the emeritus board member that year?\" Possibly referring to the year of the annual report that mentioned the rescinded plan and contributed services decrease. That is 2022. So the answer is Terry C. Pelster.\n\nBut let's verify that the 2022 annual report indeed includes the mention of contributed services decreasing to 10%. We saw that. So answer is Terry C. Pelster.\n\nBut maybe the question is tricky: \"Who was the emeritus board member that year?\" Could be the emeritus board member listed in the 2021 annual report? However, the question says \"That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services. The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%.\"\n\nThus the \"following year\" annual report is 2022. The question likely asks: \"Who was the emeritus board member that year?\" Meaning the emeritus board member listed in that annual report (2022). So answer: Terry C. Pelster.\n\nBut let's confirm that the 2021 annual report also shows emeritus board member Terry C. Pelster. Indeed we saw that. But perhaps the 2022 report might have a different emeritus board member. We saw it's the same. So answer is Terry C. Pelster.\n\nHowever, the question might be about a different organization. Let's verify the details.\n\nFirst clue: Press release about Kenya vaccine doses on Dec 5, 2021. That is likely a World Bank press release. Search for \"World Bank Kenya vaccine doses December 2021 press release\".", + "tool_calls": [ + { + "id": "fixture-tool-call-27-0", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"As of December 5, 2021, Kenya had received a total of 16,201,670 vaccines\", \"topn\": 10, \"source\": \"web\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[13] As of December 5, 2021, Kenya had received a total of 16,201,670 vaccines (Search_Results/As of December 5, 2021, Kenya had received a total of 16,201,670 vaccines)\n**viewing lines [0 - 41] of 41**\n\nL0: \nL1: URL: Search_Results/As of December 5, 2021, Kenya had received a total of \nL2: 16,201,670 vaccines\nL3: # Search Results\nL4: \nL5: * 【0†Kenya's Economy is Showing Resilience as Output Rises ...; publish_date: \nL6: none†www.worldbank.org】 Dec 14, 2021 — As of December 5, 2021, Kenya had \nL7: received a total of 16,201,670 vaccines, with 7,583,134 administered. While \nL8: vaccine acceptance is ...\nL9: * 【1†Unmet need for COVID-19 vaccination coverage in Kenya - PMC; \nL10: publish_date: none†pmc.ncbi.nlm.nih.gov】 by SK Muchiri · 2022 · Cited by 42 — As\nL11: of December 2021, six counties had a vaccination coverage of less than 5%. \nL12: These counties include Garissa, Mandera, Marsabit, Tana River, Turkana, and ...\nL13: * 【2†MINISTRY OF HEALTH; publish_date: none†covidhub.mediacouncil.or.ke】 Dec \nL14: 1, 2021 — • Total Covid-19 Vaccines Received to date- 16,201,670 ... Table 10: \nL15: Vaccine Logistics Received in the Country as at 5th, December 2021.\nL16: * 【3†COVID-19 vaccination refusal trends in Kenya over 2021 - PMC; \nL17: publish_date: none†pmc.ncbi.nlm.nih.gov】 by RT Rego · 2023 · Cited by 21 — We \nL18: assessed vaccine refusal over time in Kenya, and characterized factors \nL19: associated with changes in vaccination refusal.\nL20: * 【4†Ciheb-Kenya on the Front Lines of the COVID-19 Vaccine ...; publish_date:\nL21: none†www.ciheb.org】 Of the 721,509 persons vaccinated, 409,711 are aged 58 \nL22: years and above, 143,050 are healthcare workers, 109,593 are teachers, and \nL23: 59,155 are security officers.\nL24: * 【5†proportion of fully vaccinated adults 23.3pc.; publish_date: \nL25: none†www.facebook.com】 COVID-19: KENYA administers 150632 new vaccinations; \nL26: total jabs given now 13782963; proportion of fully vaccinated adults 23.3pc.\nL27: * 【6†National COVID-19 Vaccines Deployment and ...; publish_date: \nL28: none†www.parliament.go.ke】 Kenya has a well-established immunization program and\nL29: a robust disease surveillance and response unit (DSRU) which serves as an early\nL30: warning system, to.\nL31: * 【7†Kenya: Vaccine Requirements Violate Rights; publish_date: \nL32: none†www.hrw.org】 Dec 13, 2021 — Kenya, with an estimated adult population of \nL33: 27.2 million and a total population of 55 million, has received approximately 23\nL34: million doses as ...\nL35: * 【8†Government gives 20 million a month to get vaccinated; publish_date: \nL36: none†www.bbc.com】 Nov 22, 2021 — Although less than 10% of Kenyans are \nL37: vaccinated, the government wants to avoid a surge over Christmas.\nL38: * 【9†ministry of health kenya covid-19 immunization status ...; publish_date: \nL39: none†www.health.go.ke】 May 15, 2022 — By the end of 2021 Kenya had administered \nL40: 10.1 million doses of 5 types of Covid-19 vaccines. During this period 4.2 \nL41: million people were ...", + "tool_call_id": "fixture-tool-call-27-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Open result 0.", + "tool_calls": [ + { + "id": "fixture-tool-call-29-0", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 13, \"id\": 0}" + } + } + ] + }, + { + "role": "tool", + "content": "[14] www.worldbank.org (https://www.worldbank.org/en/news/press-release/2021/12/14/kenya-s-economy-is-showing-resilience-as-output-rises-above-pre-pandemic-levels-driven-by-a-rebound-in-the-services-sect)\n**viewing lines [0 - 52] of 61**\n\nL0: \nL1: URL: https://www.worldbank.org/en/news/press-release/2021/12/14/kenya-s-economy-\nL2: is-showing-resilience-as-output-rises-above-pre-pandemic-levels-driven-by-a-\nL3: rebound-in-the-services-sect\nL4: NAIROBI, December 14, 2021 — Kenya’s economy has demonstrated resilience to the \nL5: COVID-19 shock, with output in the first half of the year rising above pre-\nL6: pandemic levels. In 2021 as a whole, gross domestic product (GDP) is expected to\nL7: grow by 5%, one of the faster recoveries among Sub-Saharan African countries. \nL8: Overall economic performance is expected to be robust at 4.9% per year in \nL9: 2022-23, similar to the pre-pandemic pace (5% average annual growth from 2010 to\nL10: 2019). According to the 24th edition of the Kenya Economic Update, “From \nL11: Recovery to Better Jobs,” growth has been supported by rebounds in industry and,\nL12: especially, services. Agricultural output, however, fell by 0.5% year on year \nL13: in the first half of 2021 following a particularly strong performance in 2020, \nL14: partly due to below-average rains. Demand-side recovery has been supported by a \nL15: revival in private consumption, against a backdrop of improving employment \nL16: conditions and household incomes. “Kenya’s economy has shown considerable \nL17: resilience to the enormous shock of the pandemic, and this year is expected to \nL18: post one of the stronger growth rebounds in the region thanks to diversified \nL19: sources of growth and sound economic policies and management,” said Keith \nL20: Hansen, World Bank Country Director for Kenya. “However, poverty has increased, \nL21: and the buffers and coping mechanisms of households, firms, and the public \nL22: finances have been depleted.” Economic activity in Kenya has continued to adapt \nL23: to the pandemic and associated restrictions. A mix of containment measures, such\nL24: as a nightly curfew, were in effect through most of 2021, while more \nL25: economically disruptive measures such as lockdowns and travel restrictions were \nL26: phased, limiting the impact on economic activities. The vaccine rollout, which \nL27: had a slow start due to supply constraints, has picked up as new shipments of \nL28: vaccines have arrived, particularly since September. This has supported economic\nL29: recovery and growth through the third quarter of 2021. As of December 5, 2021, \nL30: Kenya had received a total of 16,201,670 vaccines, with 7,583,134 administered. \nL31: While vaccine acceptance is reportedly high there is still a long way to go \nL32: towards the government’s target of fully inoculating the adult population of \nL33: about 30 million by the end of 2022.As of December 6, 2021, about 10% of adults \nL34: (2.9 million people) had been fully vaccinated while another 16% (4.9 million \nL35: people) had received their first dose. The report also projects robust growth in\nL36: the medium-term. This outlook takes into account that some sub-sectors, such as\nL37: education, have bounced back strongly, but others such as international \nL38: tourism, have only partially and face a much more protracted recovery. The \nL39: anticipated further recovery of hotels and restaurants, trade, transport, and \nL40: other services, depends on substantial vaccination progress to help prevent new \nL41: waves of infections and associated containment measures. “Recent economic \nL42: performance has been strong and the outlook is positive, but in Kenya, as \nL43: everywhere, the future course of the pandemic remains a major source of \nL44: uncertainty, said Alex Sienaert, Senior Economist for Kenya. “To support further\nL45: recovery, reduce debt distress risks, and rebuild space for social and \nL46: development spending, it is critical that Kenya continues to implement medium-\nL47: term fiscal consolidation plans.” In addition to pandemic-related risks to the \nL48: outlook, a second key domestic risk factor stems from the drought conditions \nL49: which are affecting parts of the country and already causing severe hardship. \nL50: Should the drought intensify or spread, this would weigh on the near-term \nL51: economic outlook. Weaker global growth, higher-than-anticipated energy prices, \nL52: and tighter external financing conditions are the primary external risks. As ", + "tool_call_id": "fixture-tool-call-29-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Thus first clue confirmed.\n\nSecond clue: Defenders of Wildlife 2021 annual report. Confirmed.\n\nThird clue: Defenders of Wildlife 2022 annual report. Confirmed.\n\nThus the answer is Terry C. Pelster.\n\nBut let's verify that the 2022 annual report indeed shows \"Contributed services 10%\". We saw that. Also confirm that 2021 report shows 12%. We saw that.\n\nThus the question: \"Who was the emeritus board member that year?\" Refers to 2022. Answer: Terry C. Pelster.\n\nHowever, the question might be asking for the emeritus board member in the 2021 report? Let's read the question carefully:\n\n\"In a year before 2023, between December 1 and 15 (inclusive), a press release reported that as of the 5th of that month, a country had received a total of 16,201,670 vaccine doses, with 7,583,134 administered. That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services. The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%. Who was the emeritus board member that year? Name only, no titles.\"\n\nThe phrase \"that year\" could refer to \"the following year\" i.e., the year of the annual report that mentioned the rescinded plan. Or could refer to \"that same year\" (the year of the press release and the first annual report). However, the phrase \"that year\" appears at the end: \"Who was the emeritus board member that year?\" It likely refers to the year of the annual report that is being discussed in the preceding sentence. The preceding sentence talks about the following year's annual report. So indeed \"that year\" is the following year (2022).\n\nThus answer: Terry C. Pelster.\n\nBut to be thorough, we should check the 2022 annual report for any other emeritus board members. The search only showed \"BOARD EMERITUS Terry C. Pelster, Esq.\" Maybe there are multiple emeritus board members. But likely only one.\n\nCheck the 2021 annual report for emeritus board member. It also shows Terry C. Pelster. So maybe the question is asking for the emeritus board member in the 2021 report. But why would they ask that? The question is about connecting the clues to identify the organization and then find the emeritus board member. The clues point to Defenders of Wildlife. The question likely asks for the emeritus board member as of the 2022 report.\n\nNevertheless, we should confirm that the 2022 report is indeed the one that mentions the rescinded plan. We saw that. Also confirm that contributed services decreased from 12% to 10%. We saw that.\n\nThus answer is Terry C. Pelster.\n\nBut let's search for \"Defenders of Wildlife emeritus board member 2022 Terry Pelster\".", + "tool_calls": [ + { + "id": "fixture-tool-call-31-0", + "type": "function", + "function": { + "name": "search", + "arguments": "{\"query\": \"Terry C. Pelster Defenders of Wildlife emeritus board member\", \"topn\": 10, \"source\": \"web\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[15] Terry C. Pelster Defenders of Wildlife emeritus board member (Search_Results/Terry C. Pelster Defenders of Wildlife emeritus board member)\n**viewing lines [0 - 31] of 31**\n\nL0: \nL1: URL: Search_Results/Terry C. Pelster Defenders of Wildlife emeritus board member\nL2: # Search Results\nL3: \nL4: * 【0†2024 impact report; publish_date: none†defenders.org】 Terry Root, Ph.D. \nL5: New Mexico. Thomas D. Sisk, Ph.D. British Columbia, Canada. BOARD EMERITUS. \nL6: Terry C. Pelster, Esq. New York. SCIENCE ADVISORS. Leah Gerber, Ph ...\nL7: * 【1†Defenders of Wildlife; publish_date: none†www.wrongkindofgreen.org】 \nL8: Directors ; Victor M. Sher*, Chair – California ; Terry C. Pelster*, Vice Chair \nL9: – New York ; Richard Kopcho*, Treasurer – California ; Adelaide P. Gomer*, \nL10: Secretary ...\nL11: * 【2†Exhibit 12; publish_date: none†www.uschamber.com】 DECLARATION OF TERRY \nL12: PELSTER. I, Terry C. Pelster, declare as follows: 1. I am a current member of \nL13: Defenders of Wildlife (“Defenders”) and have been a member.\nL14: * 【3†2020 ANNUAL REPORT; publish_date: none†defenders.org】 Terry L. Root. \nL15: Florida. BOARD. EMERITUS. Terry C. Pelster. New York. Alan Steinberg. Florida. \nL16: SCIENCE. ADVISORS. Gabriela Chavarria,. Ph.D. Denver Museum of.\nL17: * 【4†2019 Annual Report; publish_date: none†defenders.org】 BOARD EMERITUS. \nL18: Terry C. Pelster. New York. Alan Steinberg. Florida. SCIENCE ADVISORS. Leah \nL19: Gerber, Ph.D. Arizona State University. Kathayoon Khalil, Ph.D.\nL20: * 【5†annual report; publish_date: none†defenders.org】 * Executive Committee \nL21: member. BoARD EMERITuS. Terry C. Pelster. New York. Alan Steinberg. Florida. \nL22: PRESIDENT. EMERITuS. Rodger Schlickeisen. Virginia. SCIENCE ...\nL23: * 【6†BEING THERE FOR BELUGAS; publish_date: none†wendeenicole.com】 Mar 2, 2023\nL24: — Defenders of Wildlife is a national, nonprofit member- ship ... Terry C. \nL25: Pelster, Esq. New York. *Executive Committee member outfront.\nL26: * 【7†Annual Reports; publish_date: none†www.annualreports.com】 WAYS TO GIVE \nL27: There are many ways to support Defenders of Wildlife and protect our imperiled \nL28: wildlife. ... Terry C. Pelster. New York. Alan Steinberg. Florida.\nL29: * 【8†2014 Annual Report; publish_date: none†defenders.org】 BOARD OF DIRECTORS \nL30: defenders of wildlife is a national, nonprofit ... BOARD EMERITUS. Terry C. \nL31: Pelster. New York. Alan Steinberg. Florida. Page 3. PO.", + "tool_call_id": "fixture-tool-call-31-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Thus Terry C. Pelster is indeed the emeritus board member.\n\nNow the question might want just \"Terry C. Pelster\". Or maybe \"Terry Pelster\". Or \"Terry C. Pelster, Esq.\" but instructions say \"Name only, no titles.\" So likely \"Terry C. Pelster\". However, maybe they want \"Terry Pelster\". Better to include full name as listed: \"Terry C. Pelster\".\n\nLet's read the question again: \"Who was the emeritus board member that year? Name only, no titles.\" So we should answer just the name, no \"Esq.\" or \"Board Emeritus\". So \"Terry C. Pelster\". Or maybe \"Terry Pelster\". But the document shows \"Terry C. Pelster, Esq.\" So the name is \"Terry C. Pelster\". \n\nBetter to verify that the 2022 annual report indeed lists \"BOARD EMERITUS Terry C. Pelster, Esq.\" We saw that line.\n\nThus answer: Terry C. Pelster.\n\nBut let's double-check that the question might refer to the 2021 annual report. The phrase \"that year\" could be ambiguous. Let's parse the timeline:\n\n- Year 1: 2021. Press release in Dec 2021 about Kenya vaccine doses. Same year (2021) annual report of Defenders of Wildlife shows 12% contributed services.\n\n- Year 2: 2022. Annual report mentions rescinded plan and contributed services decreased to 10%.\n\nThe question says \"The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan ... and that contributed services had decreased by 2% to 10%.\" This is year 2.\n\nThen asks \"Who was the emeritus board member that year?\" That likely refers to year 2.\n\nThus answer: Terry C. Pelster.\n\nHowever, perhaps the emeritus board member changed between 2021 and 2022. We need to check 2022 annual report for any other emeritus board member. We saw only Terry C. Pelster. Also check 2021 annual report: same. So likely unchanged.\n\nThus answer is Terry C. Pelster.\n\nNevertheless, we should confirm that the 2022 annual report is indeed the one that includes the phrase \"BLM rescinded the previous administration’s plan that opened most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling\". We saw that in line 200-202.\n\nThus answer is Terry C. Pelster.\n\nNow we need to provide the answer in the final message. Only the name.\n\nThus final answer: Terry C. Pelster.\n\nBut let's be thorough and search for \"Defenders of Wildlife 2022 annual report contributed services 10%\".", + "tool_calls": [ + { + "id": "fixture-tool-call-33-0", + "type": "function", + "function": { + "name": "find", + "arguments": "{\"cursor\": 8, \"pattern\": \"Contributed services\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[16] Find results for text: `contributed services` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf/find?pattern=contributed services)\n**viewing lines [0 - 10] of 10**\n\nL0: # 【0†match at L222】\nL1: income $744 Contributed services $5,047 Total Revenue $41,170 USES OF FUNDS \nL2: Biodiversity conservation $18,970 Constituency mobilization $16,528 Fundraising \nL3: $1,029 Management and general $6,354 Total Expenses $42,881 Change in net assets\nL4: (-$1,711) Net assets, start of year $45,144 Net Assets, End of the Year $43,433\nL5: \nL6: # 【1†match at L227】\nL7: Contributed services 10% Royalties and other income 2% Bequests, trusts and \nL8: split interests 29% TOTAL REVENUE Income from investments -15% Program and \nL9: support services 98% USE OF FUNDS Fundraising 2% ===== Page 13 ===== \nL10: DEFENDERS.ORG 13 ===== Page 14 ===== BOARD OF DIRECTORS John Dayton, Esq.* Chair", + "tool_call_id": "fixture-tool-call-33-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Thus confirmed.\n\nNow we should verify that the 2021 annual report shows contributed services 12%.", + "tool_calls": [ + { + "id": "fixture-tool-call-35-0", + "type": "function", + "function": { + "name": "find", + "arguments": "{\"cursor\": 4, \"pattern\": \"Contributed services\"}" + } + } + ] + }, + { + "role": "tool", + "content": "[17] Find results for text: `contributed services` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf/find?pattern=contributed services)\n**viewing lines [0 - 4] of 4**\n\nL0: # 【0†match at L194】\nL1: $3,354 Royalties and other income $3,576 Contributed services $6,140 Total \nL2: Revenue $49,819 USES OF FUNDS Biodiversity conservation $22,420 Constituency \nL3: mobilization $16,324 Fundraising $1,211 Management and general $5,865 Total \nL4: Expenses $45,820 Change in net assets $3,999 Net assets, start of year $41,145 ", + "tool_call_id": "fixture-tool-call-35-0" + }, + { + "role": "assistant", + "content": "", + "reasoning_content": "Open around L190-200 to see percentage.", + "tool_calls": [ + { + "id": "fixture-tool-call-37-0", + "type": "function", + "function": { + "name": "open", + "arguments": "{\"cursor\": 4, \"loc\": 185}" + } + } + ] + }, + { + "role": "tool", + "content": "[18] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf)\n**viewing lines [185 - 232] of 261**\n\nL185: devastating water pollution that is killing the seagrass and causing manatees to\nL186: starve. DAVID TES | SAM FRENZY DRAW DEFENDERS.ORG 11 ===== Page 12 ===== In \nL187: 2021, Defenders of Wildlife’s nearly 2.2 million members and supporters \nL188: contributed more than $49 million for wildlife and wildlife habitat. This figure\nL189: includes all those who generously donated their time and expertise. The \nL190: steadfast support of our donors allows Defenders to sustain our program and \nL191: public education efforts in the field, the courts and on Capitol Hill. 2021 \nL192: SOURCES OF FUNDS Grants and contributions $29,057 Bequests, trusts and split \nL193: interests $7,692 Income from investments, annuity reserve funds and trusts \nL194: $3,354 Royalties and other income $3,576 Contributed services $6,140 Total \nL195: Revenue $49,819 USES OF FUNDS Biodiversity conservation $22,420 Constituency \nL196: mobilization $16,324 Fundraising $1,211 Management and general $5,865 Total \nL197: Expenses $45,820 Change in net assets $3,999 Net assets, start of year $41,145 \nL198: Net Assets, End of the Year $45,144 Dollars are in thousands. 12 DEFENDERS OF \nL199: WILDLIFE Grants and contributions 58% Income from investments 7% Requests, \nL200: trusts and split interests 15% Royalties and other income 7% Contributed \nL201: services 12% Program and support services 97% Fundraising 3% ===== Page 13 =====\nL202: DEFENDERS.ORG 13 ===== Page 14 ===== John Dayton, Esq.* Chair Texas Sheryl \nL203: Sachs* Vice Chair Connecticut Whitfield Bailey* Secretary Tennessee Mark Caylor*\nL204: Treasurer District of Columbia Sylvia Baca* California Donald Barry, Esq. \nL205: Oregon Dinah Bear, Esq. Arizona Kim O’Keefe Beck, Ph.D California Gabriela \nL206: Chavarria, Ph.D Washington Jamie Rappaport Clark* President & CEO Virginia \nL207: Caroline D. Gabel Maryland Keith M. Harper, Esq. Maryland John Donovan, Esq. \nL208: Montana and California James M. Hecker, Esq. Virginia Judith Ponsikoff, Ph.D* \nL209: California Terry L. Root, Ph.D Florida BOARD EMERITUS Terry C. Pelster, Esq. New\nL210: York *Executive Committee member* ST A F F & A D V I S O R S SCIENCE ADVISORS \nL211: Gabriela Chavarria, Ph.D Burke Museum, University of Washington Leah Gerber, \nL212: Ph.D Arizona State University Kathayoon Khalil, Ph.D Oregon Zoo Rurik List, Ph.D\nL213: Universidad Autónoma Metropolitana-Lerma Barry Noon, Ph.D Colorado State \nL214: University Terry L. Root, Ph.D Stanford University Mark Schwartz, Ph.D \nL215: University of California-Davis David S. Wilcox, Ph.D Princeton University Rae \nL216: Wynn-Grant, Ph.D National Geographic Society SENIOR STAFF Jamie Rappaport Clark \nL217: President & CEO Thu Pham Chief of Staff Renee Stone Senior Vice President, \nL218: Conservation Programs & General Counsel James Stefan Senior Vice President, \nL219: Operations Matt Braughler Vice President, Philanthropy Rachel Brittin Vice \nL220: President, Communications Shawn Cantrell Vice President, Field Conservation \nL221: Programs Robert L. Dewey Vice President, Government Relations & External Affairs\nL222: Monica Goldberg Vice President, Landscape Conservation Sharon Morris-Smith Vice\nL223: President of Human Resources Elizabeth Nichols Vice President, Finance/CFO \nL224: Anthony Pegues Vice President, Diversity, Equity, Inclusion, Justice Lindsay \nL225: Rosa, Ph.D Director, Center for Conservation Innovation Michael P. Senatore Vice\nL226: President, Conservation Law Brandace Swimley Vice President, Integrated \nL227: Marketing 14 DEFENDERS OF WILDLIFE ===== Page 15 ===== THERE ARE MANY WAYS TO \nL228: SUPPORT DEFENDERS OF WILDLIFE and protect our imperiled wildlife. To learn more \nL229: or donate, contact us at 1-800-385-9712, visit defenders.org/donate or choose \nL230: one of the following options. ➤ Transfer appreciated assets directly to \nL231: Defenders and maximize potential tax benefits. Defenders.org/donate-stocks-help-\nL232: save-wildlife ➤ Become a Leadership Council member with a single donation of ", + "tool_call_id": "fixture-tool-call-37-0" + }, + { + "role": "assistant", + "content": "Terry C. Pelster", + "reasoning_content": "Thus everything matches.\n\nThus answer: Terry C. Pelster.\n\nNow we need to output only the name.\n\nThus final answer: Terry C. Pelster.\n\nBut perhaps the question expects \"Terry Pelster\". However, the document includes middle initial \"C.\" It might be better to include the full name as listed: \"Terry C. Pelster\". The instruction \"Name only, no titles\" suggests we should not include \"Esq.\" or \"Board Emeritus\". But including middle initial is likely okay.\n\nTo be safe, we could output \"Terry C. Pelster\". Let's see similar puzzles: often they expect the full name as appears in the document.\n\nThus answer: Terry C. Pelster.\n\nNow we will provide final answer.", + "tool_calls": [] + } + ] +} diff --git a/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_search_w_date.txt b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_search_w_date.txt new file mode 100644 index 00000000000..6349a4bb208 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_search_w_date.txt @@ -0,0 +1,2455 @@ +<|begin▁of▁sentence|><|User|> + +## Tools + +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<|DSML|function_calls>" block like the following as part of your reply to the user: +<|DSML|function_calls> +<|DSML|invoke name="$FUNCTION_NAME"> +<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<|DSML|invoke name="$FUNCTION_NAME2"> +... + + + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). + +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: + +<|DSML|function_calls> +... + + + +... + + +...thinking about results + +Here are the functions available in JSONSchema format: + +{"name": "search", "description": "Searches for information related to query and displays topn results.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The search query string"}, "topn": {"type": "integer", "description": "Number of top results to display", "default": 10}, "source": {"type": "string", "description": "Source to search within", "enum": ["web", "news"], "default": "web"}}, "required": ["query"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} +{"name": "open", "description": "Opens the link id from the page indicated by cursor starting at line number loc, showing num_lines lines. Valid link ids are displayed with the formatting: 【{id}†.*】. If cursor is not provided, the most recent page is implied. If id is a string, it is treated as a fully qualified URL associated with source. If loc is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available. Use this function without id to scroll to a new location of an opened page.", "parameters": {"type": "object", "properties": {"id": {"description": "Link ID (number) or fully qualified URL (string)", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "Page cursor indicator", "default": -1}, "loc": {"type": "integer", "description": "Starting line number", "default": -1}, "num_lines": {"type": "integer", "description": "Number of lines to show", "default": -1}, "view_source": {"type": "boolean", "description": "Whether to view source", "default": false}, "source": {"type": "string", "description": "Source associated with the URL", "enum": ["web", "news"], "default": "web"}}, "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} +{"name": "find", "description": "Finds exact matches of pattern in the current page, or the page given by cursor.", "parameters": {"type": "object", "properties": {"pattern": {"type": "string", "description": "The exact pattern to search for"}, "cursor": {"type": "integer", "description": "Page cursor to search in", "default": -1}}, "required": ["pattern"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} + + + +# The user's message is: 帮我调研一下,目前有哪些针对search agent的benchmark?详细介绍各自的特点、使用场景、例题。 + + +## Today’s Date +2025-11-27, Thursday. +<|Assistant|>用户想调研目前针对搜索智能体(search agent)的基准测试(benchmark),要求详细介绍各自的特点、使用场景和例题。我需要提供一份全面的概述,涵盖现有的基准测试,例如WebGPT、WebShop、ALFWorld、HotpotQA、FEVER、TriviaQA、Natural Questions、SearchQA、MS MARCO、BEIR、LoTTE、MIRACL、Zero-shot IR基准等。同时,需要区分这些基准测试是针对检索任务、问答任务,还是具体针对搜索智能体的评估。此外,还应考虑那些专门为评估搜索智能体(即能够与环境交互、执行搜索、浏览网页、使用工具的智能体)而设计的基准测试,例如WebGPT、WebShop、ALFWorld、Mind2Web、WebArena、VisualWebArena、SearchBench、WebVLN、WebNav等。我需要详细介绍每个基准测试的特点、使用场景,并举例说明。为了确保信息的全面性和时效性,我将同时进行两项搜索:一项是针对“search agent benchmark”的通用搜索,以获取广泛的基准测试信息;另一项是针对“搜索智能体 基准测试”的中文搜索,以获取更符合用户语言习惯和可能更具体的中文资料。 + +<|DSML|function_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">search agent benchmark 2024 +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">搜索智能体 基准测试 +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|end▁of▁sentence|> + + +[0] search agent benchmark 2024 (Search_Results/search agent benchmark 2024) +**viewing lines [0 - 39] of 39** + +L0: +L1: URL: Search_Results/search agent benchmark 2024 +L2: # Search Results +L3: +L4: * 【0†HSCodeComp: A Realistic and Expert-level Benchmark for ...; publish_date: +L5: none†arxiv.org】 Oct 22, 2025 — To fill this gap, we introduce HSCodeComp, the +L6: first realistic, expert-level e-commerce benchmark designed to evaluate deep +L7: search agents in ... +L8: * 【1†open-compass/GTA - A Benchmark for General Tool Agents; publish_date: +L9: none†github.com】 GTA is a benchmark to evaluate the tool-use capability of LLM- +L10: based agents in real-world scenarios. It features three main aspects. +L11: * 【2†Benchmarking real-time trust scoring across five AI Agent ...; +L12: publish_date: none†cleanlab.ai】 Aug 20, 2025 — This article evaluates 5 AI Agent +L13: architectures over the BOLAA (ICLR 2024) benchmark, and assesses the effects of +L14: adding automated trust ... +L15: * 【3†10 AI agent benchmarks; publish_date: none†www.evidentlyai.com】 Jul 11, +L16: 2025 — We put together 10 AI agent benchmarks designed to assess how well +L17: different LLMs perform as agents in real-world scenarios, ... +L18: * 【4†A state-of-the-art search API purpose-built for agents; publish_date: +L19: none†parallel.ai】 Jul 31, 2025 — To evaluate real-world performance of the +L20: Parallel Search MCP Server, we created the WISER-Search benchmark which blends +L21: WISER-Fresh (queries ... +L22: * 【5†AI Agent Benchmarks are Broken; publish_date: none†medium.com】 We break +L23: down the failure modes in current AI agent benchmarks and introduce a checklist +L24: that minimizes the gamability of AI agent benchmarks. +L25: * 【6†Benchmarks and Tree Search for Multimodal LLM Web Agents; publish_date: +L26: none†dpfried.github.io】 2024, When is Tree Search Useful? ○ Dealing with +L27: destructive actions. ○ Some things on the web are very difficult to undo, e.g., +L28: ordering an item. 56. +L29: * 【7†-Bench: Benchmarking AI agents for the real-world; publish_date: +L30: none†sierra.ai】 Jun 20, 2024 — τ-bench measures an agent's ability to interact +L31: with (simulated) human users and programmatic APIs while following domain- +L32: specific policies in a consistent ... +L33: * 【8†Browser Use = state of the art Web Agent; publish_date: none†browser- +L34: use.com】 Dec 15, 2024 — Browser Use has achieved state-of-the-art performance on +L35: the WebVoyager benchmark, with an impressive 89.1% success rate across 586 +L36: diverse web tasks. +L37: * 【9†FutureSearch Benchmarks; publish_date: none†evals.futuresearch.ai】 Find +L38: the original source of a given claim. Example: From , more than 8 out of 1000 +L39: users clicked on a phishing link monthly in 2024, up 190% vs 2023. +[1] 搜索智能体 基准测试 (Search_Results/搜索智能体 基准测试) +**viewing lines [0 - 33] of 33** + +L0: +L1: URL: Search_Results/搜索智能体 基准测试 +L2: # Search Results +L3: +L4: * 【0†WideSearch:揭示AI 智能体缺失的「广度」能力; publish_date: none†zhuanlan.zhihu.com】 Aug +L5: 16, 2025 — 为系统评估智能体在该任务上的能力,论文构建了第一个专门的基准测试 WideSearch ,包含200 个源于真实世界、横跨18 +L6: 个领域的高质量任务。 通过对超过10 个 ... +L7: * 【1†GAIA: 一个严苛的智能体基准- HuggingFace; publish_date: none†www.cnblogs.com】 Jul 9, +L8: 2024 — 我们使用一个用库构建的代码智能体 在GAIA 基准上进行测试,这可以说是最困难、最全面的智能体基准测试……最终我们取得了第一名的成绩! +L9: GAIA: 一个严苛的 ... +L10: * 【2†AI搜索智能体遭遇新挑战:滑铁卢大学团队提出更公平透明的 ...; publish_date: none†www.techwalker.com】 +L11: Aug 14, 2025 — +L12: 目前评测AI搜索智能体主要依靠BrowseComp这样的基准测试,它就像一场实时的开卷考试,让AI在真实的网络环境中搜索信息来回答复杂问题。听起来很合理 ... +L13: * 【3†Agentic AI基础设施实践经验系列(六):Agent质量评估 - AWS; publish_date: +L14: none†aws.amazon.com】 Sep 19, 2025 — TAU-bench +L15: 是一个评估AI智能体在真实世界环境中可靠性的基准测试。它评估智能体是否能够在动态的多轮对话中与用户进行交互,理解需求并完成任务。T-bench ... +L16: * 【4†DeepAgent:能自己找工具的通用推理智能体 - 高瓴人工智能学院; publish_date: none†ai.ruc.edu.cn】 +L17: Nov 6, 2025 — 在八大基准测试中,DeepAgent在绝大多数任务上全面领先所有基线模型。 +L18: 开放环境优势:在最具挑战的“开放工具检索”场景下(如ToolBench),其成功率达到64%,远 ... +L19: * 【5†BrowseComp:为浏览智能体设计的简单而具挑战性的基准测试; publish_date: none†blog.csdn.net】 Sep +L20: 22, 2025 — 该基准测试由OpenAI团队开发,旨在推动更可信赖和可靠的AI代理研究。 核心特点. +L21: 挑战性问题设计:BrowseComp的问题设计遵循严格的难度标准:. 人类创建者确保 ... +L22: * 【6†什么是GAIA?; publish_date: none†huggingface.co】 什么是GAIA? GAIA +L23: 是一个用于评估AI助手在需要核心能力组合的真实世界任务上的表现的基准,这些核心能力包括推理、多模态理解、网页浏览和熟练的工具使用。 +L24: * 【7†SuperCLUE-Agent: Agent智能体中文原生任务能力测评基准; publish_date: +L25: none†www.cluebenchmarks.com】 AI +L26: Agent是一个能够自主理解、规划决策和执行复杂任务的智能体,有代表性的Agent比如AutoGPT、AgentGPT等。 +L27: 现有关于Agent能力的测评,主要是在英文场景或任务的测评。目前还 ... +L28: * 【8†Gaia2 与ARE:赋能社区的智能体评测; publish_date: none†hub.baai.ac.cn】 Sep 24, 2025 — +L29: 快速评估任意智能体:基于真实或模拟数据,测试不同规则、工具、内容和验证方式下的表现 · 测试智能体的 工具调用与编排能力:可结合本地应用或MCP 工具 · 生成 +L30: ... +L31: * 【9†许多AI 智能体评测基准并不可靠; publish_date: none†my.oschina.net】 Aug 13, 2025 — +L32: 当前AI智能体基准测试存在严重缺陷,亟需更严谨的评估框架。作者提出AI智能体基准测试核查清单(ABC),包含43个项目,旨在识别基准测试潜在陷阱,构建可靠评估 +L33: ... + + +第一个搜索结果(针对“search agent benchmark 2024”)显示了一些相关条目,例如“HSCodeComp: A Realistic and Expert-level Benchmark for ...”、“open-compass/GTA - A Benchmark for General Tool Agents”、“10 AI agent benchmarks”、“WISER-Search benchmark”、“Benchmarks and Tree Search for Multimodal LLM Web Agents”、“τ-bench”、“WebVoyager benchmark”、“FutureSearch Benchmarks”。这些可能都与搜索智能体相关。第二个搜索结果(针对“搜索智能体 基准测试”)显示了一些中文资料,例如“WideSearch”、“BrowseComp”、“GAIA”、“SuperCLUE-Agent”等。这些都可能与搜索智能体相关。为了全面了解这些基准测试,我们需要同时打开并分析其中几个最有潜力的结果。具体来说,我们将打开第一个搜索结果中的“HSCodeComp”(索引0)和“open-compass/GTA”(索引1),以及第二个搜索结果中的“WideSearch”(索引0)和“BrowseComp”(索引5)。 + +<|DSML|function_calls> +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">0 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">0 +<|DSML|parameter name="id" string="false">1 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">1 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">1 +<|DSML|parameter name="id" string="false">5 + +<|end▁of▁sentence|> + + +[2] (https://ar5iv.org/html/2510.19631v1) +**viewing lines [0 - 119] of 983** + +L0: +L1: URL: https://ar5iv.org/html/2510.19631v1 +L2: # HSCodeComp: A Realistic and Expert-level Benchmark for Deep Search Agents in +L3: Hierarchical Rule Application +L4: +L5: Yiqian Yang† Tian Lan† Qianghuai Jia∗ Li Zhu Hui Jiang Hang Zhu Longyue Wang +L6: Weihua Luo Kaifu Zhang +L7: +L8: Alibaba International Digital Commerce∗* Corresponding Author: Qianghuai Jia +L9: (qianghuai.jqh@alibaba-inc.com) +L10: †\dagger Equal Contribution: Yiqian Yang +L11: +L12: Tian Lan +L13: +L14: ###### Abstract +L15: +L16: Abstract +L17: +L18: Effective deep search agents must not only access open-domain and domain- +L19: specific knowledge but also apply complex rules—such as legal clauses, medical +L20: manuals and tariff rules. These rules often feature vague boundaries and +L21: implicit logic relationships, making precise application challenging for agents. +L22: However, this critical capability is largely overlooked by current agent +L23: benchmarks. To fill this gap, we introduce HSCodeComp, the first realistic, +L24: expert-level e-commerce benchmark designed to evaluate deep search agents in +L25: hierarchical rule application. In this task, the deep reasoning process of +L26: agents is guided by these rules to predict 10-digit Harmonized System Code +L27: (HSCode) of products with noisy but realistic descriptions. These codes, +L28: established by the World Customs Organization, are vital for global supply chain +L29: efficiency. Built from real-world data collected from large-scale e-commerce +L30: platforms, our proposed HSCodeComp comprises 632 product entries spanning +L31: diverse product categories, with these HSCodes annotated by several human +L32: experts. Extensive experimental results on several state-of-the-art LLMs, open- +L33: source, and closed-source agents reveal a huge performance gap: best agent +L34: achieves only 46.8% 10-digit accuracy, far below human experts at 95.0%. +L35: Besides, detailed analysis demonstrates the challenges of hierarchical rule +L36: application, and test-time scaling fails to improve performance further. +L37: +L38: ## 1 Introduction +L39: +L40: Deep search agents have demonstrated significant value in solving complex real- +L41: world problems, where robust external knowledge utilization constitutes a +L42: critical capability [Wu et al., 2025, Tao et al., 2025, Li et al., 2025b]. To +L43: evaluate this capability, numerous established benchmarks are proposed to assess +L44: agents in utilizing open-domain data (e.g., GAIA [Mialon et al., 2023b] and +L45: BrowseComp [Wei et al., 2025]) and domain-specific data (e.g., WebMall [Peeters +L46: et al., 2025a], FinSearchComp [Hu et al., 2025a] and MedBrowseComp [Yu et al., +L47: 2025b]). +L48: +L49: Beyond open-domain and domain-specific data, agents also need to effectively +L50: apply rules that encode human expert knowledge, particularly in scenarios like +L51: law, medical and e-commerce [Li et al., 2025a, Chen et al., 2025b, Yao et al., +L52: 2022, Chollet et al., 2025]. For instance, legal case adjudication require +L53: interpreting abstract legal provisions, and accurate e-commerce product +L54: classification in depends on tariff rules [Grainger, 2024]. Previous works have +L55: defined rule application as using specific logical rules with supporting facts +L56: to derive conclusions [Wang et al., 2024, Servantez et al., 2024]. In contrast, +L57: we define it as a core capability for deep search agents, where human-written +L58: rules are systematically applied to guide complex reasoning and decision-making +L59: [Sadowski and Chudziak, 2025]. Building on this observation, we categorize +L60: knowledge data for deep search agents into three levels (Figure 1, left), with +L61: increasing knowledge complexity: (1) Level 1: Open-domain Data - Tests +L62: understanding and deep reasoning abilities of agents on long-form web content. +L63: Established benchmarks include GAIA [Mialon et al., 2023b] and BrowseComp [Wei +L64: et al., 2025]; (2) Level 2: Structured Data - Assesses agents to precisely +L65: utilize structured data such as databases and knowledge graphs, as seen in +L66: domain-specific benchmarks like WebMall [Peeters et al., 2025a], MedBrowseComp +L67: [Chen et al., 2025b] and FinSearchComp [Hu et al., 2025a]; (3) Level 3: Rule +L68: Data - Evaluates agents to apply complex and abstract rules [Chollet et al., +L69: 2025]. This level presents two key challenges: (a) making accurate decisions +L70: when rules contain vague natural language descriptions [Sadowski and Chudziak, +L71: 2025]; and (b) reasoning about logical dependencies among rules, such as +L72: exception clauses and cross-category relationships [Guha et al., 2023]. Despite +L73: the importance of rule application in real-world scenarios, current agent +L74: benchmarks largely overlook its evaluation. +L75: +L76: To fill this gap, we introduce HSCodeComp (short for the Harmonized System Code +L77: (HSCode) Competition), the first realistic, expert-level e-commerce benchmark +L78: designed to evaluate agents in predicting complete 10-digit Harmonized System +L79: Code (HSCode) of the product, using hierarchical rules (e.g., eWTP tariff +L80: rules111https://www.ewtp.com/web/smart/hscode). HSCodes organize products +L81: through a hierarchical structure spanning over 5,000 distinct codes across +L82: multiple classification levels, representing the global standard for classifying +L83: traded international goods, established by the World Customs Organization and +L84: implemented across more than 200 countries for customs clearance and tariff +L85: determination [Grainger, 2024, Nath et al., 2025]. Built from the data of the +L86: large-scale e-commerce platforms, our proposed HSCodeComp comprises 632 +L87: carefully curated product entries, encompassing 27 unique HS chapters and 32 +L88: distinct first-level categories. These HSCodes have been rigorously annotated by +L89: multiple e-commerce domain experts, ensuring that HSCodeComp is expert-level. +L90: Accurately predicting the exact 10-digit HSCode presents significant challenges: +L91: agents must perform multi-hop hierarchical reasoning with complex tariff rules +L92: while processing noisy but realistic product descriptions that often contain +L93: abbreviations, language variations, or incomplete information. +L94: +L95: Extensive experiments on the state-of-the-art baselines, including 14 advanced +L96: foundation models, 6 advanced open-source agent systems and 3 closed-source +L97: agent systems, demonstrate that HSCode prediction task remains a substantial +L98: challenge for current AI approaches. As shown in the Figure 1 (right), even the +L99: best-performing system (SmolAgent [Roucher et al., 2025] with GPT-5) achieves +L100: only 46.8% accuracy, substantially below the 95.0% accuracy attained by human +L101: experts. Further detailed analysis reveals that existing agent systems lack +L102: critical capabilities required for this complex hierarchical rule applications. +L103: Notably, test-time scaling approach—which has proven effective in other +L104: reasoning tasks [Guo et al., 2025, Liu et al., 2025]—fail to improve performance +L105: on HSCodeComp. These observations demonstrate the challenging nature of our +L106: proposed HSCodeComp, highlighting the need for more effective designs of agent +L107: systems. To facilitate future research, we will publicly release codes and the +L108: benchmark dataset of HSCodeComp. +L109: +L110: ## 2 Related Works +L111: +L112: ### 2.1 Previous Works in HSCode Prediction +L113: +L114: Previous works treat HSCode prediction as the e-commerce text classification +L115: task [Grainger, 2024], using pre-trained BERT models [Liao et al., 2024, Shubham +L116: et al., 2022] or Large Language Models (LLMs) prompting [Hussain and Ahmed, +L117: 2023]. However, these approaches fail to leverage domain-specific knowledge, +L118: especially the rules written by human experts [Hussain and Ahmed, 2023, Judy, +L119: 2024]. Besides, existing HSCode benchmarks face two critical limitations [Judy, +Error fetching URL `https://github.com/open-compass/GTA` +Error fetching URL `https://zhuanlan.zhihu.com/p/1939316761775301093` +[3] (https://blog.csdn.net/qq_41472205/article/details/151974603) +**viewing lines [0 - 167] of 187** + +L0: +L1: URL: https://blog.csdn.net/qq_41472205/article/details/151974603 +L2: # BrowseComp:为浏览智能体设计的简单而具挑战性的基准测试 +L3: +L4: BrowseComp:AI浏览能力评估基准 +L5: +L6: 最新推荐文章于 2025-11-12 13:40:20 发布 +L7: +L8: 原创 于 2025-09-22 22:33:04 发布 · 1.3k 阅读 +L9: +L10: · 9 +L11: · 25 · +L12: CC 4.0 BY-SA版权 +L13: +L14: 版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。 +L15: +L16: ## BrowseComp:为浏览智能体设计的简单而具挑战性的基准测试 +L17: +L18: 在人工智能从基础聊天机器人向推理器和智能体发展的进程中,具备浏览互联网能力的人工智能模型正变得越来越重要。今天,我们将介绍一个名为BrowseComp的创新基准 +L19: 测试,它专门设计用于评估AI代理在复杂网络浏览任务中的能力。 +L20: +L21: ### 什么是BrowseComp? +L22: +L23: BrowseComp(全称Browsing Competition)是一个包含1,266个挑战性问题的基准测试集,专门用于衡量AI代理在互联网上持续导航、寻找难 +L24: 以找到的纠缠信息的能力。该基准测试由OpenAI团队开发,旨在推动更可信赖和可靠的AI代理研究。 +L25: +L26: #### 核心特点 +L27: +L28: 挑战性问题设计:BrowseComp的问题设计遵循严格的难度标准: +L29: +L30: - 人类创建者确保问题在10分钟内无法被人解决 +L31: - 现有模型(包括带浏览功能的ChatGPT和早期版本的OpenAI Deep Research)无法解决 +L32: - 通过5次简单Google搜索无法在结果首页找到答案 +L33: +L34: 简单易验证:尽管问题极具挑战性,但答案形式简单——都是短字符串,便于自动验证模型输出的正确性。 +L35: +L36: ### 为什么需要BrowseComp? +L37: +L38: #### 现有基准的局限性 +L39: +L40: 传统的信息检索基准(如TriviaQA、HotpotQA等)主要关注易于查找的信息,随着语言模型的进步,这些基准已经趋于饱和。而BrowseComp专注于那些需 +L41: 要浏览大量网站才能解决的"硬核"问题。 +L42: +L43: #### 模拟真实挑战 +L44: +L45: BrowseComp问题通常采用"逆向设计"方法:创建者从一个已知事实出发,构建一个搜索空间巨大但验证简单的问题。例如: +L46: +L47: “找出2018-2023年间在EMNLP会议上发表、第一作者本科毕业于达特茅斯学院、第四作者本科毕业于宾夕法尼亚大学的科学论文标题” +L48: +L49: 这类问题验证简单,但解决起来需要检查数千篇论文并调查每位作者的背景。 +L50: +L51: ### 数据集特点 +L52: +L53: #### 主题多样性 +L54: +L55: BrowseComp涵盖了广泛的主题领域(如图2所示),包括历史、科学、文化等。创建者被鼓励基于个人兴趣设计问题,这有助于提高数据质量和参与度。 +L56: +L57: #### 质量保证 +L58: +L59: 为确保答案的唯一性,创建者需要: +L60: +L61: - 对问题内容有足够了解,确信没有其他有效答案 +L62: - 如果不确定,则添加更多约束条件 +L63: - 接受其他创建者的验证反馈 +L64: +L65: ### 人类表现基准 +L66: +L67: 为了衡量BrowseComp的难度,研究人员让人类创建者尝试解决问题(不能解答自己创建的问题)。结果显示: +L68: +L69: - **70.8%**的问题在2小时搜索后人类选择放弃 +L70: - **29.2%**的问题被成功解决 +L71: - 在解决的问题中,**86.4%**的人类答案与参考答案一致 +L72: +L73: 这表明BrowseComp确实极具挑战性,即使是熟悉数据集的人类专家也难以在有限时间内解决大部分问题。 +L74: +L75: ### AI模型表现评估 +L76: +L77: #### 各模型对比 +L78: +L79: 研究人员评估了多种模型在BrowseComp上的表现: +L80: +L81: 模型 | 准确率(%) | 校准误差(%) +L82: ---|---|--- +L83: GPT-4o | 0.6 | 69 +L84: GPT-4o(带浏览) | 1.9 | 82 +L85: GPT-4.5 | 0.9 | 68 +L86: OpenAI o1 | 9.9 | 65 +L87: Deep Research | 51.5 | 91 +L88: +L89: #### 关键发现 +L90: +L91: - 基础模型表现不佳:GPT-4o和GPT-4.5准确率接近零,凸显了基准的难度 +L92: - 浏览功能带来有限提升:启用浏览功能的GPT-4o准确率略有提高,但仍很低 +L93: - 推理能力的重要性:OpenAI o1虽然没有浏览能力,但凭借更强的推理能力获得较高准确率 +L94: - 专业模型的优势:专门为持久网络浏览训练的Deep Research模型解决了约一半的问题 +L95: +L96: #### 计算资源与性能关系 +L97: +L98: 研究表明,BrowseComp性能随测试时计算资源的增加而平滑提升(如图1所示)。这与智能体模型的特性一致——更多计算资源允许模型浏览更多网站,从而提高找到正确 +L99: 答案的机会。 +L100: +L101: ### 进阶策略分析 +L102: +L103: #### 聚合策略的效果 +L104: +L105: 通过让模型多次尝试同一问题并采用投票策略,可以显著提升性能: +L106: +L107: - 多数投票:选择样本中最常见的答案 +L108: - 加权投票:根据模型置信度加权投票 +L109: - 最佳选择:选择置信度最高的答案 +L110: +L111: 这些方法将Deep Research的性能提升了15-25%,表明模型通常能够识别自己的正确答案。 +L112: +L113: #### 任务难度分布 +L114: +L115: 分析显示,BrowseComp中的任务难度分布广泛: +L116: +L117: - 16%的任务被Deep Research完美解决(100%通过率) +L118: - 14%的任务完全失败(0%通过率) +L119: - 其余任务处于中间难度水平 +L120: +L121: ### BrowseComp的意义与局限性 +L122: +L123: #### 作为评估工具的价值 +L124: +L125: BrowseComp可被视为浏览智能体的"编程竞赛"——虽然不全面,但对核心浏览能力提供了有用的衡量。正如在编程竞赛中表现出色的模型很可能在其他编码任务中表现良 +L126: 好,能够解决BrowseComp的模型在定位难以查找信息方面应该具有很强的能力。 +L127: +L128: #### 已知局限性 +L129: +L130: - 不涉及真实用户查询分布:如生成长回答或解决查询歧义等挑战 +L131: - 答案唯一性假设:虽然尽力确保,但不能完全保证没有其他有效答案 +L132: - 单模态限制:当前仅关注文本信息检索,未来可扩展至图像、视频等模态 +L133: +L134: ### 结论与展望 +L135: +L136: BrowseComp为评估AI代理的网络浏览能力提供了一个具有挑战性且易于使用的基准。它强调的持久性、创造性和战略性浏览能力,正是未来AI智能体在实际应用中所需 +L137: 的核心技能。 +L138: +L139: 随着更多研究团队在BrowseComp上测试他们的模型,我们期待看到浏览智能体技术的快速进步,最终使AI能够更可靠地帮助我们获取那些深藏在互联网角落的宝贵信息。 +L140: +L141: 访问地址:BrowseComp已在GitHub上开源,地址为 https://github.com/openai/simple-evals +L142: +L143: 添加红包 +L144: +L145: 请填写红包祝福语或标题 +L146: +L147: 个 +L148: +L149: 红包个数最小为10个 +L150: +L151: 元 +L152: +L153: 红包金额最低5元 +L154: +L155: 当前余额3.43元 前往充值 > +L156: +L157: 需支付:10.00元 +L158: +L159: 成就一亿技术人! +L160: +L161: hope_wisdom +L162: 发出的红包 +L163: 实付元 +L164: +L165: 点击重新获取 +L166: +L167: 扫码支付 + + +打开的HSCodeComp页面详细介绍了这个基准测试,它是一个用于评估深度搜索智能体在分层规则应用方面能力的基准测试,特点包括:基于真实世界电商数据、任务复杂、需要多跳层次推理、使用嘈杂的产品描述,并且标注由专家完成。使用场景是评估智能体在规则应用方面的能力,例如关税代码分类。例题是预测产品的10位HS编码。打开的BrowseComp页面详细介绍了这个基准测试,它是一个用于评估浏览智能体(即能够在互联网上导航、搜索信息的智能体)的基准测试。其特点包括:包含1266个具有挑战性的问题,需要持续导航和搜索难以查找的信息;答案形式简单,便于自动验证;问题设计确保人类在10分钟内无法解决,且现有模型也无法解决;涵盖了广泛的主题领域。使用场景是评估智能体在复杂网络浏览任务中的能力。例题是:“找出2018-2023年间在EMNLP会议上发表、第一作者本科毕业于达特茅斯学院、第四作者本科毕业于宾夕法尼亚大学的科学论文标题”。另外两个打开的页面(open-compass/GTA和WideSearch)出现了错误,可能链接已失效或需要其他访问方式。为了进一步获取其他相关基准测试的信息,我们需要同时进行多项搜索。考虑到用户可能对搜索智能体基准测试有广泛需求,我们将同时搜索“WebVoyager benchmark”、“WebArena benchmark”、“GAIA benchmark”、“ToolBench benchmark”、“HotpotQA benchmark”、“FEVER benchmark”、“TriviaQA benchmark”、“Natural Questions benchmark”、“MS MARCO benchmark”、“BEIR benchmark”、“MIRACL benchmark”、“Zero-shot IR benchmark”、“WebGPT benchmark”、“WebShop benchmark”、“ALFWorld benchmark”、“Mind2Web benchmark”、“VisualWebArena benchmark”、“SearchBench benchmark”、“WebVLN benchmark”、“WebNav benchmark”。这些搜索将帮助我们全面了解现有基准测试的概况。 + +<|DSML|function_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">WebVoyager benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">WebArena benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">GAIA benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">ToolBench benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">HotpotQA benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">FEVER benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">TriviaQA benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">Natural Questions benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">MS MARCO benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">BEIR benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">MIRACL benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">Zero-shot IR benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">WebGPT benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">WebShop benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">ALFWorld benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">Mind2Web benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">VisualWebArena benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">SearchBench benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">WebVLN benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">WebNav benchmark +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|end▁of▁sentence|> + + +[4] WebVoyager benchmark (Search_Results/WebVoyager benchmark) +**viewing lines [0 - 40] of 40** + +L0: +L1: URL: Search_Results/WebVoyager benchmark +L2: # Search Results +L3: +L4: * 【0†WebVoyager Benchmark; publish_date: none†www.trykura.com】 Performance +L5: measured across various websites in WebVoyager shows Kura consistently +L6: outperforming other agents, with particularly strong results on e-commerce ... +L7: * 【1†WebVoyager: Building an End-to-End Web Agent with ...; publish_date: +L8: none†github.com】 WebVoyager is an innovative Large Multimodal Model (LMM) +L9: powered web agent that can complete user instructions end-to-end by interacting +L10: with real-world ... +L11: * 【2†AI Browser Agent Leaderboard | Steel.dev; publish_date: +L12: none†leaderboard.steel.dev】 See how various AI browser agents stack up based on +L13: their accuracy in completing web-based tasks on the WebVoyager benchmark. +L14: * 【3†[2401.13919] WebVoyager: Building an End-to-End Web ...; publish_date: +L15: none†arxiv.org】 by H He · 2024 · Cited by 282 — We show that WebVoyager achieves +L16: a 59.1% task success rate on our benchmark, significantly surpassing the +L17: performance of both GPT-4 (All ... +L18: * 【4†Our Agent-E SOTA Results on the WebVoyager Benchmark; publish_date: +L19: none†www.emergence.ai】 Jul 11, 2024 — WebVoyager is a benchmark that tests an +L20: agent's capabilities for navigation on dynamic live websites. It is more +L21: representative than WebArena [4] ... +L22: * 【5†Browser Use = state of the art Web Agent; publish_date: none†browser- +L23: use.com】 Dec 15, 2024 — Browser Use has achieved state-of-the-art performance on +L24: the WebVoyager benchmark, with an impressive 89.1% success rate across 586 +L25: diverse web tasks. +L26: * 【6†Magnitude achieves SOTA 94% on WebVoyager benchmark; publish_date: +L27: none†github.com】 Magnitude achieves state-of-the-art performance with 93.9% +L28: success rate on WebVoyager, beating all other browser agents. +L29: * 【7†WebVoyager: Autonomous Web Agent Benchmark; publish_date: +L30: none†www.emergentmind.com】 3 days ago — WebVoyager Benchmark is a comprehensive +L31: evaluation suite for autonomous web agents, featuring 643 tasks across 15 +L32: popular websites. +L33: * 【8†WebVoyager Benchmark Results; publish_date: none†www.browserable.ai】 +L34: Browserable has achieved 90.4% on the WebVoyager benchmark. This is best-in- +L35: class performance across all web agents. This was done across 567 web tasks +L36: which ... +L37: * 【9†89% achieved on WebVoyager using Anchor + Browser Use; publish_date: +L38: none†www.reddit.com】 Thanks to the amazing work from the browser-use open-source +L39: community and the built-in support from Anchor Browser, we've hit an 89% score +L40: on WebVoyager. +[5] WebArena benchmark (Search_Results/WebArena benchmark) +**viewing lines [0 - 42] of 42** + +L0: +L1: URL: Search_Results/WebArena benchmark +L2: # Search Results +L3: +L4: * 【0†WebArena: A Realistic Web Environment for Building ...; publish_date: +L5: none†webarena.dev】 Our benchmark is implemented in our fully interactable +L6: highly-realistic WebArena environment. It features diverse tasks human may +L7: encounter in their daily ... +L8: * 【1†[2307.13854] WebArena: A Realistic Web Environment for ...; publish_date: +L9: none†arxiv.org】 by S Zhou · 2023 · Cited by 637 — Building upon our +L10: environment, we release a set of benchmark tasks focusing on evaluating the +L11: functional correctness of task completions. +L12: * 【2†WebArena: A Realistic Web Environment for Building ...; publish_date: +L13: none†www.cmu.edu】 WebArena introduces a benchmark on interpreting high-level +L14: realistic natural language command to concrete web-based interactions. We +L15: provide annotated programs ... +L16: * 【3†GitHub - web-arena-x/webarena: Code repo for ...; publish_date: +L17: none†github.com】 [12/20/2024] Check out our new benchmark on even more +L18: consequential tasks, including terminal use and coding, TheAgentCompany. +L19: [12/21/2023] We release the ... +L20: * 【4†WebArena Benchmark and the State of Agentic AI; publish_date: +L21: none†medium.com】 In short, WebArena established a new standard for realism and +L22: complexity in web agent evaluation, forcing AI agents to operate in dynamic, +L23: high- ... +L24: * 【5†WebArena: A Realistic Web Environment for Building ...; publish_date: +L25: none†huggingface.co】 Jul 25, 2023 — WebArena, a realistic and reproducible +L26: environment, evaluates the performance of autonomous agents performing complex +L27: tasks on websites using ... +L28: * 【6†WebArena Benchmark: Evaluating Web Agents; publish_date: +L29: none†www.emergentmind.com】 Jun 30, 2025 — WebArena Benchmark is a self-contained +L30: suite that evaluates autonomous agents on realistic, multi-step web tasks using +L31: natural language ... +L32: * 【7†VisualWebArena is a benchmark for multimodal agents.; publish_date: +L33: none†github.com】 VisualWebArena is a realistic and diverse benchmark for +L34: evaluating multimodal autonomous language agents. It comprises of a set of +L35: diverse and complex web-based ... +L36: * 【8†WebDev Arena Leaderboard - LMArena; publish_date: none†web.lmarena.ai】 +L37: WebDev Arena is a real-time AI coding competition where models go head-to-head +L38: in web development challenges, developed by LMArena. +L39: * 【9†WebArena: A Realistic Web Environment for Building ...; publish_date: +L40: none†arxiv.org】 Apr 16, 2024 — We use this benchmark to evaluate several agents +L41: that can follow NL command and perform web-based tasks (§4). These agents are +L42: implemented in a ... +[6] GAIA benchmark (Search_Results/GAIA benchmark) +**viewing lines [0 - 41] of 41** + +L0: +L1: URL: Search_Results/GAIA benchmark +L2: # Search Results +L3: +L4: * 【0†GAIA Leaderboard - a Hugging Face Space by ...; publish_date: +L5: none†huggingface.co】 GAIA is a benchmark which aims at evaluating next- +L6: generation LLMs (LLMs with augmented capabilities due to added tooling, +L7: efficient prompting, access to search ... +L8: * 【1†[2311.12983] GAIA: a benchmark for General AI Assistants; publish_date: +L9: none†arxiv.org】 by G Mialon · 2023 · Cited by 367 — GAIA proposes real-world +L10: questions that require a set of fundamental abilities such as reasoning, multi- +L11: modality handling, web browsing, and generally tool-use ... +L12: * 【2†GAIA benchmark; publish_date: none†huggingface.co】 This is the +L13: organisation page for all things related to GAIA, a benchmark for General AI +L14: Assistants. You can find all the information and links on the GAIA ... +L15: * 【3†GAIA: A Benchmark for General AI Assistants; publish_date: +L16: none†ukgovernmentbeis.github.io】 This is an Inspect AI implementation of the +L17: GAIA (General AI Assistants) benchmark, consisting of 450 questions testing tool +L18: use on realistic assistant tasks. +L19: * 【4†GAIA: a benchmark for general AI assistants | Research; publish_date: +L20: none†ai.meta.com】 May 6, 2024 — GAIA proposes real-world questions that require +L21: a set of fundamental abilities such as reasoning, multi-modality handling, web +L22: browsing, and generally tool-use ... +L23: * 【5†HAL: GAIA Leaderboard; publish_date: none†hal.cs.princeton.edu】 GAIA is a +L24: benchmark for General AI Assistants that requires a set of fundamental +L25: abilities such as reasoning, multi-modality handling, web browsing, and tool- +L26: ... +L27: * 【6†GAIA: The LLM Agent Benchmark Everyone's Talking About; publish_date: +L28: none†towardsdatascience.com】 May 29, 2025 — GAIA stands for General AI +L29: Assistants benchmark [1]. This benchmark was introduced to specifically evaluate +L30: LLM agents on their ability to act as general- ... +L31: * 【7†GAIA: a benchmark for General AI Assistants; publish_date: +L32: none†openreview.net】 by G Mialon · Cited by 367 — GAIA proposes real-world +L33: questions that require a set of fundamental abilities such as reasoning, multi- +L34: modality handling, web browsing, and generally tool-use ... +L35: * 【8†Rethinking AI Evaluation: Introducing the GAIA Benchmark; publish_date: +L36: none†medium.com】 The authors introduce GAIA, a benchmark designed to assess the +L37: robustness of AI systems across a variety of practical tasks. +L38: * 【9†H2O.ai Tops the General AI Assistant (GAIA) Test; publish_date: +L39: none†h2o.ai】 Mar 17, 2025 — Our h2oGPTe Agent has once again claimed the #1 spot +L40: on the prestigious GAIA (General AI Assistants) benchmark with an impressive +L41: 75% accuracy rate. +[7] ToolBench benchmark (Search_Results/ToolBench benchmark) +**viewing lines [0 - 40] of 40** + +L0: +L1: URL: Search_Results/ToolBench benchmark +L2: # Search Results +L3: +L4: * 【0†ToolBench, an evaluation suite for LLM tool manipulation ...; +L5: publish_date: none†github.com】 The ToolBench is a benchmark consisting of +L6: diverse software tools for real-world tasks. We also provide easy-to-use +L7: infrastructure in this repository. +L8: * 【1†OpenBMB/ToolBench; publish_date: none†github.com】 [2023/7/27] New version +L9: ToolBench is released. ✨Here is an overview of the dataset construction, +L10: training, and evaluation. ✨✨Features:. +L11: * 【2†Towards Stable Large-Scale Benchmarking on Tool ...; publish_date: +L12: none†arxiv.org】 by Z Guo · 2024 · Cited by 100 — We introduce StableToolBench, a +L13: benchmark evolving from ToolBench, proposing a virtual API server and stable +L14: evaluation system. +L15: * 【3†StableToolBench - Zhicheng Guo; publish_date: none†zhichengg.github.io】 +L16: We introduce StableToolBench, a benchmark evolving from ToolBench, proposing a +L17: virtual API server and stable evaluation system. +L18: * 【4†ToolBench | EvalScope - Read the Docs; publish_date: +L19: none†evalscope.readthedocs.io】 We evaluate the effectiveness of the ToolBench +L20: benchmark: ToolBench (Qin et al., 2023b). The task involves integrating API +L21: calls to complete tasks. +L22: * 【5†Towards Stable Large-Scale Benchmarking on Tool ...; publish_date: +L23: none†aclanthology.org】 by Z Guo · 2024 · Cited by 100 — We introduce +L24: StableToolBench, a benchmark evolving from ToolBench, proposing a virtual API +L25: server and stable evaluation system. +L26: * 【6†ML-Tool-Bench: Tool-Augmented Planning for ML Tasks; publish_date: +L27: none†openreview.net】 Sep 18, 2025 — In this work, we introduce a comprehensive +L28: benchmark for evaluating tool-augmented ML agents using a curated set of 61 +L29: specialized tools and 15 ... +L30: * 【7†-Bench: Benchmarking AI agents for the real-world; publish_date: +L31: none†sierra.ai】 Jun 20, 2024 — τ-bench measures an agent's ability to interact +L32: with (simulated) human users and programmatic APIs while following domain- +L33: specific policies in a consistent ... +L34: * 【8†ToolEval Leaderboard; publish_date: none†openbmb.github.io】 ToolEval is +L35: an automatic evaluator build for tool learning which incorporates two evaluation +L36: metrics, Pass Rate and Win Rate(Preference). +L37: * 【9†What is the best benchmark dataset for multi-step tool-use?; +L38: publish_date: none†www.reddit.com】 I'm a newbie trying to evaluate the +L39: performance of different prompts strategies for multi-step tool-using, wondering +L40: what is the recommended benchmark dataset ... +[8] HotpotQA benchmark (Search_Results/HotpotQA benchmark) +**viewing lines [0 - 39] of 39** + +L0: +L1: URL: Search_Results/HotpotQA benchmark +L2: # Search Results +L3: +L4: * 【0†HotpotQA Homepage; publish_date: none†hotpotqa.github.io】 HotpotQA is a +L5: question answering dataset featuring natural, multi-hop questions, with strong +L6: supervision for supporting facts to enable more explainable ...See more +L7: * 【1†HotpotQA: A Dataset for Diverse, Explainable Multi-hop ...; publish_date: +L8: none†arxiv.org】 by Z Yang · 2018 · Cited by 3834 — HotpotQA is a dataset with +L9: 113k Wikipedia-based question-answer pairs requiring multi-document reasoning, +L10: diverse questions, sentence-level ... +L11: * 【2†hotpotqa/hotpot_qa · Datasets at Hugging Face; publish_date: +L12: none†huggingface.co】 HotpotQA is a new dataset with 113k Wikipedia-based +L13: question-answer pairs with four key features: (1) the questions require finding +L14: and reasoning over multiple ...See more +L15: * 【3†Why You Should Stop Using HotpotQA for AI Agents ...; publish_date: +L16: none†qipeng.me】 Jul 1, 2025 — HotpotQA pioneered a class of AI tasks that +L17: requires the AI system to autonomously perform multiple steps of reasoning in an +L18: open-domain setting.See more +L19: * 【4†hotpotqa/hotpot; publish_date: none†github.com】 A dataset for diverse, +L20: explainable multi-hop question answering. This repository contains the baseline +L21: model code, as well as the entire pipeline of running ...See more +L22: * 【5†HotpotQA: Multi-Hop QA Benchmark; publish_date: +L23: none†www.emergentmind.com】 Sep 10, 2025 — HotpotQA is a large-scale multi-hop +L24: question answering benchmark featuring 112,779 Wikipedia-based Q&A pairs with +L25: detailed, sentence-level ...See more +L26: * 【6†HotpotQA Dataset | Papers With Code; publish_date: +L27: none†paperswithcode.com】 HotpotQA is a question answering dataset collected on +L28: the English Wikipedia, containing about 113K crowd-sourced questions.See more +L29: * 【7†HotpotQA: A Dataset for Diverse, Explainable Multi-hop ...; publish_date: +L30: none†aclanthology.org】 by Z Yang · 2018 · Cited by 3834 — HotpotQA is a dataset +L31: with 113k Wikipedia-based question-answer pairs requiring multi-document +L32: reasoning, diverse questions, sentence-level facts, and factoid ... +L33: * 【8†Benchmark BM25S: HotpotQA; publish_date: none†www.kaggle.com】 Explore and +L34: run machine learning code with Kaggle Notebooks | Using data from No attached +L35: data sources. +L36: * 【9†mteb/hotpotqa · Datasets at Hugging Face; publish_date: +L37: none†huggingface.co】 HotpotQA is a question answering dataset featuring natural, +L38: multi-hop questions, with strong supervision for supporting facts to enable +L39: more explainable ...See more +[9] FEVER benchmark (Search_Results/FEVER benchmark) +**viewing lines [0 - 40] of 40** + +L0: +L1: URL: Search_Results/FEVER benchmark +L2: # Search Results +L3: +L4: * 【0†Fever.ai; publish_date: none†fever.ai】 We are pleased to announce that +L5: FEVER9 will be co-located with EACL 2026. In this year's workshop, we will +L6: introduce a new shared task focused on automated fact ... +L7: * 【1†a Large-scale Dataset for Fact Extraction and VERification; publish_date: +L8: none†aclanthology.org】 by J Thorne · 2018 · Cited by 2315 — In this paper we +L9: introduce a new publicly available dataset for verification against textual +L10: sources, FEVER: Fact Extraction. +L11: * 【2†awslabs/fever: FEVER (Fact Extraction and VERification) ...; +L12: publish_date: none†github.com】 In this paper we introduce a new publicly +L13: available dataset for verification against textual sources, FEVER: Fact +L14: Extraction and VERification. +L15: * 【3†FEVER: Fact Extraction and VERification; publish_date: +L16: none†www.amazon.science】 The best accuracy we achieve on labeling a claim +L17: accompanied by the correct evidence is 31.87%, while if we ignore the evidence +L18: we achieve 50.91%. Thus we ... +L19: * 【4†FEVER Dataset; publish_date: none†fever.ai】 FEVER (Fact Extraction and +L20: VERification) consists of 185,445 claims generated by altering sentences +L21: extracted from Wikipedia and subsequently verified ... +L22: * 【5†mteb/fever · Datasets at Hugging Face; publish_date: none†huggingface.co】 +L23: FEVER. An MTEB dataset. Massive Text Embedding Benchmark. FEVER (Fact +L24: Extraction and VERification) consists of 185,445 claims generated by altering +L25: sentences ... +L26: * 【6†FEVEROUS: Fact Extraction and VERification Over ...; publish_date: +L27: none†datasets-benchmarks-proceedings.neurips.cc】 by R Aly · Cited by 359 — In +L28: this paper we introduce a novel dataset and benchmark, Fact Extraction and +L29: VERification Over. Unstructured and Structured information (FEVEROUS), which ... +L30: * 【7†a large-scale dataset for Fact Extraction and VERification; publish_date: +L31: none†arxiv.org】 by J Thorne · 2018 · Cited by 2315 — In this paper we introduce +L32: a new publicly available dataset for verification against textual sources, +L33: FEVER: Fact Extraction and VERification. +L34: * 【8†FEVER Resources; publish_date: none†fever.ai】 2018 FEVER: a large-scale +L35: dataset for Fact Extraction and VERification .bib James Thorne, Andreas Vlachos, +L36: Christos Christodoulopoulos, Arpit Mittal +L37: * 【9†a Large-scale Dataset for Fact Extraction and VERification; publish_date: +L38: none†www.semanticscholar.org】 This paper introduces a new publicly available +L39: dataset for verification against textual sources, FEVER, which consists of +L40: 185,445 claims generated by ... +[10] TriviaQA benchmark (Search_Results/TriviaQA benchmark) +**viewing lines [0 - 35] of 35** + +L0: +L1: URL: Search_Results/TriviaQA benchmark +L2: # Search Results +L3: +L4: * 【0†TriviaQA; publish_date: none†nlp.cs.washington.edu】 TriviaQA is a reading +L5: comprehension dataset containing over 650K question-answer-evidence triples. +L6: TriviaQA includes 95K question-answer pairs authored ... +L7: * 【1†TriviaQA: A Large Scale Distantly Supervised Challenge ...; publish_date: +L8: none†aclanthology.org】 by M Joshi · 2017 · Cited by 3451 — We present TriviaQA, +L9: a challenging reading comprehension dataset containing over 650K question- +L10: answer-evidence triples. TriviaQA includes 95K question ... +L11: * 【2†mandarjoshi/trivia_qa · Datasets at Hugging Face; publish_date: +L12: none†huggingface.co】 TriviaqQA is a reading comprehension dataset containing +L13: over 650K question-answer-evidence triples. TriviaqQA includes 95K question- +L14: answer pairs authored by ... +L15: * 【3†[1705.03551] TriviaQA: A Large Scale Distantly Supervised ...; +L16: publish_date: none†arxiv.org】 by M Joshi · 2017 · Cited by 3451 — We present +L17: TriviaQA, a challenging reading comprehension dataset containing over 650K +L18: question-answer-evidence triples. +L19: * 【4†TriviaQA; publish_date: none†epoch.ai】 An open-domain question answering +L20: benchmark with challenging trivia questions paired with evidence documents. +L21: * 【5†TriviaQA Leaderboard; publish_date: none†llm-stats.com】 What is the +L22: TriviaQA benchmark? A large-scale reading comprehension dataset containing over +L23: 650K question-answer-evidence triples. TriviaQA includes 95K ... +L24: * 【6†Code for the TriviaQA reading comprehension dataset; publish_date: +L25: none†github.com】 A large scale distantly supervised challenge dataset for +L26: reading comprehension. In Association for Computational Linguistics (ACL) 2017, +L27: Vancouver, Canada. +L28: * 【7†TriviaQA - Model Benchmarks - The Regularizer; publish_date: +L29: none†www.theregularizer.com】 May 4, 2025 — Compare the performance of different +L30: AI models across standardized benchmarks. Higher scores generally indicate +L31: better performance, but context ... +L32: * 【8†TriviaQA: A Large Scale Distantly Supervised Challenge ...; publish_date: +L33: none†www.cs.utexas.edu】 by M Joshi · Cited by 3445 — We present TriviaQA, a +L34: challenging reading comprehension dataset contain- ing over 650K question- +L35: answer-evidence triples. TriviaQA includes 95K question-. +[11] Natural Questions benchmark (Search_Results/Natural Questions benchmark) +**viewing lines [0 - 39] of 39** + +L0: +L1: URL: Search_Results/Natural Questions benchmark +L2: # Search Results +L3: +L4: * 【0†Natural Questions: a Benchmark for Question Answering ...; publish_date: +L5: none†research.google】 by T Kwiatkowski · Cited by 4339 — We present the Natural +L6: Questions corpus, a question answering dataset. Questions consist of real +L7: anonymized, aggregated queries issued to the Google search ... +L8: * 【1†Natural Questions: A Benchmark for Question Answering ...; publish_date: +L9: none†aclanthology.org】 by T Kwiatkowski · Cited by 4308 — Abstract. We present +L10: the Natural Questions corpus, a question answering data set. Questions consist +L11: of real anonymized, aggregated queries issued. +L12: * 【2†Google's Natural Questions; publish_date: none†ai.google.com】 Natural +L13: Questions. A Benchmark for Question Answering Research. View examples · Download +L14: dataset. Open Domain Question Answering. A core goal in artificial ... +L15: * 【3†google-research-datasets/natural-questions; publish_date: +L16: none†github.com】 Natural Questions (NQ) contains real user questions issued to +L17: Google search, and answers found from Wikipedia by annotators. NQ is designed +L18: for the training and ... +L19: * 【4†Natural Questions: A Benchmark for Question Answering ...; publish_date: +L20: none†direct.mit.edu】 Aug 1, 2019 — We present the Natural Questions corpus, a +L21: question answering data set. Questions consist of real anonymized, aggregated +L22: queries issued to the Google search ... +L23: * 【5†ir_datasets : Natural Questions; publish_date: none†ir-datasets.com】 +L24: Google Natural Questions is a Q&A dataset containing long, short, and Yes/No +L25: answers from Wikipedia. ir_datasets frames this around an ad-hoc ranking setting +L26: ... +L27: * 【6†sentence-transformers/natural-questions · Datasets at ...; publish_date: +L28: none†huggingface.co】 This dataset is a collection of question-answer pairs from +L29: the Natural Questions dataset. See Natural Questions for additional information. +L30: * 【7†Google's Natural Questions; publish_date: none†ai.google.com】 Natural +L31: Questions contains 307K training examples, 8K examples for development, and a +L32: further 8K examples for testing. In the paper, we demonstrate a human ... +L33: * 【8†A Benchmark for Question Answering Research; publish_date: +L34: none†www.researchgate.net】 Jul 27, 2025 — We present the Natural Questions +L35: corpus, a question answering data set. Questions consist of real anonymized, +L36: aggregated queries issued to the Google search ... +L37: * 【9†natural-questions; publish_date: none†docs.unity.rc.umass.edu】 Sep 4, +L38: 2025 — “Natural questions: a benchmark for question answering research.” +L39: Transactions of the Association for Computational Linguistics 7 (2019): ... +[12] MS MARCO benchmark (Search_Results/MS MARCO benchmark) +**viewing lines [0 - 41] of 41** + +L0: +L1: URL: Search_Results/MS MARCO benchmark +L2: # Search Results +L3: +L4: * 【0†MS MARCO - Microsoft Open Source; publish_date: none†microsoft.github.io】 +L5: The MS MARCO datasets are intended for non-commercial research purposes only to +L6: promote advancement in the field of artificial intelligence and related areas, +L7: ... +L8: * 【1†microsoft/ms_marco · Datasets at Hugging Face; publish_date: +L9: none†huggingface.co】 Starting with a paper released at NIPS 2016, MS MARCO is a +L10: collection of datasets focused on deep learning in search. The first dataset was +L11: a question ... +L12: * 【2†Benchmarking Ranking Models in the Large-Data Regime; publish_date: +L13: none†arxiv.org】 by N Craswell · 2021 · Cited by 89 — This paper uses the MS +L14: MARCO and TREC Deep Learning Track as our case study, comparing it to the case +L15: of TREC ad hoc ranking in the 1990s. +L16: * 【3†Benchmarking Ranking Models in the Large-Data Regime; publish_date: +L17: none†www.microsoft.com】 This paper uses the MS MARCO and TREC Deep Learning +L18: Track as our case study, comparing it to the case of TREC ad hoc ranking in the +L19: 1990s. We show how the ... +L20: * 【4†Datasets for Document and Passage Ranking Leadboards; publish_date: +L21: none†microsoft.github.io】 The MS MARCO document and passage ranking leaderboards +L22: complements the TREC Deep Learning Track by providing on-going evaluation of +L23: submissions using pre- ... +L24: * 【5†MS MARCO: Benchmarking Ranking Models in the Large- ...; publish_date: +L25: none†dl.acm.org】 Jul 11, 2021 — This paper uses the MS MARCO and TREC Deep +L26: Learning Track as our case study, comparing it to the case of TREC ad hoc +L27: ranking in the 1990s. +L28: * 【6†ir_datasets : MSMARCO (passage); publish_date: none†ir-datasets.com】 A +L29: passage ranking benchmark with a collection of 8.8 million passages and question +L30: queries. Most relevance judgments are shallow. +L31: * 【7†MS MARCO; publish_date: none†sbert.net】 MS MARCO Passage Ranking is a +L32: large dataset to train models for information retrieval. It consists of about +L33: 500k real search queries from Bing search engine ... +L34: * 【8†MS MARCO: A Human Generated MAchine Reading ...; publish_date: +L35: none†arxiv.org】 by P Bajaj · 2016 · Cited by 1151 — We introduce a large scale +L36: MAchine Reading COmprehension dataset, which we name MS MARCO. The dataset +L37: comprises of 1,010,916 anonymized ... +L38: * 【9†MS MARCO Web Search: A Large-scale Information-rich ...; publish_date: +L39: none†www.microsoft.com】 May 13, 2024 — MS MARCO Web Search offers a retrieval +L40: benchmark with three web retrieval challenge tasks that demands innovations in +L41: both machine learning and ... +[13] BEIR benchmark (Search_Results/BEIR benchmark) +**viewing lines [0 - 37] of 37** + +L0: +L1: URL: Search_Results/BEIR benchmark +L2: # Search Results +L3: +L4: * 【0†详细介绍文本检索基准BEIR: A Heterogeneous Benchmark ...; publish_date: +L5: none†blog.csdn.net】 2023年1月1日 — +L6: BEIR旨在为所有不同的检索任务提供一个一站式的零样本评估基准。为了构建一个全面的评估基准,选择方法对于收集具有理想属性的任务和数据集至关重要。对于 ... +L7: * 【1†beir-cellar/beir; publish_date: none†github.com】 BEIR is a heterogeneous +L8: benchmark containing diverse IR tasks. It also provides a common and easy +L9: framework for evaluation of your NLP-based retrieval models ... +L10: * 【2†BEIR: A Heterogenous Benchmark for Zero-shot Evaluation ...; +L11: publish_date: none†arxiv.org】 作者:N Thakur · 2021 · 被引用次数:1480 — We introduce +L12: Benchmarking-IR (BEIR), a robust and heterogeneous evaluation benchmark for +L13: information retrieval. +L14: * 【3†BeIR; publish_date: none†huggingface.co】 BEIR (Benchmarking IR) consists +L15: of a homogenous benchmark for diverse sentence or passage level IR tasks. It +L16: provides a common and easy framework for the cross ... +L17: * 【4†论文分享:BEIR A Heterogeneous Benchmark for Zero-shot ...; publish_date: +L18: none†zhuanlan.zhihu.com】 2022年10月3日 — 分享论文,夹带个人理解的分享,建议结合原论文看。 1 研究背景. +L19: 本论文主要关注的领域是query-document检索(下文简称qd检索),即根据query去文档库里 ... +L20: * 【5†Benchmarking IR Information Retrieval (BEIR); publish_date: +L21: none†zilliz.com】 BEIR is a benchmark designed for evaluating the versatility and +L22: robustness of information retrieval models. It features 18 diverse datasets +L23: from domains like ... +L24: * 【6†BEIR (Benchmarking IR) - OpenDataLab; publish_date: none†opendatalab.com】 +L25: 简介-Introduction. BEIR(Benchmarking +L26: IR)是包含不同信息检索(IR)任务的异构基准。通过BEIR,可以系统地研究多种神经检索方法的零样本泛化能力。 +L27: * 【7†What is the BEIR benchmark and how is it used?; publish_date: +L28: none†milvus.io】 The BEIR (Benchmarking Information Retrieval) benchmark is a +L29: standardized framework designed to evaluate the effectiveness of search and +L30: retrieval algorithms. +L31: * 【8†BEIR Benchmark数据集卡片; publish_date: none†www.atyun.com】 BEIR +L32: Benchmark数据集卡片. 数据集简介. BEIR是一个异构评测基准,由18个多样化的数据集构建而成,代表了9个信息检索任务:. 事实查证: FEVER , +L33: Climate-FEVER , SciFact ... +L34: * 【9†Evaluating search relevance part 1 - The BEIR benchmark; publish_date: +L35: none†www.elastic.co】 2024年7月16日 — Learn to evaluate your search system in the +L36: context of better understanding the BEIR benchmark, with tips & techniques to +L37: improve your ... +[14] MIRACL benchmark (Search_Results/MIRACL benchmark) +**viewing lines [0 - 41] of 41** + +L0: +L1: URL: Search_Results/MIRACL benchmark +L2: # Search Results +L3: +L4: * 【0†MIRACL | Multilingual Information Retrieval Across a ...; publish_date: +L5: none†project-miracl.github.io】 MIRACL (Multilingual Information Retrieval Across +L6: a Continuum of Languages) is an WSDM 2023 Cup challenge that focuses on search +L7: across 18 different ... +L8: * 【1†project-miracl/miracl: A large-scale multilingual dataset for ...; +L9: publish_date: none†github.com】 A large-scale multilingual dataset for +L10: Information Retrieval. Thorough human-annotations across 18 diverse languages. +L11: * 【2†A Large, multilingual, visual document retrieval benchmark; publish_date: +L12: none†arxiv.org】 by R Osmulski · 2025 · Cited by 2 — MIRACL-VISION is a +L13: challenging, representative, multilingual evaluation benchmark for visual +L14: retrieval pipelines and will help the community build robust ... +L15: * 【3†miracl/miracl · Datasets at Hugging Face; publish_date: +L16: none†huggingface.co】 MIRACL (Multilingual Information Retrieval Across a +L17: Continuum of Languages) is a multilingual retrieval dataset that focuses on +L18: search across 18 different ... +L19: * 【4†MIRACL: A Multilingual Retrieval Dataset Covering 18 ...; publish_date: +L20: none†direct.mit.edu】 by X Zhang · 2023 · Cited by 131 — MIRACL is a multilingual +L21: dataset for ad hoc retrieval across 18 languages that collectively encompass +L22: over three billion native speakers around the world. +L23: * 【5†(PDF) MIRACL-VISION: A Large, multilingual, visual ...; publish_date: +L24: none†www.researchgate.net】 May 23, 2025 — MIRACL-VISION covers 18 languages, and +L25: is an extension of the MIRACL dataset, a popular benchmark to evaluate text- +L26: based multilingual retrieval ... +L27: * 【6†A Large, multilingual, visual document retrieval benchmark; publish_date: +L28: none†arxiv.org】 by R Osmulski · 2025 · Cited by 2 — MIRACL-VISION is a +L29: challenging, representative, multilingual evaluation benchmark for visual +L30: retrieval pipelines and will help the community ... +L31: * 【7†ir_datasets : MIRACL; publish_date: none†ir-datasets.com】 +L32: "miracl/ar/test-a". The held-out test set (version a) for Arabic. +L33: queriesdocsCitationMetadata. 936 queries. Language: ar. Query type: +L34: GenericQuery: (namedtuple). +L35: * 【8†Evaluate on MIRACL — BGE documentation; publish_date: none†bge-model.com】 +L36: MIRACL (Multilingual Information Retrieval Across a Continuum of Languages) is +L37: an WSDM 2023 Cup challenge that focuses on search across 18 different languages. +L38: * 【9†MIRACL - Alpha's Tech Garden; publish_date: +L39: none†techgarden.alphasmanifesto.com】 MIRACL (Multilingual Information Retrieval +L40: Across a Continuum of Languages) is a multilingual dataset we have built for the +L41: WSDM 2023 Cup ... +[15] Zero-shot IR benchmark (Search_Results/Zero-shot IR benchmark) +**viewing lines [0 - 40] of 40** + +L0: +L1: URL: Search_Results/Zero-shot IR benchmark +L2: # Search Results +L3: +L4: * 【0†BEIR: A Heterogenous Benchmark for Zero-shot Evaluation ...; +L5: publish_date: none†arxiv.org】 by N Thakur · 2021 · Cited by 1480 — We introduce +L6: Benchmarking-IR (BEIR), a robust and heterogeneous evaluation benchmark for +L7: information retrieval.See more +L8: * 【1†beir-cellar/beir; publish_date: none†github.com】 BEIR: A Heterogenous +L9: Benchmark for Zero-shot Evaluation of Information Retrieval Models (NeurIPS +L10: 2021, Datasets and Benchmarks Track); Resources for Brewing ...See more +L11: * 【2†Benchmarking IR Information Retrieval (BEIR); publish_date: +L12: none†zilliz.com】 BEIR is a tool to evaluate how well Information Retrieval +L13: systems perform across many tasks and types of information, and is a standard +L14: benchmark. +L15: * 【3†BEIR: A Heterogeneous Benchmark for Zero-shot ...; publish_date: +L16: none†datasets-benchmarks-proceedings.neurips.cc】 by N Thakur · Cited by 1480 — +L17: BEIR is a robust, heterogeneous benchmark for information retrieval, using 18 +L18: datasets and 9 tasks to evaluate model generalization. +L19: * 【4†BEIR; publish_date: none†eval.ai】 BEIR is a heterogeneous zero-shot +L20: retrieval benchmark containing 18 datasets from diverse text retrieval tasks and +L21: domains.See more +L22: * 【5†[2409.15763] IRSC: A Zero-shot Evaluation Benchmark for ...; +L23: publish_date: none†arxiv.org】 by H Lin · 2024 · Cited by 2 — This paper +L24: introduces the IRSC benchmark for evaluating the performance of embedding models +L25: in multilingual RAG tasks.See more +L26: * 【6†FactIR: A Real-World Zero-shot Open-Domain Retrieval ...; publish_date: +L27: none†dl.acm.org】 May 23, 2025 — In this paper, we present a real-world retrieval +L28: benchmark FactIR, derived from Factiverse production logs, enhanced with human +L29: annotations. We ...See more +L30: * 【7†UniIR: Training and Benchmarking Universal Multimodal ...; publish_date: +L31: none†tiger-ai-lab.github.io】 At test time, we evaluated the zero-shot +L32: performance of all fine-tuned models, as well as SoTA pre-trained retrievers on +L33: the three held-out datasets. UniIR ...See more +L34: * 【8†Zero-Shot BEIR Tasks; publish_date: none†www.emergentmind.com】 Aug 26, +L35: 2025 — Zero-Shot BEIR Tasks are evaluation methodologies that assess IR models' +L36: ability to generalize to unseen query domains without task-specific ...See more +L37: * 【9†BEIR-PL: Zero Shot Information Retrieval Benchmark for ...; publish_date: +L38: none†aclanthology.org】 by K Wojtasik · 2024 · Cited by 12 — BEIR-PL is a new +L39: benchmark with 13 datasets for Polish Information Retrieval, created to advance +L40: research in this area. +[16] WebGPT benchmark (Search_Results/WebGPT benchmark) +**viewing lines [0 - 38] of 38** + +L0: +L1: URL: Search_Results/WebGPT benchmark +L2: # Search Results +L3: +L4: * 【0†WebGPT: Improving the factual accuracy of language ...; publish_date: +L5: none†openai.com】 Dec 16, 2021 — Our models outperform GPT‑3 on TruthfulQA and +L6: exhibit more favourable scaling properties. However, our models lag behind human +L7: performance, ... +L8: * 【1†A Simple Yet Challenging Benchmark for Browsing Agents; publish_date: +L9: none†arxiv.org】 by J Wei · 2025 · Cited by 124 — Abstract. We present +L10: BrowseComp, a simple yet challenging benchmark for measuring the ability for +L11: agents to browse the web. +L12: * 【2†openai/webgpt_comparisons · Datasets at Hugging Face; publish_date: +L13: none†huggingface.co】 This is the dataset of all comparisons that were marked as +L14: suitable for reward modeling by the end of the WebGPT project. There are 19,578 +L15: comparisons in total. +L16: * 【3†Evaluation & Limitations of WebGPT, WebVoyager & Agent-E; publish_date: +L17: none†deepsense.ai】 Oct 14, 2024 — WebArena benchmark features 812 tasks +L18: evaluated using metrics such as Exact Match, Must Include, and Fuzzy Match, +L19: focusing on outcomes rather ... +L20: * 【4†OpenAI Announces Question-Answering AI WebGPT; publish_date: +L21: none†www.infoq.com】 Jan 25, 2022 — On the TriviaQA benchmark, WebGPT +L22: outperformed GPT-3, producing answers that were true 75% of the time, and "both +L23: true and informative" 54% of ... +L24: * 【5†WebGPT: Improving the factual accuracy of language models ...; +L25: publish_date: none†kargarisaac.medium.com】 The top-performing model generated +L26: answers that were preferred over 56% of the time compared to answers produced by +L27: human demonstrators, with ... +L28: * 【6†Browser-assisted question-answering with human feedback; publish_date: +L29: none†www.alphaxiv.org】 WebGPT represents a significant advancement in long-form +L30: question answering by combining the language generation capabilities of GPT-3 +L31: with real-time web ... +L32: * 【7†Benchmarking Open-Source Large Language Models, GPT-4 ...; publish_date: +L33: none†ai.nejm.org】 by S Wu · 2024 · Cited by 69 — We show that the current widely +L34: used open-source LLMs have poor zero-shot reasoning ability in nephrology +L35: compared with GPT-4 and Claude 2. +L36: * 【8†0hq/WebGPT: Run GPT model on ...; publish_date: none†github.com】 WebGPT +L37: is a vanilla JS and HTML implementation of a transformer model, intended as a +L38: proof-of-concept as well as educational resource. +[17] WebShop benchmark (Search_Results/WebShop benchmark) +**viewing lines [0 - 41] of 41** + +L0: +L1: URL: Search_Results/WebShop benchmark +L2: # Search Results +L3: +L4: * 【0†WebShop: Towards Scalable Real-World Web Interaction ...; publish_date: +L5: none†arxiv.org】 by S Yao · 2022 · Cited by 710 — To bridge this gap, we develop +L6: WebShop -- a simulated e-commerce website environment with 1.18 million real- +L7: world products and 12,087 crowd- ... +L8: * 【1†WebShop; publish_date: none†webshop-pnlp.github.io】 To bridge this gap, +L9: we develop WebShop – a simulated e-commerce website environment with 1.18 +L10: million real-world products and 12,087 crowd-sourced text ... +L11: * 【2†princeton-nlp/WebShop; publish_date: none†github.com】 WebShop is a +L12: simulated e-commerce website environment with 1.18 million real-world products +L13: and 12,087 crowd-sourced text instructions. In this environment, an ... +L14: * 【3†WebShop: Towards Scalable Real-World Web Interaction ...; publish_date: +L15: none†papers.nips.cc】 by S Yao · 2022 · Cited by 710 — We collect over 1,600 +L16: human trajectories to first validate the benchmark, then train and evaluate a +L17: diverse range of agents using reinforcement learning, ... +L18: * 【4†WebShop: Towards Scalable Real-World Web Interaction ...; publish_date: +L19: none†proceedings.neurips.cc】 by S Yao · 2022 · Cited by 709 — We have developed +L20: WebShop, a new web-based benchmark for sequential decision making and language +L21: grounding, modeled on interaction with an e-commerce website. +L22: * 【5†Webshop & Benchmark Analysis | Documentation Infinity; publish_date: +L23: none†docs.fact-finder.com】 Aug 15, 2025 — Evaluation of your shop based on +L24: different categories in comparison, to your competitors/industry. Recommended +L25: when doing a shop relaunch. +L26: * 【6†A Multi-Shop Benchmark for Evaluating Web Agents; publish_date: +L27: none†arxiv.org】 by R Peeters · 2025 · Cited by 2 — Compared to existing +L28: e-commerce benchmarks, such as WebShop or ShoppingBench, WebMall introduces +L29: comparison-shopping tasks across multiple shops ... +L30: * 【7†WebShop: towards scalable real-world web interaction with ...; +L31: publish_date: none†dl.acm.org】 by S Yao · 2022 · Cited by 710 — To bridge this +L32: gap, we develop WebShop - a simulated e-commerce website environment with 1.18 +L33: million real-world products and 12, 087 crowd- ... +L34: * 【8†[PDF] WebShop: Towards Scalable Real-World Web ...; publish_date: +L35: none†www.semanticscholar.org】 It is shown that agents trained on WebShop exhibit +L36: non-trivial sim-to-real transfer when evaluated on amazon.com and ebay.com, +L37: indicating the potential ... +L38: * 【9†X-WebAgentBench: A Multilingual Interactive Web ...; publish_date: +L39: none†aclanthology.org】 by P Wang · 2025 · Cited by 3 — (2023) based on the +L40: English WebShop benchmark (Yao et al., 2022), while the multilingual task scores +L41: are ob- tained through evaluation on our own benchmark. +[18] ALFWorld benchmark (Search_Results/ALFWorld benchmark) +**viewing lines [0 - 31] of 31** + +L0: +L1: URL: Search_Results/ALFWorld benchmark +L2: # Search Results +L3: +L4: * 【0†ALFWorld; publish_date: none†alfworld.github.io】 ALFWorld contains +L5: interactive TextWorld environments (Côté et. al) that parallel embodied worlds +L6: in the ALFRED dataset (Shridhar et. al). +L7: * 【1†ALFWorld: Aligning Text and Embodied Environments for ...; publish_date: +L8: none†arxiv.org】 by M Shridhar · 2020 · Cited by 674 — ALFWorld enables the +L9: creation of a new BUTLER agent whose abstract knowledge, learned in TextWorld, +L10: corresponds directly to concrete, visually grounded actions. +L11: * 【2†ALFWorld: Aligning Text and Embodied Environments ...; publish_date: +L12: none†github.com】 ALFWorld contains interactive TextWorld environments (Côté et. +L13: al) that parallel embodied worlds in the ALFRED dataset (Shridhar et. al). +L14: * 【3†alfworld - benchmark's activity; publish_date: none†huggingface.co】 MM- +L15: IQ: Benchmarking Human-Like Abstraction and Reasoning in Multimodal Models Paper +L16: • 2502.00698 • Published Feb 1 • 24 +L17: * 【4†Tackling AlfWorld with Action Attention and Common ...; publish_date: +L18: none†neurips.cc】 On the Alfworld benchmark for indoor instruction following, we +L19: achieve a significantly higher success rate (50% over the baseline) with our +L20: novel object ... +L21: * 【5†ALFWORLD: ALIGNING TEXT AND EMBODIED ...; publish_date: +L22: none†openreview.net】 by M Shridhar · Cited by 674 — The ALFRED dataset (Shridhar +L23: et al., 2020), set in the THOR simulator (Kolve et al., 2017), is a benchmark +L24: for learning to com- plete embodied household tasks ... +L25: * 【6†AlfWorld; publish_date: none†primo.ai】 Mar 23, 2024 — A simulator that +L26: enables agents to learn abstract, text based policies in TextWorld (Côté et al., +L27: 2018) and then execute goals from the ALFRED benchmark. +L28: * 【7†AlfWorld performance across 134 tasks showing cumulative...; +L29: publish_date: none†www.researchgate.net】 In the AlfWorld benchmark, we defined +L30: hallucination as the occurrence of two or more consecutive identical actions in +L31: which the environment responded with ... +[19] Mind2Web benchmark (Search_Results/Mind2Web benchmark) +**viewing lines [0 - 40] of 40** + +L0: +L1: URL: Search_Results/Mind2Web benchmark +L2: # Search Results +L3: +L4: * 【0†Mind2Web: Towards a Generalist Agent for the Web; publish_date: none†osu- +L5: nlp-group.github.io】 Mind2Web is a dataset for developing and evaluating +L6: generalist agents for the web that can follow language instructions to complete +L7: complex tasks on any ... +L8: * 【1†Online-Mind2Web Leaderboard; publish_date: none†huggingface.co】 Online- +L9: Mind2Web is a benchmark designed to evaluate the real-world performance of web +L10: agents on live websites, featuring 300 tasks across 136 popular sites ... +L11: * 【2†Mind2Web: Towards a Generalist Agent for the Web; publish_date: +L12: none†github.com】 Mind2Web is the first dataset for developing and evaluating +L13: generalist agents for the web that can follow language instructions to complete +L14: complex tasks on any ... +L15: * 【3†HAL: Online Mind2Web Leaderboard; publish_date: +L16: none†hal.cs.princeton.edu】 Online Mind2Web leaderboard for evaluating AI agents' +L17: ability to complete tasks on real, changing webpages. +L18: * 【4†[2506.21506] Mind2Web 2: Evaluating Agentic Search with ...; +L19: publish_date: none†arxiv.org】 by B Gou · 2025 · Cited by 11 — In this paper, we +L20: introduce Mind2Web 2, a benchmark of 130 realistic, high-quality, and long- +L21: horizon tasks that require real-time web browsing and extensive ... +L22: * 【5†Mind2Web 2: Evaluating Agentic Search with Agent-as-a-Judge; +L23: publish_date: none†osu-nlp-group.github.io】 We introduce Mind2Web 2, a benchmark +L24: of 130 realistic, high-quality, long-horizon tasks that require real-time web +L25: browsing and extensive information ... +L26: * 【6†Mind2Web: The Benchmark for AI Agent Evaluation and ...; publish_date: +L27: none†www.enhans.ai】 Sep 26, 2025 — Mind2Web is a globally recognized web-based +L28: AI Agent evaluation benchmark introduced by the NLP group at Ohio State +L29: University at NeurIPS 2023. +L30: * 【7†Evaluating AI Web Agents: Insights from the WebCanvas ...; publish_date: +L31: none†medium.com】 Thanks to the comprehensive WebCanvas Benchmark, which +L32: incorporates a robust Mind2Web-Live data set of 542 live web tasks and 2,439 ... +L33: * 【8†Mind2Web: Towards a Generalist Agent for the Web; publish_date: +L34: none†proceedings.neurips.cc】 by X Deng · 2023 · Cited by 760 — We introduce +L35: Mind2Web, the first dataset for developing and evaluating generalist agents for +L36: the web that can follow language instructions to complete complex ... +L37: * 【9†Mind2Web: Towards a Generalist Agent for the Web; publish_date: +L38: none†openreview.net】 by X Deng · Cited by 760 — We introduce Mind2Web, the first +L39: dataset for developing and evaluating generalist agents for the web that can +L40: follow language instructions to complete complex ... +[20] VisualWebArena benchmark (Search_Results/VisualWebArena benchmark) +**viewing lines [0 - 38] of 38** + +L0: +L1: URL: Search_Results/VisualWebArena benchmark +L2: # Search Results +L3: +L4: * 【0†VisualWebArena is a benchmark for multimodal agents.; publish_date: +L5: none†github.com】 VisualWebArena is a realistic and diverse benchmark for +L6: evaluating multimodal autonomous language agents. +L7: * 【1†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date: +L8: none†arxiv.org】 by JY Koh · 2024 · Cited by 363 — To bridge this gap, we +L9: introduce VisualWebArena, a benchmark designed to assess the performance of +L10: multimodal web agents on realistic \textit{ ... +L11: * 【2†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date: +L12: none†jykoh.com】 To bridge this gap, we introduce VisualWebArena, a benchmark +L13: designed to assess the performance of multimodal web agents on realistic +L14: visually grounded tasks. +L15: * 【3†VisualWebArena: Evaluating Multimodal Agents on ...; publish_date: +L16: none†arxiv.org】 VisualWebArena is a research benchmark to measure and evaluate +L17: the progress of multimodal agents. It is primarily meant to act as a self- +L18: contained sandbox ... +L19: * 【4†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date: +L20: none†aclanthology.org】 by JY Koh · 2024 · Cited by 363 — To bridge this gap, we +L21: introduce VisualWebArena, a benchmark designed to assess the performance of +L22: multimodal web agents on *realistic visually grounded tasks*. +L23: * 【5†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date: +L24: none†www.semanticscholar.org】 VisualWebArena: Evaluating Multimodal Agents on +L25: Realistic Visual Web Tasks ... MMInA, a multihop and multimodal benchmark to +L26: evaluate the embodied agents ... +L27: * 【6†CMU Researchers Introduce VisualWebArena: An AI ...; publish_date: +L28: none†www.marktechpost.com】 Feb 9, 2024 — VisualWebArena, a benchmark designed +L29: and developed to evaluate the performance of multimodal web agents on realistic +L30: and visually stimulating challenges. +L31: * 【7†Evaluating Multimodal Agents on Realistic Visual Web Tasks; publish_date: +L32: none†www.themoonlight.io】 The paper "VisualWebArena: Evaluating Multimodal +L33: Agents on Realistic Visually Grounded Web Tasks" introduces a new benchmark, +L34: **VisualWebArena**, ... +L35: * 【8†WebArena: A Realistic Web Environment for Building ...; publish_date: +L36: none†webarena.dev】 Our benchmark is implemented in our fully interactable +L37: highly-realistic WebArena environment. It features diverse tasks human may +L38: encounter in their daily ... +[21] SearchBench benchmark (Search_Results/SearchBench benchmark) +**viewing lines [0 - 40] of 40** + +L0: +L1: URL: Search_Results/SearchBench benchmark +L2: # Search Results +L3: +L4: * 【0†Talc-AI/search-bench; publish_date: none†github.com】 A practical +L5: benchmark that focuses on every day helpfulness of LLM products, not just the +L6: underlying models. Searchbench is a benchmark that addresses these ... +L7: * 【1†Evaluating LLMs' Ability to Reason About Search Problems; publish_date: +L8: none†arxiv.org】 These capabilities are essential for robust reasoning, making +L9: SearchBench a valuable benchmark for evaluating LLMs' reasoning capabilities as +L10: they continue to ... +L11: * 【2†NasimBrz/SearchBench · Datasets at Hugging Face; publish_date: +L12: none†huggingface.co】 Dataset Summary. SearchBench is a benchmark designed to +L13: evaluate Language Models' (LLMs) ability to solve state-based problems that +L14: require combinatorial search ... +L15: * 【3†Evaluating LLMs' Ability to Reason About Search Problems; publish_date: +L16: none†openreview.net】 2025年10月22日 — To further investigate this, we introduce a +L17: new benchmark, SearchBench, which contains 11 unique search problems inspired by +L18: intuitive puzzles. +L19: * 【4†Navigating the Labyrinth: Evaluating and Enhancing LLMs' ...; +L20: publish_date: none†hub.baai.ac.cn】 2024年6月17日 — +L21: 论文提出了一个新的基准测试SearchBench,包含11种独特的搜索问题类型,并自动化生成任意数量的实例和分析解决方案的可行性、正确性和最优性。论文使用A* +L22: ... +L23: * 【5†Towards Unified Text-based Person Retrieval: A Large- ...; publish_date: +L24: none†blog.csdn.net】 2023年10月17日 — ... Search +L25: Benchmark(面向统一的基于文本的人物检索:一个大规模的多属性和语言搜索基准); 研究背景. 相关工作; BENCHMARK. 论文方法分析. 网络框架; +L26: 1、APTM ... +L27: * 【6†Desearch-ai/ai-search-benchmark; publish_date: none†github.com】 The +L28: SearchBench repository addresses common issues with traditional benchmarks by +L29: focusing on practical, everyday use cases rather than theoretical limits. It ... +L30: * 【7†o1 results for 3 benchmarks: PlanBench, SearchBench, ...; publish_date: +L31: none†www.reddit.com】 o1 results for 3 benchmarks: PlanBench, SearchBench, and +L32: Summary of a Haystack. AI. PlanBench: Paper "LLMs Still Can't Plan; Can LRMs? A +L33: ... +L34: * 【8†Evaluating LLMs' Ability to Reason About Search Problems; publish_date: +L35: none†ui.adsabs.harvard.edu】 To further investigate this, we introduce a new +L36: benchmark, SearchBench, which contains 11 unique search problems inspired by +L37: intuitive puzzles. Each SearchBench ... +L38: * 【9†Introducing SearchBench; publish_date: none†www.tag1consulting.com】 +L39: Toward this goal, over the weekend I launched a new project called SearchBench, +L40: a Drupal module for benchmarking Drupal's search performance. As the module ... +[22] WebVLN benchmark (Search_Results/WebVLN benchmark) +**viewing lines [0 - 42] of 42** + +L0: +L1: URL: Search_Results/WebVLN benchmark +L2: # Search Results +L3: +L4: * 【0†WebVLN: Vision-and-Language Navigation on Websites; publish_date: +L5: none†ojs.aaai.org】 by Q Chen · 2024 · Cited by 35 — the WebVLN-v1 dataset, where +L6: the performance is far from saturation, highlighting the utility of our +L7: WebVLN-v1 as a benchmark to assess progress in this field. +L8: * 【1†[2312.15820] WebVLN: Vision-and-Language Navigation on Websites; +L9: publish_date: none†ar5iv.labs.arxiv.org】 Experimental results show that WebVLN- +L10: Net outperforms current VLN and web-related navigation methods. ... Code is +L11: available at: https://github.com/WebVLN/WebVLN. +L12: * 【2†WebVLN: Vision-and-Language Navigation on Websites; publish_date: +L13: none†github.com】 Experimental results show that WebVLN-Net outperforms current +L14: VLN and web-related navigation methods. We believe that the introduction of the +L15: new WebVLN task ... +L16: * 【3†Vision-and-Language Navigation in the Real-World; publish_date: +L17: none†digital.library.adelaide.edu.au】 By leveraging our proposed WebVLN-v1 +L18: dataset, experimental results showcase the superior performance of WebVLN-Net +L19: compared to existing VLN and web-related ... +L20: * 【4†WebVLN: Vision-and-Language Navigation on Websites; publish_date: +L21: none†www.researchgate.net】 Experimental results show that WebVLN-Net outperforms +L22: current VLN and web-related navigation methods. We believe that the +L23: introduction of the newWebVLN task and ... +L24: * 【5†[PDF] WebVLN: Vision-and-Language Navigation on Websites; publish_date: +L25: none†www.semanticscholar.org】 A new task named Vision-and-Language Navigation on +L26: Websites (WebVLN), where question-based instructions are used to train an +L27: agent, emulating how users ... +L28: * 【6†WebVLN: Vision-and-Language Navigation on Websites; publish_date: +L29: none†arxiv.org】 by Q Chen · 2023 · Cited by 35 — Experimental results show that +L30: WebVLN-Net outperforms current VLN and web-related navigation methods. We +L31: believe that the introduction of the ... +L32: * 【7†Human-Aware Vision-and-Language Navigation; publish_date: +L33: none†proceedings.neurips.cc】 by H Li · 2024 · Cited by 19 — Vision-and-Language +L34: Navigation (VLN) [2, 7, 9, 40] has emerged as a key benchmark for evaluating. +L35: Sim2Real transfer [23], showing impressive performance in ... +L36: * 【8†LiveBench; publish_date: none†livebench.ai】 Introducing LiveBench: a +L37: benchmark for LLMs designed with test set contamination and objective evaluation +L38: in mind. +L39: * 【9†MG-VLN: Benchmarking Multi-Goal and Long-Horizon ...; publish_date: +L40: none†ieeexplore.ieee.org】 by J Zhang · 2024 — This task aims to provide a +L41: simulation benchmark to guide the design of lifelong and long-horizon navigation +L42: robots. +[23] WebNav benchmark (Search_Results/WebNav benchmark) +**viewing lines [0 - 36] of 36** + +L0: +L1: URL: Search_Results/WebNav benchmark +L2: # Search Results +L3: +L4: * 【0†WebNav: A New Large-Scale Task for Natural Language ...; publish_date: +L5: none†github.com】 WebNav is a benchmark task for evaluating an agent with +L6: abilities to understand natural language and plan on partially observed +L7: environments. +L8: * 【1†[1602.02261] End-to-End Goal-Driven Web Navigation; publish_date: +L9: none†arxiv.org】 by R Nogueira · 2016 · Cited by 39 — We propose a goal-driven +L10: web navigation as a benchmark task for evaluating an agent with abilities to +L11: understand natural language and plan on partially ... +L12: * 【2†nyu-dl/dl4ir-webnav; publish_date: none†github.com】 WebNav is a benchmark +L13: task for evaluating an agent with abilities to understand natural language and +L14: plan on partially observed environments. +L15: * 【3†WebNav: A New Large-Scale Task for Natural Language ...; publish_date: +L16: none†www.researchgate.net】 We propose a goal-driven web navigation as a +L17: benchmark task for evaluating an agent with abilities to understand natural +L18: language and plan on partially ... +L19: * 【4†WebNav: An Intelligent Agent for Voice-Controlled Web ...; publish_date: +L20: none†arxiv.org】 In this paper, we introduce WebNav, an innovative multi-modal +L21: agent designed to automate web tasks based on natural language. +L22: * 【5†WebCanvas: Benchmarking Web Agents in Online ...; publish_date: +L23: none†openreview.net】 by Y Pan · Cited by 78 — TL;DR: We introduce WebCanvas, an +L24: online evaluation framework for web agents designed to address the dynamic +L25: nature of web interactions. +L26: * 【6†WebNav: An Intelligent Agent for Voice-Controlled Web ...; publish_date: +L27: none†huggingface.co】 Mar 17, 2025 — Preliminary evaluations show that WebNav +L28: outperforms traditional screen readers in response time and task completion +L29: accuracy for the visually ... +L30: * 【7†Web Bench - A new way to compare AI Browser Agents; publish_date: +L31: none†www.skyvern.com】 May 29, 2025 — TL;DR: Web Bench is a new dataset to +L32: evaluate web browsing agents that consists of 5,750 tasks on 452 different +L33: websites, with 2,454 tasks ... +L34: * 【8†Web Framework Benchmarks; publish_date: none†www.techempower.com】 In the +L35: following tests, we have measured the performance of several web application +L36: platforms, full-stack frameworks, and micro-frameworks. + + +我们收到了大量关于各种基准测试的搜索结果。现在,我们需要整理一份针对搜索智能体的基准测试列表。这些基准测试可以分为几类:网络浏览智能体基准测试(例如BrowseComp、WebVoyager、WebArena、VisualWebArena、Mind2Web、WebVLN、WebNav)、工具使用智能体基准测试(例如ToolBench、GTA)、问答/检索基准测试(例如HotpotQA、FEVER、TriviaQA、Natural Questions、MS MARCO、BEIR、MIRACL),以及特定领域的搜索基准测试(例如HSCodeComp、WebShop、ALFWorld)。此外,还有一些通用智能体基准测试,例如GAIA、SuperCLUE-Agent等。用户要求“详细介绍各自的特点、使用场景、例题”。因此,我需要为每个基准测试提供详细的描述,包括其特点、使用场景,并举例说明。为了全面获取这些信息,我将同时打开以下基准测试的相关页面:WebVoyager(索引3)、WebArena(索引1)、GAIA(索引1)、ToolBench(索引0)、HotpotQA(索引0)、FEVER(索引1)、TriviaQA(索引0)、Natural Questions(索引0)、MS MARCO(索引0)、BEIR(索引1)、MIRACL(索引0)、WebShop(索引0)、ALFWorld(索引0)、Mind2Web(索引0)、VisualWebArena(索引0)、SearchBench(索引1)、WebVLN(索引1)和WebNav(索引0)。 + +<|DSML|function_calls> +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">4 +<|DSML|parameter name="id" string="false">3 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">5 +<|DSML|parameter name="id" string="false">1 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">6 +<|DSML|parameter name="id" string="false">1 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">7 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">8 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">9 +<|DSML|parameter name="id" string="false">1 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">10 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">11 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">12 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">13 +<|DSML|parameter name="id" string="false">1 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">14 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">17 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">18 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">19 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">20 +<|DSML|parameter name="id" string="false">0 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">21 +<|DSML|parameter name="id" string="false">1 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">22 +<|DSML|parameter name="id" string="false">1 + +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">23 +<|DSML|parameter name="id" string="false">0 + +<|end▁of▁sentence|> + + +[24] (https://ar5iv.org/abs/2401.13919) +**viewing lines [0 - 129] of 825** + +L0: +L1: URL: https://ar5iv.org/abs/2401.13919 +L2: # WebVoyager : Building an End-to-End Web Agent with +L3: Large Multimodal Models +L4: +L5: Hongliang He1,3 , Wenlin Yao2, Kaixin Ma2, Wenhao Yu2, Yong Dai2, +L6: +L7: Hongming Zhang2, +L8: Zhenzhong Lan3, +L9: Dong Yu2 +L10: 1Zhejiang University, +L11: 2Tencent AI Lab, +L12: 3Westlake University +L13: hehongliang@westlake.edu.cn, wenlinyao@global.tencent.com +L14: Work done during the internship at Tencent AI Lab. +L15: +L16: ###### Abstract +L17: +L18: The advancement of large language models (LLMs) leads to a new era marked by the +L19: development of autonomous applications in the real world, which drives +L20: innovation in the creation of advanced web-based agents. Existing web agents +L21: typically only handle one input modality and are evaluated only in simplified +L22: web simulators or static web snapshots, greatly limiting their applicability in +L23: real-world scenarios. To bridge this gap, we introduce WebVoyager, an innovative +L24: Large Multimodal Model (LMM) powered web agent that can complete user +L25: instructions end-to-end by interacting with real-world websites. Moreover, we +L26: propose a new evaluation protocol for web agents to address the challenges of +L27: automatic evaluation of open-ended web agent tasks, leveraging the robust +L28: multimodal comprehension capabilities of GPT-4V. We create a new benchmark by +L29: gathering real-world tasks from 15 widely used websites to evaluate our agents. +L30: We show that WebVoyager achieves a 55.7% task success rate, significantly +L31: surpassing the performance of both GPT-4 (All Tools) and the WebVoyager (text- +L32: only) setups, underscoring the exceptional capability of WebVoyager in practical +L33: applications. We found that our proposed automatic evaluation achieves 85.3% +L34: agreement with human judgment, paving the way for further development of web +L35: agents in a real-world setting.111Our code and data will be released at +L36: https://github.com/MinorJerry/WebVoyager +L37: +L38: ## 1 Introduction +L39: +L40: The recent advancement of large language models (LLMs), such as ChatGPT and +L41: GPT-4 (OpenAI, 2023), have sparked significant interest in developing LLM-based +L42: autonomous agents (AutoGPT, 2022) for complex task execution (Qin et al., 2023; +L43: Schick et al., 2023). Recent studies have explored the construction of text- +L44: based web browsing environments and how to instruct large language model agents +L45: to perform web navigation (Nakano et al., 2021; Gur et al., 2023; Zhou et al., +L46: 2023; Lu et al., 2023). The primary challenge in these works lies in managing +L47: complex and verbose HTML texts, and solutions include simplifying and +L48: structuring HTML (Nakano et al., 2021; Zhou et al., 2023; Gur et al., 2023; Deng +L49: et al., 2023). +L50: +L51: However, existing approaches overlook a critical functionality of browsing: +L52: rendering HTML into visual webpages. Particularly, vision capability is crucial +L53: for utilizing tools like web browsers, as rendered web pages are inherently +L54: designed with user experience (UX), emphasizing intuitive information and +L55: structured presentation. This design principle of rendering makes visual +L56: analysis more effective than mere HTML representation. At present, large +L57: multimodal models (LMMs), particularly GPT-4V(ision) (OpenAI, 2023) and Gemini +L58: (Team et al., 2023), demonstrate a remarkable ability to integrate intricate +L59: visual cues with textual information. Existing studies such as Pix2Struct (Lee +L60: et al., 2023) and WebArena (Zhou et al., 2023), have initiated explorations into +L61: using screenshots as inputs for decision-making in web navigation, yet these +L62: are preliminary and do not represent a deep exploration. Therefore, building +L63: multimodal web agents to leverage the environment rendered by browsers through +L64: screenshots, thus mimicking human web browsing behavior, is now a viable +L65: approach to enhance web navigation efficiency. +L66: +L67: We introduce WebVoyager, a multimodal web agent designed to handle web tasks +L68: online in an end-to-end manner, which denotes managing the process from start to +L69: finish autonomously without intermediate human intervention. We construct an +L70: online environment using Selenium for WebVoyager, feeding it with screenshots +L71: and textual content in interactive web elements. Inspired by Set-of-Mark +L72: Prompting (Yang et al., 2023a), we mark interactive web elements on screenshots +L73: (see Figure 2) to facilitate decision-making for WebVoyager. As a pioneer in +L74: combining vision and text information during web navigation, we advocate that +L75: autonomous end-to-end task completion, multimodal capabilities and online +L76: navigation constitute the essential trajectory toward the genuine intelligence +L77: of web agents. +L78: +L79: Another challenge arises when it comes to evaluating an end-to-end web agent +L80: with online navigation. Existing benchmarks, such as Mind2Web (Deng et al., +L81: 2023), primarily focus on stepwise and offline evaluation, where agents follow +L82: predefined “golden” trajectory for action selection. This approach, however, may +L83: not fully account for the variety of viable strategies to accomplish a task, as +L84: it only reflects one possible plan. This limitation could lead to a biased +L85: evaluation and difficulties in fairly comparing different methods. To more +L86: accurately gauge the capabilities of web agents in end-to-end task completion, +L87: we save screenshots throughout the online navigation process, and then use +L88: GPT-4V to evaluate these trajectories and the final results automatically. Human +L89: evaluations are also conducted to verify the results and confirm the +L90: reliability of GPT-4V as the evaluator. +L91: +L92: We conduct evaluations on a collected dataset, which is semi-automatically +L93: generated using a self-instruct (Wang et al., 2022) method, comprising 300 web +L94: tasks from 15 commonly accessed websites. Additionally, we extract 90 web- +L95: related tasks of level 1 and level 2 from the GAIA (Mialon et al., 2023) to +L96: enrich our evaluation. We compare our WebVoyager with 1) GPT-4 (All +L97: Tools)222GPT-4 (All Tools) is an integrated tool-based agent released by OpenAI +L98: in Oct. 2023. See https://chat.openai.com/, and 2) WebVoyager in a text-only +L99: setting, employing the accessibility tree proposed in WebArena (Zhou et al., +L100: 2023) to describe web pages. The results show that WebVoyager achieves a Task +L101: Success Rate of 55.7%, significantly outperforming GPT-4 (All Tools) with a rate +L102: of 32.7% and the text-only setting with a rate of 39.0%, demonstrating the +L103: effectiveness of our method. Furthermore, we report the consistency between +L104: human-human and human-GPT4V to ensure credibility. Our main contributions are as +L105: follows: +L106: +L107: - • +L108: We employ a multimodal web agent that integrates textual and visual information +L109: to address web tasks end-to-end and introduce a generalist planning approach for +L110: navigation. +L111: +L112: - • +L113: We build an online web browsing environment, offering a variety of tasks +L114: centered on widely used websites and introducing a method for expanding these +L115: tasks. +L116: +L117: - • +L118: We conduct manual evaluations of navigation trajectories and propose an +L119: automated evaluation protocol using GPT-4V. We present a comprehensive analysis +L120: of the evaluation results and show that GPT-4V can serve as a reliable evaluator +L121: for online agents. +L122: +L123: ## 2 Related Work +L124: +L125: ### 2.1 Web Agents +L126: +L127: Autonomous web navigation (Shi et al., 2017; Yang et al., 2023b) requires an +L128: agent to follow instructions, construct plans, comprehend complex web +L129: structures, and decompose tasks into step-by-step decisions (Weng, 2023). To +[25] (https://ar5iv.org/abs/2307.13854) +**viewing lines [0 - 124] of 1092** + +L0: +L1: URL: https://ar5iv.org/abs/2307.13854 +L2: # WebArena: A Realistic Web Environment for Building Autonomous Agents +L3: +L4: Shuyan Zhou Frank F. Xu11footnotemark: 1 Hao Zhu Xuhui Zhou22footnotemark: 2 +L5: +L6: Robert Lo22footnotemark: 2 Abishek Sridhar22footnotemark: 2 Xianyi Cheng Tianyue +L7: Ou +L8: Yonatan Bisk Daniel Fried Uri Alon Graham Neubig +L9: Carnegie Mellon University +L10: {shuyanzh, fangzhex, gneubig}@cs.cmu.edu +L11: Lead contributors.Equal contribution. +L12: +L13: ###### Abstract +L14: +L15: With advances in generative AI, there is now potential for autonomous agents to +L16: manage daily tasks via natural language commands. However, current agents are +L17: primarily created and tested in simplified synthetic environments, leading to a +L18: disconnect with real-world scenarios. In this paper, we build an environment for +L19: language-guided agents that is highly realistic and reproducible. Specifically, +L20: we focus on agents that perform tasks on the web, and create an environment +L21: with fully functional websites from four common domains: e-commerce, social +L22: forum discussions, collaborative software development, and content management. +L23: Our environment is enriched with tools (e.g., a map) and external knowledge +L24: bases (e.g., user manuals) to encourage human-like task-solving. Building upon +L25: our environment, we release a set of benchmark tasks focusing on evaluating the +L26: functional correctness of task completions. The tasks in our benchmark are +L27: diverse, long-horizon, and designed to emulate tasks that humans routinely +L28: perform on the internet. We experiment with several baseline agents, integrating +L29: recent techniques such as reasoning before acting. The results demonstrate that +L30: solving complex tasks is challenging: our best GPT-4-based agent only achieves +L31: an end-to-end task success rate of 14.41%, significantly lower than the human +L32: performance of 78.24%. These results highlight the need for further development +L33: of robust agents, that current state-of-the-art large language models are far +L34: from perfect performance in these real-life tasks, and that WebArena can be used +L35: to measure such progress. +L36: +L37: Our code, data, environment reproduction resources, and video demonstrations are +L38: publicly available at https://webarena.dev/. +L39: +L40: ## 1 Introduction +L41: +L42: Autonomous agents that perform everyday tasks via human natural language +L43: commands could significantly augment human capabilities, improve efficiency, and +L44: increase accessibility. Nonetheless, to fully leverage the power of autonomous +L45: agents, it is crucial to understand their behavior within an environment that is +L46: both authentic and reproducible. This will allow measurement of the ability of +L47: agents on tasks that human users care about in a fair and consistent manner. +L48: +L49: Current environments for evaluate agents tend to over-simplify real-world +L50: situations. As a result, the functionality of many environments is a limited +L51: version of their real-world counterparts, leading to a lack of task diversity +L52: (Shi et al., 2017; Anderson et al., 2018; Gordon et al., 2018; Misra et al., +L53: 2016; Shridhar et al., 2020; 2021; Yao et al., 2022a). In addition, these +L54: simplifications often lower the complexity of tasks as compared to their +L55: execution in the real world (Puig et al., 2018; Shridhar et al., 2020; Yao et +L56: al., 2022a). Finally, some environments are presented as a static resource (Shi +L57: et al., 2017; Deng et al., 2023) where agents are confined to accessing only +L58: those states that were previously cached during data collection, thus limiting +L59: the breadth and diversity of exploration. Dor evaluation, many environments +L60: focus on comparing the textual surface form of the predicted action sequences +L61: with reference action sequences, disregarding the functional correctness of the +L62: executions and possible alternative solutions (Puig et al., 2018; Jernite et +L63: al., 2019; Xu et al., 2021; Li et al., 2020; Deng et al., 2023). These +L64: limitations often result in a discrepancy between simulated environments and the +L65: real world, and can potentially impact the generalizability of AI agents to +L66: successfully understand, adapt, and operate within complex real-world +L67: situations. +L68: +L69: We introduce WebArena, a realistic and reproducible web environment designed to +L70: facilitate the development of autonomous agents capable of executing tasks (§2). +L71: An overview of WebArena is in Figure 1. Our environment comprises four fully +L72: operational, self-hosted web applications, each representing a distinct domain +L73: prevalent on the internet: online shopping, discussion forums, collaborative +L74: development, and business content management. Furthermore, WebArena incorporates +L75: several utility tools, such as map, calculator, and scratchpad, to best support +L76: possible human-like task executions. Lastly, WebArena is complemented by an +L77: extensive collection of documentation and knowledge bases that vary from general +L78: resources like English Wikipedia to more domain-specific references, such as +L79: manuals for using the integrated development tool (Fan et al., 2022). The +L80: content populating these websites is extracted from their real-world +L81: counterparts, preserving the authenticity of the content served on each +L82: platform. We deliver the hosting services using Docker containers with gym-APIs +L83: (Brockman et al., 2016), ensuring both the usability and the reproducibility of +L84: WebArena. +L85: +L86: Along with WebArena, we release a ready-to-use benchmark with 812 long-horizon +L87: web-based tasks (§3). Each task is described as a high-level natural language +L88: intent, emulating the abstract language usage patterns typically employed by +L89: humans (Bisk et al., 2019). Two example intents are shown in the upper left of +L90: Figure 1. We focus on evaluating the functional correctness of these tasks, +L91: i.e., does the result of the execution actually achieve the desired goal (§3.2). +L92: For instance, to evaluate the example in Figure 2, our evaluation method +L93: verifies the concrete contents in the designated repository. This evaluation is +L94: not only more reliable (Zhong et al., 2017; Chen et al., 2021; Wang et al., +L95: 2022) than comparing the textual surface-form action sequences (Puig et al., +L96: 2018; Deng et al., 2023) but also accommodate a range of potential valid paths +L97: to achieve the same goal, which is a ubiquitous phenomenon in sufficiently +L98: complex tasks. +L99: +L100: We use this benchmark to evaluate several agents that can follow NL command and +L101: perform web-based tasks (§4). These agents are implemented in a few-shot in- +L102: context learning fashion with powerful large language models (LLMs) such as +L103: GPT-4 and PALM-2. Experiment results show that the best GPT-4 agent performance +L104: is somewhat limited, with an end-to-end task success rate of only 14.41%, while +L105: the human performance is 78.24%. We hypothesize that the limited performance of +L106: current LLMs stems from a lack of crucial capabilities such as active +L107: exploration and failure recovery to successfully perform complex tasks (§5.2). +L108: These outcomes underscore the necessity for further development towards robust +L109: and effective agents (LeCun, 2022) in WebArena. +L110: +L111: ## 2 WebArena: Websites as an Environment for Autonomous Agents +L112: +L113: Our goal is to create a realistic and reproducible web environment. We achieve +L114: reproducibility by making the environment standalone, without relying on live +L115: websites. This circumvents technical challenges such as bots being subject to +L116: CAPTCHAs, unpredictable content modifications, and configuration changes, which +L117: obstruct a fair comparison across different systems over time. We achieve +L118: realism by using open-source libraries that underlie many in-use sites from +L119: several popular categories and importing data to our environment from their +L120: real-world counterparts. +L121: +L122: ### 2.1 Controlling Agents through High-level Natural Language +L123: +L124: The WebArena environment is denoted asℰ\mathcal{E} with state space +[26] (https://ar5iv.org/abs/2311.12983) +**viewing lines [0 - 118] of 1207** + +L0: +L1: URL: https://ar5iv.org/abs/2311.12983 +L2: 1]FAIR, Meta 2]HuggingFace 3]AutoGPT 4]GenAI, Meta +L3: +L4: # GAIA: A Benchmark for General AI Assistants +L5: +L6: Grégoire Mialon Clémentine Fourrier Craig Swift Thomas Wolf Yann LeCun Thomas +L7: Scialom [ [ [ [ {gmialon,tscialom}@meta.com clementine@huggingface.co +L8: +L9: ###### Abstract +L10: +L11: We introduce GAIA, a benchmark for General AI Assistants that, if solved, would +L12: represent a milestone in AI research. GAIA proposes real-world questions that +L13: require a set of fundamental abilities such as reasoning, multi-modality +L14: handling, web browsing, and generally tool-use proficiency. GAIA questions are +L15: conceptually simple for humans yet challenging for most advanced AIs: we show +L16: that human respondents obtain 92% vs. 15% for GPT-4 equipped with plugins. This +L17: notable performance disparity contrasts with the recent trend of LLMs +L18: outperforming humans on tasks requiring professional skills in e.g. law or +L19: chemistry. GAIA’s philosophy departs from the current trend in AI benchmarks +L20: suggesting to target tasks that are ever more difficult for humans. We posit +L21: that the advent of Artificial General Intelligence (AGI) hinges on a system’s +L22: capability to exhibit similar robustness as the average human does on such +L23: questions. Using GAIA’s methodology, we devise 466 questions and their answer. +L24: We release our questions while retaining answers to 300 of them to power a +L25: leader-board hereby accessible. +L26: +L27: \correspondence +L28: +L29: ## 1 Introduction +L30: +L31: Large Language Models (LLMs) arguably open the way to general purpose systems. +L32: Indeed, the latest among them (OpenAI, 2023; Anthropic, 2023; Anil et al., 2023; +L33: Touvron et al., 2023) are fluent, knowledgeable, aligned to some extent with +L34: human preferences (Ouyang et al., 2022), and can be augmented (Mialon et al., +L35: 2023) with tools such as web browsers or code interpreters in a zero or few-shot +L36: setting (Brown et al., 2020). However, evaluating these systems is an open +L37: problem: given their emerging new capabilities, LLMs are regularly breaking AI +L38: benchmarks, at an ever-increasing rate (Kiela et al., 2023). +L39: +L40: In search for more challenging benchmarks, current trend suggests to seek tasks +L41: that are ever more difficult for humans, and challenge LLMs with more intricate +L42: educational assessments, for example in STEM and Law, or target more complex +L43: realisations, such as writing a coherent book. But, tasks that are difficult for +L44: humans are not necessarily difficult for recent systems: the challenging MMLU +L45: or GSM8k benchmarks for example (Hendrycks et al., 2021; Cobbe et al., 2021) are +L46: already close to be solved,111GPT4 does 86.4% on MMLU. Human non-specialist +L47: accuracy on the benchmark is only 34.5% Expert-level human performance is +L48: estimated at 89.8%. due to rapid LLM improvement possibly combined with data +L49: contamination.222See for example the case of Hellaswag. Furthermore, open-ended +L50: generation generally requires human or model-based evaluation (Zheng et al., +L51: 2023). Human evaluation will become less and less feasible when increasing the +L52: task complexity, e.g. in terms of output length or required skills: how to +L53: evaluate a book generated by an AI, or solutions to maths problems that few +L54: people in the world can solve? Model-based evaluations on the other hand are by +L55: construction dependent of stronger models hence cannot evaluate new state-of- +L56: the-art models, without mentioning potential subtle biases such as preferring +L57: the first choice presented (Zheng et al., 2023). Overall, evaluating new AI +L58: systems requires to rethink benchmarks (Chollet, 2019). +L59: +L60: Alternatively to tasks that are harder for humans, AI systems could be asked to +L61: solve conceptually simple tasks yet that require accurate execution of complex +L62: sequences of actions, with large combinatorial spaces. The output could only be +L63: obtained upon successful completion of the task and be easy to validate, +L64: analogous to the Proof of Work algorithm (Jakobsson and Juels, 1999; Dwork and +L65: Naor, 1993), where a computer is asked to solve a complex problem whose solution +L66: is easy to verify. Tasks for AI assistants, given their need for access to a +L67: diverse and uncertain world, meet this criterion while being inherently rooted +L68: in practical use cases. +L69: +L70: We move in that direction by proposing GAIA, a benchmark for General AI +L71: Assistants featuring 466 carefully crafted questions and their answer, along +L72: with the associated design methodology. Our questions are easy to create, +L73: challenging for AI systems—for LLMs, most require complex generations—, yet +L74: admit a unique, factual answer, allowing a simple and robust automatic +L75: evaluation. +L76: +L77: GAIA attempts to avoid current pitfalls of LLMs evaluation by targeting: +L78: +L79: Real-world and challenging questions. For example, a LLM will typically need to +L80: browse the open and changing web, handle multi-modality, or reason over multiple +L81: steps to answer our questions. Conversely, many LLM benchmarks are quite +L82: specific and/or restricted to closed and synthetic environments. +L83: +L84: Easy interpretability through conceptually simple tasks—non experts annotators +L85: exhibit a near perfect score—, associated reasoning trace, and few but highly +L86: curated questions. This is in contrast with aggregated benchmarks that can lack +L87: efficiency and reliability (Perlitz et al., 2023). +L88: +L89: Non-gameability. Answering the questions requires successful completion of some +L90: number of steps, which cannot easily be brute forced due to their diversity. The +L91: possibility to check the reasoning trace, the accuracy required in the answers, +L92: their absence in plain text from the internet prevent a possible data +L93: contamination. In contrast, multiple choice answers (e.g., MMLU) make +L94: contamination assessment more difficult since a wrong reasoning trace can more +L95: easily get to the correct choice. +L96: +L97: Simplicity of use. Crucially, the answers to our questions are factoid, concise +L98: and unambiguous. These properties allow simple, fast and factual evaluation. Our +L99: questions are meant to be answered in zero shot, limiting the influence of the +L100: evaluation setup. By opposition, many LLM benchmarks require evaluations that +L101: are sensitive to the experimental setup such as the number and nature of prompts +L102: (Liang et al., 2022b) (Section 8.2), or the benchmark +L103: implementation.333https://huggingface.co/blog/evaluating-mmlu-leaderboard +L104: +L105: In spite of being successful at tasks that are difficult for humans, the most +L106: capable LLMs do poorly on GAIA. Even equipped with tools, GPT4 does not exceed a +L107: 30% success rate for the easiest of our tasks, and 0% for the hardest. In the +L108: meantime, the average success rate for human respondents is 92%. Consequently, a +L109: system capable of solving GAIA can be assessed in the context of t-AGI,444As +L110: defined in https://www.alignmentforum.org/posts/BoA3agdkAzL6HQtQP/clarifying- +L111: and-predicting-agi, a t-AGI beats, on most tasks, most human experts who are +L112: given time t to perform the task noting that humans typically take between 6 +L113: minutes for the simplest questions to 17 minutes for the most complex ones. From +L114: a related perspective, such system would arguably be a competent General AI +L115: within the framework recently proposed in Morris et al. (2023), which also +L116: appear to be the next milestone in AI research since ChatGPT (OpenAI, 2023) is +L117: one level below. This paper covers the composition of GAIA, its design choices, +L118: and explain how to craft questions and the associated challenges so that the +Error fetching URL `https://github.com/sambanova/toolbench` +Error fetching URL `https://hotpotqa.github.io/` +[27] (https://aclanthology.org/N18-1074.pdf) +**viewing lines [0 - 134] of 809** + +L0: +L1: URL: https://aclanthology.org/N18-1074.pdf +L2: ===== Page 1 ===== +L3: +L4: # FEVER: a large-scale dataset for Fact Extraction and VERification +L5: +L6: James Thorne\(^1\), Andreas Vlachos\(^1\), Christos Christodoulopoulos\(^2\), +L7: and Arpit Mittal\(^2\) +L8: +L9: \(^1\)Department of Computer Science, University of Sheffield +L10: \(^2\)Amazon Research Cambridge +L11: {j.thorne, a.vlachos}@sheffield.ac.uk +L12: {chrchrs, mitarpit}@amazon.co.uk +L13: +L14: ## Abstract +L15: +L16: In this paper we introduce a new publicly available dataset for verification +L17: against textual sources, FEVER: Fact Extraction and VERification. It consists of +L18: 185,445 claims generated by altering sentences extracted from Wikipedia and +L19: subsequently verified without knowledge of the sentence they were derived from. +L20: The claims are classified as Supported, Refuted or NotEnoughInfo by annotators +L21: achieving 0.6841 in Fleiss \(\kappa\). For the first two classes, the annotators +L22: also recorded the sentence(s) forming the necessary evidence for their +L23: judgment. To characterize the challenge of the dataset presented, we develop a +L24: pipeline approach and compare it to suitably designed oracles. The best accuracy +L25: we achieve on labeling a claim accompanied by the correct evidence is 31.87%, +L26: while if we ignore the evidence we achieve 50.91%. Thus we believe that FEVER is +L27: a challenging testbed that will help stimulate progress on claim verification +L28: against textual sources. +L29: +L30: ## 1 Introduction +L31: +L32: The ever-increasing amounts of textual information available combined with the +L33: ease in sharing it through the web has increased the demand for verification, +L34: also referred to as fact checking. While it has received a lot of attention in +L35: the context of journalism, verification is important for other domains, e.g. +L36: information in scientific publications, product reviews, etc. +L37: +L38: In this paper we focus on verification of textual claims against textual +L39: sources. When compared to textual entailment (TE)/natural language inference +L40: (Dagan et al., 2009; Bowman et al., 2015), the key difference is that in these +L41: tasks the passage to verify each claim is given, and in recent years it +L42: typically consists a single sentence, while in verification systems it is +L43: retrieved from a large set of documents in order to form the evidence. Another +L44: related task is question answering (QA), for which approaches have recently been +L45: extended to handle large-scale resources such as Wikipedia (Chen et al., 2017). +L46: However, questions typically provide the information needed to identify the +L47: answer, while information missing from a claim can often be crucial in +L48: retrieving refuting evidence. For example, a claim stating "Fiji's largest +L49: island is Kauai." can be refuted by retrieving "Kauai is the oldest Hawaiian +L50: Island." as evidence. +L51: +L52: Progress on the aforementioned tasks has benefited from the availability of +L53: large-scale datasets (Bowman et al., 2015; Rajpurkar et al., 2016). However, +L54: despite the rising interest in verification and fact checking among researchers, +L55: the datasets currently used for this task are limited to a few hundred claims. +L56: Indicatively, the recently conducted Fake News Challenge (Pomerleau and Rao, +L57: 2017) with 50 participating teams used a dataset consisting of 300 claims +L58: verified against 2,595 associated news articles which is orders of magnitude +L59: smaller than those used for TE and QA. +L60: +L61: In this paper we present a new dataset for claim verification, FEVER: Fact +L62: Extraction and VERification. It consists of 185,445 claims manually verified +L63: against the introductory sections of Wikipedia pages and classified as +L64: Supported, Refuted or NotEnoughInfo. For the first two classes, systems and +L65: annotators need to also return the combination of sentences forming the +L66: necessary evidence supporting or refuting the claim (see Figure 1). The claims +L67: were generated by human annotators extracting claims from Wikipedia and mutating +L68: them in a variety of ways, some of which were meaning-altering. The +L69: verification of each +L70: +L71: 809 +L72: +L73: Proceedings of NAACL-HLT 2018, pages 809–819 +L74: +L75: New Orleans, Louisiana, June 1 - 6, 2018. ©2018 Association for Computational +L76: Linguistics +L77: +L78: ===== Page 2 ===== +L79: +L80: claim was conducted in a separate annotation process by annotators who were +L81: aware of the page but not the sentence from which original claim was extracted +L82: and thus in 31.75% of the claims more than one sentence was considered +L83: appropriate evidence. Claims require composition of evidence from multiple +L84: sentences in 16.82% of cases. Furthermore, in 12.15% of the claims, this +L85: evidence was taken from multiple pages. +L86: +L87: To ensure annotation consistency, we developed suitable guidelines and user +L88: interfaces, resulting in inter-annotator agreement of 0.6841 in Fleiss (Fleiss, +L89: 1971) in claim verification classification, and 95.42% precision and 72.36% +L90: recall in evidence retrieval. +L91: +L92: To characterize the challenges posed by FEVER we develop a pipeline approach +L93: which, given a claim, first identifies relevant documents, then selects +L94: sentences forming the evidence from the documents and finally classifies the +L95: claim w.r.t. evidence. The best performing version achieves 31.87% accuracy in +L96: verification when requiring correct evidence to be retrieved for claims +L97: Supported or Refuted, and 50.91% if the correctness of the evidence is ignored, +L98: both indicating the difficulty but also the feasibility of the task. We also +L99: conducted oracle experiments in which components of the pipeline were replaced +L100: by the gold standard annotations, and observed that the most challenging part of +L101: the task is selecting the sentences containing the evidence. In addition to +L102: publishing the data via our website1, we also publish the annotation interfaces2 +L103: and the baseline system3 to stimulate further research on verification. +L104: +L105: Footnote 1: http://fever.ai +L106: +L107: Footnote 2: https://github.com/awslabs/fever +L108: +L109: Footnote 3: https://github.com/sheffieldnlp/fever-baselines +L110: +L111: ## 2 Related Works +L112: +L113: Vlachos and Riedel (2014) constructed a dataset for claim verification +L114: consisting of 106 claims, selecting data from fact-checking websites such as +L115: PolitiFact, taking advantage of the labelled claims available there. However, in +L116: order to develop claim verification components we typically require the +L117: justification for each verdict, including the sources used. While this +L118: information is usually available in justifications provided by the journalists, +L119: they are not in a machine-readable form. Thus, also considering the small number +L120: of claims, the task defined by the dataset proposed remains too challenging for +L121: the ML/NLP methods currently available. Wang (2017) extended this approach by +L122: including all 12.8K claims available by Politifact via its API, however the +L123: justification and the evidence contained in it was ignored in the experiments as +L124: it was not machine-readable. Instead, the claims were classified considering +L125: only the text and the metadata related to the person making the claim. While +L126: this rendered the task amenable to current NLP/ML methods, it does not allow for +L127: verification against any sources and no evidence needs to be returned to +L128: justify the verdicts. +L129: +L130: The Fake News challenge (Pomerleau and Rao, 2017) modelled verification as +L131: stance classification: given a claim and an article, predict whether the article +L132: supports, refutes, observes (neutrally states the claim) or is irrelevant to +L133: the claim. It consists of 50K labelled claim-article pairs, combining 300 claims +L134: with 2,582 articles. The claims and the articles were curated and labeled by +[28] (https://nlp.cs.washington.edu/triviaqa/) +**viewing lines [0 - 48] of 48** + +L0: +L1: URL: https://nlp.cs.washington.edu/triviaqa/ +L2: ## TriviaQA: A Large Scale Dataset for Reading Comprehension and Question +L3: Answering +L4: +L5: TriviaQA is a reading comprehension dataset containing over 650K question- +L6: answer-evidence triples. TriviaQA includes 95K question-answer pairs authored by +L7: trivia enthusiasts and independently gathered evidence documents, six per +L8: question on average, that provide high quality distant supervision for answering +L9: the questions. The details can be found in our ACL 17 paper TriviaQA: A Large +L10: Scale Distantly Supervised Challenge Dataset for Reading Comprehension +L11: +L12: Mandar Joshi, Eunsol Choi, Daniel Weld, Luke Zettlemoyer. TriviaQA: A Large +L13: Scale Distantly Supervised Challenge Dataset for Reading Comprehension +L14: +L15: In Association for Computational Linguistics (ACL) 2017, Vancouver, Canada. +L16: [bib] +L17: +L18: ## News +L19: +L20: ##### Jul 2017 +L21: +L22: The TriviaQA leaderboard is now live on Codalab. Submit your predictions for +L23: evaluation on the test set! +L24: +L25: ## Data +L26: +L27: If you are interested in the reading comprehension task motivated in the paper, +L28: click on the link below to download the data. +L29: +L30: If you are interested in open domain QA, click on the link below to download the +L31: data. It contains the unfiltered dataset with 110K question-answer pairs. The +L32: Wikipedia and top 10 search documents can be obtained from the RC version. The +L33: main difference between the RC version above and the unfiltered dataset is that +L34: not all documents (in the unfiltered set) for a given question contain the +L35: answer string(s). This makes the unfiltered dataset more appropriate for IR- +L36: style QA. +L37: +L38: The University of Washington does not own the copyright of the questions and +L39: documents included in TriviaQA. +L40: +L41: ## Code +L42: +L43: Check out our Github repository. +L44: +L45: ## Contact +L46: +L47: For any questions about the code or data, please contact Mandar Joshi -- {first +L48: name of the first author}90[at]cs[dot]washington[dot]edu +Error fetching URL `https://research.google/pubs/natural-questions-a-benchmark-for-question-answering-research/` +Error fetching URL `https://microsoft.github.io/msmarco/` +Error fetching URL `https://github.com/beir-cellar/beir` +Error fetching URL `https://project-miracl.github.io/` +[29] (https://ar5iv.org/abs/2207.01206) +**viewing lines [0 - 122] of 1392** + +L0: +L1: URL: https://ar5iv.org/abs/2207.01206 +L2: # WebShop: Towards Scalable Real-World Web Interaction with Grounded Language +L3: Agents +L4: +L5: Shunyu Yao111Equal contribution. Project site with code, data, and demos: +L6: https://webshop-pnlp.github.io. Howard Chen111Equal contribution. Project site +L7: with code, data, and demos: https://webshop-pnlp.github.io. John Yang Karthik +L8: Narasimhan +L9: +L10: Department of Computer Science, Princeton University +L11: {shunyuy, howardchen, jy1682, karthikn}@princeton.edu +L12: +L13: ###### Abstract +L14: +L15: Existing benchmarks for grounding language in interactive environments either +L16: lack real-world linguistic elements, or prove difficult to scale up due to +L17: substantial human involvement in the collection of data or feedback signals. To +L18: bridge this gap, we develop WebShop – a simulated e-commerce website environment +L19: with million real-world products and 1.181.18 crowd-sourced text instructions. +L20: Given a text instruction specifying a product requirement, an agent needs to +L21: navigate multiple types of webpages and issue diverse actions to find, +L22: customize, and purchase an item. WebShop provides several challenges for +L23: language grounding including understanding compositional instructions, query +L24: (re-)formulation, comprehending and acting on noisy text in webpages, and +L25: performing strategic exploration. We collect over 12,08712,087 human +L26: demonstrations for the task, and train and evaluate a diverse range of agents +L27: using reinforcement learning, imitation learning, and pre-trained image and +L28: language models. Our best model achieves a task success rate of 1,6001,600, +L29: which outperforms rule-based heuristics (29%29\%) but is far lower than human +L30: expert performance (9.6%9.6\%). We also analyze agent and human trajectories and +L31: ablate various model components to provide insights for developing future +L32: agents with stronger language understanding and decision making abilities. +L33: Finally, we show that agents trained on WebShop exhibit non-trivial sim-to-real +L34: transfer when evaluated on amazon.com and ebay.com , indicating the potential +L35: value of WebShop in developing practical web-based agents that can operate in +L36: the wild.59%59\% +L37: +L38: ## 1 Introduction +L39: +L40: Recent advances in natural language processing (NLP) and reinforcement learning +L41: (RL) have brought about several exciting developments in agents that can perform +L42: sequential decision making while making use of linguistic context [30, 50, 58]. +L43: On the other hand, large-scale language models like GPT-3 [6] and BERT [11] are +L44: excelling at traditional NLP benchmarks such as text classification, +L45: information extraction and question answering. While the former set of tasks are +L46: limited in their set of linguistic concepts and prove difficult to scale up, +L47: the latter tasks usually contain static, non-interactive datasets that lack +L48: adequate grounding to extra-linguistic concepts [4]. In order to make further +L49: progress in building grounded language models, we believe there is a need for +L50: scalable interactive environments that contain: (1) language elements that +L51: reflect rich, real-world usage and are collectible at scale, and (2) task +L52: feedback that is well-defined and automatically computable to facilitate +L53: interactive learning, without the constant need for expensive feedback from +L54: humans. +L55: +L56: The world wide web (WWW) is a massive open-domain interactive environment that +L57: inherently satisfies the first aforementioned requirement through its +L58: interconnected set of pages with natural text, images and interactive elements. +L59: By being simultaneously scalable, semantic, interactive, dynamic and realistic, +L60: the web is uniquely different from existing environments for autonomous agents +L61: like games or 3D navigation. Moreover, the web also provides a practical +L62: environment to deploy trained agents, with great potential for alleviating human +L63: efforts in tedious tasks (e.g. buying products, booking appointments). While +L64: there has been prior work on building web-based tasks, they either lack depth in +L65: the transition and action spaces, or prove difficult to scale up. Some +L66: benchmarks only contain either a single classification task [39, 46, 31] or +L67: interactions containing only a handful of different pages in each episode [43]. +L68: Others propose tasks with longer horizons but are either limited to following +L69: hyperlinks for web navigation [36] or require human-in-the-loop feedback due to +L70: the lack of an automated reward function [33]. +L71: +L72: In this paper, we introduce WebShop (Figure 1) – a large-scale interactive web- +L73: based environment for language understanding and decision making – and train +L74: autonomous agents to complete tasks on this benchmark. With the goals of being +L75: scalable and containing realistic language and visual elements, WebShop emulates +L76: the task of online shopping on an e-commerce website, where the agent’s goal is +L77: to understand a human-provided text instruction and purchase a product to match +L78: the specifications. To do so, the agent needs to query the website’s search +L79: engine, choose items to explore from search results, open and read their +L80: description and details, and select the necessary options (e.g. 32 oz., red +L81: color) before clicking the ‘Buy’ button. In order to pick the optimal product +L82: that matches user requirements, the agent may need to view and compare various +L83: products (including backtracking between pages), and potentially perform +L84: multiple searches. WebShop contains over one million products scraped from +L85: amazon.com, over thousand crowdsourced instructions, and a diverse semantic +L86: action space of searching text queries and choosing text buttons. It is packaged +L87: into a convenient OpenAI Gym [5] environment and can be rendered in two modes +L88: (HTML or simple) with parallel observation spaces that are easy for human and +L89: model respectively. Rewards are automatically computed using a combination of +L90: programmatic matching functions that consider the attributes, type, options and +L91: price of the chosen product, alleviating the need for human evaluation and +L92: providing a path to scaling up interactive learning.1212 +L93: +L94: We develop several agents to perform this task, using both reinforcement +L95: learning (RL) and imitation learning (IL). We also leverage the latest pre- +L96: trained language models [26, 11] for representing and generating text. Our +L97: modular architecture includes a factorized processing of state observations and +L98: action choices using ResNets (visual) and Transformers (text), followed by an +L99: attention fusion layer that helps the agent contextually score each action. Our +L100: best agent achieves an average score of (out of 62.462.4) and successfully +L101: completes the task 100100 of the time, significantly higher than a heuristic +L102: baseline that achieves 28.7%28.7\% and 45.645.6, respectively. While this +L103: demonstrates the potential for IL and RL, the agents are still much lower than +L104: human experts, who can achieve 9.6%9.6\% and 82.182.1 on this task.*** In our +L105: analysis (§5.3), we observe that the task requires patience and consistency, +L106: which is lacking in some crowdsource workers, leading to lower scores. Even with +L107: this caveat, the gap between human performance and the model remains +L108: significant. We perform several analyses and ablation studies to identify the +L109: cause of this gap and find several avenues for agent improvement in the future +L110: including more robust search generation, explicit memory modules, and better +L111: handling of noisy web text. Finally, we also demonstrate an instance of sim-to- +L112: real transfer by deploying agents trained with WebShop to operate on amazon.com +L113: and ebay.com, and find that they can achieve similar performances despite search +L114: engine and product differences, and consistently outperform the rule baseline +L115: of using the first result returned by the commercial search engines when +L116: directly searching the instruction texts. This demonstrates the practical +L117: potential of our work towards developing agents that can operate autonomously on +L118: the world wide web (WWW).59.6%59.6\% +L119: +L120: ## 2 Related Work +L121: +L122: Reinforcement learning on the web. Nogueira and Cho [36] introduced WikiNav as a +Error fetching URL `http://alfworld.github.io/` +Error fetching URL `https://osu-nlp-group.github.io/Mind2Web/` +Error fetching URL `https://github.com/web-arena-x/visualwebarena` +[30] (https://ar5iv.org/pdf/2406.12172) +**viewing lines [0 - 127] of 1478** + +L0: +L1: URL: https://ar5iv.org/pdf/2406.12172 +L2: # Navigating the Labyrinth: Evaluating and Enhancing LLMs’ Ability to Reason +L3: About Search Problems +L4: +L5: Nasim Borazjanizadeh +L6: +L7: Berkeley AI Research, UC Berkeley +L8: \AndRoei Herzig +L9: Berkeley AI Research, UC Berkeley +L10: \AndTrevor Darrell +L11: Berkeley AI Research, UC Berkeley +L12: \AndRogerio Feris +L13: MIT-IBM Watson AI Lab +L14: \AndLeonid Karlinsky +L15: MIT-IBM Watson AI Lab +L16: +L17: ###### Abstract +L18: +L19: Recently, Large Language Models (LLMs) attained impressive performance in math +L20: and reasoning benchmarks. However, they still often struggle with logic problems +L21: and puzzles that are relatively easy for humans. To further investigate this, +L22: we introduce a new benchmark, SearchBench, containing 11 unique search problems, +L23: each equipped with automated pipelines to generate an arbitrary number of +L24: instances and analyze the feasibility, correctness, and optimality of LLM- +L25: generated solutions. We show that even the most advanced LLMs fail to solve +L26: these problems end-to-end in text, e.g., GPT4 solves only 1.4%. SearchBench +L27: problems require considering multiple pathways to the solution as well as +L28: backtracking, posing a significant challenge to auto-regressive models. +L29: Instructing LLMs to generate code that solves the problem helps, but only +L30: slightly, e.g., GPT4’s performance rises to 11.7%. In this work, we show that +L31: in-context learning with A* algorithm implementations enhances performance. The +L32: full potential of this promoting approach emerges when combined with our +L33: proposed Multi-Stage-Multi-Try method, which breaks down the algorithm +L34: implementation into two stages and verifies the first stage against unit tests, +L35: raising GPT-4’s performance above 57%. +L36: +L37: \doparttoc\faketableofcontents +L38: +L39: ### 1 Introduction +L40: +L41: The advent of Large Language Models (LLMs) has revolutionized the field of +L42: natural language processing, with models like Gemini[18], GPT-4[26] +L43: demonstrating unprecedented performance on reasoning tasks such as GSM8k[8]. +L44: However, these models still exhibit surprising failures on some intuitive +L45: tasks[2, 30, 22] and struggle with multi-step compositional reasoning, +L46: combinatorial problems, and planning [9, 40, 44]. Inspired by these observations +L47: and to further investigate LLMs’ reasoning abilities, we offer a new benchmark +L48: of search problems, SearchBench. The problems in SearchBench are combinatorial, +L49: defined as tasks that involve finding an optimal object from a finite set of +L50: objects, where the set of feasible solutions is either discrete or can be +L51: reduced to a discrete set [43]. These problems are predominantly NP-hard and +L52: necessitate systematic exploration of action paths and backtracking to +L53: intermediate feasible states; thus, SearchBench implicitly investigates the +L54: LLM’s capacity for non-linear reasoning. +L55: +L56: SearchBench has five distinct problem categories: (i) pathfinding, (ii) puzzles, +L57: (iii) subset sum, (iv) sorting, and (v) under-determined systems; further +L58: divided into 11 unique problem types. Each problem type is inspired by known +L59: puzzles and combinatorial problems but augmented with modified rules and +L60: constraints to ensure substantial differences from similar problems LLMs +L61: encountered during their training. And the solution to each problem is a +L62: sequence of actions leading from the initial state to the goal state, while +L63: optimizing a cost. We generate100 instances of varying difficulty per problem +L64: type using an automatic pipeline, resulting in 1107 problem instances total. +L65: Each problem type in SearchBench is equipped with an automatic pipeline that +L66: evaluates LLM-generated solutions on three dimensions: feasibility, correctness, +L67: and optimality. Feasibility checks whether the actions taken follow the +L68: problem’s rules; correctness verifies if a feasible solution reaches the goal +L69: state; and optimality checks if the least cost solution was found.∼\sim +L70: +L71: SearchBench is challenging to LLMs due to several factors. Firstly, natural +L72: language is less suited for describing or updating accurate representations of +L73: complex intermediate states. Secondly, our experiments show LLMs struggle with +L74: exploring a combinatorial exponentially exploding state-space. Despite the fact +L75: that some methods were developed for long-context reasoning [4, 13, 50], +L76: SearchBench problems cannot be easily summarized [4], reasoned about [13], or +L77: processed in parallel due to their size [50, 45]. Our findings show that even +L78: the strongest LLMs [26] almost completely fail to solve SearchBench problems in +L79: text-only mode. +L80: +L81: To provide further insights, we show that LLMs’ performance on SearchBench +L82: improves by prompting the models to solve the problems using the A* search +L83: algorithm [11]. A* is a heuristic-based graph traversal algorithm known for its +L84: time efficiency and provable optimality guarantees, making it the most suitable +L85: search algorithm for solving the problems in our benchmark. This method +L86: leverages A*’s correctness and optimality, while offloading some of the non- +L87: linear computations involved in searching the state-space to code execution. +L88: Additionally, to improve the quality of generated A* codes, motivated that +L89: ensembling helps generation quality[41, 47, 21], we introduce the Multi-Stage- +L90: Multi-Try (MSMT) inference strategy. In the "Multi-Try" aspect of MSMT, before +L91: evaluating the solution returned by the code, we first verify whether the code +L92: generated by the model satisfies a set of unit tests: (i) it is executable; (ii) +L93: it returns a list as output; and (iii) data type of list elements is correct. +L94: If the code fails any of the tests, MSMT re-runs the LLM until a valid code is +L95: generated or allowed number of attempts is exhausted. The "Multi-Stage" aspect +L96: of MSMT generates the code in two steps: (i) ‘A* Implementation’ - the +L97: implementation of an instance-agnostic A* algorithm for the problem type; and +L98: (ii) Initialization - the instantiation of initial conditions and state +L99: variables of the problem instance. In MSMT ’Initialization’ is generated +L100: conditioned on the ‘A* Implementation’ (which is generated separately first and +L101: provided in ‘Intitialization’ prompt). We demonstrate that our MSMT A* method +L102: (Fig. 2) significantly enhances the LLMs’ ability to solve search problems, +L103: outperforming all other prompting strategies we used to evaluate models on +L104: SearchBench, including 0-shot text, 4-shot Chain-of-Thought (CoT)[42] text, +L105: 0-shot code generation, and 4-shot A* prompting with the naive greedy decoding +L106: strategy. +L107: +L108: To summarize, our main contributions are as follows: (i) We contribute the +L109: SearchBench benchmark designed to assess the capability of LLMs in solving +L110: state-based problems requiring combinatorial search; (ii) We introduce the A* +L111: Multi-Stage-Multi-Try (MSMT) framework that leverages the flexibility and +L112: natural language comprehension capabilities of LLMs, reliability and structured +L113: nature of the A* algorithm, along with unitests to verify LLMs’ code generation +L114: in order to address search problems. (iii) We show a strong advantage of MSMT A* +L115: prompting for solving search problems robustly with LLMs, suggesting +L116: interesting future work directions and yet leaving significant room for future +L117: improvement on SearchBench. +L118: +L119: ### 2 SearchBench Benchmark +L120: +L121: SearchBench includes five categories of problems: puzzles, subset sum, sorting, +L122: pathfinding, and under-determined systems, further subdivided into 11 problem +L123: types and instantiated as 1107 natural language problem instances. Each problem +L124: type requires a different algorithmic approach to solve. Our problems are +L125: inspired by well-known (mostly NP-hard) puzzles and search problems, though +L126: their rules were modified to ensure a solved instance of a problem similar to +L127: any SearchBench problem was not observed during LLMs’ massive internet-scale +Error fetching URL `https://ar5iv.labs.ar5iv.org/html/2312.15820` +Error fetching URL `https://github.com/jxwuyi/WebNav` + + +在同时打开多个基准测试相关页面后,我们获得了大量信息。WebVoyager 是一个多模态网页智能体基准测试,专注于在真实网站上进行端到端任务执行,其特点包括多模态输入(屏幕截图和文本)、在15个常用网站上的300个任务,以及使用GPT-4V进行自动评估。使用场景是评估网页智能体在真实世界中的导航和任务完成能力。例题可能是在亚马逊上找到特定产品并将其加入购物车。 + +WebArena 是一个真实且可复现的网页环境,用于评估自主智能体。其特点包括四个功能完整的网站(电子商务、论坛、协作开发、内容管理)、812项长视野任务,以及专注于功能正确性的评估。使用场景是评估智能体在复杂网页任务中的表现。例题可能是在电子商务网站上购买符合特定标准的商品。 + +GAIA 是一个针对通用AI助手的基准测试,要求具备推理、多模态处理、网页浏览和工具使用等能力。其特点包括466个问题,对人类来说概念简单但对AI具有挑战性,答案简洁明确,支持自动评估。使用场景是评估AI助手在真实世界问题中的表现。例题可能是“找出2022年法国总统选举中获胜者的出生城市”。 + +ToolBench 是一个用于评估LLM工具操作能力的基准测试。其特点包括多样化的软件工具、真实世界任务,以及易于使用的基础设施。使用场景是评估智能体使用工具完成任务的能力。例题可能是使用计算器API计算表达式,或使用搜索API查找信息。 + +HotpotQA 是一个多跳问答基准测试,其特点包括基于Wikipedia的113k个问题-答案对,需要多文档推理,并提供句子级别的支持事实。使用场景是评估模型在回答需要多跳推理的问题时的能力。例题可能是“《了不起的盖茨比》的作者还写过哪些小说?”。 + +FEVER 是一个事实提取和验证基准测试,其特点包括185,445个基于Wikipedia的声明,需要分类为支持、反驳或信息不足,并需要提供证据句子。使用场景是评估系统验证声称事实的能力。例题可能是“斐济最大的岛屿是考艾岛”,系统需要检索证据来支持或反驳这一说法。 + +TriviaQA 是一个阅读理解基准测试,其特点包括超过65万个问题-答案-证据三元组,问题由琐事爱好者编写,并附带证据文档。使用场景是评估模型在开放域问答中的能力。例题可能是“哪位演员在《星球大战:原力觉醒》中扮演凯洛·伦?”。 + +Natural Questions 是一个问答基准测试,其特点包括来自谷歌搜索的真实问题,答案来自Wikipedia,并包含长答案和短答案。使用场景是评估开放域问答系统。例题可能是“谁写了《傲慢与偏见》?”,答案可能是“简·奥斯汀”。 + +MS MARCO 是一个用于深度学习搜索的基准测试,其特点包括来自Bing的真实搜索查询、人工生成的答案,以及用于段落排序和文档排序的任务。使用场景是训练和评估检索模型。例题可能是“如何冲泡咖啡?”,系统需要检索相关段落来回答。 + +BEIR 是一个用于零样本评估的异构信息检索基准测试,其特点包括18个数据集、9项任务,涵盖事实核查、问答、对话检索等。使用场景是评估检索模型在零样本设置下的泛化能力。例题可能是给定一个查询,在文档集合中检索相关文档。 + +MIRACL 是一个多语言信息检索基准测试,其特点涵盖18种语言,专注于ad-hoc检索,并包含人工标注的相关性判断。使用场景是评估多语言检索模型。例题可能是用中文查询“全球变暖的原因”,并检索相关文档。 + +WebShop 是一个模拟电子商务网站环境,用于评估语言智能体。其特点包括118万个真实产品、12,087条众包文本指令,以及需要导航、搜索和购买的任务。使用场景是评估智能体在电子商务任务中的表现。例题可能是“找到一款价格低于50美元的无线蓝牙耳机”。 + +ALFWorld 是一个将文本和具身环境对齐的基准测试,用于在文本世界中训练智能体,然后将其转移到具身环境中。其特点包括基于文本的环境,与ALFRED数据集并行,并专注于指令跟随。使用场景是评估智能体在家庭任务中的规划能力。例题可能是“在厨房里找到一个杯子,把它拿到客厅”。 + +Mind2Web 是一个用于评估网页通用智能体的基准测试,其特点包括来自真实网站的任务、需要多步交互,并专注于语言指令跟随。使用场景是评估智能体在复杂网页任务中的表现。例题可能是“在GitHub上创建一个新的仓库,命名为‘test-repo’”。 + +VisualWebArena 是一个用于多模态智能体的基准测试,其特点包括真实视觉网页任务、需要理解视觉和文本元素,并评估功能正确性。使用场景是评估多模态智能体在视觉接地网页任务中的表现。例题可能是“在纽约时报网站上找到今天头条新闻的图片”。 + +SearchBench 是一个用于评估LLM解决搜索问题能力的基准测试,其特点包括11种搜索问题类型、自动生成实例,并评估可行性、正确性和最优性。使用场景是评估LLM在组合搜索问题中的推理能力。例题可能是“在一个迷宫中找到从起点到终点的最短路径”。 + +WebVLN 是一个网站上的视觉语言导航基准测试,其特点包括基于问题的指令、需要导航网站,并评估导航成功率。使用场景是评估智能体在网站中基于视觉和语言指令进行导航的能力。例题可能是“找到‘联系我们’页面并提取电子邮件地址”。 + +WebNav 是一个用于目标驱动网页导航的基准测试,其特点包括部分观察环境、需要规划和理解自然语言指令。使用场景是评估智能体在网页导航中的能力。例题可能是“导航到维基百科上关于‘人工智能’的页面”。 + +现在,我需要整理一份针对搜索智能体的基准测试列表。考虑到用户可能更关注中文资料,我会优先考虑那些有中文描述的基准测试,例如BrowseComp、HSCodeComp、GAIA、WebVoyager、WebArena、Mind2Web、WebShop、ALFWorld、HotpotQA、FEVER、TriviaQA、Natural Questions、MS MARCO、BEIR、MIRACL、SearchBench、WebVLN、WebNav。我会为每个基准测试提供特点、使用场景和例题。同时,我会引用相关来源,例如HSCodeComp论文、BrowseComp博客文章、WebVoyager论文、WebArena论文、GAIA论文、WebShop论文、SearchBench论文等。我的回答将用中文呈现,并确保内容全面、结构清晰。现在开始撰写回答。针对搜索智能体(Search Agent)的评估,已涌现出多个侧重点不同的基准测试。以下将分类介绍这些基准,涵盖其核心特点、典型使用场景及例题。 + +--- + +## 🔍 1. 网页浏览/导航类基准 +这类基准要求智能体在真实或模拟的网页环境中执行多步操作,以完成复杂任务。 + +| 基准 | 核心特点 | 使用场景 | 例题 | +|------|----------|----------|------| +| **BrowseComp**(OpenAI)【3†L16-L35】 | - 1,266 个挑战性问题,需持续导航、查找难以直接搜索的信息
- 问题设计保证人类在 10 分钟内无法解决,且现有模型(包括带浏览的 ChatGPT)也难以完成
- 答案为短字符串,便于自动验证 | 评估智能体在复杂网络浏览任务中的持久性、创造性搜索能力 | “找出 2018‑2023 年间在 EMNLP 会议上发表、第一作者本科毕业于达特茅斯学院、第四作者本科毕业于宾夕法尼亚大学的科学论文标题”【3†L45-L49】 | +| **WebVoyager**【24†L18-L33】 | - 多模态(视觉+文本)端到端网页智能体基准
- 包含 300 个真实任务,覆盖 15 个常用网站(如 Amazon、Wikipedia)
- 使用 GPT‑4V 进行自动评估,与人工评估一致性达 85.3% | 评估智能体在真实网站上的端到端任务完成能力(多模态输入) | “在 Amazon 上找到‘无线蓝牙耳机’并按价格从低到高排序,然后选择最便宜的一款加入购物车” | +| **WebArena**【25†L15-L35】 | - 完全可交互的、高度真实的网页环境(4 个领域:电商、论坛、协作开发、内容管理)
- 812 项长视野任务,评估功能正确性而非表面动作匹配
- 最佳 GPT‑4 智能体成功率仅 14.41%,远低于人类的 78.24% | 评估自主智能体在真实网页环境中的复杂任务执行能力 | “在电商网站中购买一双‘黑色运动鞋’,价格低于 $80,并将收货地址修改为纽约”【25†L86-L92】 | +| **Mind2Web**【19†L4-L14】 | - 首个面向通用网页智能体的数据集,任务来自真实网站
- 支持在线评估(Online‑Mind2Web),包含 300 个实时任务、136 个流行网站
- 强调多步指令跟随与复杂交互 | 评估智能体在任意网站上的通用任务完成能力 | “在 GitHub 上创建一个名为 ‘test‑repo’ 的公开仓库,并添加一个 README.md 文件” | +| **VisualWebArena**【20†L4-L10】 | - 多模态网页智能体基准,任务依赖视觉信息(截图)
- 包含丰富的视觉接地任务,如识别图片、按钮位置等
- 评估智能体在真实视觉网页环境中的表现 | 评估多模态智能体在视觉丰富的网页任务中的能力 | “在纽约时报网站找到今天头条新闻的配图,并描述图中主要物体” | +| **WebVLN**【22†L4-L11】 | - 网站上的视觉‑语言导航任务,智能体需根据问题指令在网站中导航
- 数据集 WebVLN‑v1 包含多种网站导航场景
- 评估导航成功率与任务完成度 | 评估智能体在网站中基于视觉与语言指令的导航能力 | “找到‘联系我们’页面,并提取其中的邮箱地址” | +| **WebNav**【23†L4-L11】 | - 目标驱动的网页导航基准,环境为部分观测
- 智能体需理解自然语言指令并在网页中规划动作序列
- 早期经典任务,常用于测试基础导航能力 | 评估智能体在部分观测网页环境中的规划与指令跟随能力 | “导航到 Wikipedia 上关于‘人工智能’的页面,并摘录第一段内容” | + +--- + +## 🛠️ 2. 工具使用/通用助手类基准 +这类基准测试智能体调用外部工具(如搜索、计算、API)来解决实际问题的能力。 + +| 基准 | 核心特点 | 使用场景 | 例题 | +|------|----------|----------|------| +| **GAIA**【26†L10-L25】 | - 466 个真实世界问题,需要推理、多模态处理、网页浏览、工具使用等综合能力
- 对人类简单(92% 准确率),但对最强 GPT‑4 仅 15%
- 答案唯一、事实性,支持自动评估 | 评估通用 AI 助手在需要多步工具使用的真实任务中的表现 | “2022 年法国总统选举的获胜者出生在哪个城市?”【26†L10-L25】 | +| **ToolBench**(OpenBMB)【7†L4-L10】 | - 包含多样软件工具的真实任务基准
- 提供易用的基础设施(虚拟 API 服务器等)
- 评估 LLM 调用工具完成复杂任务的能力 | 评估智能体在工具调用与组合上的熟练度 | “使用计算器 API 计算 (3.14 * 7.2)^2,然后使用搜索 API 查找‘圆周率’的最新精确值” | +| **HSCodeComp**【2†L16-L35】 | - 首个针对深度搜索智能体在分层规则应用上的专家级电商基准
- 任务:根据嘈杂的商品描述预测 10 位 HS 编码(关税规则)
- 632 个商品条目,标注由多位电商专家完成,最佳智能体准确率仅 46.8%(人类 95%) | 评估智能体在复杂规则(如关税分类)下的深层推理与搜索能力 | “给定商品描述‘男士纯棉针织T恤,领口为V领’,预测其 10 位 HS 编码”【2†L24-L32】 | + +--- + +## ❓ 3. 问答/检索类基准 +这类基准主要评估智能体在开放域或特定领域的信息检索与问答能力。 + +| 基准 | 核心特点 | 使用场景 | 例题 | +|------|----------|----------|------| +| **HotpotQA**【8†L4-L10】 | - 113k 个基于 Wikipedia 的多跳问答对
- 需要跨多个文档推理,并提供句子级支持事实
- 涵盖多样的问题类型(比较、列举、因果等) | 评估模型在需要多跳推理的开放域问答中的能力 | “《了不起的盖茨比》的作者还写过哪些小说?” | +| **FEVER**【27†L16-L28】 | - 185,445 个基于 Wikipedia 的声明,需分类为 Supported/Refuted/NotEnoughInfo
- 要求提供证据句子(可多句、多页面)
- 挑战性高(最佳系统仅 31.87% 准确率) | 评估系统在事实核查与证据检索上的能力 | “斐济最大的岛屿是考艾岛。”【27†L48-L50】 | +| **TriviaQA**【28†L5-L10】 | - 超过 650k 个问答‑证据三元组,问题由琐事爱好者编写
- 每个问题平均提供 6 篇证据文档,适合远程监督阅读
- 包含阅读理解和开放域 QA 两种任务设置 | 评估模型在开放域琐事问答中的检索与阅读理解能力 | “哪位演员在《星球大战:原力觉醒》中扮演凯洛·伦?” | +| **Natural Questions**【11†L4-L8】 | - 来自 Google 搜索的真实用户问题,答案来自 Wikipedia
- 包含长答案(段落)和短答案(实体/日期等)
- 训练集 307k,开发/测试集各 8k | 评估开放域问答系统对真实用户查询的响应能力 | “谁写了《傲慢与偏见》?”(答案:“简·奥斯汀”) | +| **MS MARCO**【12†L4-L10】 | - 基于 Bing 真实搜索查询的检索/问答基准
- 包含段落排序、文档排序、问答等任务
- 数据规模大(约 500k 查询,880 万段落) | 训练与评估检索模型在真实搜索场景下的表现 | “如何冲泡咖啡?”(系统需检索相关段落回答) | +| **BEIR**【13†L4-L10】 | - 异构零样本检索基准,包含 18 个数据集、9 种任务(事实核查、QA、对话检索等)
- 评估模型在未见任务上的泛化能力
- 已成为检索模型的标准评估套件 | 评估检索模型在零样本设置下的跨任务泛化能力 | 给定查询“全球变暖的主要原因”,在文档集合中检索相关文档 | +| **MIRACL**【14†L4-L10】 | - 多语言检索基准,覆盖 18 种语言(包括中文、阿拉伯语等)
- 专注于 ad‑hoc 检索,所有查询与文档均经过人工标注
- 提供多语言检索挑战(WSDM 2023 Cup) | 评估多语言检索模型在不同语言下的检索效果 | 中文查询:“全球变暖的原因”,检索相关中文文档 | + +--- + +## 🧩 4. 组合搜索/推理类基准 +这类基准专门测试智能体在组合搜索问题上的逻辑推理与规划能力。 + +| 基准 | 核心特点 | 使用场景 | 例题 | +|------|----------|----------|------| +| **SearchBench**【30†L22-L36】 | - 11 种独特的搜索问题类型(路径查找、谜题、子集和、排序、欠定系统等)
- 自动生成任意数量实例,并评估解决方案的可行性、正确性、最优性
- GPT‑4 在纯文本模式下仅解决 1.4%,通过 A* 代码生成提升至 11.7% | 评估 LLM 在组合搜索问题上的推理与规划能力 | “在一个 5x5 迷宫中,从起点 (0,0) 到终点 (4,4) 的最短路径是什么?(障碍物位置已知)”【30†L56-L64】 | + +--- + +## 🛒 5. 特定领域/仿真环境基准 +这类基准针对特定领域(如电商、家庭环境)设计,测试智能体在仿真环境中的交互能力。 + +| 基准 | 核心特点 | 使用场景 | 例题 | +|------|----------|----------|------| +| **WebShop**【29†L15-L30】 | - 模拟电商网站环境,包含 118 万真实商品、12,087 条众包指令
- 智能体需根据文本指令搜索、筛选、购买商品
- 奖励自动计算,最佳模型成功率 29%(人类 59%) | 评估语言智能体在电商任务中的导航、搜索与决策能力 | “找到一款价格低于 $50 的无线蓝牙耳机,评分至少 4 星,并加入购物车”【29†L20-L25】 | +| **ALFWorld**【18†L4-L10】 | - 将文本环境(TextWorld)与具身环境(ALFRED)对齐的基准
- 智能体先在文本中学习抽象策略,然后在视觉环境中执行
- 包含 134 个家庭任务(如“在厨房找到杯子并拿到客厅”) | 评估智能体在文本与具身环境之间的迁移学习与规划能力 | “在厨房找到‘杯子’,把它拿到客厅放在桌子上” | + +--- + +## 📊 总结 +以上基准覆盖了搜索智能体的不同能力维度:从**网页浏览、工具使用、问答检索**,到**组合搜索推理**以及**特定领域仿真**。选择合适的基准取决于您要评估的具体能力(如多模态交互、工具调用、多跳推理等)。这些基准大多提供公开的数据集与评估代码,可直接用于实验。 + +> 注:本文基于 2025 年 11 月前的公开资料整理,部分基准可能已有更新版本,建议访问对应官网获取最新信息。<|end▁of▁sentence|> \ No newline at end of file diff --git a/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_search_wo_date.txt b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_search_wo_date.txt new file mode 100644 index 00000000000..bc99497a433 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_search_wo_date.txt @@ -0,0 +1,1069 @@ +<|begin▁of▁sentence|><|User|> + +## Tools + +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<|DSML|function_calls>" block like the following as part of your reply to the user: +<|DSML|function_calls> +<|DSML|invoke name="$FUNCTION_NAME"> +<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<|DSML|invoke name="$FUNCTION_NAME2"> +... + + + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). + +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: + +<|DSML|function_calls> +... + + + +... + + +...thinking about results + +Here are the functions available in JSONSchema format: + +{"name": "search", "description": "Searches for information related to query and displays topn results.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The search query string"}, "topn": {"type": "integer", "description": "Number of top results to display", "default": 10}, "source": {"type": "string", "description": "Source to search within", "enum": ["web", "news"], "default": "web"}}, "required": ["query"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} +{"name": "open", "description": "Opens the link id from the page indicated by cursor starting at line number loc, showing num_lines lines. Valid link ids are displayed with the formatting: 【{id}†.*】. If cursor is not provided, the most recent page is implied. If id is a string, it is treated as a fully qualified URL associated with source. If loc is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available. Use this function without id to scroll to a new location of an opened page.", "parameters": {"type": "object", "properties": {"id": {"description": "Link ID (number) or fully qualified URL (string)", "anyOf": [{"type": "integer"}, {"type": "string"}], "default": -1}, "cursor": {"type": "integer", "description": "Page cursor indicator", "default": -1}, "loc": {"type": "integer", "description": "Starting line number", "default": -1}, "num_lines": {"type": "integer", "description": "Number of lines to show", "default": -1}, "view_source": {"type": "boolean", "description": "Whether to view source", "default": false}, "source": {"type": "string", "description": "Source associated with the URL", "enum": ["web", "news"], "default": "web"}}, "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} +{"name": "find", "description": "Finds exact matches of pattern in the current page, or the page given by cursor.", "parameters": {"type": "object", "properties": {"pattern": {"type": "string", "description": "The exact pattern to search for"}, "cursor": {"type": "integer", "description": "Page cursor to search in", "default": -1}}, "required": ["pattern"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} +{"name": "python", "description": "Use this tool to execute Python code. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files). When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The python code to execute in the jupyter environment."}}, "required": ["code"], "additionalProperties": false, "$schema": "http://json-schema.org/draft-07/schema#"}} + + + +# The user's message is: In a year before 2023, between December 1 and 15 (inclusive), a press release reported that as of the 5th of that month, a country had received a total of 16,201,670 vaccine doses, with 7,583,134 administered. That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services. The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%. Who was the emeritus board member that year? Name only, no titles.<|Assistant|>The user asks: "In a year before 2023, between December 1 and 15 (inclusive), a press release reported that as of the 5th of that month, a country had received a total of 16,201,670 vaccine doses, with 7,583,134 administered. That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services. The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%. Who was the emeritus board member that year? Name only, no titles." + +We need to parse the puzzle. + +We have clues about a press release about vaccine doses for a country, dated between Dec 1-15 inclusive, reporting as of 5th of that month total doses received 16,201,670 and administered 7,583,134. That is likely a COVID-19 vaccine update for some country. Which country? Might be a smaller country. Could be something like "Uganda", "Zimbabwe", "Kenya", "Philippines", etc. Let's search. + +Search for "16,201,670 vaccine doses" "7,583,134 administered". + +<|DSML|function_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">"16,201,670" vaccine doses +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|end▁of▁sentence|> + + +[0] "16,201,670" vaccine doses (Search_Results/"16,201,670" vaccine doses) +**viewing lines [0 - 38] of 38** + +L0: +L1: URL: Search_Results/"16,201,670" vaccine doses +L2: # Search Results +L3: +L4: * 【0†Kenya's Economy is Showing Resilience as Output Rises ...; publish_date: +L5: none†www.worldbank.org】 Dec 14, 2021 — As of December 5, 2021, Kenya had +L6: received a total of 16,201,670 vaccines, with 7,583,134 administered. While +L7: vaccine acceptance is ... +L8: * 【1†MINISTRY OF HEALTH; publish_date: none†covidhub.mediacouncil.or.ke】 Dec +L9: 1, 2021 — Total Doses Received 16,201,670. Table 10 gives the total vaccines +L10: received since the start of Covid -19 vaccination exercise in the country. +L11: * 【2†Output Result Page; publish_date: none†open.unicef.org】 ... 16,201,670 +L12: doses of multiple vaccines nationwide and full vaccination of 15.5 per cent with +L13: two doses of COVID-19 vaccine as of 31 December 2021. +L14: * 【3†rebased GDP; publish_date: none†documents1.worldbank.org】 Dec 7, 2021 — +L15: As of December 5, 2021,. Kenya had received a total of 16,201,670 vaccines, +L16: with. 7,583,134 administered. Vaccine acceptance is reportedly high. +L17: * 【4†Integrated Annual Report; publish_date: none†www.co-opbank.co.ke】 May 27, +L18: 2022 — ... doses of Covid-19 vaccines and administered close to 17 million ... +L19: 16,201,670, huku 7,583,134 zikiwa tayari zimedungwa watu. Bado kuna ... +L20: * 【5†World Bank lifts Kenya's growth prospect to 5% in 2021; publish_date: +L21: none†www.africa-press.net】 ... 16,201,670 vaccines, with 7,583,134 administered. +L22: While vaccine acceptance is reportedly high there is still a long way to go +L23: towards the government's ... +L24: * 【6†2020 Annual Report United States; publish_date: none†www.hi-us.org】 +L25: including medicine, vaccines, Covid-19 prevention kits, and kits of ... +L26: 16,201,670. 9,630,584. Prepaid expenses. 56,707. 22,601. Total current assets. +L27: * 【7†HOSPITAL CENTRAL DE LA DEFENSA; publish_date: +L28: none†contrataciondelestado.es】 16201670 X. URPA-URPA. 0. HCD-000393/00. KONTRON. +L29: 7000 ... LASER SHOT LBP-1120. CNEUR-C. NEUROFISIOLOGÍA. 0 ... VICTOREEN CT HEAD +L30: DOSE PHANTON. 43650-014. +L31: * 【8†body metal concentrations: Topics by ...; publish_date: +L32: none†www.science.gov】 PMID:16201670. Design Characteristics and Tobacco Metal +L33: Concentrations in Filtered Cigars. PubMed. Caruso, Rosalie V; O'Connor, Richard +L34: J; Travers, Mark J ... +L35: * 【9†Public Offices, Personal Demands: Capability in ...; publish_date: +L36: none†dokumen.pub】 ... (16201670).9 The Institutiones Politicae consists of two +L37: books. Book one deals with the origin, nature and goal of the commonwealth on +L38: the one hand, and the ... + + +Result shows Kenya. So the press release is about Kenya's vaccine doses as of Dec 5, 2021. The press release likely by World Bank or something. The search result shows "World Bank lifts Kenya's growth prospect..." maybe a press release dated Dec 14, 2021. But the clue says "between December 1 and 15 (inclusive), a press release reported that as of the 5th of that month, a country had received a total of 16,201,670 vaccine doses, with 7,583,134 administered." That matches Kenya on Dec 5, 2021. So the year is 2021. + +Now the next clue: "That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services." + +This sounds like an environmental or conservation organization. "Nearly 2.2 million members and supporters contributed over $49 million." That might be something like "The Nature Conservancy", "World Wildlife Fund", "Sierra Club", "Environmental Defense Fund", "Natural Resources Defense Council", etc. Let's search. + +Search for "nearly 2.2 million members and supporters contributed over $49 million" "12% of total revenue coming from contributed services". + +<|DSML|function_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">"nearly 2.2 million members" "contributed over $49 million" +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|end▁of▁sentence|> + + +[1] "nearly 2.2 million members" "contributed over $49 million" (Search_Results/"nearly 2.2 million members" "contributed over $49 million") +**viewing lines [0 - 41] of 41** + +L0: +L1: URL: Search_Results/"nearly 2.2 million members" "contributed over $49 million" +L2: # Search Results +L3: +L4: * 【0†The Global Good Fund, Microsoft and Goodwill join forces ...; +L5: publish_date: none†globalgoodfund.org】 Sep 28, 2022 — ... more than 2.2 million +L6: people. They additionally contribute close to $49 billion in local, state and +L7: federal tax revenues. When COVID hit ... +L8: * 【1†Almost 22 billion American tax dollars spent to wipe out a ...; +L9: publish_date: none†www.facebook.com】 US military funding for Israel's war crimes +L10: in Lebanon and Gaza has now cost US taxpayers over $22 billion. When millions +L11: struggle to afford the ... +L12: * 【2†Corporate America has largely abandoned its post-January ...; +L13: publish_date: none†www.citizensforethics.org】 Jul 29, 2025 — Since the January 6 +L14: insurrection, over 2,000 corporate and industry group PACs have given over $174 +L15: million to members of the Sedition ... +L16: * 【3†Audit shows millions in questionable taxpayer spending at ...; +L17: publish_date: none†www.aol.com】 18 hours ago — ... nearly doubled from 1.3 +L18: million to about 2.2 million. That is more than one in four Washington state +L19: residents receiving Medicaid, and the ... +L20: * 【4†Incarceration and Poverty in the United States - AAF; publish_date: +L21: none†www.americanactionforum.org】 Jun 30, 2020 — The United States currently +L22: incarcerates 2.2 million people, nearly half of whom are non-violent drug +L23: offenders, accused people held pre ... +L24: * 【5†How the NRA Spends Revenue (2023); publish_date: none†paddockpost.com】 +L25: Apr 25, 2025 — In 2023, the NRA reported total revenue of $176 million (compared +L26: to $211 million in 2022, $227 million in 2021, $282 million in 2020, $292 ... +L27: * 【6†NCA Newsroom; publish_date: none†www.ncausa.org】 ... 2.2 million U.S. +L28: jobs—operating in every U.S. state and territory and contributing nearly $350 +L29: billion to the U.S. economy every year. For media inquiries ... +L30: * 【7†The Early Childhood Educator Workforce; publish_date: +L31: none†cscce.berkeley.edu】 About 2.2 million adults are paid to care for and +L32: educate more than 9.7 million children between birth and age five in center- +L33: based and home-based settings. +L34: * 【8†Bringing Out the Worst in Congress: CRA By the Numbers ...; publish_date: +L35: none†progressivereform.org】 Jul 14, 2025 — In fact, the top ten recipients (out +L36: of 25) account for almost $7 million, showing how these contributions are +L37: targeted and strategic. Real ... +L38: * 【9†Club 100: Over 100 Americans gave more than $100000 to ...; publish_date: +L39: none†publicintegrity.org】 Jan 7, 2022 — Through Internet investments during the +L40: 1990s, he attracted 2.2 million online accounts valued at $174 million. Schwab +L41: has been a vocal ... + + +Not that. Try search for "2.2 million members and supporters contributed over $49 million". + +<|DSML|function_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">"2.2 million members" "contributed over $49" +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|end▁of▁sentence|> + + +[2] "2.2 million members" "contributed over $49" (Search_Results/"2.2 million members" "contributed over $49") +**viewing lines [0 - 2] of 2** + +L0: +L1: URL: Search_Results/"2.2 million members" "contributed over $49" +L2: # Search Results + + +Maybe it's an annual report of a non-profit. Search for "annual report 2.2 million members contributed $49 million". + +<|DSML|function_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">"2.2 million members" "annual report" $49 million +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|end▁of▁sentence|> + + +[3] "2.2 million members" "annual report" $49 million (Search_Results/"2.2 million members" "annual report" $49 million) +**viewing lines [0 - 36] of 36** + +L0: +L1: URL: Search_Results/"2.2 million members" "annual report" $49 million +L2: # Search Results +L3: +L4: * 【0†20-F; publish_date: none†www.sec.gov】 ANNUAL REPORT PURSUANT TO SECTION +L5: ... Our membership grew from 2.1 million members as at December 31, 2023 to 2.2 +L6: million members as at December 31, 2024. +L7: * 【1†Oportun Reports Fourth Quarter and Full Year 2023 ...; publish_date: +L8: none†investor.oportun.com】 Mar 12, 2024 — Oportun (Nasdaq: OPRT) is a mission- +L9: driven fintech that puts its 2.2 million members' financial goals within reach. +L10: ... annual report on ... +L11: * 【2†2 0 21; publish_date: none†www.annualreports.com】 ANNUAL REPORT. 2. 0. +L12: 21. 2. 0. 21. Page 2. 2. DEFENDERS OF WILDLIFE. 2. 0. 21. 2. 0. 21 ... In 2021, +L13: Defenders of Wildlife's nearly 2.2 million members and. +L14: * 【3†Annual report and accounts 2020; publish_date: none†www.3i.com】 +L15: Disclaimer. The Annual report and accounts have been prepared solely to provide +L16: information to shareholders. ... 2.2 million members. In December 2019, we sold +L17: ... +L18: * 【4†united states securities and exchange commission; publish_date: +L19: none†s21.q4cdn.com】 Dec 14, 2018 — ☒ ANNUAL REPORT ... approximately 5.7 million +L20: Members, including approximately 2.2 million Members from the acquisition of +L21: Best Doctors. +L22: * 【5†securities and exchange commission; publish_date: none†www.sec.gov】 ... +L23: Annual Report on Form 10-K filed with the SEC on March 10, 2020. The ... 2.2 +L24: million members at June 30, 2020. The following table highlights the ... +L25: * 【6†2022-23 Annual Comprehensive Financial Report - CalPERS; publish_date: +L26: none†www.calpers.ca.gov】 Nov 17, 2023 — ... annual report is consistent with +L27: these accounting principles ... 2.2 million members and health benefit services +L28: to over 1.5 million ... +L29: * 【7†Medical Centers Report; publish_date: +L30: none†regents.universityofcalifornia.edu】 Oct 8, 2019 — which will allow its 2.2 +L31: million members to receive advanced and ... annual report of the University of +L32: California. For the Medical ... +L33: * 【8†Working TogeTher to make HealtH care better; publish_date: +L34: none†www.unitedhealthgroup.com】 Feb 8, 2012 — ANNUAL REPORT PURSUANT TO SECTION +L35: 13 OR 15(d) OF THE ... approximately 2.2 million members enrolled in its +L36: Medicare. Advantage ... + + +Result 2 shows "Defenders of Wildlife". That might be the organization. Let's open. + +<|DSML|function_calls> +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">3 +<|DSML|parameter name="id" string="false">2 + +<|end▁of▁sentence|> + + +[4] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf) +**viewing lines [0 - 53] of 261** + +L0: +L1: URL: https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders- +L2: wildlife_2021.pdf +L3: ===== Page 1 ===== 2021 ANNUAL REPORT ===== Page 2 ===== # DEFENDERS OF WILDLIFE +L4: made important progress for imperiled species and vital landscapes across the +L5: United States in 2021. \--- **LAWYERS** immediately **challenged** the premature +L6: and reckless decision to strip **gray wolves** of federal **Endangered Species +L7: Act (ESA)** protections. For many decades, Defenders has led the effort to +L8: protect and restore the gray wolf, and we will continue to fight the +L9: unscientific and hostile anti-wolf policies that impede conservation progress +L10: and will carry on our unrelenting battle to restore federal protections for this +L11: iconic keystone species. \--- **LOBBYISTS** worked around the clock to keep +L12: wildlife and climate priorities in the **Infrastructure Investment and Jobs +L13: Act**. We also continue fighting to keep important wildlife and habitat funding +L14: in relevant **appropriations bills**. \--- 2 DEFENDERS OF WILDLIFE ===== Page 3 +L15: ===== POLICY EXPERTS pushed forward on the urgent need for a National +L16: Biodiversity Strategy (NBS), an all-of-government approach to address the +L17: unprecedented loss of wildlife and habitat we are experiencing. We have coupled +L18: this with our new campaign to expand the National Wildlife Refuge System to +L19: preserve our nation’s only lands set aside for wildlife. By defending, funding +L20: and expanding our national wildlife refuges, we will directly address +L21: biodiversity loss and climate change while promoting increased equitable access +L22: to nature. FIELD TEAMS were on the ground helping to recover imperiled species. +L23: From panthers and sea turtles in Florida to wolves, bison and black-footed +L24: ferrets in Montana, Defenders’ conservation experts were in the field saving +L25: wildlife all over the country. CONSERVATION INNOVATION EXPERTS provided +L26: comprehensive analyses to guide policy and inform conservation strategies to +L27: reach the goal of protecting 30% of our terrestrial and marine systems by 2030 +L28: (“30x30”). Defenders’ Center for Conservation Innovation (CCI) produced a report +L29: which details actions we need to take to achieve 30x30 while protecting +L30: biodiversity and addressing the climate crisis. DEFENDERS.ORG ===== Page 4 ===== +L31: WE HAVE ACCOMPLISHED MUCH THIS YEAR WORKING WITH AN ADMINISTRATION THAT VALUES +L32: SCIENCE AND CONSERVATION. Our many successes include the return of protections +L33: to the Tongass National Forest in Alaska and the suspension of oil and gas +L34: leasing permits for the Arctic National Wildlife Refuge. Defenders also played a +L35: lead role in successfully defending the Migratory Bird Protection Act from +L36: attack and securing critical protections for migratory birds. Throughout 2021 +L37: Defenders made important progress for the wildlife and wild places we all love, +L38: yet our nation’s wildlife still face unprecedented challenges. More and more +L39: species, including manatees, Mexican gray wolves, polar bears and North Atlantic +L40: right whales, face the very real threat of extinction because of climate +L41: change, habitat loss, pollution and inadequate protections. In our work we +L42: continue to demonstrate success by following the science, building partnerships, +L43: leading with innovative solutions and focusing on collaborative and practical +L44: approaches in defense of nature. Together we can make a lasting difference for +L45: wildlife and habitats. Your support is essential to our work and we couldn’t do +L46: it without you. Thank you for being an important part of our efforts to protect +L47: imperiled wildlife across our nation. Jamie Rappaport Clark President and CEO +L48: John Dayton Board Chair OUR NEW TELEVISION SERIES, WILDLIFE NATION WITH JEFF +L49: CORWIN, premiered in October and airs every weekend on ABC stations nationwide. +L50: This groundbreaking series takes Jeff across the country to meet Defenders’ +L51: experts along with our science, tribal and community partners to educate and +L52: inspire viewers to join with Defenders to protect wildlife. While highlighting +L53: some of our important conservation work and raising awareness of important + + +Search within PDF for "2.2 million members". + +<|DSML|function_calls> +<|DSML|invoke name="find"> +<|DSML|parameter name="cursor" string="false">4 +<|DSML|parameter name="pattern" string="true">2.2 + +<|end▁of▁sentence|> + + +[5] Find results for text: `2.2` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf/find?pattern=2.2) +**viewing lines [0 - 4] of 4** + +L0: # 【0†match at L187】 +L1: 2021, Defenders of Wildlife’s nearly 2.2 million members and supporters +L2: contributed more than $49 million for wildlife and wildlife habitat. This figure +L3: includes all those who generously donated their time and expertise. The +L4: steadfast support of our donors allows Defenders to sustain our program and + + +Open around L180. + +<|DSML|function_calls> +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">4 +<|DSML|parameter name="loc" string="false">170 + +<|end▁of▁sentence|> + + +[6] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf) +**viewing lines [170 - 215] of 261** + +L170: into the wild. **10 DEFENDERS OF WILDLIFE** ===== Page 11 ===== AFTER ANOTHER +L171: SUCCESSFUL DEFENDERS LAWSUIT ON BEHALF OF THE CRITICALLY ENDANGERED RED WOLF, +L172: FWS reversed its decision to limit the recovery area and committed to a robust +L173: release strategy. After years of inaction, FWS released eight wolves to the wild +L174: in North Carolina and plan to release nine more wolves in the spring of 2022. +L175: This is an incredible win for this critically endangered species whose +L176: population has dwindled down to single digits in the wild because of +L177: mismanagement, vehicle strikes and poaching. DEFENDERS CONTINUED TO LEAD EFFORTS +L178: TO PROTECT THE FLORIDA MANATEE, a beloved species that suffered the deadliest +L179: year on record in 2021, tragically surpassing 1,000 deaths because of water +L180: pollution and lack of warm water habitat. Defenders led advocacy and education +L181: aimed at restoring the natural flow of the dammed Ocklawaha River, which would +L182: provide critical warm-water habitat that manatees need to survive. Defenders’ +L183: legal team continued to fight for manatees in the courts, holding government +L184: agencies accountable for protecting critical habitat and addressing the +L185: devastating water pollution that is killing the seagrass and causing manatees to +L186: starve. DAVID TES | SAM FRENZY DRAW DEFENDERS.ORG 11 ===== Page 12 ===== In +L187: 2021, Defenders of Wildlife’s nearly 2.2 million members and supporters +L188: contributed more than $49 million for wildlife and wildlife habitat. This figure +L189: includes all those who generously donated their time and expertise. The +L190: steadfast support of our donors allows Defenders to sustain our program and +L191: public education efforts in the field, the courts and on Capitol Hill. 2021 +L192: SOURCES OF FUNDS Grants and contributions $29,057 Bequests, trusts and split +L193: interests $7,692 Income from investments, annuity reserve funds and trusts +L194: $3,354 Royalties and other income $3,576 Contributed services $6,140 Total +L195: Revenue $49,819 USES OF FUNDS Biodiversity conservation $22,420 Constituency +L196: mobilization $16,324 Fundraising $1,211 Management and general $5,865 Total +L197: Expenses $45,820 Change in net assets $3,999 Net assets, start of year $41,145 +L198: Net Assets, End of the Year $45,144 Dollars are in thousands. 12 DEFENDERS OF +L199: WILDLIFE Grants and contributions 58% Income from investments 7% Requests, +L200: trusts and split interests 15% Royalties and other income 7% Contributed +L201: services 12% Program and support services 97% Fundraising 3% ===== Page 13 ===== +L202: DEFENDERS.ORG 13 ===== Page 14 ===== John Dayton, Esq.* Chair Texas Sheryl +L203: Sachs* Vice Chair Connecticut Whitfield Bailey* Secretary Tennessee Mark Caylor* +L204: Treasurer District of Columbia Sylvia Baca* California Donald Barry, Esq. +L205: Oregon Dinah Bear, Esq. Arizona Kim O’Keefe Beck, Ph.D California Gabriela +L206: Chavarria, Ph.D Washington Jamie Rappaport Clark* President & CEO Virginia +L207: Caroline D. Gabel Maryland Keith M. Harper, Esq. Maryland John Donovan, Esq. +L208: Montana and California James M. Hecker, Esq. Virginia Judith Ponsikoff, Ph.D* +L209: California Terry L. Root, Ph.D Florida BOARD EMERITUS Terry C. Pelster, Esq. New +L210: York *Executive Committee member* ST A F F & A D V I S O R S SCIENCE ADVISORS +L211: Gabriela Chavarria, Ph.D Burke Museum, University of Washington Leah Gerber, +L212: Ph.D Arizona State University Kathayoon Khalil, Ph.D Oregon Zoo Rurik List, Ph.D +L213: Universidad Autónoma Metropolitana-Lerma Barry Noon, Ph.D Colorado State +L214: University Terry L. Root, Ph.D Stanford University Mark Schwartz, Ph.D +L215: University of California-Davis David S. Wilcox, Ph.D Princeton University Rae + + +Scrolling further for "Contributed services 12%". Already seen. So the annual report is Defenders of Wildlife 2021 annual report. + +Next clue: "The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%." + +Thus, the following year (2022) annual report of Defenders of Wildlife includes mention of rescinding plan to open most of 22 million acres of NPR-A to oil and gas drilling. Also contributed services decreased to 10% (from 12% in 2021). So we need to find the 2022 annual report for Defenders of Wildlife. + +Search for "Defenders of Wildlife 2022 annual report National Petroleum Reserve Alaska". + +<|DSML|function_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">Defenders of Wildlife 2022 annual report National Petroleum Reserve Alaska +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|end▁of▁sentence|> + + +[7] Defenders of Wildlife 2022 annual report National Petroleum Reserve Alaska (Search_Results/Defenders of Wildlife 2022 annual report National Petroleum Reserve Alaska) +**viewing lines [0 - 37] of 37** + +L0: +L1: URL: Search_Results/Defenders of Wildlife 2022 annual report National Petroleum +L2: Reserve Alaska +L3: # Search Results +L4: +L5: * 【0†CELEBRATING YEARS; publish_date: none†www.annualreports.com】 With less +L6: than 340 right whales left in the wild, Defenders is fighting tirelessly to end +L7: deadly lobster gear entanglements and vessel strikes that are driving ... +L8: * 【1†Financials; publish_date: none†defenders.org】 We invite you to explore +L9: the reports below to learn more about our activities and accomplishments, and +L10: how we put your money to work for wildlife. +L11: * 【2†Alaska Program Looks Back on 2022; publish_date: none†defenders.org】 Feb +L12: 9, 2023 — Thanks to a lawsuit joined by Defenders, seven million acres were +L13: returned to protection within the National Petroleum Reserve-Alaska (NPR-A), ... +L14: * 【3†Defenders-of-Wildlife-2022-Financial-Statement. ...; publish_date: +L15: none†defenders.org】 We have audited the accompanying consolidated financial +L16: statements of Defenders of Wildlife and Affiliated Defenders of Wildlife Action +L17: Fund (collectively, ... +L18: * 【4†2022 Annual Report; publish_date: none†alaskaconservation.org】 Jun 13, +L19: 2023 — In 2022, we focused on three landscapes: the Arctic. National Wildlife +L20: Refuge, Bristol Bay, and the Tongass National Forest. In March 2022,. +L21: * 【5†Assessment of ecological and cultural values within the ...; +L22: publish_date: none†www.blm.gov】 This document was written to provide technical +L23: information regarding the ecological importance of the National Petroleum +L24: Reserve – Alaska (NPR-A). Several ... +L25: * 【6†Accomplishments Report; publish_date: none†defenders.org】 National +L26: Petroleum Reserve-Alaska. Identified for exceptional wildlife and cultural +L27: values, including critical habitat for polar bears and other species ... +L28: * 【7†2022 annual report; publish_date: none†dory- +L29: plantain-s2zc.squarespace.com】 These projects are made possible through +L30: collaborations with Defenders of. Wildlife, Cook Inletkeeper, Trustees for +L31: Alaska,. Environmental Investigation Agency, ... +L32: * 【8†23IMPACT REPORT; publish_date: none†www.annualreports.com】 Defenders of +L33: Wildlife made incredible progress protecting wildlife and wild places in 2023, +L34: helping shape a brighter future for imperiled species and vital ... +L35: * 【9†From Leasing to Land Protections; publish_date: none†defenders-cci.org】 +L36: Abstract. When you hear the name, National Petroleum Reserve - Alaska (NPR-A, or +L37: Reserve), you might think of a barren land filled with oil wells and ... + + +Open result 0. + +<|DSML|function_calls> +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">7 +<|DSML|parameter name="id" string="false">0 + +<|end▁of▁sentence|> + + +[8] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf) +**viewing lines [0 - 53] of 289** + +L0: +L1: URL: https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders- +L2: wildlife_2022.pdf +L3: ===== Page 1 ===== 2022 Impact Report C E L E B R A T I N G Y E A R S ===== Page +L4: 2 ===== 2022 Defenders of Wildlife made important progress for imperiled +L5: species and vital landscapes across the United States in 2022. GRAY WOLF | JIM +L6: GUMMERAL MAY STOCK PRIOR Lawyers successfully challenged the previous +L7: administration’s decision to delist the gray wolf and restored critical federal +L8: protections under the Endangered Species Act. This latest triumph in court is +L9: part of our ongoing battle to protect and restore gray wolves throughout their +L10: historical range and shield them from persecution by extremist legislators in +L11: Idaho, Montana and Wyoming. TWO MORE FATALIZED GRAY SWALLETS TO SEA TO SHARE +L12: ALLIANCE Lobbyists worked around the clock to expand funding for wildlife +L13: conservation in the FY2022 federal spending bill, which included $31 million (a +L14: 44% increase) for the Bureau of Land Management’s Threatened and Endangered +L15: Species Program, $2.5 million (an 81% increase) for the U.S. Department of +L16: Agriculture Wildlife Services’ Nonlethal Initiative to prevent human-wildlife +L17: conflicts and $21 million (a 320% increase) for North Atlantic right whale +L18: conservation. 2 DEFENDERS OF WILDLIFE ===== Page 3 ===== **Policy Experts** +L19: played a crucial role in securing international trade protections for 100 +L20: species of sharks and rays, all 158 species of glass frogs and 73 species of +L21: reptiles, including 21 species of desert horned lizards, at the Convention on +L22: International Trade in Endangered Species (CITES) in Panama. \--- **Field +L23: Teams** worked tirelessly to protect and restore imperiled species across the +L24: country. From Florida manatees and red wolves in the Southeast to belugas and +L25: grizzly bears in Alaska, Defenders’ conservation experts were on the ground +L26: saving species that need our help to survive and thrive. \--- **Conservation +L27: Innovation Experts** published more than 10 peer-reviewed studies on topics that +L28: include the Cook Inlet beluga whale, golden-cheeked warbler, global parrot +L29: biodiversity, the Endangered Species Act, the effects of mountaintop removal +L30: mining on endangered species, the ecological importance of panthers and the +L31: implementation of “30x30” – the globally recognized goal to which President +L32: Biden committed the U.S. to conserve 30% of our imperiled lands and waters by +L33: 2030. \--- **DEFENDERS.ORG** ===== Page 4 ===== THANK YOU Defenders celebrated +L34: our 75th anniversary in 2022—an exciting milestone that we attribute to our +L35: unwavering dedication to our wildlife conservation mission. From helping to pass +L36: the Endangered Species Act and other bedrock conservation laws to leading the +L37: advocacy on the reintroduction of wolves to Yellowstone National Park and other +L38: parts of the West, Defenders is proud to inspire, lead and ultimately celebrate +L39: some of the most consequential conservation wins for wildlife and wild places. +L40: It is wonderfully fitting that we also mark 2022 as a year when we successfully +L41: defended, yet again, the original focal species and symbol of our organization, +L42: the gray wolf. While the future for gray wolves and many other species remains +L43: uncertain, we will never stop fighting to protect the imperiled wildlife that +L44: need us and to restore vibrant and balanced ecosystems. Our mission is more +L45: important than ever due to the escalating biodiversity and climate crises. +L46: Defenders is determined to build on our accomplishments and seize new +L47: opportunities to address these challenges through our multifaceted approach to +L48: conservation. We focus on developing innovative and science-based solutions, +L49: coordinating on-the-ground species and habitat restoration, promoting human- +L50: wildlife coexistence, protecting and strengthening our bedrock environmental +L51: laws, investing in impactful collaborations and partnerships, building a more +L52: diverse and powerful conservation movement and challenging attacks on wildlife +L53: through advocacy, policy and litigation. Defenders’ continued success is made + + +Search within PDF for "National Petroleum Reserve". + +<|DSML|function_calls> +<|DSML|invoke name="find"> +<|DSML|parameter name="cursor" string="false">8 +<|DSML|parameter name="pattern" string="true">Petroleum + +<|end▁of▁sentence|> + + +[9] Find results for text: `petroleum` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf/find?pattern=petroleum) +**viewing lines [0 - 4] of 4** + +L0: # 【0†match at L201】 +L1: the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas +L2: drilling and removed protections for important wildlife habitat. Protections +L3: have now been restored for nearly half of the reserve’s pristine lands, which +L4: are vital habitat for shorebirds, denning polar bears and tens of thousands of + + +Open around L190. + +<|DSML|function_calls> +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">8 +<|DSML|parameter name="loc" string="false">180 + +<|end▁of▁sentence|> + + +[10] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf) +**viewing lines [180 - 227] of 289** + +L180: the sixth successful transfer of bison to the Assiniboine and Sioux Tribes of +L181: Fort Peck since 2019. \--- **SWIFT FIX KITS | © ROSIMA PAELARINTSKIMMA MADDIAL +L182: 200 AND CONSERVATION BIOLOGY INSTITUTE** \--- **Celebrating our third year** of +L183: a collaborative program with the Aaniih and Nakoda Tribes and others to restore +L184: swift foxes to the Fort Belknap Indian Reservation in Montana, Defenders helped +L185: with the release of 28 more swift foxes. With over 100 foxes reintroduced +L186: through this program, monitoring efforts show that they are reproducing in the +L187: wild—a critical measure of success for a self-sustaining population. \--- +L188: **Defenders continued to lead the way** for conserving and recovering the +L189: endangered black-footed ferret, supporting the black-footed ferret survey for +L190: the Fort Belknap Indian community. Thirty-six ferrets were vaccinated against +L191: sylvatic plague and two dozen kits were released in the wild. \--- **10 +L192: DEFENDERS OF WILDLIFE** ===== Page 11 ===== Defenders helped to bring hope for +L193: recovery for the endangered military macaw, adding 11 fledglings to a growing +L194: wild population in Puerta Vallarta, Mexico, that is under pressure from habitat +L195: loss and poachers for the illegal pet trade. Accord- ing to our recent report, +L196: the 2008 parrot trade ban that Defenders fought to achieve is working. +L197: Preventing more than 30,000 parrots from being illegally trapped each year, the +L198: trade ban has resulted in a 47% decrease in the illegal trade of parrots and an +L199: 88% decrease in U.S. seizures of Mexican parrots. As a result of a Defenders +L200: lawsuit, BLM rescinded the previous administration’s plan that opened most of +L201: the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas +L202: drilling and removed protections for important wildlife habitat. Protections +L203: have now been restored for nearly half of the reserve’s pristine lands, which +L204: are vital habitat for shorebirds, denning polar bears and tens of thousands of +L205: caribou in the Teshekpuk caribou herd. Our new storymap illustrates why the +L206: reserve is so important to wildlife. Defenders marked the 20th anniversary of +L207: our Sea Otter Awareness Week. In celebration of this milestone, we sponsored +L208: state legislation, signed by Gov. Gavin Newsom, that formally recognizes the +L209: event’s anniversary and acknowledges the vital role that sea otters play in +L210: California’s nearshore ecosystems. This annual celebration provides a unique +L211: opportunity to raise awareness of the important role sea otters play in +L212: maintaining ecosystem health, our need to coexist with sea otters and the +L213: efforts of Defenders and others in advancing sea otter conservation. +L214: DEFENDERS.ORG ===== Page 12 ===== FINANCIAL REPORT In 2022, Defenders of +L215: Wildlife’s nearly 2.2 million members and supporters contributed more than $41 +L216: million for wildlife and wildlife habitat. This figure includes all those who +L217: generously donated their time and expertise. The steadfast support of our donors +L218: allows Defenders to sustain our program and public education efforts in the +L219: field, the courts and on Capitol Hill. 2022 SOURCES OF FUNDS Grants and +L220: contributions $28,946 Bequests, trusts and split interests $13,928 Income from +L221: investments, annuity reserve funds and trusts (-$7,495) Royalties and other +L222: income $744 Contributed services $5,047 Total Revenue $41,170 USES OF FUNDS +L223: Biodiversity conservation $18,970 Constituency mobilization $16,528 Fundraising +L224: $1,029 Management and general $6,354 Total Expenses $42,881 Change in net assets +L225: (-$1,711) Net assets, start of year $45,144 Net Assets, End of the Year $43,433 +L226: Dollars are in thousands. 12 DEFENDERS OF WILDLIFE Grants and contributions 59% +L227: Contributed services 10% Royalties and other income 2% Bequests, trusts and + + +Scrolling a bit more for "Contributed services 10%". Already seen. So indeed the 2022 annual report shows contributed services 10%, down from 12% in 2021. + +Thus the organization is Defenders of Wildlife. + +The question: "Who was the emeritus board member that year? Name only, no titles." + +We need to identify the emeritus board member listed in the 2022 annual report. From the 2021 annual report, we saw "BOARD EMERITUS Terry C. Pelster, Esq." However, that is for 2021. The question asks "Who was the emeritus board member that year?" That year refers to the following year (2022). So we need to check the 2022 annual report for the emeritus board member. + +Open the 2022 PDF near the board list. Search for "Emeritus". + +<|DSML|function_calls> +<|DSML|invoke name="find"> +<|DSML|parameter name="cursor" string="false">8 +<|DSML|parameter name="pattern" string="true">Emeritus + +<|end▁of▁sentence|> + + +[11] Find results for text: `emeritus` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf/find?pattern=emeritus) +**viewing lines [0 - 10] of 10** + +L0: # 【0†match at L237】 +L1: Sisk, Ph.D British Columbia, Canada BOARD EMERITUS Terry C. Pelster, Esq. New +L2: York *Executive Committee member* STAFF & ADVISORS SCIENCE ADVISORS Gabriela +L3: Chavarria, Ph.D Burke Museum, University of Washington Leah Gerber, Ph.D Arizona +L4: State University Kathayoon Khalil, Ph.D Oregon Zoo Rurik List, Ph.D Universidad +L5: +L6: # 【1†match at L243】 +L7: Davis Thomas D. Sisk, Ph.D Emeritus Professor Northern Arizona University David +L8: S. Wilcox, Ph.D Princeton University Rae Wynn-Grant, Ph.D National Geographic +L9: Society SENIOR STAFF Jamie Rappaport Clark President & CEO Thu Pham Chief of +L10: Staff James Stofan Senior Vice President, Operations McCrystle Adams Vice + + +Open around L230. + +<|DSML|function_calls> +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">8 +<|DSML|parameter name="loc" string="false">220 + +<|end▁of▁sentence|> + + +[12] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf) +**viewing lines [220 - 267] of 289** + +L220: contributions $28,946 Bequests, trusts and split interests $13,928 Income from +L221: investments, annuity reserve funds and trusts (-$7,495) Royalties and other +L222: income $744 Contributed services $5,047 Total Revenue $41,170 USES OF FUNDS +L223: Biodiversity conservation $18,970 Constituency mobilization $16,528 Fundraising +L224: $1,029 Management and general $6,354 Total Expenses $42,881 Change in net assets +L225: (-$1,711) Net assets, start of year $45,144 Net Assets, End of the Year $43,433 +L226: Dollars are in thousands. 12 DEFENDERS OF WILDLIFE Grants and contributions 59% +L227: Contributed services 10% Royalties and other income 2% Bequests, trusts and +L228: split interests 29% TOTAL REVENUE Income from investments -15% Program and +L229: support services 98% USE OF FUNDS Fundraising 2% ===== Page 13 ===== +L230: DEFENDERS.ORG 13 ===== Page 14 ===== BOARD OF DIRECTORS John Dayton, Esq.* Chair +L231: Texas Sheryl Sachs* Vice Chair Connecticut Dinah Bear, Esq.* Secretary Arizona +L232: Mark Caylor* Treasurer District of Columbia Susan Alexander Connecticut Sylvia +L233: Baca* California Whitefield Bailey Tennessee Donald Barry, Esq. Oregon Kim +L234: O’Keefe Beck, Ph.D California Gabriela Chavarria, Ph.D Washington Jamie +L235: Rappaport Clark* President & CEO Virginia John Donovan, Esq.* Montana and +L236: California Caroline D. Gabel Maryland Terry L. Root, Ph.D New Mexico Thomas D. +L237: Sisk, Ph.D British Columbia, Canada BOARD EMERITUS Terry C. Pelster, Esq. New +L238: York *Executive Committee member* STAFF & ADVISORS SCIENCE ADVISORS Gabriela +L239: Chavarria, Ph.D Burke Museum, University of Washington Leah Gerber, Ph.D Arizona +L240: State University Kathayoon Khalil, Ph.D Oregon Zoo Rurik List, Ph.D Universidad +L241: Autónoma Metropolitana-Lerma Barry Noon, Ph.D Colorado State University Terry +L242: L. Root, Ph.D Stanford University Mark Schwartz, Ph.D University of California- +L243: Davis Thomas D. Sisk, Ph.D Emeritus Professor Northern Arizona University David +L244: S. Wilcox, Ph.D Princeton University Rae Wynn-Grant, Ph.D National Geographic +L245: Society SENIOR STAFF Jamie Rappaport Clark President & CEO Thu Pham Chief of +L246: Staff James Stofan Senior Vice President, Operations McCrystle Adams Vice +L247: President, Conservation Law Matt Braughler Vice President, Philanthropy Rachel +L248: Brittin Vice President, Communications Shawn Cantrell Vice President, Field +L249: Conservation Programs Robert L. Dewey Vice President, Government Relations & +L250: External Affairs Monica Goldberg Vice President, Landscape Conservation Sharon +L251: Morris-Smith Vice President of Human Resources Elizabeth Nichols Vice President, +L252: Finance/CFO Anthony Pegues Vice President, Diversity, Equity, Inclusion, +L253: Justice Lindsay Rosa, Ph.D Vice President, Conservation Research and Innovation +L254: Brandace Swimley Vice President, Integrated Marketing 14 DEFENDERS OF WILDLIFE +L255: ===== Page 15 ===== There are many ways to support Defenders of Wildlife and +L256: protect our imperiled wildlife. To learn more or donate, contact us at +L257: 1-800-385-9712, visit Defenders.org/donate or choose one of the following +L258: options. ➤ Transfer appreciated assets directly to Defenders and maximize +L259: potential tax benefits. Defenders.org/ donate-stocks-help-save-wildlife ➤ Become +L260: a Leadership Council member with a single donation of $1,000 or more and enjoy +L261: more personal involvement with Defenders. Defenders.org/leadershipcouncil ➤ If +L262: you are 70.5 or older, roll over a portion of your IRA to benefit wildlife. +L263: Remember, these gifts must come directly from your IRA account. +L264: Defenders.org/planned-giving/gifts-your-retirement-plan ➤ Get started on a +L265: contribution from your donor-advised fund Defenders.org/donor-advised-funds ➤ +L266: Check with your personnel office or that of your spouse/partner about your +L267: company’s Corporate Matching Gifts program or visit us at + + +Thus the emeritus board member is Terry C. Pelster. However, the question might ask for "Who was the emeritus board member that year?" Possibly referring to the year of the annual report that mentioned the rescinded plan and contributed services decrease. That is 2022. So the answer is Terry C. Pelster. + +But let's verify that the 2022 annual report indeed includes the mention of contributed services decreasing to 10%. We saw that. So answer is Terry C. Pelster. + +But maybe the question is tricky: "Who was the emeritus board member that year?" Could be the emeritus board member listed in the 2021 annual report? However, the question says "That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services. The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%." + +Thus the "following year" annual report is 2022. The question likely asks: "Who was the emeritus board member that year?" Meaning the emeritus board member listed in that annual report (2022). So answer: Terry C. Pelster. + +But let's confirm that the 2021 annual report also shows emeritus board member Terry C. Pelster. Indeed we saw that. But perhaps the 2022 report might have a different emeritus board member. We saw it's the same. So answer is Terry C. Pelster. + +However, the question might be about a different organization. Let's verify the details. + +First clue: Press release about Kenya vaccine doses on Dec 5, 2021. That is likely a World Bank press release. Search for "World Bank Kenya vaccine doses December 2021 press release". + +<|DSML|function_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">As of December 5, 2021, Kenya had received a total of 16,201,670 vaccines +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|end▁of▁sentence|> + + +[13] As of December 5, 2021, Kenya had received a total of 16,201,670 vaccines (Search_Results/As of December 5, 2021, Kenya had received a total of 16,201,670 vaccines) +**viewing lines [0 - 41] of 41** + +L0: +L1: URL: Search_Results/As of December 5, 2021, Kenya had received a total of +L2: 16,201,670 vaccines +L3: # Search Results +L4: +L5: * 【0†Kenya's Economy is Showing Resilience as Output Rises ...; publish_date: +L6: none†www.worldbank.org】 Dec 14, 2021 — As of December 5, 2021, Kenya had +L7: received a total of 16,201,670 vaccines, with 7,583,134 administered. While +L8: vaccine acceptance is ... +L9: * 【1†Unmet need for COVID-19 vaccination coverage in Kenya - PMC; +L10: publish_date: none†pmc.ncbi.nlm.nih.gov】 by SK Muchiri · 2022 · Cited by 42 — As +L11: of December 2021, six counties had a vaccination coverage of less than 5%. +L12: These counties include Garissa, Mandera, Marsabit, Tana River, Turkana, and ... +L13: * 【2†MINISTRY OF HEALTH; publish_date: none†covidhub.mediacouncil.or.ke】 Dec +L14: 1, 2021 — • Total Covid-19 Vaccines Received to date- 16,201,670 ... Table 10: +L15: Vaccine Logistics Received in the Country as at 5th, December 2021. +L16: * 【3†COVID-19 vaccination refusal trends in Kenya over 2021 - PMC; +L17: publish_date: none†pmc.ncbi.nlm.nih.gov】 by RT Rego · 2023 · Cited by 21 — We +L18: assessed vaccine refusal over time in Kenya, and characterized factors +L19: associated with changes in vaccination refusal. +L20: * 【4†Ciheb-Kenya on the Front Lines of the COVID-19 Vaccine ...; publish_date: +L21: none†www.ciheb.org】 Of the 721,509 persons vaccinated, 409,711 are aged 58 +L22: years and above, 143,050 are healthcare workers, 109,593 are teachers, and +L23: 59,155 are security officers. +L24: * 【5†proportion of fully vaccinated adults 23.3pc.; publish_date: +L25: none†www.facebook.com】 COVID-19: KENYA administers 150632 new vaccinations; +L26: total jabs given now 13782963; proportion of fully vaccinated adults 23.3pc. +L27: * 【6†National COVID-19 Vaccines Deployment and ...; publish_date: +L28: none†www.parliament.go.ke】 Kenya has a well-established immunization program and +L29: a robust disease surveillance and response unit (DSRU) which serves as an early +L30: warning system, to. +L31: * 【7†Kenya: Vaccine Requirements Violate Rights; publish_date: +L32: none†www.hrw.org】 Dec 13, 2021 — Kenya, with an estimated adult population of +L33: 27.2 million and a total population of 55 million, has received approximately 23 +L34: million doses as ... +L35: * 【8†Government gives 20 million a month to get vaccinated; publish_date: +L36: none†www.bbc.com】 Nov 22, 2021 — Although less than 10% of Kenyans are +L37: vaccinated, the government wants to avoid a surge over Christmas. +L38: * 【9†ministry of health kenya covid-19 immunization status ...; publish_date: +L39: none†www.health.go.ke】 May 15, 2022 — By the end of 2021 Kenya had administered +L40: 10.1 million doses of 5 types of Covid-19 vaccines. During this period 4.2 +L41: million people were ... + + +Open result 0. + +<|DSML|function_calls> +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">13 +<|DSML|parameter name="id" string="false">0 + +<|end▁of▁sentence|> + + +[14] www.worldbank.org (https://www.worldbank.org/en/news/press-release/2021/12/14/kenya-s-economy-is-showing-resilience-as-output-rises-above-pre-pandemic-levels-driven-by-a-rebound-in-the-services-sect) +**viewing lines [0 - 52] of 61** + +L0: +L1: URL: https://www.worldbank.org/en/news/press-release/2021/12/14/kenya-s-economy- +L2: is-showing-resilience-as-output-rises-above-pre-pandemic-levels-driven-by-a- +L3: rebound-in-the-services-sect +L4: NAIROBI, December 14, 2021 — Kenya’s economy has demonstrated resilience to the +L5: COVID-19 shock, with output in the first half of the year rising above pre- +L6: pandemic levels. In 2021 as a whole, gross domestic product (GDP) is expected to +L7: grow by 5%, one of the faster recoveries among Sub-Saharan African countries. +L8: Overall economic performance is expected to be robust at 4.9% per year in +L9: 2022-23, similar to the pre-pandemic pace (5% average annual growth from 2010 to +L10: 2019). According to the 24th edition of the Kenya Economic Update, “From +L11: Recovery to Better Jobs,” growth has been supported by rebounds in industry and, +L12: especially, services. Agricultural output, however, fell by 0.5% year on year +L13: in the first half of 2021 following a particularly strong performance in 2020, +L14: partly due to below-average rains. Demand-side recovery has been supported by a +L15: revival in private consumption, against a backdrop of improving employment +L16: conditions and household incomes. “Kenya’s economy has shown considerable +L17: resilience to the enormous shock of the pandemic, and this year is expected to +L18: post one of the stronger growth rebounds in the region thanks to diversified +L19: sources of growth and sound economic policies and management,” said Keith +L20: Hansen, World Bank Country Director for Kenya. “However, poverty has increased, +L21: and the buffers and coping mechanisms of households, firms, and the public +L22: finances have been depleted.” Economic activity in Kenya has continued to adapt +L23: to the pandemic and associated restrictions. A mix of containment measures, such +L24: as a nightly curfew, were in effect through most of 2021, while more +L25: economically disruptive measures such as lockdowns and travel restrictions were +L26: phased, limiting the impact on economic activities. The vaccine rollout, which +L27: had a slow start due to supply constraints, has picked up as new shipments of +L28: vaccines have arrived, particularly since September. This has supported economic +L29: recovery and growth through the third quarter of 2021. As of December 5, 2021, +L30: Kenya had received a total of 16,201,670 vaccines, with 7,583,134 administered. +L31: While vaccine acceptance is reportedly high there is still a long way to go +L32: towards the government’s target of fully inoculating the adult population of +L33: about 30 million by the end of 2022.As of December 6, 2021, about 10% of adults +L34: (2.9 million people) had been fully vaccinated while another 16% (4.9 million +L35: people) had received their first dose. The report also projects robust growth in +L36: the medium-term. This outlook takes into account that some sub-sectors, such as +L37: education, have bounced back strongly, but others such as international +L38: tourism, have only partially and face a much more protracted recovery. The +L39: anticipated further recovery of hotels and restaurants, trade, transport, and +L40: other services, depends on substantial vaccination progress to help prevent new +L41: waves of infections and associated containment measures. “Recent economic +L42: performance has been strong and the outlook is positive, but in Kenya, as +L43: everywhere, the future course of the pandemic remains a major source of +L44: uncertainty, said Alex Sienaert, Senior Economist for Kenya. “To support further +L45: recovery, reduce debt distress risks, and rebuild space for social and +L46: development spending, it is critical that Kenya continues to implement medium- +L47: term fiscal consolidation plans.” In addition to pandemic-related risks to the +L48: outlook, a second key domestic risk factor stems from the drought conditions +L49: which are affecting parts of the country and already causing severe hardship. +L50: Should the drought intensify or spread, this would weigh on the near-term +L51: economic outlook. Weaker global growth, higher-than-anticipated energy prices, +L52: and tighter external financing conditions are the primary external risks. As + + +Thus first clue confirmed. + +Second clue: Defenders of Wildlife 2021 annual report. Confirmed. + +Third clue: Defenders of Wildlife 2022 annual report. Confirmed. + +Thus the answer is Terry C. Pelster. + +But let's verify that the 2022 annual report indeed shows "Contributed services 10%". We saw that. Also confirm that 2021 report shows 12%. We saw that. + +Thus the question: "Who was the emeritus board member that year?" Refers to 2022. Answer: Terry C. Pelster. + +However, the question might be asking for the emeritus board member in the 2021 report? Let's read the question carefully: + +"In a year before 2023, between December 1 and 15 (inclusive), a press release reported that as of the 5th of that month, a country had received a total of 16,201,670 vaccine doses, with 7,583,134 administered. That same year, a company's annual report revealed that its nearly 2.2 million members and supporters contributed over $49 million, with 12% of total revenue coming from contributed services. The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan to open most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling and that contributed services had decreased by 2% to 10%. Who was the emeritus board member that year? Name only, no titles." + +The phrase "that year" could refer to "the following year" i.e., the year of the annual report that mentioned the rescinded plan. Or could refer to "that same year" (the year of the press release and the first annual report). However, the phrase "that year" appears at the end: "Who was the emeritus board member that year?" It likely refers to the year of the annual report that is being discussed in the preceding sentence. The preceding sentence talks about the following year's annual report. So indeed "that year" is the following year (2022). + +Thus answer: Terry C. Pelster. + +But to be thorough, we should check the 2022 annual report for any other emeritus board members. The search only showed "BOARD EMERITUS Terry C. Pelster, Esq." Maybe there are multiple emeritus board members. But likely only one. + +Check the 2021 annual report for emeritus board member. It also shows Terry C. Pelster. So maybe the question is asking for the emeritus board member in the 2021 report. But why would they ask that? The question is about connecting the clues to identify the organization and then find the emeritus board member. The clues point to Defenders of Wildlife. The question likely asks for the emeritus board member as of the 2022 report. + +Nevertheless, we should confirm that the 2022 report is indeed the one that mentions the rescinded plan. We saw that. Also confirm that contributed services decreased from 12% to 10%. We saw that. + +Thus answer is Terry C. Pelster. + +But let's search for "Defenders of Wildlife emeritus board member 2022 Terry Pelster". + +<|DSML|function_calls> +<|DSML|invoke name="search"> +<|DSML|parameter name="query" string="true">Terry C. Pelster Defenders of Wildlife emeritus board member +<|DSML|parameter name="topn" string="false">10 +<|DSML|parameter name="source" string="true">web + +<|end▁of▁sentence|> + + +[15] Terry C. Pelster Defenders of Wildlife emeritus board member (Search_Results/Terry C. Pelster Defenders of Wildlife emeritus board member) +**viewing lines [0 - 31] of 31** + +L0: +L1: URL: Search_Results/Terry C. Pelster Defenders of Wildlife emeritus board member +L2: # Search Results +L3: +L4: * 【0†2024 impact report; publish_date: none†defenders.org】 Terry Root, Ph.D. +L5: New Mexico. Thomas D. Sisk, Ph.D. British Columbia, Canada. BOARD EMERITUS. +L6: Terry C. Pelster, Esq. New York. SCIENCE ADVISORS. Leah Gerber, Ph ... +L7: * 【1†Defenders of Wildlife; publish_date: none†www.wrongkindofgreen.org】 +L8: Directors ; Victor M. Sher*, Chair – California ; Terry C. Pelster*, Vice Chair +L9: – New York ; Richard Kopcho*, Treasurer – California ; Adelaide P. Gomer*, +L10: Secretary ... +L11: * 【2†Exhibit 12; publish_date: none†www.uschamber.com】 DECLARATION OF TERRY +L12: PELSTER. I, Terry C. Pelster, declare as follows: 1. I am a current member of +L13: Defenders of Wildlife (“Defenders”) and have been a member. +L14: * 【3†2020 ANNUAL REPORT; publish_date: none†defenders.org】 Terry L. Root. +L15: Florida. BOARD. EMERITUS. Terry C. Pelster. New York. Alan Steinberg. Florida. +L16: SCIENCE. ADVISORS. Gabriela Chavarria,. Ph.D. Denver Museum of. +L17: * 【4†2019 Annual Report; publish_date: none†defenders.org】 BOARD EMERITUS. +L18: Terry C. Pelster. New York. Alan Steinberg. Florida. SCIENCE ADVISORS. Leah +L19: Gerber, Ph.D. Arizona State University. Kathayoon Khalil, Ph.D. +L20: * 【5†annual report; publish_date: none†defenders.org】 * Executive Committee +L21: member. BoARD EMERITuS. Terry C. Pelster. New York. Alan Steinberg. Florida. +L22: PRESIDENT. EMERITuS. Rodger Schlickeisen. Virginia. SCIENCE ... +L23: * 【6†BEING THERE FOR BELUGAS; publish_date: none†wendeenicole.com】 Mar 2, 2023 +L24: — Defenders of Wildlife is a national, nonprofit member- ship ... Terry C. +L25: Pelster, Esq. New York. *Executive Committee member outfront. +L26: * 【7†Annual Reports; publish_date: none†www.annualreports.com】 WAYS TO GIVE +L27: There are many ways to support Defenders of Wildlife and protect our imperiled +L28: wildlife. ... Terry C. Pelster. New York. Alan Steinberg. Florida. +L29: * 【8†2014 Annual Report; publish_date: none†defenders.org】 BOARD OF DIRECTORS +L30: defenders of wildlife is a national, nonprofit ... BOARD EMERITUS. Terry C. +L31: Pelster. New York. Alan Steinberg. Florida. Page 3. PO. + + +Thus Terry C. Pelster is indeed the emeritus board member. + +Now the question might want just "Terry C. Pelster". Or maybe "Terry Pelster". Or "Terry C. Pelster, Esq." but instructions say "Name only, no titles." So likely "Terry C. Pelster". However, maybe they want "Terry Pelster". Better to include full name as listed: "Terry C. Pelster". + +Let's read the question again: "Who was the emeritus board member that year? Name only, no titles." So we should answer just the name, no "Esq." or "Board Emeritus". So "Terry C. Pelster". Or maybe "Terry Pelster". But the document shows "Terry C. Pelster, Esq." So the name is "Terry C. Pelster". + +Better to verify that the 2022 annual report indeed lists "BOARD EMERITUS Terry C. Pelster, Esq." We saw that line. + +Thus answer: Terry C. Pelster. + +But let's double-check that the question might refer to the 2021 annual report. The phrase "that year" could be ambiguous. Let's parse the timeline: + +- Year 1: 2021. Press release in Dec 2021 about Kenya vaccine doses. Same year (2021) annual report of Defenders of Wildlife shows 12% contributed services. + +- Year 2: 2022. Annual report mentions rescinded plan and contributed services decreased to 10%. + +The question says "The following year, the annual report stated that the organization had played a role in rescinding a previous administration’s plan ... and that contributed services had decreased by 2% to 10%." This is year 2. + +Then asks "Who was the emeritus board member that year?" That likely refers to year 2. + +Thus answer: Terry C. Pelster. + +However, perhaps the emeritus board member changed between 2021 and 2022. We need to check 2022 annual report for any other emeritus board member. We saw only Terry C. Pelster. Also check 2021 annual report: same. So likely unchanged. + +Thus answer is Terry C. Pelster. + +Nevertheless, we should confirm that the 2022 annual report is indeed the one that includes the phrase "BLM rescinded the previous administration’s plan that opened most of the 22 million acres of the National Petroleum Reserve in Alaska to oil and gas drilling". We saw that in line 200-202. + +Thus answer is Terry C. Pelster. + +Now we need to provide the answer in the final message. Only the name. + +Thus final answer: Terry C. Pelster. + +But let's be thorough and search for "Defenders of Wildlife 2022 annual report contributed services 10%". + +<|DSML|function_calls> +<|DSML|invoke name="find"> +<|DSML|parameter name="cursor" string="false">8 +<|DSML|parameter name="pattern" string="true">Contributed services + +<|end▁of▁sentence|> + + +[16] Find results for text: `contributed services` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2022.pdf/find?pattern=contributed services) +**viewing lines [0 - 10] of 10** + +L0: # 【0†match at L222】 +L1: income $744 Contributed services $5,047 Total Revenue $41,170 USES OF FUNDS +L2: Biodiversity conservation $18,970 Constituency mobilization $16,528 Fundraising +L3: $1,029 Management and general $6,354 Total Expenses $42,881 Change in net assets +L4: (-$1,711) Net assets, start of year $45,144 Net Assets, End of the Year $43,433 +L5: +L6: # 【1†match at L227】 +L7: Contributed services 10% Royalties and other income 2% Bequests, trusts and +L8: split interests 29% TOTAL REVENUE Income from investments -15% Program and +L9: support services 98% USE OF FUNDS Fundraising 2% ===== Page 13 ===== +L10: DEFENDERS.ORG 13 ===== Page 14 ===== BOARD OF DIRECTORS John Dayton, Esq.* Chair + + +Thus confirmed. + +Now we should verify that the 2021 annual report shows contributed services 12%. + +<|DSML|function_calls> +<|DSML|invoke name="find"> +<|DSML|parameter name="cursor" string="false">4 +<|DSML|parameter name="pattern" string="true">Contributed services + +<|end▁of▁sentence|> + + +[17] Find results for text: `contributed services` in `www.annualreports.com` (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf/find?pattern=contributed services) +**viewing lines [0 - 4] of 4** + +L0: # 【0†match at L194】 +L1: $3,354 Royalties and other income $3,576 Contributed services $6,140 Total +L2: Revenue $49,819 USES OF FUNDS Biodiversity conservation $22,420 Constituency +L3: mobilization $16,324 Fundraising $1,211 Management and general $5,865 Total +L4: Expenses $45,820 Change in net assets $3,999 Net assets, start of year $41,145 + + +Open around L190-200 to see percentage. + +<|DSML|function_calls> +<|DSML|invoke name="open"> +<|DSML|parameter name="cursor" string="false">4 +<|DSML|parameter name="loc" string="false">185 + +<|end▁of▁sentence|> + + +[18] www.annualreports.com (https://www.annualreports.com/HostedData/AnnualReportArchive/d/defenders-wildlife_2021.pdf) +**viewing lines [185 - 232] of 261** + +L185: devastating water pollution that is killing the seagrass and causing manatees to +L186: starve. DAVID TES | SAM FRENZY DRAW DEFENDERS.ORG 11 ===== Page 12 ===== In +L187: 2021, Defenders of Wildlife’s nearly 2.2 million members and supporters +L188: contributed more than $49 million for wildlife and wildlife habitat. This figure +L189: includes all those who generously donated their time and expertise. The +L190: steadfast support of our donors allows Defenders to sustain our program and +L191: public education efforts in the field, the courts and on Capitol Hill. 2021 +L192: SOURCES OF FUNDS Grants and contributions $29,057 Bequests, trusts and split +L193: interests $7,692 Income from investments, annuity reserve funds and trusts +L194: $3,354 Royalties and other income $3,576 Contributed services $6,140 Total +L195: Revenue $49,819 USES OF FUNDS Biodiversity conservation $22,420 Constituency +L196: mobilization $16,324 Fundraising $1,211 Management and general $5,865 Total +L197: Expenses $45,820 Change in net assets $3,999 Net assets, start of year $41,145 +L198: Net Assets, End of the Year $45,144 Dollars are in thousands. 12 DEFENDERS OF +L199: WILDLIFE Grants and contributions 58% Income from investments 7% Requests, +L200: trusts and split interests 15% Royalties and other income 7% Contributed +L201: services 12% Program and support services 97% Fundraising 3% ===== Page 13 ===== +L202: DEFENDERS.ORG 13 ===== Page 14 ===== John Dayton, Esq.* Chair Texas Sheryl +L203: Sachs* Vice Chair Connecticut Whitfield Bailey* Secretary Tennessee Mark Caylor* +L204: Treasurer District of Columbia Sylvia Baca* California Donald Barry, Esq. +L205: Oregon Dinah Bear, Esq. Arizona Kim O’Keefe Beck, Ph.D California Gabriela +L206: Chavarria, Ph.D Washington Jamie Rappaport Clark* President & CEO Virginia +L207: Caroline D. Gabel Maryland Keith M. Harper, Esq. Maryland John Donovan, Esq. +L208: Montana and California James M. Hecker, Esq. Virginia Judith Ponsikoff, Ph.D* +L209: California Terry L. Root, Ph.D Florida BOARD EMERITUS Terry C. Pelster, Esq. New +L210: York *Executive Committee member* ST A F F & A D V I S O R S SCIENCE ADVISORS +L211: Gabriela Chavarria, Ph.D Burke Museum, University of Washington Leah Gerber, +L212: Ph.D Arizona State University Kathayoon Khalil, Ph.D Oregon Zoo Rurik List, Ph.D +L213: Universidad Autónoma Metropolitana-Lerma Barry Noon, Ph.D Colorado State +L214: University Terry L. Root, Ph.D Stanford University Mark Schwartz, Ph.D +L215: University of California-Davis David S. Wilcox, Ph.D Princeton University Rae +L216: Wynn-Grant, Ph.D National Geographic Society SENIOR STAFF Jamie Rappaport Clark +L217: President & CEO Thu Pham Chief of Staff Renee Stone Senior Vice President, +L218: Conservation Programs & General Counsel James Stefan Senior Vice President, +L219: Operations Matt Braughler Vice President, Philanthropy Rachel Brittin Vice +L220: President, Communications Shawn Cantrell Vice President, Field Conservation +L221: Programs Robert L. Dewey Vice President, Government Relations & External Affairs +L222: Monica Goldberg Vice President, Landscape Conservation Sharon Morris-Smith Vice +L223: President of Human Resources Elizabeth Nichols Vice President, Finance/CFO +L224: Anthony Pegues Vice President, Diversity, Equity, Inclusion, Justice Lindsay +L225: Rosa, Ph.D Director, Center for Conservation Innovation Michael P. Senatore Vice +L226: President, Conservation Law Brandace Swimley Vice President, Integrated +L227: Marketing 14 DEFENDERS OF WILDLIFE ===== Page 15 ===== THERE ARE MANY WAYS TO +L228: SUPPORT DEFENDERS OF WILDLIFE and protect our imperiled wildlife. To learn more +L229: or donate, contact us at 1-800-385-9712, visit defenders.org/donate or choose +L230: one of the following options. ➤ Transfer appreciated assets directly to +L231: Defenders and maximize potential tax benefits. Defenders.org/donate-stocks-help- +L232: save-wildlife ➤ Become a Leadership Council member with a single donation of + + +Thus everything matches. + +Thus answer: Terry C. Pelster. + +Now we need to output only the name. + +Thus final answer: Terry C. Pelster. + +But perhaps the question expects "Terry Pelster". However, the document includes middle initial "C." It might be better to include the full name as listed: "Terry C. Pelster". The instruction "Name only, no titles" suggests we should not include "Esq." or "Board Emeritus". But including middle initial is likely okay. + +To be safe, we could output "Terry C. Pelster". Let's see similar puzzles: often they expect the full name as appears in the document. + +Thus answer: Terry C. Pelster. + +Now we will provide final answer.Terry C. Pelster<|end▁of▁sentence|> \ No newline at end of file diff --git a/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_vllm_parity.txt b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_vllm_parity.txt new file mode 100644 index 00000000000..f377fad6564 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v32/fixtures/test_output_vllm_parity.txt @@ -0,0 +1,112 @@ +<|begin▁of▁sentence|> + +## Tools + +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<|DSML|function_calls>" block like the following as part of your reply to the user: +<|DSML|function_calls> +<|DSML|invoke name="$FUNCTION_NAME"> +<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<|DSML|invoke name="$FUNCTION_NAME2"> +... + + + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). + +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: + +<|DSML|function_calls> +... + + + +... + + +...thinking about results + +Here are the functions available in JSONSchema format: + +{"name": "get_datetime", "description": "Get the current date and time", "parameters": {"type": "object", "properties": {"timezone": {"type": "string", "description": "The timezone, e.g. Asia/Shanghai, UTC"}}, "required": ["timezone"]}} +{"name": "get_weather", "description": "Get the weather for a specific date and location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name, e.g. Beijing, Hangzhou"}, "date": {"type": "string", "description": "The date in YYYY-MM-DD format"}}, "required": ["location", "date"]}} + +You are a helpful Assistant.<|User|>明天杭州和北京的天气怎么样?<|Assistant|> + +<|DSML|function_calls> +<|DSML|invoke name="get_datetime"> +<|DSML|parameter name="timezone" string="true">Asia/Shanghai + +<|end▁of▁sentence|> + + +{"current_date": "2024-01-15", "current_time": "14:30:00", "timezone": "Asia/Shanghai"} + + + + +<|DSML|function_calls> +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Hangzhou +<|DSML|parameter name="date" string="true">2024-01-16 + +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Beijing +<|DSML|parameter name="date" string="true">2024-01-16 + +<|end▁of▁sentence|> + + +{"location": "Hangzhou", "date": "2024-01-16", "temperature_high": "12", "temperature_low": "5", "weather": "多云", "humidity": "65%"} +{"location": "Beijing", "date": "2024-01-16", "temperature_high": "-2", "temperature_low": "-8", "weather": "晴", "humidity": "30%"} + + +根据查询结果,明天(2024年1月16日)的天气情况如下: + +**杭州**: +- 天气:多云 +- 最高温度:12°C +- 最低温度:5°C +- 湿度:65% + +**北京**: +- 天气:晴 +- 最高温度:-2°C +- 最低温度:-8°C +- 湿度:30% + +杭州明天会比较温暖但有些多云,而北京会很冷但是晴天。建议在北京的朋友要注意保暖!<|end▁of▁sentence|><|User|>谢谢!那后天呢?<|Assistant|>用户现在问后天的天气,后天是2024-01-17,我可以直接查询(因为已知今天日期)🗓️ + +<|DSML|function_calls> +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Hangzhou +<|DSML|parameter name="date" string="true">2024-01-17 + +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Beijing +<|DSML|parameter name="date" string="true">2024-01-17 + +<|end▁of▁sentence|> + + +{"location": "Hangzhou", "date": "2024-01-17", "temperature_high": "15", "temperature_low": "8", "weather": "小雨", "humidity": "80%"} +{"location": "Beijing", "date": "2024-01-17", "temperature_high": "0", "temperature_low": "-6", "weather": "多云", "humidity": "45%"} + + +获取到后天的天气数据,整理回复给用户📝后天(2024年1月17日)的天气情况: + +**杭州**: +- 天气:小雨 +- 最高温度:15°C +- 最低温度:8°C +- 湿度:80% + +**北京**: +- 天气:多云 +- 最高温度:0°C +- 最低温度:-6°C +- 湿度:45% + +杭州后天会有小雨,温度略有回升,记得带伞。北京会稍微暖和一点,但依然很冷,请继续做好保暖措施。<|end▁of▁sentence|> \ No newline at end of file diff --git a/rust/src/chat/src/renderer/deepseek_v32/mod.rs b/rust/src/chat/src/renderer/deepseek_v32/mod.rs new file mode 100644 index 00000000000..97225bbab09 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v32/mod.rs @@ -0,0 +1,31 @@ +mod encoding; + +use vllm_text::Prompt; + +use super::{ChatRenderer, RenderedPrompt}; +use crate::Result; +use crate::request::ChatRequest; + +/// Dedicated DeepSeek V3.2 renderer. +#[derive(Debug, Clone, Copy, Default)] +pub struct DeepSeekV32ChatRenderer; + +impl DeepSeekV32ChatRenderer { + /// Create the dedicated DeepSeek V3.2 renderer. + pub fn new() -> Self { + Self + } +} + +impl ChatRenderer for DeepSeekV32ChatRenderer { + fn render(&self, request: &ChatRequest) -> Result { + request.validate()?; + + Ok(RenderedPrompt { + prompt: Prompt::Text(encoding::render_request(request)?), + }) + } +} + +#[cfg(test)] +mod tests; diff --git a/rust/src/chat/src/renderer/deepseek_v32/tests.rs b/rust/src/chat/src/renderer/deepseek_v32/tests.rs new file mode 100644 index 00000000000..0b8f2b09e11 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v32/tests.rs @@ -0,0 +1,422 @@ +use std::fs; +use std::path::PathBuf; + +use expect_test::{ExpectFile, expect, expect_file}; +use serde::Deserialize; +use serde_json::{Value, json}; +use thiserror_ext::AsReport; + +use super::DeepSeekV32ChatRenderer; +use crate::error::Error; +use crate::event::{AssistantContentBlock, AssistantToolCall}; +use crate::request::{ + ChatContentPart, ChatMessage, ChatRequest, ChatTool, ChatToolChoice, GenerationPromptMode, +}; +use crate::{ChatRenderer, ChatRole}; + +#[derive(Debug, Deserialize)] +struct FixtureRequest { + #[serde(default)] + tools: Vec, + messages: Vec, +} + +#[derive(Debug, Deserialize)] +struct FixtureTool { + function: FixtureToolFunction, +} + +#[derive(Debug, Deserialize)] +struct FixtureToolFunction { + name: String, + description: Option, + parameters: Value, + #[serde(default)] + strict: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "role", rename_all = "snake_case")] +enum FixtureMessage { + System { + content: String, + }, + Developer { + content: String, + #[serde(default)] + tools: Vec, + }, + User { + content: String, + }, + Assistant { + #[serde(default)] + content: String, + #[serde(default)] + reasoning_content: String, + #[serde(default)] + tool_calls: Vec, + }, + Tool { + content: String, + #[serde(default)] + tool_call_id: Option, + }, +} + +#[derive(Debug, Deserialize)] +struct FixtureToolCall { + #[serde(default)] + id: Option, + function: FixtureToolCallFunction, +} + +#[derive(Debug, Deserialize)] +struct FixtureToolCallFunction { + name: String, + arguments: String, +} + +fn render_request(request: &ChatRequest) -> String { + DeepSeekV32ChatRenderer::new() + .render(request) + .unwrap() + .prompt + .into_text() + .expect("deepseek renderer should return text prompt") +} + +fn render_result(request: &ChatRequest) -> Result { + DeepSeekV32ChatRenderer::new().render(request).map(|rendered| { + rendered + .prompt + .into_text() + .expect("deepseek renderer should return text prompt") + }) +} + +fn thinking_request(messages: Vec) -> ChatRequest { + let mut request = ChatRequest { + request_id: "deepseek-v32-small-test".to_string(), + messages, + ..ChatRequest::for_test() + }; + if matches!( + request.messages.last().map(ChatMessage::role), + Some(ChatRole::Assistant) + ) { + request.chat_options.generation_prompt_mode = GenerationPromptMode::NoGenerationPrompt; + } + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), Value::Bool(true)); + request +} + +fn fixture_request(input_name: &str) -> ChatRequest { + let fixture = fs::read_to_string(fixture_path(input_name)).unwrap(); + let fixture: FixtureRequest = serde_json::from_str(&fixture).unwrap(); + let mut request = ChatRequest { + request_id: "deepseek-v32-fixture".to_string(), + messages: fixture + .messages + .into_iter() + .enumerate() + .map(|(index, message)| match message { + FixtureMessage::System { content } => ChatMessage::system(content), + FixtureMessage::Developer { content, tools } => ChatMessage::developer( + content, + (!tools.is_empty()).then(|| to_chat_tools(&tools)), + ), + FixtureMessage::User { content } => ChatMessage::user(content), + FixtureMessage::Assistant { + content, + reasoning_content, + tool_calls, + } => { + let mut blocks = Vec::new(); + if !reasoning_content.is_empty() { + blocks.push(AssistantContentBlock::Reasoning { + text: reasoning_content, + }); + } + if !content.is_empty() { + blocks.push(AssistantContentBlock::Text { text: content }); + } + blocks.extend(tool_calls.into_iter().enumerate().map( + |(tool_index, tool_call)| { + AssistantContentBlock::ToolCall(AssistantToolCall { + id: tool_call.id.unwrap_or_else(|| { + format!("fixture-tool-call-{index}-{tool_index}") + }), + name: tool_call.function.name, + arguments: tool_call.function.arguments, + }) + }, + )); + ChatMessage::assistant_blocks(blocks) + } + FixtureMessage::Tool { + content, + tool_call_id, + } => ChatMessage::tool_response( + content, + tool_call_id.unwrap_or_else(|| format!("fixture-tool-response-{index}")), + ), + }) + .collect(), + tools: to_chat_tools(&fixture.tools), + tool_choice: if fixture.tools.is_empty() { + ChatToolChoice::None + } else { + ChatToolChoice::Auto + }, + ..ChatRequest::for_test() + }; + if matches!( + request.messages.last().map(ChatMessage::role), + Some(ChatRole::Assistant) + ) { + request.chat_options.generation_prompt_mode = GenerationPromptMode::NoGenerationPrompt; + } + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), Value::Bool(true)); + request +} + +fn to_chat_tools(tools: &[FixtureTool]) -> Vec { + tools + .iter() + .map(|tool| ChatTool { + name: tool.function.name.clone(), + description: tool.function.description.clone(), + parameters: tool.function.parameters.clone(), + strict: tool.function.strict, + }) + .collect() +} + +fn fixture_path(name: &str) -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("src/renderer/deepseek_v32") + .join("fixtures") + .join(name) +} + +fn assert_fixture(input_name: &str, expected: ExpectFile) { + let request = fixture_request(input_name); + let rendered = render_request(&request); + expected.assert_eq(&rendered); +} + +#[test] +fn renders_vllm_parity_prompt_for_request_level_tools_fixture() { + assert_fixture( + "test_input.json", + expect_file!["fixtures/test_output_vllm_parity.txt"], + ); +} + +#[test] +fn renders_official_search_fixture_without_date() { + assert_fixture( + "test_input_search_wo_date.json", + expect_file!["fixtures/test_output_search_wo_date.txt"], + ); +} + +#[test] +fn renders_official_search_fixture_with_date() { + assert_fixture( + "test_input_search_w_date.json", + expect_file!["fixtures/test_output_search_w_date.txt"], + ); +} + +#[test] +fn request_level_tools_are_lowered_as_synthetic_leading_system_message() { + let mut request = ChatRequest { + request_id: "deepseek-v32-tools".to_string(), + messages: vec![ + ChatMessage::system("System prompt."), + ChatMessage::text(ChatRole::User, "Hello"), + ], + tools: vec![ChatTool { + name: "lookup".to_string(), + description: Some("Look things up".to_string()), + parameters: json!({ + "type": "object", + "properties": { + "query": { + "type": "string" + } + }, + "required": ["query"] + }), + strict: None, + }], + tool_choice: ChatToolChoice::Auto, + ..ChatRequest::for_test() + }; + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), Value::Bool(true)); + + let rendered = render_request(&request); + + assert!(rendered.starts_with("<|begin▁of▁sentence|>\n\n## Tools\n")); + assert!(rendered.contains("
\nSystem prompt.")); + assert!(rendered.ends_with("<|User|>Hello<|Assistant|>")); +} + +#[test] +fn developer_turn_is_treated_as_last_user_like_turn() { + let request = thinking_request(vec![ChatMessage::developer("Follow policy.", None)]); + + let rendered = render_request(&request); + + assert!(rendered.contains("# The user's message is: Follow policy.")); + assert!(rendered.ends_with("<|Assistant|>")); +} + +#[test] +fn historical_assistant_reasoning_is_dropped_before_final_user_turn() { + let request = thinking_request(vec![ + ChatMessage::assistant_blocks(vec![ + AssistantContentBlock::Reasoning { + text: "internal reasoning".to_string(), + }, + AssistantContentBlock::Text { + text: "Visible answer.".to_string(), + }, + ]), + ChatMessage::user("What about the next one?"), + ]); + + let rendered = render_request(&request); + + assert!(!rendered.contains("internal reasoning")); + assert!(rendered.contains("Visible answer.<|end▁of▁sentence|>")); + assert!(rendered.ends_with("<|User|>What about the next one?<|Assistant|>")); +} + +#[test] +fn historical_assistant_reasoning_is_dropped_before_final_developer_turn() { + let request = thinking_request(vec![ + ChatMessage::assistant_blocks(vec![ + AssistantContentBlock::Reasoning { + text: "internal reasoning".to_string(), + }, + AssistantContentBlock::Text { + text: "Visible answer.".to_string(), + }, + ]), + ChatMessage::developer("Follow the rubric.", None), + ]); + + let rendered = render_request(&request); + + assert!(!rendered.contains("internal reasoning")); + assert!(rendered.contains("Visible answer.<|end▁of▁sentence|>")); + assert!(rendered.ends_with( + "<|User|>\n\n# The user's message is: Follow the rubric.<|Assistant|>" + )); +} + +#[test] +fn tool_results_after_last_user_resume_thinking() { + let request = thinking_request(vec![ + ChatMessage::user("Check the weather."), + ChatMessage::assistant_blocks(vec![AssistantContentBlock::ToolCall(AssistantToolCall { + id: "call-weather".to_string(), + name: "weather".to_string(), + arguments: "{\"city\":\"Hangzhou\"}".to_string(), + })]), + ChatMessage::tool_response("{\"ok\":true}", "call-weather"), + ]); + + let rendered = render_request(&request); + + assert!(rendered.contains( + "<|User|>Check the weather.<|Assistant|>\n\n<|DSML|function_calls>" + )); + assert!(rendered.ends_with("
\n\n")); +} + +#[test] +fn tool_results_follow_assistant_tool_call_id_order() { + let request = thinking_request(vec![ + ChatMessage::user("Check two cities."), + ChatMessage::assistant_blocks(vec![ + AssistantContentBlock::ToolCall(AssistantToolCall { + id: "call-hangzhou".to_string(), + name: "weather".to_string(), + arguments: "{\"city\":\"Hangzhou\"}".to_string(), + }), + AssistantContentBlock::ToolCall(AssistantToolCall { + id: "call-beijing".to_string(), + name: "weather".to_string(), + arguments: "{\"city\":\"Beijing\"}".to_string(), + }), + ]), + ChatMessage::tool_response("{\"city\":\"Beijing\"}", "call-beijing"), + ChatMessage::tool_response("{\"city\":\"Hangzhou\"}", "call-hangzhou"), + ]); + + let rendered = render_request(&request); + + assert!(rendered.contains( + "\n{\"city\":\"Hangzhou\"}\n{\"city\":\"Beijing\"}\n" + )); +} + +#[test] +fn tool_results_require_matching_tool_call_ids() { + let request = thinking_request(vec![ + ChatMessage::user("Check the weather."), + ChatMessage::assistant_blocks(vec![AssistantContentBlock::ToolCall(AssistantToolCall { + id: "call-weather".to_string(), + name: "weather".to_string(), + arguments: "{\"city\":\"Hangzhou\"}".to_string(), + })]), + ChatMessage::tool_response("{\"ok\":true}", "call-unknown"), + ]); + + let error = render_result(&request).unwrap_err(); + + expect!["chat template error: invalid DeepSeek V3.2 tool message: unknown tool_call_id `call-unknown`"] + .assert_eq(&error.to_report_string()); +} + +#[test] +fn assistant_after_last_user_requires_reasoning_or_tool_calls() { + let request = thinking_request(vec![ + ChatMessage::user("Hello"), + ChatMessage::assistant_text("Hi there."), + ]); + + let error = render_result(&request).unwrap_err(); + + expect!["chat template error: invalid DeepSeek V3.2 assistant message after last user message: expected reasoning or tool calls"] + .assert_eq(&error.to_report_string()); +} +#[test] +fn render_rejects_multimodal_input() { + let request = ChatRequest { + messages: vec![ChatMessage::user(vec![ChatContentPart::image_url( + "data:image/png;base64,test", + )])], + ..ChatRequest::for_test() + }; + + let error = DeepSeekV32ChatRenderer::new().render(&request).unwrap_err(); + + assert!(matches!( + error, + Error::UnsupportedMultimodalContent("image_url") + )); +} diff --git a/rust/src/chat/src/renderer/deepseek_v4/encoding.rs b/rust/src/chat/src/renderer/deepseek_v4/encoding.rs new file mode 100644 index 00000000000..54a69248618 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v4/encoding.rs @@ -0,0 +1,558 @@ +//! DeepSeek V4 prompt renderer. +//! +//! Original Python implementation: +//! + +use std::collections::HashMap; +use std::fmt::Write as _; + +use serde::Serialize; +use serde_json::Value; +use serde_json_fmt::JsonFormat; + +use crate::error::{Error, Result}; +use crate::request::{ChatContent, ChatMessage, ChatRequest, ChatTool, ReasoningEffort}; +use crate::{AssistantContentBlock, AssistantMessageExt, AssistantToolCall}; + +const BOS_TOKEN: &str = "<|begin▁of▁sentence|>"; +const EOS_TOKEN: &str = "<|end▁of▁sentence|>"; +const THINKING_START_TOKEN: &str = ""; +const THINKING_END_TOKEN: &str = ""; +const DSML_TOKEN: &str = "|DSML|"; +const USER_SP_TOKEN: &str = "<|User|>"; +const ASSISTANT_SP_TOKEN: &str = "<|Assistant|>"; +const REASONING_EFFORT_MAX: &str = concat!( + "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n", + "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n", + "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n", +); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ThinkingMode { + Chat, + Thinking, +} + +#[serde_with::skip_serializing_none] +#[derive(Debug, Serialize)] +struct RenderedToolSchema<'a> { + name: &'a str, + description: Option<&'a str>, + parameters: &'a Value, + strict: Option, +} + +/// Render one chat request into the final prompt string. +pub(super) fn render_request(request: &ChatRequest) -> Result { + let (thinking_mode, max_reasoning_effort) = resolve_thinking_options(request)?; + let request_tools = request_tools(request); + let synthetic_tool_system = needs_synthetic_tool_system(request, request_tools); + let drop_thinking = request.parse_template_bool("drop_thinking")?.unwrap_or(true) + && !rendered_tools_present(request, request_tools); + let last_user_render_index = + find_last_user_render_index(request.messages.as_slice(), synthetic_tool_system); + let mut out = String::from(BOS_TOKEN); + if thinking_mode == ThinkingMode::Thinking && max_reasoning_effort { + out.push_str(REASONING_EFFORT_MAX); + } + + let mut request_tools_attached = false; + let mut render_index = 0isize; + if synthetic_tool_system { + render_system_message(&mut out, None, request_tools)?; + request_tools_attached = true; + render_index += 1; + } + + for (message_index, message) in request.messages.iter().enumerate() { + if is_following_tool_response(request.messages.as_slice(), message_index) { + continue; + } + + let current_render_index = render_index; + render_index += 1; + + match message { + ChatMessage::System { content } => { + let tools = if !request_tools_attached { + request_tools_attached = true; + request_tools + } else { + &[] + }; + render_system_message(&mut out, Some(content), tools)?; + } + ChatMessage::Developer { content, tools } => { + render_developer_message(&mut out, content, tools.as_deref().unwrap_or(&[]))?; + } + ChatMessage::User { content } => render_user_message(&mut out, content)?, + ChatMessage::Assistant { content } => { + // Mirror Python: thinking block (reasoning + ) is + // emitted whenever thinking is active and reasoning isn't + // dropped - i.e. drop_thinking is off OR this turn lies + // strictly after the last user turn. + let emit_thinking_block = thinking_mode == ThinkingMode::Thinking + && (!drop_thinking || current_render_index > last_user_render_index); + let append_eos = !(message_index + 1 == request.messages.len() + && request.chat_options.continue_final_message()); + render_assistant_message(&mut out, emit_thinking_block, append_eos, content)?; + } + ChatMessage::ToolResponse { .. } => { + render_tool_response_block(&mut out, request.messages.as_slice(), message_index)?; + } + } + + if is_user_like_entry(message) + && next_rendered_entry_is_assistant_or_end(request.messages.as_slice(), message_index) + { + write_assistant_transition( + &mut out, + thinking_mode, + drop_thinking, + current_render_index >= last_user_render_index, + ); + } + } + + Ok(out) +} + +/// Resolve DeepSeek V4's thinking controls. Unlike the Python tokenizer +/// wrapper, the Rust renderer only consumes the typed top-level +/// `reasoning_effort`; the generic template-kwargs map is left for HF +/// templates. +fn resolve_thinking_options(request: &ChatRequest) -> Result<(ThinkingMode, bool)> { + let mut thinking_mode = match request.enable_thinking()?.unwrap_or(false) { + true => ThinkingMode::Thinking, + false => ThinkingMode::Chat, + }; + let mut max_reasoning_effort = false; + + match request.chat_options.reasoning_effort { + Some(ReasoningEffort::None) => thinking_mode = ThinkingMode::Chat, + Some(ReasoningEffort::Max | ReasoningEffort::XHigh) => max_reasoning_effort = true, + Some(_) | None => {} + } + + Ok((thinking_mode, max_reasoning_effort)) +} + +/// Return request-level tools only when native tool parsing is enabled. +fn request_tools(request: &ChatRequest) -> &[ChatTool] { + if request.tool_parsing_enabled() { + request.tools.as_slice() + } else { + &[] + } +} + +/// Return whether request tools need a synthetic leading system entry. +fn needs_synthetic_tool_system(request: &ChatRequest, request_tools: &[ChatTool]) -> bool { + !request_tools.is_empty() + && !request + .messages + .iter() + .any(|message| matches!(message, ChatMessage::System { .. })) +} + +/// Return whether any rendered message carries tool schemas. +fn rendered_tools_present(request: &ChatRequest, request_tools: &[ChatTool]) -> bool { + !request_tools.is_empty() + || request.messages.iter().any(|message| { + matches!( + message, + ChatMessage::Developer { + tools: Some(tools), + .. + } if !tools.is_empty() + ) + }) +} + +/// Find the last user-like turn after inline tool-response merging. +fn find_last_user_render_index(messages: &[ChatMessage], synthetic_tool_system: bool) -> isize { + let mut render_index = isize::from(synthetic_tool_system); + let mut last_user_index = -1; + + for (message_index, message) in messages.iter().enumerate() { + if is_following_tool_response(messages, message_index) { + continue; + } + + if is_user_like_entry(message) { + last_user_index = render_index; + } + render_index += 1; + } + + last_user_index +} + +/// Return whether this tool message is already covered by a previous tool run. +fn is_following_tool_response(messages: &[ChatMessage], message_index: usize) -> bool { + matches!(messages[message_index], ChatMessage::ToolResponse { .. }) + && message_index > 0 + && matches!( + messages[message_index - 1], + ChatMessage::ToolResponse { .. } + ) +} + +/// Return whether one rendered entry should be treated as user-like. +fn is_user_like_entry(message: &ChatMessage) -> bool { + matches!( + message, + ChatMessage::Developer { .. } | ChatMessage::User { .. } | ChatMessage::ToolResponse { .. } + ) +} + +/// Return whether the next rendered entry is assistant, or there is no next +/// entry. +fn next_rendered_entry_is_assistant_or_end(messages: &[ChatMessage], message_index: usize) -> bool { + let mut next_index = message_index + 1; + if matches!(messages[message_index], ChatMessage::ToolResponse { .. }) { + while next_index < messages.len() + && matches!(messages[next_index], ChatMessage::ToolResponse { .. }) + { + next_index += 1; + } + } + + messages + .get(next_index) + .map(|message| matches!(message, ChatMessage::Assistant { .. })) + .unwrap_or(true) +} + +/// Render the tool preamble shown to the model, V4 flavor. +fn render_tools(out: &mut String, tools: &[ChatTool]) -> Result<()> { + out.push_str( + r#"## Tools + +You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following: + +<|DSML|tool_calls> +<|DSML|invoke name="$TOOL_NAME"> +<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<|DSML|invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by ), you MUST output your complete reasoning inside ... BEFORE any tool calls or final response. + +Otherwise, output directly after with tool calls or final response. + +### Available Tool Schemas + +"#, + ); + + for (index, tool) in tools.iter().enumerate() { + if index > 0 { + out.push('\n'); + } + render_tool_schema(out, tool)?; + } + + out.push_str( + "\n\nYou MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.\n", + ); + Ok(()) +} + +/// Serialize one typed tool schema into the JSON shape embedded in the prompt. +fn render_tool_schema(out: &mut String, tool: &ChatTool) -> Result<()> { + out.push_str(&json_dumps(&RenderedToolSchema { + name: &tool.name, + description: tool.description.as_deref(), + parameters: &tool.parameters, + strict: tool.strict, + })?); + Ok(()) +} + +/// Render a system turn, optionally followed by the V4 tool preamble. +fn render_system_message( + out: &mut String, + content: Option<&ChatContent>, + tools: &[ChatTool], +) -> Result<()> { + if let Some(content) = content { + write_chat_content(out, content)?; + } + if !tools.is_empty() { + out.push_str("\n\n"); + render_tools(out, tools)?; + } + Ok(()) +} + +/// Developer messages are rendered as user-like turns with optional tools. +fn render_developer_message( + out: &mut String, + content: &ChatContent, + tools: &[ChatTool], +) -> Result<()> { + if content.is_empty() { + return Err(Error::ChatTemplate( + "invalid DeepSeek V4 developer message: empty content".to_string(), + )); + } + + out.push_str(USER_SP_TOKEN); + write_chat_content(out, content)?; + if !tools.is_empty() { + out.push_str("\n\n"); + render_tools(out, tools)?; + } + Ok(()) +} + +/// Render one plain user turn. +fn render_user_message(out: &mut String, content: &ChatContent) -> Result<()> { + out.push_str(USER_SP_TOKEN); + write_chat_content(out, content)?; + Ok(()) +} + +/// Render a contiguous tool-response run as one synthetic user turn. +fn render_tool_response_block( + out: &mut String, + messages: &[ChatMessage], + message_index: usize, +) -> Result<()> { + let (block_start, block_end) = tool_response_block_bounds(messages, message_index); + let sorted_indices = sorted_tool_response_indices(messages, block_start, block_end); + + out.push_str(USER_SP_TOKEN); + for (offset, message_index) in sorted_indices.iter().enumerate() { + if offset > 0 { + out.push_str("\n\n"); + } + let ChatMessage::ToolResponse { content, .. } = &messages[*message_index] else { + unreachable!("tool response block should only contain tool messages"); + }; + write_tool_result(out, content)?; + } + + Ok(()) +} + +/// Return the contiguous tool-response block containing `actual_index`. +fn tool_response_block_bounds(messages: &[ChatMessage], actual_index: usize) -> (usize, usize) { + let mut block_start = actual_index; + while block_start > 0 && matches!(messages[block_start - 1], ChatMessage::ToolResponse { .. }) { + block_start -= 1; + } + + let mut block_end = actual_index + 1; + while block_end < messages.len() + && matches!(messages[block_end], ChatMessage::ToolResponse { .. }) + { + block_end += 1; + } + + (block_start, block_end) +} + +fn sorted_tool_response_indices( + messages: &[ChatMessage], + block_start: usize, + block_end: usize, +) -> Vec { + let Some(tool_call_order) = last_tool_call_order_before(messages, block_start) else { + return (block_start..block_end).collect(); + }; + + let mut indices = (block_start..block_end).collect::>(); + indices.sort_by_key(|index| { + let ChatMessage::ToolResponse { tool_call_id, .. } = &messages[*index] else { + unreachable!("tool response block should only contain tool messages"); + }; + tool_call_order.get(tool_call_id.as_str()).copied().unwrap_or(0) + }); + indices +} + +fn last_tool_call_order_before( + messages: &[ChatMessage], + message_index: usize, +) -> Option> { + let mut tool_call_order = None; + for message in &messages[..message_index] { + if let ChatMessage::Assistant { content } = message { + let order = content + .tool_calls() + .enumerate() + .map(|(index, tool_call)| (tool_call.id.as_str(), index)) + .collect::>(); + if !order.is_empty() { + tool_call_order = Some(order); + } + } + } + tool_call_order +} + +/// Render one tool response payload inside a V4 `` block. +fn write_tool_result(out: &mut String, content: &ChatContent) -> Result<()> { + out.push_str(""); + write_chat_content(out, content)?; + out.push_str(""); + Ok(()) +} + +/// Append the assistant transition token after a user-like turn. +fn write_assistant_transition( + out: &mut String, + thinking_mode: ThinkingMode, + drop_thinking: bool, + opens_thinking: bool, +) { + out.push_str(ASSISTANT_SP_TOKEN); + if thinking_mode == ThinkingMode::Thinking && (!drop_thinking || opens_thinking) { + out.push_str(THINKING_START_TOKEN); + } else { + out.push_str(THINKING_END_TOKEN); + } +} + +/// Render one assistant turn, including optional reasoning, DSML tool calls, +/// and the trailing EOS marker. +fn render_assistant_message( + out: &mut String, + emit_thinking_block: bool, + append_eos: bool, + content: &[AssistantContentBlock], +) -> Result<()> { + let has_tool_calls = content.has_tool_calls(); + + if emit_thinking_block { + if content.has_reasoning() { + write_assistant_reasoning(out, content); + } + out.push_str(THINKING_END_TOKEN); + } + + write_assistant_text(out, content); + + if has_tool_calls { + out.push_str("\n\n<|DSML|tool_calls>\n"); + for (index, tool_call) in content.tool_calls().enumerate() { + if index > 0 { + out.push('\n'); + } + render_tool_call(out, tool_call)?; + } + out.push_str("\n"); + } + + if append_eos { + out.push_str(EOS_TOKEN); + } + Ok(()) +} + +/// Render one assistant tool call in DSML XML-like format. +fn render_tool_call(out: &mut String, tool_call: &AssistantToolCall) -> Result<()> { + writeln!(out, "<{DSML_TOKEN}invoke name=\"{}\">", tool_call.name) + .expect("writing to String cannot fail"); + encode_arguments_to_dsml(out, tool_call)?; + write!(out, "\n").expect("writing to String cannot fail"); + Ok(()) +} + +/// Convert one assistant tool-call arguments object into DSML parameter form. +/// +/// String values are emitted raw with `string="true"`, while all other JSON +/// values are rendered with JSON syntax and `string="false"`. +fn encode_arguments_to_dsml(out: &mut String, tool_call: &AssistantToolCall) -> Result<()> { + let arguments: Value = serde_json::from_str(&tool_call.arguments).map_err(|error| { + Error::ChatTemplate(format!( + "assistant tool call has invalid JSON arguments for DeepSeek V4: {error}" + )) + })?; + let Some(arguments) = arguments.as_object() else { + return Err(Error::ChatTemplate( + "assistant tool call arguments for DeepSeek V4 must be a JSON object".to_string(), + )); + }; + + let mut wrote_parameter = false; + for (key, value) in arguments { + if wrote_parameter { + out.push('\n'); + } + + let is_string = matches!(value, Value::String(_)); + write!( + out, + "<{DSML_TOKEN}parameter name=\"{key}\" string=\"{}\">", + if is_string { "true" } else { "false" } + ) + .expect("writing to String cannot fail"); + + match value { + Value::String(value) => out.push_str(value), + value => out.push_str(&json_dumps(value)?), + } + + write!(out, "").expect("writing to String cannot fail"); + wrote_parameter = true; + } + + Ok(()) +} + +/// Write chat content directly into the destination buffer without flattening +/// it into an intermediate `String`. +fn write_chat_content(out: &mut String, content: &ChatContent) -> Result<()> { + match content { + ChatContent::Text(text) => out.push_str(text), + ChatContent::Parts(parts) => { + for part in parts { + out.push_str(part.as_text()?); + } + } + } + Ok(()) +} + +/// Write all reasoning blocks in encounter order. +fn write_assistant_reasoning(out: &mut String, content: &[AssistantContentBlock]) { + for block in content { + if let AssistantContentBlock::Reasoning { text } = block { + out.push_str(text); + } + } +} + +/// Write all visible assistant text blocks in encounter order. +fn write_assistant_text(out: &mut String, content: &[AssistantContentBlock]) { + for block in content { + if let AssistantContentBlock::Text { text } = block { + out.push_str(text); + } + } +} + +/// Compact JSON serialization used by this renderer for exact prompt text. +fn json_dumps(value: &T) -> Result { + JsonFormat::new() + .comma(", ") + .expect("literal comma separator is valid JSON") + .colon(": ") + .expect("literal colon separator is valid JSON") + .ascii(false) + .format_to_string(value) + .map_err(|error| { + Error::ChatTemplate(format!( + "failed to serialize DeepSeek V4 JSON payload: {error}" + )) + }) +} diff --git a/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_input_1.json b/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_input_1.json new file mode 100644 index 00000000000..d423b221fa8 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_input_1.json @@ -0,0 +1,81 @@ +{ + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a specific location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit" + } + }, + "required": ["location"] + } + } + }, + { + "type": "function", + "function": { + "name": "search", + "description": "Search the web for information", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query" + }, + "num_results": { + "type": "integer", + "description": "Number of results to return" + } + }, + "required": ["query"] + } + } + } + ], + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What's the weather in Beijing?" + }, + { + "role": "assistant", + "reasoning_content": "The user wants to know the weather in Beijing. I should use the get_weather tool.", + "tool_calls": [ + { + "id": "call_001", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"Beijing\", \"unit\": \"celsius\"}" + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "call_001", + "content": "{\"temperature\": 22, \"condition\": \"sunny\", \"humidity\": 45}" + }, + { + "role": "assistant", + "reasoning_content": "Got the weather data. Let me format a nice response.", + "content": "The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity." + } + ] +} diff --git a/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_input_2.json b/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_input_2.json new file mode 100644 index 00000000000..13b3454a902 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_input_2.json @@ -0,0 +1,24 @@ +[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "reasoning_content": "The user said hello, I should greet back.", + "content": "Hi there! How can I help you?" + }, + { + "role": "user", + "content": "What is the capital of France?" + }, + { + "role": "assistant", + "reasoning_content": "The user asks about the capital of France. It is Paris.", + "content": "The capital of France is Paris." + } +] \ No newline at end of file diff --git a/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_output_1.txt b/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_output_1.txt new file mode 100644 index 00000000000..7e3c9bd5a39 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_output_1.txt @@ -0,0 +1,36 @@ +<|begin▁of▁sentence|>You are a helpful assistant. + +## Tools + +You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<|DSML|tool_calls>" block like the following: + +<|DSML|tool_calls> +<|DSML|invoke name="$TOOL_NAME"> +<|DSML|parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<|DSML|invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by ), you MUST output your complete reasoning inside ... BEFORE any tool calls or final response. + +Otherwise, output directly after with tool calls or final response. + +### Available Tool Schemas + +{"name": "get_weather", "description": "Get the weather for a specific location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city name"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "Temperature unit"}}, "required": ["location"]}} +{"name": "search", "description": "Search the web for information", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}, "num_results": {"type": "integer", "description": "Number of results to return"}}, "required": ["query"]}} + +You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. +<|User|>What's the weather in Beijing?<|Assistant|>The user wants to know the weather in Beijing. I should use the get_weather tool. + +<|DSML|tool_calls> +<|DSML|invoke name="get_weather"> +<|DSML|parameter name="location" string="true">Beijing +<|DSML|parameter name="unit" string="true">celsius + +<|end▁of▁sentence|><|User|>{"temperature": 22, "condition": "sunny", "humidity": 45}<|Assistant|>Got the weather data. Let me format a nice response.The weather in Beijing is currently sunny with a temperature of 22°C and 45% humidity.<|end▁of▁sentence|> \ No newline at end of file diff --git a/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_output_2.txt b/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_output_2.txt new file mode 100644 index 00000000000..fc397ef5497 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v4/fixtures/test_output_2.txt @@ -0,0 +1 @@ +<|begin▁of▁sentence|>You are a helpful assistant.<|User|>Hello<|Assistant|>Hi there! How can I help you?<|end▁of▁sentence|><|User|>What is the capital of France?<|Assistant|>The user asks about the capital of France. It is Paris.The capital of France is Paris.<|end▁of▁sentence|> \ No newline at end of file diff --git a/rust/src/chat/src/renderer/deepseek_v4/mod.rs b/rust/src/chat/src/renderer/deepseek_v4/mod.rs new file mode 100644 index 00000000000..7c3f4631d20 --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v4/mod.rs @@ -0,0 +1,30 @@ +mod encoding; + +use vllm_text::Prompt; + +use super::{ChatRenderer, RenderedPrompt}; +use crate::Result; +use crate::request::ChatRequest; + +/// Dedicated DeepSeek V4 renderer. +#[derive(Debug, Clone, Copy, Default)] +pub struct DeepSeekV4ChatRenderer; + +impl DeepSeekV4ChatRenderer { + pub fn new() -> Self { + Self + } +} + +impl ChatRenderer for DeepSeekV4ChatRenderer { + fn render(&self, request: &ChatRequest) -> Result { + request.validate()?; + + Ok(RenderedPrompt { + prompt: Prompt::Text(encoding::render_request(request)?), + }) + } +} + +#[cfg(test)] +mod tests; diff --git a/rust/src/chat/src/renderer/deepseek_v4/tests.rs b/rust/src/chat/src/renderer/deepseek_v4/tests.rs new file mode 100644 index 00000000000..78936d8e68e --- /dev/null +++ b/rust/src/chat/src/renderer/deepseek_v4/tests.rs @@ -0,0 +1,369 @@ +use std::fs; +use std::path::PathBuf; + +use expect_test::{ExpectFile, expect, expect_file}; +use serde::Deserialize; +use serde_json::Value; + +use super::DeepSeekV4ChatRenderer; +use crate::event::{AssistantContentBlock, AssistantToolCall}; +use crate::request::{ + ChatMessage, ChatRequest, ChatTool, ChatToolChoice, GenerationPromptMode, ReasoningEffort, +}; +use crate::{ChatRenderer, ChatRole}; + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum FixtureFile { + WithTools(FixtureRequest), + MessagesOnly(Vec), +} + +#[derive(Debug, Deserialize)] +struct FixtureRequest { + #[serde(default)] + tools: Vec, + messages: Vec, +} + +impl FixtureFile { + fn into_parts(self) -> (Vec, Vec) { + match self { + Self::WithTools(req) => (req.tools, req.messages), + Self::MessagesOnly(messages) => (Vec::new(), messages), + } + } +} + +#[derive(Debug, Deserialize)] +struct FixtureTool { + function: FixtureToolFunction, +} + +#[derive(Debug, Deserialize)] +struct FixtureToolFunction { + name: String, + description: Option, + parameters: Value, + #[serde(default)] + strict: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "role", rename_all = "snake_case")] +enum FixtureMessage { + System { + content: String, + }, + Developer { + content: String, + #[serde(default)] + tools: Vec, + }, + User { + content: String, + }, + Assistant { + #[serde(default)] + content: String, + #[serde(default)] + reasoning_content: String, + #[serde(default)] + tool_calls: Vec, + }, + Tool { + content: String, + #[serde(default)] + tool_call_id: Option, + }, +} + +#[derive(Debug, Deserialize)] +struct FixtureToolCall { + #[serde(default)] + id: Option, + function: FixtureToolCallFunction, +} + +#[derive(Debug, Deserialize)] +struct FixtureToolCallFunction { + name: String, + arguments: String, +} + +fn render_request(request: &ChatRequest) -> String { + DeepSeekV4ChatRenderer::new() + .render(request) + .unwrap() + .prompt + .into_text() + .expect("deepseek v4 renderer should return text prompt") +} + +fn fixture_request(input_name: &str) -> ChatRequest { + let fixture = fs::read_to_string(fixture_path(input_name)).unwrap(); + let fixture: FixtureFile = serde_json::from_str(&fixture).unwrap(); + let (fixture_tools, fixture_messages) = fixture.into_parts(); + let mut request = ChatRequest { + request_id: "deepseek-v4-fixture".to_string(), + messages: fixture_messages + .into_iter() + .enumerate() + .map(|(index, message)| match message { + FixtureMessage::System { content } => ChatMessage::system(content), + FixtureMessage::Developer { content, tools } => ChatMessage::developer( + content, + (!tools.is_empty()).then(|| to_chat_tools(&tools)), + ), + FixtureMessage::User { content } => ChatMessage::user(content), + FixtureMessage::Assistant { + content, + reasoning_content, + tool_calls, + } => { + let mut blocks = Vec::new(); + if !reasoning_content.is_empty() { + blocks.push(AssistantContentBlock::Reasoning { + text: reasoning_content, + }); + } + if !content.is_empty() { + blocks.push(AssistantContentBlock::Text { text: content }); + } + blocks.extend(tool_calls.into_iter().enumerate().map( + |(tool_index, tool_call)| { + AssistantContentBlock::ToolCall(AssistantToolCall { + id: tool_call.id.unwrap_or_else(|| { + format!("fixture-tool-call-{index}-{tool_index}") + }), + name: tool_call.function.name, + arguments: tool_call.function.arguments, + }) + }, + )); + ChatMessage::assistant_blocks(blocks) + } + FixtureMessage::Tool { + content, + tool_call_id, + } => ChatMessage::tool_response( + content, + tool_call_id.unwrap_or_else(|| format!("fixture-tool-response-{index}")), + ), + }) + .collect(), + tools: to_chat_tools(&fixture_tools), + tool_choice: if fixture_tools.is_empty() { + ChatToolChoice::None + } else { + ChatToolChoice::Auto + }, + ..ChatRequest::for_test() + }; + if matches!( + request.messages.last().map(ChatMessage::role), + Some(ChatRole::Assistant) + ) { + request.chat_options.generation_prompt_mode = GenerationPromptMode::NoGenerationPrompt; + } + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), Value::Bool(true)); + request +} + +fn to_chat_tools(tools: &[FixtureTool]) -> Vec { + tools + .iter() + .map(|tool| ChatTool { + name: tool.function.name.clone(), + description: tool.function.description.clone(), + parameters: tool.function.parameters.clone(), + strict: tool.function.strict, + }) + .collect() +} + +fn fixture_path(name: &str) -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("src/renderer/deepseek_v4") + .join("fixtures") + .join(name) +} + +fn assert_fixture(input_name: &str, expected: ExpectFile) { + let request = fixture_request(input_name); + let rendered = render_request(&request); + expected.assert_eq(&rendered); +} + +#[test] +fn renders_v4_fixture_1_tool_call_round_trip() { + assert_fixture( + "test_input_1.json", + expect_file!["fixtures/test_output_1.txt"], + ); +} + +#[test] +fn renders_v4_fixture_2_multi_turn_drop_thinking() { + assert_fixture( + "test_input_2.json", + expect_file!["fixtures/test_output_2.txt"], + ); +} + +#[test] +fn reasoning_effort_max_adds_prefix_when_thinking_is_enabled() { + let mut request = ChatRequest { + messages: vec![ChatMessage::user("solve it")], + ..ChatRequest::for_test() + }; + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), Value::Bool(true)); + request.chat_options.reasoning_effort = Some(ReasoningEffort::Max); + + let rendered = render_request(&request); + + expect![[r#" + <|begin▁of▁sentence|>Reasoning Effort: Absolute maximum with no shortcuts permitted. + You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios. + Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked. + + <|User|>solve it<|Assistant|>"#]] + .assert_eq(&rendered); +} + +#[test] +fn reasoning_effort_none_disables_thinking() { + let mut request = ChatRequest { + messages: vec![ChatMessage::user("answer directly")], + ..ChatRequest::for_test() + }; + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), Value::Bool(true)); + request.chat_options.reasoning_effort = Some(ReasoningEffort::None); + + let rendered = render_request(&request); + + expect!["<|begin▁of▁sentence|><|User|>answer directly<|Assistant|>"] + .assert_eq(&rendered); +} + +#[test] +fn reasoning_effort_template_kwarg_is_ignored() { + let mut request = ChatRequest { + messages: vec![ChatMessage::user("solve it")], + ..ChatRequest::for_test() + }; + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), Value::Bool(true)); + request.chat_options.template_kwargs.insert( + "reasoning_effort".to_string(), + Value::String("max".to_string()), + ); + + let rendered = render_request(&request); + + expect!["<|begin▁of▁sentence|><|User|>solve it<|Assistant|>"].assert_eq(&rendered); +} + +#[test] +fn tool_results_are_sorted_by_previous_assistant_tool_call_order() { + let request = ChatRequest { + messages: vec![ + ChatMessage::assistant_blocks(vec![ + AssistantContentBlock::ToolCall(AssistantToolCall { + id: "second".to_string(), + name: "second_tool".to_string(), + arguments: "{}".to_string(), + }), + AssistantContentBlock::ToolCall(AssistantToolCall { + id: "first".to_string(), + name: "first_tool".to_string(), + arguments: "{}".to_string(), + }), + ]), + ChatMessage::tool_response("first result", "first"), + ChatMessage::tool_response("second result", "second"), + ], + ..ChatRequest::for_test() + }; + + let rendered = render_request(&request); + + expect![[r#" + <|begin▁of▁sentence|> + + <|DSML|tool_calls> + <|DSML|invoke name="second_tool"> + + + <|DSML|invoke name="first_tool"> + + + <|end▁of▁sentence|><|User|>second result + + first result<|Assistant|>"#]] + .assert_eq(&rendered); +} + +#[test] +fn drop_thinking_false_keeps_prior_assistant_reasoning() { + let mut request = ChatRequest { + messages: vec![ + ChatMessage::assistant_blocks(vec![ + AssistantContentBlock::Reasoning { + text: "old reasoning".to_string(), + }, + AssistantContentBlock::Text { + text: "old answer".to_string(), + }, + ]), + ChatMessage::user("next"), + ], + ..ChatRequest::for_test() + }; + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), Value::Bool(true)); + request + .chat_options + .template_kwargs + .insert("drop_thinking".to_string(), Value::Bool(false)); + + let rendered = render_request(&request); + + expect!( + "<|begin▁of▁sentence|>old reasoningold answer<|end▁of▁sentence|><|User|>next<|Assistant|>" + ) + .assert_eq(&rendered); +} + +#[test] +fn continue_final_assistant_omits_final_eos() { + let request = ChatRequest { + messages: vec![ + ChatMessage::user("write"), + ChatMessage::assistant_text("partial answer"), + ], + chat_options: crate::request::ChatOptions { + generation_prompt_mode: GenerationPromptMode::ContinueFinalAssistant, + ..Default::default() + }, + ..ChatRequest::for_test() + }; + + let rendered = render_request(&request); + + expect!["<|begin▁of▁sentence|><|User|>write<|Assistant|>partial answer"] + .assert_eq(&rendered); +} diff --git a/rust/src/chat/src/renderer/hf/error.rs b/rust/src/chat/src/renderer/hf/error.rs new file mode 100644 index 00000000000..fcc48c75aba --- /dev/null +++ b/rust/src/chat/src/renderer/hf/error.rs @@ -0,0 +1,15 @@ +use thiserror::Error as ThisError; + +#[derive(Debug, ThisError)] +pub(crate) enum TemplateError { + #[error("failed to render jinja template")] + Jinja(#[from] minijinja::Error), + #[error("failed to read chat template file")] + ReadTemplateFile(#[source] std::io::Error), + #[error("chat template looks like a file path but does not exist")] + MissingTemplatePath, + #[error("failed to parse chat_template.json")] + ParseTemplateJson(#[source] serde_json::Error), + #[error("chat_template.json does not contain a valid template")] + InvalidTemplateJson, +} diff --git a/rust/src/chat/src/renderer/hf/format.rs b/rust/src/chat/src/renderer/hf/format.rs new file mode 100644 index 00000000000..a9b35d0f41c --- /dev/null +++ b/rust/src/chat/src/renderer/hf/format.rs @@ -0,0 +1,400 @@ +use std::collections::{HashSet, VecDeque}; +use std::fmt; +use std::str::FromStr; + +use minijinja::machinery::ast::{Expr, ForLoop, Set, Stmt}; +use minijinja::machinery::{WhitespaceConfig, parse}; +use minijinja::syntax::SyntaxConfig; +use serde_with::DeserializeFromStr; + +/// Chat template content format. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ChatTemplateContentFormat { + /// Content is a simple string. + #[default] + String, + /// Content is a list of structured parts (OpenAI format). + OpenAi, +} + +/// Configurable chat-template content format selection. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr)] +pub enum ChatTemplateContentFormatOption { + /// Detect the format from the template source. + #[default] + Auto, + /// Always flatten content into plain strings before rendering. + String, + /// Always pass content through in OpenAI-compatible structured form. + OpenAi, +} + +impl ChatTemplateContentFormatOption { + pub const AUTO_LITERAL: &str = "auto"; + pub const OPENAI_LITERAL: &str = "openai"; + pub const STRING_LITERAL: &str = "string"; +} + +impl FromStr for ChatTemplateContentFormatOption { + type Err = String; + + fn from_str(value: &str) -> Result { + if value.eq_ignore_ascii_case(Self::AUTO_LITERAL) { + Ok(Self::Auto) + } else if value.eq_ignore_ascii_case(Self::STRING_LITERAL) { + Ok(Self::String) + } else if value.eq_ignore_ascii_case(Self::OPENAI_LITERAL) { + Ok(Self::OpenAi) + } else { + Err(format!( + "invalid content format `{value}`; expected one of: {}, {}, {}", + Self::AUTO_LITERAL, + Self::STRING_LITERAL, + Self::OPENAI_LITERAL + )) + } + } +} + +impl fmt::Display for ChatTemplateContentFormatOption { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Auto => f.write_str(Self::AUTO_LITERAL), + Self::String => f.write_str(Self::STRING_LITERAL), + Self::OpenAi => f.write_str(Self::OPENAI_LITERAL), + } + } +} + +fn is_var_access(expr: &Expr, varname: &str) -> bool { + matches!(expr, Expr::Var(v) if v.id == varname) +} + +fn is_const_str(expr: &Expr, value: &str) -> bool { + matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value)) +} + +fn is_attr_access(expr: &Expr, varname: &str, key: &str) -> bool { + match expr { + Expr::GetItem(g) => is_var_access(&g.expr, varname) && is_const_str(&g.subscript_expr, key), + Expr::GetAttr(g) => is_var_access(&g.expr, varname) && g.name == key, + _ => false, + } +} + +fn is_var_or_elems_access(expr: &Expr, varname: &str, key: Option<&str>) -> bool { + match expr { + Expr::Filter(f) => { + f.expr.as_ref().is_some_and(|inner| is_var_or_elems_access(inner, varname, key)) + } + Expr::Test(t) => is_var_or_elems_access(&t.expr, varname, key), + Expr::Slice(s) => is_var_or_elems_access(&s.expr, varname, key), + _ => key.map_or_else( + || is_var_access(expr, varname), + |key| is_attr_access(expr, varname, key), + ), + } +} + +fn visit_stmt<'a>( + stmt: &'a Stmt<'a>, + assignments: &mut Vec<&'a Set<'a>>, + loops: &mut Vec<&'a ForLoop<'a>>, +) { + match stmt { + Stmt::Template(t) => { + for child in &t.children { + visit_stmt(child, assignments, loops); + } + } + Stmt::ForLoop(fl) => { + loops.push(fl); + for child in &fl.body { + visit_stmt(child, assignments, loops); + } + for child in &fl.else_body { + visit_stmt(child, assignments, loops); + } + } + Stmt::IfCond(ic) => { + for child in &ic.true_body { + visit_stmt(child, assignments, loops); + } + for child in &ic.false_body { + visit_stmt(child, assignments, loops); + } + } + Stmt::WithBlock(wb) => { + for child in &wb.body { + visit_stmt(child, assignments, loops); + } + } + Stmt::Set(set_stmt) => assignments.push(set_stmt), + Stmt::SetBlock(sb) => { + for child in &sb.body { + visit_stmt(child, assignments, loops); + } + } + Stmt::AutoEscape(ae) => { + for child in &ae.body { + visit_stmt(child, assignments, loops); + } + } + Stmt::FilterBlock(fb) => { + for child in &fb.body { + visit_stmt(child, assignments, loops); + } + } + Stmt::Block(b) => { + for child in &b.body { + visit_stmt(child, assignments, loops); + } + } + Stmt::Macro(m) => { + for child in &m.body { + visit_stmt(child, assignments, loops); + } + } + Stmt::CallBlock(cb) => { + for child in &cb.macro_decl.body { + visit_stmt(child, assignments, loops); + } + } + _ => {} + } +} + +fn collect_assignments_and_loops<'a>( + root: &'a Stmt<'a>, +) -> (Vec<&'a Set<'a>>, Vec<&'a ForLoop<'a>>) { + let mut assignments = Vec::new(); + let mut loops = Vec::new(); + visit_stmt(root, &mut assignments, &mut loops); + (assignments, loops) +} + +fn iter_nodes_assign_var_or_elems(root: &Stmt<'_>, varname: &str) -> Vec { + let (assignments, _) = collect_assignments_and_loops(root); + + let mut discovered = vec![varname.to_string()]; + let mut seen = HashSet::from([varname.to_string()]); + let mut related = VecDeque::from([varname.to_string()]); + + while let Some(related_varname) = related.pop_front() { + for assign in &assignments { + let Expr::Var(lhs) = &assign.target else { + continue; + }; + + if is_var_or_elems_access(&assign.expr, &related_varname, None) { + let lhs_name = lhs.id.to_string(); + if seen.insert(lhs_name.clone()) { + discovered.push(lhs_name.clone()); + if lhs_name != related_varname { + related.push_back(lhs_name); + } + } + } + } + } + + discovered +} + +fn iter_nodes_assign_messages_item(root: &Stmt<'_>) -> Vec { + let message_varnames = iter_nodes_assign_var_or_elems(root, "messages"); + let (_, loops) = collect_assignments_and_loops(root); + + let mut discovered = Vec::new(); + let mut seen = HashSet::new(); + + for loop_ast in loops { + let Expr::Var(target) = &loop_ast.target else { + continue; + }; + + if message_varnames + .iter() + .any(|varname| is_var_or_elems_access(&loop_ast.iter, varname, None)) + { + let target_name = target.id.to_string(); + if seen.insert(target_name.clone()) { + discovered.push(target_name); + } + } + } + + discovered +} + +fn has_content_item_loop(root: &Stmt<'_>) -> bool { + let message_varnames = iter_nodes_assign_messages_item(root); + let (_, loops) = collect_assignments_and_loops(root); + + loops.into_iter().any(|loop_ast| { + matches!(loop_ast.target, Expr::Var(_)) + && message_varnames + .iter() + .any(|varname| is_var_or_elems_access(&loop_ast.iter, varname, Some("content"))) + }) +} + +/// Detect the content format expected by a Jinja2 chat template based on AST +/// analysis. +pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateContentFormat { + let ast = match parse( + template, + "template", + SyntaxConfig {}, + WhitespaceConfig::default(), + ) { + Ok(ast) => ast, + Err(_) => return ChatTemplateContentFormat::String, + }; + + if has_content_item_loop(&ast) { + ChatTemplateContentFormat::OpenAi + } else { + ChatTemplateContentFormat::String + } +} + +#[cfg(test)] +mod tests { + use std::fs; + use std::path::{Path, PathBuf}; + + use expect_test::expect; + + use super::{ChatTemplateContentFormat, detect_chat_template_content_format}; + + fn detect(template: &str) -> ChatTemplateContentFormat { + detect_chat_template_content_format(template) + } + + fn vllm_examples_dir() -> PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests/templates/vllm_examples") + .canonicalize() + .expect("vLLM example template directory should exist locally") + } + + fn read_vllm_example(relative_path: &str) -> String { + fs::read_to_string(vllm_examples_dir().join(relative_path)) + .unwrap_or_else(|_| panic!("failed to read vLLM example template: {relative_path}")) + } + + fn iter_vllm_example_template_paths() -> impl Iterator { + let mut paths = fs::read_dir(vllm_examples_dir()) + .expect("failed to read vLLM example template directory") + .map(|entry| entry.expect("failed to read vLLM example template dir entry").path()) + .filter(|path| path.extension().is_some_and(|ext| ext == "jinja")) + .collect::>(); + paths.sort(); + paths.into_iter() + } + + #[test] + fn detects_string_template_without_content_loop() { + assert_eq!( + detect("{% for message in messages %}{{ message.content }}{% endfor %}"), + ChatTemplateContentFormat::String + ); + } + + #[test] + fn detects_openai_template_with_direct_content_loop() { + assert_eq!( + detect( + "{% for message in messages %}{% for content in message['content'] %}{{ content }}{% endfor %}{% endfor %}" + ), + ChatTemplateContentFormat::OpenAi + ); + } + + #[test] + fn detects_openai_template_with_messages_alias() { + assert_eq!( + detect( + "{% set msgs = messages %}{% for message in msgs %}{% for content in message.content %}{{ content }}{% endfor %}{% endfor %}" + ), + ChatTemplateContentFormat::OpenAi + ); + } + + #[test] + fn does_not_detect_content_alias_loop_as_openai() { + assert_eq!( + detect( + "{% for message in messages %}{% set parts = message.content %}{% for item in parts %}{{ item }}{% endfor %}{% endfor %}" + ), + ChatTemplateContentFormat::String + ); + } + + #[test] + fn does_not_treat_length_or_index_access_as_openai() { + assert_eq!( + detect("{% for message in messages %}{{ message.content|length }}{% endfor %}"), + ChatTemplateContentFormat::String + ); + assert_eq!( + detect("{% for message in messages %}{{ message.content[0] }}{% endfor %}"), + ChatTemplateContentFormat::String + ); + } + + #[test] + fn matches_vllm_example_template_formats() { + let snapshot = iter_vllm_example_template_paths() + .map(|path| { + let file_name = path + .file_name() + .and_then(|name| name.to_str()) + .expect("template file name should be valid UTF-8"); + let template = read_vllm_example(file_name); + let format = detect(&template); + format!("{file_name:50} => {format:?}") + }) + .collect::>() + .join("\n"); + + expect![[r#" + template_alpaca.jinja => String + template_baichuan.jinja => String + template_chatglm.jinja => String + template_chatglm2.jinja => String + template_chatml.jinja => String + template_falcon.jinja => String + template_falcon_180b.jinja => String + template_inkbot.jinja => String + template_teleflm.jinja => String + tool_chat_template_deepseekr1.jinja => String + tool_chat_template_deepseekv3.jinja => String + tool_chat_template_deepseekv31.jinja => String + tool_chat_template_functiongemma.jinja => String + tool_chat_template_gemma3_pythonic.jinja => OpenAi + tool_chat_template_gemma4.jinja => OpenAi + tool_chat_template_glm4.jinja => String + tool_chat_template_granite.jinja => String + tool_chat_template_granite_20b_fc.jinja => String + tool_chat_template_hermes.jinja => String + tool_chat_template_hunyuan_a13b.jinja => String + tool_chat_template_internlm2_tool.jinja => String + tool_chat_template_llama3.1_json.jinja => OpenAi + tool_chat_template_llama3.2_json.jinja => OpenAi + tool_chat_template_llama3.2_pythonic.jinja => String + tool_chat_template_llama4_json.jinja => OpenAi + tool_chat_template_llama4_pythonic.jinja => OpenAi + tool_chat_template_minimax_m1.jinja => OpenAi + tool_chat_template_mistral.jinja => String + tool_chat_template_mistral3.jinja => OpenAi + tool_chat_template_mistral_parallel.jinja => String + tool_chat_template_phi4_mini.jinja => String + tool_chat_template_qwen3coder.jinja => String + tool_chat_template_toolace.jinja => String + tool_chat_template_xlam_llama.jinja => String + tool_chat_template_xlam_qwen.jinja => String"#]] + .assert_eq(&snapshot); + } +} diff --git a/rust/src/chat/src/renderer/hf/mod.rs b/rust/src/chat/src/renderer/hf/mod.rs new file mode 100644 index 00000000000..851df5068f4 --- /dev/null +++ b/rust/src/chat/src/renderer/hf/mod.rs @@ -0,0 +1,970 @@ +use std::collections::HashMap; + +use serde::Serialize; +use serde_json::Value; +use thiserror_ext::AsReport as _; +use tracing::{info, trace, warn}; +use vllm_text::Prompt; +use vllm_text::backend::hf::{ + HfSpecialTokens, HfTokenizerConfig, ResolvedModelFiles, load_tokenizer_config, +}; + +use self::format::{ + ChatTemplateContentFormat, ChatTemplateContentFormatOption as ContentFormatOption, +}; +use self::template::{CompiledChatTemplate, TemplateContext}; +use super::{ChatRenderer, RenderedPrompt}; +use crate::error::Result; +use crate::request::{ChatContent, ChatContentPart, ChatMessage, ChatRequest}; +use crate::{ + AssistantContentBlock, AssistantMessageExt, ChatTool, Error, LoadModelBackendsOptions, +}; + +mod error; +mod format; +mod template; +mod tojson; + +pub use template::{load_chat_template, resolve_chat_template}; + +pub use self::format::ChatTemplateContentFormatOption; + +#[derive(Debug, Clone)] +pub struct MultimodalRenderInfo { + pub placeholder_token: String, +} + +/// Hugging Face chat-template renderer backed by the local Jinja chat-template +/// state. +pub struct HfChatRenderer { + default_template: Option, + default_template_kwargs: HashMap, + content_format: ContentFormatOption, + special_tokens: Option, + multimodal: Option, +} + +impl HfChatRenderer { + /// Create a renderer from the given template string. + pub fn new( + template: Option, + default_template_kwargs: HashMap, + content_format: ContentFormatOption, + ) -> Result { + Ok(Self { + default_template: template + .map(|template| { + CompiledChatTemplate::new(template, content_format) + .map_err(|error| Error::ChatTemplate(error.to_report_string())) + }) + .transpose()?, + default_template_kwargs, + content_format, + special_tokens: None, + multimodal: None, + }) + } + + pub fn with_special_tokens(mut self, special_tokens: Option) -> Self { + self.special_tokens = special_tokens; + self + } + + pub fn with_multimodal(mut self, multimodal: Option) -> Self { + self.multimodal = multimodal; + self + } + + /// Create a renderer from the given model files and loading options. + pub fn load( + files: &ResolvedModelFiles, + options: LoadModelBackendsOptions, + multimodal: Option, + ) -> Result { + let HfTokenizerConfig { + special_tokens, + chat_template, + .. + } = load_tokenizer_config(files.tokenizer_config_path.as_deref())?; + let mut template = chat_template; + let special_tokens = (!special_tokens.is_empty()).then_some(special_tokens); + + if let Some(configured_template) = options.chat_template.as_deref() { + template = Some( + resolve_chat_template(configured_template) + .map_err(|error| Error::ChatTemplate(error.to_report_string()))?, + ); + info!("using configured chat template override"); + } else if let Some(chat_template_path) = files.chat_template_path.as_deref() { + // If independent chat template file(s) exist and contain non-empty content, + // they take priority over template entries in the tokenizer config + let file_template = load_chat_template(chat_template_path) + .map_err(|error| Error::ChatTemplate(error.to_report_string()))?; + + if file_template.as_ref().is_some_and(|t| !t.trim().is_empty()) { + info!( + path = %chat_template_path.display(), + "loaded dedicated chat template file, overriding tokenizer_config chat_template" + ); + template = file_template; + } else { + warn!( + path = %chat_template_path.display(), + "ignoring empty dedicated chat template file and falling back to tokenizer_config chat_template" + ); + } + } + + Ok(Self::new( + template, + options.default_chat_template_kwargs, + options.chat_template_content_format, + )? + .with_special_tokens(special_tokens) + .with_multimodal(multimodal)) + } + + /// Apply the chat template to one chat request, rendering the prompt string + /// to be tokenized and submitted to the model. + /// + /// If the request carries a per-request `chat_template` override, a + /// temporary template is compiled from that string and used instead of + /// the model's default. + fn apply_chat_template(&self, request: &ChatRequest) -> Result { + let override_template = request + .chat_options + .chat_template + .as_ref() + .map(|template| { + CompiledChatTemplate::new(template.clone(), self.content_format) + .map_err(|error| Error::ChatTemplate(error.to_report_string())) + }) + .transpose()?; + let template = override_template + .as_ref() + .or(self.default_template.as_ref()) + .ok_or(Error::MissingChatTemplate)?; + + self.apply_chat_template_inner(template, request) + } + + fn apply_chat_template_inner( + &self, + effective_template: &CompiledChatTemplate, + request: &ChatRequest, + ) -> Result { + let messages = to_template_messages( + &request.messages, + effective_template.content_format(), + self.multimodal.as_ref(), + )?; + let tools = request.tool_parsing_enabled().then(|| to_template_tools(&request.tools)); + trace!( + message_count = messages.len(), + content_format = ?effective_template.content_format(), + ?messages, + ?tools, + "applying chat template" + ); + + let mut merged_template_kwargs = self.default_template_kwargs.clone(); + merged_template_kwargs.extend(request.chat_options.template_kwargs.clone()); + let prompt = effective_template + .apply(TemplateContext { + messages: &messages, + add_generation_prompt: request.chat_options.add_generation_prompt(), + continue_final_message: request.chat_options.continue_final_message(), + tools: tools.as_deref(), + documents: request.documents.as_deref(), + template_kwargs: Some(&merged_template_kwargs), + special_tokens: self.special_tokens.as_ref(), + reasoning_effort: request.chat_options.reasoning_effort, + }) + .map_err(|error| Error::ChatTemplate(error.to_report_string()))?; + + trace!( + prompt_len = prompt.len(), + prompt, "rendered chat template prompt" + ); + + Ok(RenderedPrompt { + prompt: Prompt::Text(prompt), + }) + } +} + +impl ChatRenderer for HfChatRenderer { + fn render(&self, request: &ChatRequest) -> Result { + self.apply_chat_template(request) + } +} + +/// Chat message in the JSON shape expected by Jinja chat templates. +// TODO: borrow more fields directly from the original `ChatMessage`. +#[serde_with::skip_serializing_none] +#[derive(Debug, Serialize)] +struct TemplateMessage { + role: &'static str, + content: TemplateContent, + // Developer-role messages may provide message-local tools in the same shape + // as top-level request tools. + tools: Option>, + // Reasoning-capable HF templates are inconsistent on the exact field name, + // so expose both variants for compatibility. + reasoning: Option, + reasoning_content: Option, + // Function-call-capable templates commonly expect assistant tool calls + // under this OpenAI-compatible field name. + tool_calls: Option>, + // Tool-role messages refer back to the assistant call they are answering. + tool_call_id: Option, +} + +/// Chat content in the two shapes HF templates commonly expect. +#[derive(Debug, Serialize)] +#[serde(untagged)] +enum TemplateContent { + String(String), + OpenAi(Vec), +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum TemplateContentPart { + Text { text: String }, + Image, +} + +#[derive(Debug, Serialize)] +struct TemplateToolCall { + id: String, + r#type: &'static str, // always "function" + function: TemplateToolFunction, +} + +#[derive(Debug, Serialize)] +struct TemplateToolFunction { + name: String, + arguments: Value, +} + +#[derive(Debug, Serialize)] +pub(super) struct TemplateTool { + #[serde(rename = "type")] + tool_type: &'static str, + function: TemplateToolDefinition, +} + +#[derive(Debug, Serialize)] +struct TemplateToolDefinition { + name: String, + description: Option, + parameters: Value, + strict: Option, +} + +/// Convert chat messages into the JSON shape expected by Jinja chat templates. +fn to_template_messages( + messages: &[ChatMessage], + content_format: ChatTemplateContentFormat, + multimodal: Option<&MultimodalRenderInfo>, +) -> Result> { + messages + .iter() + .map(|message| to_template_message(message, content_format, multimodal)) + .collect() +} + +fn to_template_message( + message: &ChatMessage, + content_format: ChatTemplateContentFormat, + multimodal: Option<&MultimodalRenderInfo>, +) -> Result { + Ok(match message { + ChatMessage::System { content } => TemplateMessage { + role: "system", + content: to_template_content(content, content_format, multimodal)?, + tools: None, + reasoning: None, + reasoning_content: None, + tool_calls: None, + tool_call_id: None, + }, + ChatMessage::Developer { content, tools } => TemplateMessage { + role: "developer", + content: to_template_content(content, content_format, multimodal)?, + tools: tools.as_deref().map(to_template_tools), + reasoning: None, + reasoning_content: None, + tool_calls: None, + tool_call_id: None, + }, + ChatMessage::User { content } => TemplateMessage { + role: "user", + content: to_template_content(content, content_format, multimodal)?, + tools: None, + reasoning: None, + reasoning_content: None, + tool_calls: None, + tool_call_id: None, + }, + ChatMessage::Assistant { content } => { + let text = content.text(); + let reasoning = content.reasoning(); + let tool_calls = to_template_tool_calls(content)?; + let content = + to_template_content(&ChatContent::Text(text), content_format, multimodal)?; + TemplateMessage { + role: "assistant", + content, + tools: None, + reasoning: reasoning.clone(), + reasoning_content: reasoning, + tool_calls, + tool_call_id: None, + } + } + ChatMessage::ToolResponse { + content, + tool_call_id, + } => TemplateMessage { + role: "tool", + content: to_template_content(content, content_format, multimodal)?, + tools: None, + reasoning: None, + reasoning_content: None, + tool_calls: None, + tool_call_id: Some(tool_call_id.clone()), + }, + }) +} + +fn to_template_tool_calls( + content: &[AssistantContentBlock], +) -> Result>> { + let mut tool_calls = Vec::new(); + + for tool_call in content.tool_calls() { + let arguments = serde_json::from_str::(&tool_call.arguments).map_err(|error| { + Error::ChatTemplate(format!( + "assistant tool call `{}` has invalid JSON arguments: {}", + tool_call.id, + error.as_report() + )) + })?; + + tool_calls.push(TemplateToolCall { + id: tool_call.id.clone(), + r#type: "function", + function: TemplateToolFunction { + name: tool_call.name.clone(), + arguments, + }, + }); + } + + Ok((!tool_calls.is_empty()).then_some(tool_calls)) +} + +fn to_template_content( + content: &ChatContent, + content_format: ChatTemplateContentFormat, + multimodal: Option<&MultimodalRenderInfo>, +) -> Result { + Ok(match content_format { + ChatTemplateContentFormat::String => { + TemplateContent::String(to_template_string_content(content, multimodal)?) + } + ChatTemplateContentFormat::OpenAi => { + TemplateContent::OpenAi(to_template_openai_content(content, multimodal)?) + } + }) +} + +fn to_template_openai_content( + content: &ChatContent, + multimodal: Option<&MultimodalRenderInfo>, +) -> Result> { + match content { + ChatContent::Text(text) => Ok(vec![TemplateContentPart::Text { text: text.clone() }]), + ChatContent::Parts(parts) => parts + .iter() + .map(|part| match part { + ChatContentPart::Text { text } => { + Ok(TemplateContentPart::Text { text: text.clone() }) + } + // All multimodal contents are normalized to `{ "type": }`. + ChatContentPart::ImageUrl { .. } => { + multimodal.ok_or(Error::UnsupportedMultimodalContent("image_url"))?; + Ok(TemplateContentPart::Image) + } + }) + .collect(), + } +} + +fn to_template_string_content( + content: &ChatContent, + multimodal: Option<&MultimodalRenderInfo>, +) -> Result { + match content { + ChatContent::Text(text) => Ok(text.clone()), + ChatContent::Parts(parts) => { + let mut out = String::new(); + for part in parts { + match part { + ChatContentPart::Text { text } => out.push_str(text), + ChatContentPart::ImageUrl { .. } => { + let multimodal = + multimodal.ok_or(Error::UnsupportedMultimodalContent("image_url"))?; + out.push_str(&multimodal.placeholder_token); + } + } + } + Ok(out) + } + } +} + +fn to_template_tools(tools: &[ChatTool]) -> Vec { + tools + .iter() + .map(|tool| TemplateTool { + tool_type: "function", + function: TemplateToolDefinition { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + strict: tool.strict, + }, + }) + .collect() +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use expect_test::expect; + use serde_json::Value; + use vllm_text::Prompt; + use vllm_text::backend::hf::{HfSpecialTokens, NamedSpecialToken}; + + use super::{ChatTemplateContentFormatOption, HfChatRenderer, MultimodalRenderInfo}; + use crate::request::{ + ChatContentPart, ChatMessage, ChatRequest, ChatRole, ChatTool, ChatToolChoice, + GenerationPromptMode, ReasoningEffort, + }; + use crate::{AssistantContentBlock, ChatRenderer, Error, Result}; + + const QWEN3_0_6B_TEMPLATE: &str = include_str!("../../../tests/templates/qwen3.jinja"); + const QWEN3_5_0_8B_TEMPLATE: &str = include_str!("../../../tests/templates/qwen35.jinja"); + + fn sample_request(messages: Vec) -> ChatRequest { + ChatRequest { + messages, + request_id: "render-test".to_string(), + ..ChatRequest::for_test() + } + } + + fn render(template: Option<&str>, request: &ChatRequest) -> Result { + HfChatRenderer::new( + template.map(str::to_owned), + HashMap::new(), + ChatTemplateContentFormatOption::Auto, + )? + .render(request)? + .prompt + .into_text() + .map_err(|_| unreachable!("HF renderer should return text prompt")) + } + + fn render_mm( + template: &str, + request: &ChatRequest, + content_format: ChatTemplateContentFormatOption, + ) -> Result { + HfChatRenderer::new(Some(template.to_string()), HashMap::new(), content_format)? + .with_multimodal(Some(MultimodalRenderInfo { + placeholder_token: "".to_string(), + })) + .render(request) + } + + fn image_request() -> ChatRequest { + sample_request(vec![ChatMessage::user(vec![ + ChatContentPart::text("a"), + ChatContentPart::image_url("data:image/png;base64,test"), + ChatContentPart::text("b"), + ])]) + } + + #[test] + fn string_content_format_replaces_image_with_placeholder_text() { + let rendered = render_mm( + "{{ messages[0].content }}", + &image_request(), + ChatTemplateContentFormatOption::String, + ) + .unwrap(); + + assert_eq!(rendered.prompt, Prompt::Text("ab".to_string())); + } + + #[test] + fn openai_content_format_normalizes_image_url_for_template() { + let rendered = render_mm( + "{% for item in messages[0].content %}{% if item.type == 'image' %}<|image_pad|>{% else %}{{ item.text }}{% endif %}{% endfor %}", + &image_request(), + ChatTemplateContentFormatOption::OpenAi, + ) + .unwrap(); + + assert_eq!(rendered.prompt, Prompt::Text("a<|image_pad|>b".to_string())); + } + + #[test] + fn chat_template_supports_pycompat_templates() { + let request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + + let rendered = render( + Some( + "{% for message in messages %}{% if message.content.startswith('') %}think{% else %}plain{% endif %}{% endfor %}", + ), + &request, + ) + .unwrap(); + + assert_eq!(rendered, "think"); + } + + #[test] + fn chat_template_passes_continue_final_message_to_template() { + let mut request = sample_request(vec![ChatMessage::text( + ChatRole::Assistant, + "The capital of", + )]); + + assert_eq!( + render( + Some("{% if continue_final_message %}continue{% else %}new{% endif %}"), + &request, + ) + .unwrap(), + "new" + ); + + request.chat_options.generation_prompt_mode = GenerationPromptMode::ContinueFinalAssistant; + + assert_eq!( + render( + Some("{% if continue_final_message %}continue{% else %}new{% endif %}"), + &request, + ) + .unwrap(), + "continue" + ); + } + + #[test] + fn chat_template_flattens_text_parts_for_string_templates() { + let request = sample_request(vec![ChatMessage::user(vec![ + ChatContentPart::text("hello"), + ChatContentPart::text(" world"), + ])]); + + let rendered = render(Some("{{ messages[0].content }}"), &request).unwrap(); + + assert_eq!(rendered, "hello world"); + } + + #[test] + fn chat_template_exposes_developer_tools() { + let request = sample_request(vec![ChatMessage::developer( + "policy", + Some(vec![ChatTool { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: serde_json::json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }), + strict: Some(true), + }]), + )]); + + let rendered = render( + Some("{{ messages[0].role }}|{{ messages[0].content }}|{{ messages[0].tools[0].function.name }}|{{ messages[0].tools[0].function.parameters.required[0] }}"), + &request, + ) + .unwrap(); + + assert_eq!(rendered, "developer|policy|get_weather|city"); + } + + #[test] + fn chat_template_keeps_string_text_for_openai_detected_templates() { + let request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + + let rendered = render( + Some( + "{%- for message in messages %}{%- if message.content is string %}{%- set content = message.content %}{{ content }}{%- endif %}{%- endfor %}", + ), + &request, + ) + .unwrap(); + + assert_eq!(rendered, "hello"); + } + + #[test] + fn chat_template_emits_openai_text_blocks_for_structured_templates() { + let request = sample_request(vec![ChatMessage::user(vec![ + ChatContentPart::text("hello"), + ChatContentPart::text("world"), + ])]); + + let rendered = render( + Some( + "{%- for message in messages %}{%- for item in message.content %}{{ item.text }}|{%- endfor %}{%- endfor %}", + ), + &request, + ) + .unwrap(); + + assert_eq!(rendered, "hello|world|"); + } + + #[test] + fn chat_template_per_request_override() { + let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + + // Default template renders one way. + let default_rendered = render(Some("{{ messages[0].content }}"), &request).unwrap(); + assert_eq!(default_rendered, "hello"); + + // Per-request override replaces the default template entirely. + request.chat_options.chat_template = Some("override:{{ messages[0].content }}".to_string()); + let overridden = render(Some("{{ messages[0].content }}"), &request).unwrap(); + assert_eq!(overridden, "override:hello"); + } + + #[test] + fn chat_template_per_request_override_without_default_template() { + let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + request.chat_options.chat_template = Some("override:{{ messages[0].content }}".to_string()); + + let rendered = render(None, &request).unwrap(); + + assert_eq!(rendered, "override:hello"); + } + + #[test] + fn chat_template_requires_a_template() { + let request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + let error = render(None, &request).unwrap_err(); + + assert!(matches!(error, Error::MissingChatTemplate)); + } + + #[test] + fn chat_template_injects_special_tokens_into_context() { + let request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + let special_tokens = HfSpecialTokens { + bos_token: Some(NamedSpecialToken::Text("".to_string())), + ..Default::default() + }; + + let rendered = HfChatRenderer::new( + Some("{{ bos_token }}|{{ bos_token is defined }}".to_string()), + HashMap::new(), + ChatTemplateContentFormatOption::Auto, + ) + .unwrap() + .with_special_tokens(Some(special_tokens)) + .apply_chat_template(&request) + .unwrap(); + + assert_eq!(rendered.prompt, Prompt::Text("|true".to_string())); + } + + #[test] + fn chat_template_exposes_assistant_reasoning_separately() { + let request = sample_request(vec![ChatMessage::assistant_blocks(vec![ + AssistantContentBlock::Reasoning { + text: "inner".to_string(), + }, + AssistantContentBlock::Text { + text: "outer".to_string(), + }, + ])]); + + let rendered = render( + Some("{{ messages[0].reasoning_content }}|{{ messages[0].content }}"), + &request, + ) + .unwrap(); + + assert_eq!(rendered, "inner|outer"); + } + + #[test] + fn chat_template_forces_string_content_format_when_configured() { + let request = sample_request(vec![ChatMessage::user(vec![ + ChatContentPart::text("hello"), + ChatContentPart::text(" world"), + ])]); + + let rendered = HfChatRenderer::new( + Some( + "{%- if messages[0].content is string -%}{{ messages[0].content }}{%- else -%}{%- for item in messages[0].content %}{{ item.text }}|{%- endfor -%}{%- endif -%}".to_string(), + ), + HashMap::new(), + ChatTemplateContentFormatOption::String, + ) + .unwrap() + .render(&request) + .unwrap() + .prompt; + + assert_eq!(rendered, Prompt::Text("hello world".to_string())); + } + + #[test] + fn chat_template_forces_openai_content_format_when_configured() { + let request = sample_request(vec![ChatMessage::user(vec![ + ChatContentPart::text("hello"), + ChatContentPart::text(" world"), + ])]); + + let rendered = HfChatRenderer::new( + Some("{{ messages[0].content[0].text }}{{ messages[0].content[1].text }}".to_string()), + HashMap::new(), + ChatTemplateContentFormatOption::OpenAi, + ) + .unwrap() + .render(&request) + .unwrap() + .prompt; + + assert_eq!(rendered, Prompt::Text("hello world".to_string())); + } + + #[test] + fn chat_template_merges_default_template_kwargs_before_request_kwargs() { + let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + request + .chat_options + .template_kwargs + .insert("enable_thinking".to_string(), Value::Bool(true)); + + let renderer = HfChatRenderer::new( + Some("{{ enable_thinking }}|{{ default_only }}".to_string()), + HashMap::from([ + ("enable_thinking".to_string(), Value::Bool(false)), + ("default_only".to_string(), Value::String("x".to_string())), + ]), + ChatTemplateContentFormatOption::Auto, + ) + .unwrap(); + + let rendered = renderer.render(&request).unwrap().prompt; + + assert_eq!(rendered, Prompt::Text("true|x".to_string())); + } + + #[test] + fn chat_template_reasoning_effort_overrides_template_kwargs() { + let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + request.chat_options.reasoning_effort = Some(ReasoningEffort::Max); + request.chat_options.template_kwargs.insert( + "reasoning_effort".to_string(), + Value::String("low".to_string()), + ); + + let renderer = HfChatRenderer::new( + Some("{{ reasoning_effort }}".to_string()), + HashMap::from([( + "reasoning_effort".to_string(), + Value::String("medium".to_string()), + )]), + ChatTemplateContentFormatOption::Auto, + ) + .unwrap(); + + let rendered = renderer.render(&request).unwrap().prompt; + + assert_eq!(rendered, Prompt::Text("max".to_string())); + } + + #[test] + fn qwen3_template_omits_reasoning_for_historical_assistant_messages() { + let request = sample_request(vec![ + ChatMessage::text( + ChatRole::User, + "Hi. Tell me about the capital of France in short", + ), + ChatMessage::assistant_blocks(vec![ + AssistantContentBlock::Reasoning { + text: "\nOkay, the user is asking... I think that's all.\n".to_string(), + }, + AssistantContentBlock::Text { + text: "Paris is the capital of France.".to_string(), + }, + ]), + ChatMessage::text(ChatRole::User, "Tell me about Paris more."), + ]); + + let rendered = render(Some(QWEN3_0_6B_TEMPLATE), &request).unwrap(); + + expect![[r#" + <|im_start|>user + Hi. Tell me about the capital of France in short<|im_end|> + <|im_start|>assistant + Paris is the capital of France.<|im_end|> + <|im_start|>user + Tell me about Paris more.<|im_end|> + <|im_start|>assistant + "#]] + .assert_eq(&rendered); + } + + #[test] + fn qwen3_template_keeps_reasoning_after_the_last_user_query() { + let mut request = sample_request(vec![ + ChatMessage::text(ChatRole::User, "What is 1 + 1?"), + ChatMessage::assistant_blocks(vec![ + AssistantContentBlock::Reasoning { + text: "need simple arithmetic".to_string(), + }, + AssistantContentBlock::Text { + text: "2".to_string(), + }, + ]), + ]); + request.chat_options.generation_prompt_mode = GenerationPromptMode::NoGenerationPrompt; + + let rendered = render(Some(QWEN3_0_6B_TEMPLATE), &request).unwrap(); + + expect![[r#" + <|im_start|>user + What is 1 + 1?<|im_end|> + <|im_start|>assistant + + need simple arithmetic + + + 2<|im_end|> + "#]] + .assert_eq(&rendered); + } + + #[test] + fn chat_template_exposes_tools_to_templates_when_auto_enabled() { + let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + request.tools = vec![ChatTool { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: serde_json::json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }), + strict: None, + }]; + request.tool_choice = ChatToolChoice::Auto; + + let rendered = render( + Some("{{ tools[0].function.name }}|{{ tools[0].function.parameters.required[0] }}"), + &request, + ) + .unwrap(); + + assert_eq!(rendered, "get_weather|city"); + } + + #[test] + fn chat_template_exposes_assistant_tool_calls_and_tool_messages() { + let request = sample_request(vec![ + ChatMessage::assistant_blocks(vec![AssistantContentBlock::ToolCall( + crate::AssistantToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: r#"{"city":"Paris"}"#.to_string(), + }, + )]), + ChatMessage::tool_response("Sunny", "call_1"), + ]); + + let rendered = render( + Some( + "{{ messages[0].tool_calls[0].function.name }}|{{ messages[0].tool_calls[0].function.arguments.city }}|{{ messages[1].tool_call_id }}|{{ messages[1].content }}", + ), + &request, + ) + .unwrap(); + + assert_eq!(rendered, "get_weather|Paris|call_1|Sunny"); + } + + #[test] + fn qwen35_template_renders_prefilled_reasoning_start_when_thinking_enabled() { + let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + request + .chat_options + .template_kwargs + .insert("enable_thinking".to_string(), Value::Bool(true)); + + let rendered = render(Some(QWEN3_5_0_8B_TEMPLATE), &request).unwrap(); + + expect![[r#" + <|im_start|>user + hello<|im_end|> + <|im_start|>assistant + + "#]] + .assert_eq(&rendered); + } + + #[test] + fn qwen35_template_renders_closed_empty_reasoning_span_when_thinking_disabled() { + let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + request + .chat_options + .template_kwargs + .insert("enable_thinking".to_string(), Value::Bool(false)); + + let rendered = render(Some(QWEN3_5_0_8B_TEMPLATE), &request).unwrap(); + + expect![[r#" + <|im_start|>user + hello<|im_end|> + <|im_start|>assistant + + + + + "#]] + .assert_eq(&rendered); + } + + #[test] + fn qwen35_template_omits_assistant_reasoning_prefill_without_generation_prompt() { + let mut request = sample_request(vec![ChatMessage::text(ChatRole::User, "hello")]); + request.chat_options.generation_prompt_mode = GenerationPromptMode::NoGenerationPrompt; + request + .chat_options + .template_kwargs + .insert("enable_thinking".to_string(), Value::Bool(true)); + + let rendered = render(Some(QWEN3_5_0_8B_TEMPLATE), &request).unwrap(); + + expect![[r#" + <|im_start|>user + hello<|im_end|> + "#]] + .assert_eq(&rendered); + } +} diff --git a/rust/src/chat/src/renderer/hf/template.rs b/rust/src/chat/src/renderer/hf/template.rs new file mode 100644 index 00000000000..b71efc1e53e --- /dev/null +++ b/rust/src/chat/src/renderer/hf/template.rs @@ -0,0 +1,316 @@ +//! Chat template support for tokenizers using Jinja2 templates. +//! +//! This module is inlined from SMG's tokenizer crate with local adaptations: +//! - thinking-related detection/state is removed +//! - special tokens are wired to `vllm_text::backends::hf::HfSpecialTokens` + +use std::collections::HashMap; +use std::fs; +use std::path::Path; + +use minijinja::Environment; +use serde::{Deserialize, Serialize}; +use serde_json::{self}; +use vllm_text::backend::hf::HfSpecialTokens; + +use super::error::TemplateError; +use super::format::{ + ChatTemplateContentFormat, ChatTemplateContentFormatOption, detect_chat_template_content_format, +}; +use super::tojson::hf_tojson_filter; +use crate::renderer::hf::{TemplateMessage, TemplateTool}; +use crate::request::ReasoningEffort; + +type Result = std::result::Result; + +/// Build a pre-configured environment with the given template string. +fn build_environment(template: String) -> Result> { + let mut env = Environment::new(); + + env.set_trim_blocks(true); + env.set_lstrip_blocks(true); + + env.add_template_owned("chat".to_owned(), template)?; + + env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback); + env.add_filter("tojson", hf_tojson_filter); + + Ok(env) +} + +#[serde_with::skip_serializing_none] +#[derive(Default, Serialize)] +pub(super) struct TemplateContext<'a> { + pub(super) messages: &'a [TemplateMessage], + pub(super) add_generation_prompt: bool, + pub(super) continue_final_message: bool, + pub(super) tools: Option<&'a [TemplateTool]>, + pub(super) documents: Option<&'a [serde_json::Value]>, + #[serde(flatten)] + pub(super) special_tokens: Option<&'a HfSpecialTokens>, + #[serde(flatten)] + pub(super) template_kwargs: Option<&'a HashMap>, + // By putting top-level `reasoning_effort` after `template_kwargs`, this overrides any + // `reasoning_effort` value that might be present there. + pub(super) reasoning_effort: Option, +} + +/// Load chat template from a file (`.jinja` or `.json` containing Jinja). +pub fn load_chat_template(template_path: &Path) -> Result> { + let content = fs::read_to_string(template_path).map_err(TemplateError::ReadTemplateFile)?; + + if template_path.extension().is_some_and(|ext| ext == "json") { + #[derive(Deserialize)] + #[serde(untagged)] + enum ChatTemplateFile { + String(String), + Object { chat_template: String }, + } + + let json_value = + serde_json::from_str(&content).map_err(TemplateError::ParseTemplateJson)?; + let json_template = + serde_json::from_value(json_value).map_err(|_| TemplateError::InvalidTemplateJson)?; + + return Ok(Some(match json_template { + ChatTemplateFile::String(template) => template, + ChatTemplateFile::Object { chat_template } => chat_template, + })); + } + + let template = content.trim().replace("\\n", "\n"); + Ok(Some(template)) +} + +/// Resolve a configured chat template value into a template string. +pub fn resolve_chat_template(chat_template: &str) -> Result { + let path = Path::new(chat_template); + if path.exists() { + return load_chat_template(path).map(|template| template.unwrap_or_default()); + } + + const JINJA_CHARS: [char; 3] = ['{', '}', '\n']; + if chat_template.chars().any(|c| JINJA_CHARS.contains(&c)) { + return Ok(chat_template.to_string()); + } + + Err(TemplateError::MissingTemplatePath) +} + +/// One compiled chat template with its Jinja environment and detected content +/// format. +pub(super) struct CompiledChatTemplate { + /// Cached, fully-configured environment for one compiled template. + env: Environment<'static>, + content_format: ChatTemplateContentFormat, +} + +impl CompiledChatTemplate { + /// Compile the given chat template string into a [`CompiledChatTemplate`]. + pub fn new(template: String, content_format: ChatTemplateContentFormatOption) -> Result { + let content_format = match content_format { + ChatTemplateContentFormatOption::Auto => detect_chat_template_content_format(&template), + ChatTemplateContentFormatOption::String => ChatTemplateContentFormat::String, + ChatTemplateContentFormatOption::OpenAi => ChatTemplateContentFormat::OpenAi, + }; + let env = build_environment(template)?; + Ok(Self { + env, + content_format, + }) + } + + /// Apply the compiled template to the given context and return the rendered + /// prompt. + pub fn apply(&self, ctx: TemplateContext<'_>) -> Result { + let tmpl = self.env.get_template("chat")?; + tmpl.render(ctx).map_err(TemplateError::from) + } + + pub fn content_format(&self) -> ChatTemplateContentFormat { + self.content_format + } +} + +#[cfg(test)] +mod tests { + use std::fs; + + use tempfile::TempDir; + use vllm_text::backend::hf::{HfSpecialTokens, NamedSpecialToken}; + + use super::*; + + #[test] + fn test_chat_template_state_valid_template() { + let template = CompiledChatTemplate::new( + "{{ messages }}".to_string(), + ChatTemplateContentFormatOption::Auto, + ) + .unwrap(); + assert_eq!(template.content_format(), ChatTemplateContentFormat::String); + let result = template.apply(TemplateContext::default()).unwrap(); + assert_eq!(result, "[]"); + } + + #[test] + fn test_chat_template_state_invalid_template() { + let result = CompiledChatTemplate::new( + "{% invalid".to_string(), + ChatTemplateContentFormatOption::Auto, + ); + assert!(result.is_err()); + let err = result.err().unwrap().to_string(); + assert!( + err.contains("failed to render jinja template"), + "Error should explain parse failure, got: {err}" + ); + } + + #[test] + fn test_special_tokens_injected_into_context() { + let template = "{{ bos_token }}hello{{ eos_token }}"; + let template = + CompiledChatTemplate::new(template.to_string(), ChatTemplateContentFormatOption::Auto) + .unwrap(); + + let special_tokens = HfSpecialTokens { + bos_token: Some(NamedSpecialToken::Text("".to_string())), + eos_token: Some(NamedSpecialToken::Text("".to_string())), + ..Default::default() + }; + + let result = template + .apply(TemplateContext { + special_tokens: Some(&special_tokens), + ..Default::default() + }) + .unwrap(); + + assert_eq!(result, "hello"); + } + + #[test] + fn test_special_tokens_undefined_when_not_provided() { + let template = "{% if bos_token is defined %}{{ bos_token }}{% endif %}hello"; + let template = + CompiledChatTemplate::new(template.to_string(), ChatTemplateContentFormatOption::Auto) + .unwrap(); + + let result = template.apply(TemplateContext::default()).unwrap(); + assert_eq!(result, "hello"); + } + + #[test] + fn test_special_tokens_partial() { + let template = + "{{ bos_token }}hello{% if eos_token is defined %}{{ eos_token }}{% endif %}"; + let template = + CompiledChatTemplate::new(template.to_string(), ChatTemplateContentFormatOption::Auto) + .unwrap(); + + let special_tokens = HfSpecialTokens { + bos_token: Some(NamedSpecialToken::Text("".to_string())), + eos_token: None, + ..Default::default() + }; + + let result = template + .apply(TemplateContext { + special_tokens: Some(&special_tokens), + ..Default::default() + }) + .unwrap(); + + assert_eq!(result, "hello"); + } + + #[test] + fn test_tojson_filter_supports_indent_and_sort_keys() { + let template = CompiledChatTemplate::new( + "{{ payload | tojson(indent=2, sort_keys=true) }}".to_string(), + ChatTemplateContentFormatOption::Auto, + ) + .unwrap(); + let mut kwargs = HashMap::new(); + kwargs.insert("payload".to_string(), serde_json::json!({"b": 1, "a": 2})); + + let result = template + .apply(TemplateContext { + template_kwargs: Some(&kwargs), + ..Default::default() + }) + .unwrap(); + + assert_eq!(result, "{\n \"a\": 2,\n \"b\": 1\n}"); + } + + #[test] + fn test_load_chat_template_from_file_jinja() { + let dir = TempDir::new().unwrap(); + let path = dir.path().join("chat_template.jinja"); + fs::write(&path, "{{ messages }}").unwrap(); + + let template = load_chat_template(&path).unwrap(); + + assert_eq!(template.as_deref(), Some("{{ messages }}")); + } + + #[test] + fn test_resolve_chat_template_from_inline_literal() { + let template = resolve_chat_template("{{ messages }}").unwrap(); + + assert_eq!(template, "{{ messages }}"); + } + + #[test] + fn test_resolve_chat_template_from_existing_file() { + let dir = TempDir::new().unwrap(); + let path = dir.path().join("chat_template.jinja"); + fs::write(&path, "{{ messages }}").unwrap(); + + let template = resolve_chat_template(path.to_str().unwrap()).unwrap(); + + assert_eq!(template, "{{ messages }}"); + } + + #[test] + fn test_resolve_chat_template_rejects_missing_path_like_value() { + let error = resolve_chat_template("missing_template.jinja").unwrap_err(); + + assert!(matches!(error, TemplateError::MissingTemplatePath)); + } + + #[test] + fn test_chat_template_state_respects_explicit_content_format_override() { + let template = CompiledChatTemplate::new( + "{% for item in messages[0].content %}{{ item.text }}{% endfor %}".to_string(), + ChatTemplateContentFormatOption::String, + ) + .unwrap(); + + assert_eq!(template.content_format(), ChatTemplateContentFormat::String); + } + + #[test] + fn test_load_chat_template_from_file_json_string() { + let dir = TempDir::new().unwrap(); + let path = dir.path().join("chat_template.json"); + fs::write(&path, "\"{{ messages }}\"").unwrap(); + + let template = load_chat_template(&path).unwrap(); + + assert_eq!(template.as_deref(), Some("{{ messages }}")); + } + + #[test] + fn test_load_chat_template_from_file_json_object() { + let dir = TempDir::new().unwrap(); + let path = dir.path().join("chat_template.json"); + fs::write(&path, r#"{"chat_template":"{{ messages }}"}"#).unwrap(); + + let template = load_chat_template(&path).unwrap(); + + assert_eq!(template.as_deref(), Some("{{ messages }}")); + } +} diff --git a/rust/src/chat/src/renderer/hf/tojson.rs b/rust/src/chat/src/renderer/hf/tojson.rs new file mode 100644 index 00000000000..1c5c20f4765 --- /dev/null +++ b/rust/src/chat/src/renderer/hf/tojson.rs @@ -0,0 +1,277 @@ +use minijinja::value::{Kwargs, ViaDeserialize}; +use minijinja::{Error as MinijinjaError, ErrorKind, Value}; +use serde::Deserialize; +use serde_json::{self, Value as JsonValue}; +use serde_json_fmt::{JsonFormat, JsonSyntaxError}; +use thiserror_ext::AsReport; + +/// Hugging Face-compatible `tojson` filter for chat templates. +/// +/// We cannot use MiniJinja's built-in filter directly because HF relies on +/// Python `json.dumps` semantics: +/// - no HTML escaping +/// - extra kwargs such as `ensure_ascii`, `separators`, and `sort_keys` +/// - Python-style `indent` handling +pub(super) fn hf_tojson_filter( + value: Value, + kwargs: Kwargs, +) -> std::result::Result { + let ensure_ascii = kwargs.get::>("ensure_ascii")?.unwrap_or(false); + let indent = parse_indent( + kwargs.get::>>("indent")?.map(|value| value.0), + ); + let separators = parse_separators( + kwargs + .get::>>("separators")? + .map(|value| value.0), + indent.is_some(), + ); + let sort_keys = kwargs.get::>("sort_keys")?.unwrap_or(false); + + kwargs.assert_all_used()?; + + let json_value: serde_json::Value = serde_json::to_value(&value).map_err(|e| { + MinijinjaError::new( + ErrorKind::InvalidOperation, + format!("Failed to convert to JSON value: {e}"), + ) + })?; + + let json_str = { + let value_to_serialize = if sort_keys { + &sort_json_keys(&json_value) + } else { + &json_value + }; + + build_json_format(indent, separators.0, separators.1, ensure_ascii)? + .format_to_string(value_to_serialize) + .map_err(|e| { + MinijinjaError::new( + ErrorKind::InvalidOperation, + format!("Failed to serialize JSON: {}", e.as_report()), + ) + })? + }; + + Ok(Value::from_safe_string(json_str)) +} + +#[derive(Deserialize)] +#[serde(untagged)] +enum IndentArg { + // Python `json.dumps` accepts bool, int, and string indentation styles. + Bool(bool), + Integer(i64), + String(String), +} + +fn parse_indent(value: Option) -> Option { + match value? { + IndentArg::Bool(indent) => Some(if indent { + " ".to_owned() + } else { + String::new() + }), + IndentArg::Integer(indent) => Some(if indent > 0 { + " ".repeat(indent as usize) + } else { + String::new() + }), + IndentArg::String(indent) => Some(indent), + } +} + +#[derive(Deserialize)] +struct SeparatorsArg((String, String)); + +fn parse_separators(value: Option, pretty: bool) -> (String, String) { + let Some(SeparatorsArg((item_separator, key_separator))) = value else { + let default_item_separator = if pretty { "," } else { ", " }; + let default_key_separator = ": "; + + return ( + default_item_separator.to_owned(), + default_key_separator.to_owned(), + ); + }; + + (item_separator, key_separator) +} + +fn build_json_format( + indent: Option, + item_separator: String, + key_separator: String, + ensure_ascii: bool, +) -> std::result::Result { + JsonFormat::new() + .indent(indent) + .map_err(map_json_syntax_error("indent"))? + .comma(item_separator) + .map_err(map_json_syntax_error("separators (item)"))? + .colon(key_separator) + .map_err(map_json_syntax_error("separators (key)")) + .map(|format| format.ascii(ensure_ascii)) +} + +fn map_json_syntax_error( + field: &'static str, +) -> impl FnOnce(JsonSyntaxError) -> MinijinjaError + Copy { + move |error| { + MinijinjaError::new( + ErrorKind::InvalidOperation, + format!("invalid {field} value for tojson: {error}"), + ) + } +} + +/// Recursively sort all object keys in a JSON value. +fn sort_json_keys(value: &JsonValue) -> JsonValue { + match value { + JsonValue::Object(map) => { + let mut sorted: serde_json::Map = serde_json::Map::new(); + let mut keys: Vec<_> = map.keys().collect(); + keys.sort(); + for key in keys { + sorted.insert(key.clone(), sort_json_keys(&map[key])); + } + JsonValue::Object(sorted) + } + JsonValue::Array(arr) => JsonValue::Array(arr.iter().map(sort_json_keys).collect()), + _ => value.clone(), + } +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use minijinja::Environment; + use serde_json::json; + use thiserror_ext::AsReport; + + use super::hf_tojson_filter; + + fn render(template: &str, payload: serde_json::Value) -> String { + let mut env = Environment::new(); + env.add_filter("tojson", hf_tojson_filter); + env.render_str(template, json!({ "payload": payload })).unwrap() + } + + fn render_error(template: &str, payload: serde_json::Value) -> minijinja::Error { + let mut env = Environment::new(); + env.add_filter("tojson", hf_tojson_filter); + env.render_str(template, json!({ "payload": payload })).unwrap_err() + } + + #[test] + fn tojson_does_not_html_escape_like_minijinja_builtin() { + let rendered = render("{{ payload|tojson }}", json!("&'")); + assert_eq!(rendered, "\"&'\""); + } + + #[test] + fn tojson_supports_sort_keys_recursively() { + let rendered = render( + "{{ payload|tojson(sort_keys=true) }}", + json!({ + "z": {"b": 1, "a": 2}, + "a": 0 + }), + ); + + assert_eq!(rendered, "{\"a\": 0, \"z\": {\"a\": 2, \"b\": 1}}"); + } + + #[test] + fn tojson_supports_indent() { + let rendered = render("{{ payload|tojson(indent=2) }}", json!([1, 2])); + + assert_eq!(rendered, "[\n 1,\n 2\n]"); + } + + #[test] + fn tojson_supports_ensure_ascii_false() { + let rendered = render("{{ payload|tojson(ensure_ascii=false) }}", json!("中文")); + assert_eq!(rendered, "\"中文\""); + } + + #[test] + fn tojson_supports_ensure_ascii_true() { + let rendered = render("{{ payload|tojson(ensure_ascii=true) }}", json!("中文")); + assert_eq!(rendered, "\"\\u4e2d\\u6587\""); + } + + #[test] + fn tojson_supports_separators() { + let rendered = render( + "{{ payload|tojson(separators=[',', ':']) }}", + json!({ + "x": [1, 2] + }), + ); + + assert_eq!(rendered, "{\"x\":[1,2]}"); + } + + #[test] + fn tojson_supports_negative_indent_as_newline_only() { + let rendered = render("{{ payload|tojson(indent=-1) }}", json!([1, 2])); + assert_eq!(rendered, "[\n1,\n2\n]"); + } + + #[test] + fn tojson_supports_string_indent() { + let rendered = render("{{ payload|tojson(indent=' ') }}", json!([1, 2])); + assert_eq!(rendered, "[\n 1,\n 2\n]"); + } + + #[test] + fn tojson_supports_boolean_indent() { + let rendered_true = render("{{ payload|tojson(indent=true) }}", json!([1, 2])); + assert_eq!(rendered_true, "[\n 1,\n 2\n]"); + + let rendered_false = render("{{ payload|tojson(indent=false) }}", json!([1, 2])); + assert_eq!(rendered_false, "[\n1,\n2\n]"); + } + + #[test] + fn tojson_combines_indent_sort_keys_separators_and_ensure_ascii() { + let rendered = render( + "{{ payload|tojson(ensure_ascii=true, sort_keys=true, separators=[',', ':'], indent=' ') }}", + json!({ + "b": "<中>", + "a": [1, 2] + }), + ); + + assert_eq!( + rendered, + "{\n \"a\":[\n 1,\n 2\n ],\n \"b\":\"<\\u4e2d>\"\n}" + ); + } + + #[test] + fn tojson_rejects_invalid_indent() { + let error = render_error("{{ payload|tojson(indent='-->') }}", json!({"a": 1})); + expect!["invalid operation: invalid indent value for tojson: string contains unexpected character '-' (in :1)"] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn tojson_rejects_invalid_separator_shape() { + let error = render_error("{{ payload|tojson(separators=':,') }}", json!({"a": 1})); + expect!["cannot deserialize: invalid type: string \":,\", expected a tuple of size 2 (in :1)"] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn tojson_rejects_invalid_key_separator() { + let error = render_error( + "{{ payload|tojson(separators=[',', '=>']) }}", + json!({"a": 1}), + ); + expect!["invalid operation: invalid separators (key) value for tojson: string contains unexpected character '=' (in :1)"] + .assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/chat/src/renderer/mod.rs b/rust/src/chat/src/renderer/mod.rs new file mode 100644 index 00000000000..07ff5d0b6dd --- /dev/null +++ b/rust/src/chat/src/renderer/mod.rs @@ -0,0 +1,31 @@ +use std::sync::Arc; + +use vllm_text::Prompt; + +use crate::error::Result; +use crate::request::ChatRequest; + +pub mod deepseek_v32; +pub mod deepseek_v4; +pub mod hf; +mod selection; + +pub use deepseek_v4::DeepSeekV4ChatRenderer; +pub use deepseek_v32::DeepSeekV32ChatRenderer; +pub use selection::RendererSelection; + +/// Rendered chat prompt submitted to the text backend. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RenderedPrompt { + pub prompt: Prompt, +} + +/// Minimal chat-prompt renderer used by `vllm-chat`. +pub trait ChatRenderer: Send + Sync { + /// Render one chat request into the text prompt submitted to the text + /// backend. + fn render(&self, request: &ChatRequest) -> Result; +} + +/// Shared trait-object form of [`ChatRenderer`]. +pub type DynChatRenderer = Arc; diff --git a/rust/src/chat/src/renderer/selection.rs b/rust/src/chat/src/renderer/selection.rs new file mode 100644 index 00000000000..f4bd565bafd --- /dev/null +++ b/rust/src/chat/src/renderer/selection.rs @@ -0,0 +1,109 @@ +use std::fmt; +use std::str::FromStr; + +use serde_with::DeserializeFromStr; + +/// Specify which chat renderer implementation to use. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, DeserializeFromStr)] +pub enum RendererSelection { + /// Use model-based auto-detection. + #[default] + Auto, + /// Force the generic Hugging Face chat-template renderer. + Hf, + /// Force the DeepSeek V3.2 renderer. + DeepSeekV32, + /// Force the DeepSeek V4 renderer. + DeepSeekV4, +} + +impl RendererSelection { + pub const AUTO_LITERAL: &str = "auto"; + pub const DEEPSEEK_V32_LITERAL: &str = "deepseek_v32"; + pub const DEEPSEEK_V4_LITERAL: &str = "deepseek_v4"; + pub const HF_LITERAL: &str = "hf"; + + /// Resolve the renderer selection using the given model type string, if + /// it's `Auto`. + pub fn resolve(self, model_type: &str) -> Self { + match self { + Self::Auto => match model_type { + Self::DEEPSEEK_V32_LITERAL => Self::DeepSeekV32, + Self::DEEPSEEK_V4_LITERAL => Self::DeepSeekV4, + _ => Self::Hf, + }, + selection => selection, + } + } +} + +impl FromStr for RendererSelection { + type Err = String; + + fn from_str(value: &str) -> Result { + if value.eq_ignore_ascii_case(Self::AUTO_LITERAL) { + Ok(Self::Auto) + } else if value.eq_ignore_ascii_case(Self::HF_LITERAL) { + Ok(Self::Hf) + } else if value.eq_ignore_ascii_case(Self::DEEPSEEK_V32_LITERAL) { + Ok(Self::DeepSeekV32) + } else if value.eq_ignore_ascii_case(Self::DEEPSEEK_V4_LITERAL) { + Ok(Self::DeepSeekV4) + } else { + Err(format!( + "unknown renderer `{value}` (expected one of: auto, hf, deepseek_v32, deepseek_v4)" + )) + } + } +} + +impl fmt::Display for RendererSelection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Auto => f.write_str(Self::AUTO_LITERAL), + Self::Hf => f.write_str(Self::HF_LITERAL), + Self::DeepSeekV32 => f.write_str(Self::DEEPSEEK_V32_LITERAL), + Self::DeepSeekV4 => f.write_str(Self::DEEPSEEK_V4_LITERAL), + } + } +} + +#[cfg(test)] +mod tests { + use super::RendererSelection; + + #[test] + fn renderer_selection_parses_known_values() { + assert_eq!( + "auto".parse::().unwrap(), + RendererSelection::Auto + ); + assert_eq!( + "hf".parse::().unwrap(), + RendererSelection::Hf + ); + assert_eq!( + "deepseek_v32".parse::().unwrap(), + RendererSelection::DeepSeekV32 + ); + assert_eq!( + "deepseek_v4".parse::().unwrap(), + RendererSelection::DeepSeekV4 + ); + } + + #[test] + fn renderer_selection_display_round_trips() { + for selection in [ + RendererSelection::Auto, + RendererSelection::Hf, + RendererSelection::DeepSeekV32, + RendererSelection::DeepSeekV4, + ] { + assert_eq!( + selection.to_string().parse::().unwrap(), + selection + ); + } + } +} diff --git a/rust/src/chat/src/request.rs b/rust/src/chat/src/request.rs new file mode 100644 index 00000000000..c1cb83b8dc3 --- /dev/null +++ b/rust/src/chat/src/request.rs @@ -0,0 +1,662 @@ +use std::collections::HashMap; + +use llm_multimodal::ImageDetail; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +pub use vllm_text::SamplingParams; +use vllm_text::TextDecodeOptions; +pub use vllm_tool_parser::Tool as ChatTool; + +use crate::AssistantMessageExt; +use crate::error::{Error, Result}; +use crate::event::{AssistantContentBlock, AssistantMessage}; + +/// Role label for one text-only chat message. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ChatRole { + System, + Developer, + User, + Assistant, + ToolResponse, +} + +/// One text-only chat content part in OpenAI-style block format. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChatContentPart { + /// One plain-text content block. + Text { text: String }, + /// One image URL/data URL content block. + ImageUrl { + image_url: String, + detail: Option, + uuid: Option, + }, + // ImageData... + // ImageEmbeds... +} + +impl ChatContentPart { + /// Construct one text content part with plain string content. + pub fn text(text: impl Into) -> Self { + Self::Text { text: text.into() } + } + + /// Construct one image URL content part with the given URL string. + pub fn image_url(image_url: impl Into) -> Self { + Self::ImageUrl { + image_url: image_url.into(), + detail: None, + uuid: None, + } + } + + /// Return the text content of this part when it's a text block, or an + /// "unsupported multimodal content" error otherwise. + pub(crate) fn as_text(&self) -> Result<&str> { + match self { + Self::Text { text } => Ok(text), + Self::ImageUrl { .. } => Err(Error::UnsupportedMultimodalContent("image_url")), + } + } + + /// Return whether this part is a text block with empty content. + pub(crate) fn is_empty_text(&self) -> bool { + matches!(self, Self::Text { text } if text.is_empty()) + } + + /// Return whether this part contains any multimodal content. + pub(crate) fn is_multimodal(&self) -> bool { + match self { + Self::Text { .. } => false, + Self::ImageUrl { .. } => true, + } + } +} + +/// Text-only chat content. +/// +/// This supports either a simple string or an OpenAI-style list of text blocks. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ChatContent { + /// Simple text content. + Text(String), + /// OpenAI-style blocks. + Parts(Vec), +} + +impl ChatContent { + /// Flatten the text content into one plain string without adding + /// separators. + // TODO: this method will be truly fallible once we add non-text content parts. + pub fn try_flatten_to_text(&self) -> Result { + Ok(match self { + Self::Text(text) => text.clone(), + Self::Parts(parts) => { + parts.iter().map(ChatContentPart::as_text).collect::>>()?.concat() + } + }) + } + + /// Return whether there's no text content or only empty text blocks. + pub fn is_empty(&self) -> bool { + match self { + Self::Text(text) => text.is_empty(), + Self::Parts(parts) => parts.iter().all(ChatContentPart::is_empty_text), + } + } + + /// Return whether this content contains any multimodal parts. + pub fn has_multimodal(&self) -> bool { + match self { + Self::Text(_) => false, + Self::Parts(parts) => parts.iter().any(ChatContentPart::is_multimodal), + } + } +} + +impl From for ChatContent { + fn from(value: String) -> Self { + Self::Text(value) + } +} + +impl From<&str> for ChatContent { + fn from(value: &str) -> Self { + Self::Text(value.to_string()) + } +} + +impl From> for ChatContent { + fn from(value: Vec) -> Self { + Self::Parts(value) + } +} + +/// One chat message. +/// +/// Original Python API reference: +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "role", rename_all = "snake_case")] +pub enum ChatMessage { + /// System message content. + System { content: ChatContent }, + /// Developer message content plus optional message-local tools. + Developer { + content: ChatContent, + tools: Option>, + }, + /// User message content. + User { content: ChatContent }, + /// Assistant history content assembled from structured assistant blocks. + Assistant { content: Vec }, + /// Tool response content associated with one prior assistant tool call. + ToolResponse { + content: ChatContent, + tool_call_id: String, + }, +} + +impl ChatMessage { + /// Construct one chat message with plain string content. + pub fn text(role: ChatRole, text: impl Into) -> Self { + let content: String = text.into(); + + match role { + ChatRole::System => Self::system(content), + ChatRole::Developer => Self::developer(content, None), + ChatRole::User => Self::user(content), + ChatRole::Assistant => Self::assistant_text(content), + ChatRole::ToolResponse => { + panic!( + "tool response messages require a tool_call_id; \ + use ChatMessage::tool_response() instead" + ) + } + } + } + + /// Construct one chat message with system role. + pub fn system(content: impl Into) -> Self { + Self::System { + content: content.into(), + } + } + + /// Construct one chat message with developer role. + pub fn developer(content: impl Into, tools: Option>) -> Self { + Self::Developer { + content: content.into(), + tools, + } + } + + /// Construct one chat message with user role. + pub fn user(content: impl Into) -> Self { + Self::User { + content: content.into(), + } + } + + /// Construct one chat message with assistant role and plain string content. + pub fn assistant_text(text: impl Into) -> Self { + Self::Assistant { + content: vec![AssistantContentBlock::Text { text: text.into() }], + } + } + + /// Construct one chat message with assistant role and structured content + /// blocks. + pub fn assistant_blocks(content: Vec) -> Self { + Self::Assistant { content } + } + + /// Construct one tool-role message. + pub fn tool_response(content: impl Into, tool_call_id: impl Into) -> Self { + Self::ToolResponse { + content: content.into(), + tool_call_id: tool_call_id.into(), + } + } + + /// Return the chat role of this message. + pub fn role(&self) -> ChatRole { + match self { + Self::System { .. } => ChatRole::System, + Self::Developer { .. } => ChatRole::Developer, + Self::User { .. } => ChatRole::User, + Self::Assistant { .. } => ChatRole::Assistant, + Self::ToolResponse { .. } => ChatRole::ToolResponse, + } + } + + /// Concatenate the visible text carried by this message. + pub fn text_content(&self) -> Result { + match self { + Self::System { content } + | Self::Developer { content, .. } + | Self::User { content } + | Self::ToolResponse { content, .. } => content.try_flatten_to_text(), + Self::Assistant { content } => Ok(content.text()), + } + } + + /// Concatenate assistant reasoning text when present. + pub fn reasoning_content(&self) -> Option { + match self { + Self::Assistant { content } => content.reasoning(), + Self::System { .. } + | Self::Developer { .. } + | Self::User { .. } + | Self::ToolResponse { .. } => None, + } + } + + /// Return whether this message contains any multimodal content. + pub fn has_multimodal(&self) -> bool { + match self { + Self::System { content } + | Self::Developer { content, .. } + | Self::User { content } + | Self::ToolResponse { content, .. } => content.has_multimodal(), + Self::Assistant { .. } => false, + } + } +} + +impl From for ChatMessage { + fn from(value: AssistantMessage) -> Self { + Self::Assistant { + content: value.content, + } + } +} + +/// Controls how prompt rendering should end after the existing chat history. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum GenerationPromptMode { + /// Append a generation prompt for a new assistant turn. + /// + /// Equivalent to `add_generation_prompt = true` and `continue_final_message + /// = false`. + #[default] + StartNewAssistant, + /// Leave the final assistant message open so generation continues it. + /// + /// Equivalent to `add_generation_prompt = false` and + /// `continue_final_message = true`. + ContinueFinalAssistant, + /// Render the existing chat history without adding any trailing generation + /// prompt. + /// + /// Equivalent to `add_generation_prompt = false` and + /// `continue_final_message = false`. + NoGenerationPrompt, +} + +/// Effort level for reasoning models. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ReasoningEffort { + None, + Minimal, + Low, + Medium, + High, + XHigh, + Max, +} + +impl ReasoningEffort { + pub fn as_str(self) -> &'static str { + match self { + Self::None => "none", + Self::Minimal => "minimal", + Self::Low => "low", + Self::Medium => "medium", + Self::High => "high", + Self::XHigh => "xhigh", + Self::Max => "max", + } + } +} + +/// Chat-template-related request options. +/// +/// These are the small subset of chat controls that currently affect prompt +/// rendering in `vllm-chat`. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ChatOptions { + /// Controls whether rendering starts a new assistant turn, continues the + /// final assistant message, or emits no trailing generation prompt at + /// all. + pub generation_prompt_mode: GenerationPromptMode, + + /// Per-request Jinja chat template override. When set, this template is + /// used instead of the model's default chat template. + pub chat_template: Option, + + /// Effort level exposed to chat templates for reasoning models. + pub reasoning_effort: Option, + + /// Additional keyword arguments exposed to the chat template. + pub template_kwargs: HashMap, +} + +impl Default for ChatOptions { + fn default() -> Self { + Self { + generation_prompt_mode: GenerationPromptMode::StartNewAssistant, + chat_template: None, + reasoning_effort: None, + template_kwargs: HashMap::new(), + } + } +} + +impl ChatOptions { + /// Whether to add a generation prompt for a new assistant turn after the + /// existing chat history. + pub fn add_generation_prompt(&self) -> bool { + matches!( + self.generation_prompt_mode, + GenerationPromptMode::StartNewAssistant + ) + } + + /// Whether to leave the final assistant message open so generation + /// continues it. + pub fn continue_final_message(&self) -> bool { + matches!( + self.generation_prompt_mode, + GenerationPromptMode::ContinueFinalAssistant + ) + } +} + +/// Tool-choice semantics supported by `vllm-chat`. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ChatToolChoice { + Auto, + #[default] + None, +} + +/// One chat request ready to be rendered into a prompt and lowered into a +/// generate request. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ChatRequest { + /// Stable caller-supplied request ID. + pub request_id: String, + /// Ordered chat history to render. + pub messages: Vec, + /// User-facing sampling parameters accepted by `vllm-chat`. + pub sampling_params: SamplingParams, + /// Chat-specific rendering options. + pub chat_options: ChatOptions, + /// Function tools made available to the model for this request. + pub tools: Vec, + /// Tool-choice behavior for this request. + pub tool_choice: ChatToolChoice, + /// Text decode options for incremental detokenization. + pub decode_options: TextDecodeOptions, + /// Whether to emit intermediate northbound content deltas before the + /// terminal result. + /// + /// If `false`, callers only observe the terminal accumulated assistant + /// output. If `true`, callers may receive zero or more incremental + /// content events before the final terminal one. + pub intermediate: bool, + /// Request scheduling priority (lower means earlier handling; default 0). + pub priority: i32, + /// Documents for RAG (retrieval-augmented generation), passed to the chat + /// template. + pub documents: Option>, + /// Salt for prefix cache isolation in multi-user environments. + pub cache_salt: Option, + /// Whether to add special tokens (e.g. BOS) during prompt tokenization. + pub add_special_tokens: bool, + /// Override data parallel rank. + #[serde(default)] + pub data_parallel_rank: Option, +} + +impl ChatRequest { + /// Return one minimal valid request fixture for tests. + pub fn for_test() -> Self { + Self { + request_id: "test-request".to_string(), + messages: vec![ChatMessage::text(ChatRole::User, "test")], + sampling_params: SamplingParams::default(), + chat_options: ChatOptions::default(), + tools: Vec::new(), + tool_choice: ChatToolChoice::None, + decode_options: TextDecodeOptions::default(), + intermediate: true, + priority: 0, + documents: None, + cache_salt: None, + add_special_tokens: false, + data_parallel_rank: None, + } + } + + /// Validate basic request invariants before rendering. + pub fn validate(&self) -> Result<()> { + if self.messages.is_empty() { + return Err(Error::EmptyMessages); + } + match ( + self.chat_options.generation_prompt_mode, + self.messages.last().map(ChatMessage::role), + ) { + (GenerationPromptMode::ContinueFinalAssistant, Some(ChatRole::Assistant)) => {} + (GenerationPromptMode::ContinueFinalAssistant, _) => { + return Err(Error::ContinueFinalAssistantWithoutFinalAssistant); + } + (GenerationPromptMode::NoGenerationPrompt, _) + | (GenerationPromptMode::StartNewAssistant, _) => {} + } + Ok(()) + } + + /// Return true if this request contains any multimodal content in its + /// messages. + pub fn has_multimodal(&self) -> bool { + self.messages.iter().any(ChatMessage::has_multimodal) + } + + /// Return true if this request should enable tool parsing based on the tool + /// choice and tool list. + pub(crate) fn tool_parsing_enabled(&self) -> bool { + matches!(self.tool_choice, ChatToolChoice::Auto) && !self.tools.is_empty() + } + + /// Return the request-level thinking toggle when explicitly requested. + /// + /// We currently accept the two request kwargs `thinking` and + /// `enable_thinking`. Both must be booleans when present. If both are + /// present, they must have the same value. If neither key is provided, + /// return `None`. + pub(crate) fn enable_thinking(&self) -> Result> { + let thinking = self.parse_template_bool("thinking")?; + let enable_thinking = self.parse_template_bool("enable_thinking")?; + + match (thinking, enable_thinking) { + (None, None) => Ok(None), + (Some(thinking), Some(enable_thinking)) if thinking != enable_thinking => { + Err(Error::ChatTemplate( + "template kwargs `thinking` and `enable_thinking` must match when both are set" + .to_string(), + )) + } + (Some(thinking), _) => Ok(Some(thinking)), + (None, Some(enable_thinking)) => Ok(Some(enable_thinking)), + } + } + + pub(crate) fn parse_template_bool(&self, key: &str) -> Result> { + match self.chat_options.template_kwargs.get(key) { + None => Ok(None), + Some(Value::Bool(value)) => Ok(Some(*value)), + Some(other) => Err(Error::ChatTemplate(format!( + "template kwarg `{key}` must be a boolean, got {other}" + ))), + } + } +} + +impl ChatRole { + /// Return the chat-template role string used by the current text-only chat + /// backend. + pub fn as_str(&self) -> &'static str { + match self { + Self::System => "system", + Self::Developer => "developer", + Self::User => "user", + Self::Assistant => "assistant", + Self::ToolResponse => "tool_response", + } + } +} + +#[cfg(test)] +mod tests { + use serde_json::{json, to_value}; + + use super::{ChatContent, ChatContentPart, ChatMessage, ChatRequest, ChatRole, ChatTool}; + use crate::Error; + use crate::event::AssistantContentBlock; + + #[test] + fn chat_content_deserializes_from_raw_string() { + let content: ChatContent = serde_json::from_value(json!("hello")).unwrap(); + assert_eq!(content, ChatContent::Text("hello".to_string())); + } + + #[test] + fn chat_content_deserializes_from_openai_text_blocks() { + let content: ChatContent = + serde_json::from_value(json!([{ "type": "text", "text": "hello" }])).unwrap(); + assert_eq!( + content, + ChatContent::Parts(vec![ChatContentPart::text("hello")]) + ); + } + + #[test] + fn chat_content_from_string_like_values_builds_text() { + assert_eq!( + ChatContent::from("hello"), + ChatContent::Text("hello".to_string()) + ); + assert_eq!( + ChatContent::from("hello".to_string()), + ChatContent::Text("hello".to_string()) + ); + } + + #[test] + fn chat_content_try_flattens_text_parts_without_separators() { + let content = ChatContent::Parts(vec![ + ChatContentPart::text("hello"), + ChatContentPart::text(" world"), + ]); + assert_eq!(content.try_flatten_to_text().unwrap(), "hello world"); + } + + #[test] + fn assistant_message_collects_visible_and_reasoning_text() { + let message = ChatMessage::assistant_blocks(vec![ + AssistantContentBlock::Reasoning { + text: "inner".to_string(), + }, + AssistantContentBlock::Text { + text: "outer".to_string(), + }, + ]); + + assert_eq!(message.role(), ChatRole::Assistant); + assert_eq!(message.text_content().unwrap(), "outer"); + assert_eq!(message.reasoning_content().as_deref(), Some("inner")); + } + + #[test] + fn developer_message_round_trips_through_serde() { + let message = ChatMessage::developer( + "hello", + Some(vec![ChatTool { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + }), + strict: Some(true), + }]), + ); + + let value = to_value(&message).unwrap(); + let decoded: ChatMessage = serde_json::from_value(value).unwrap(); + assert_eq!(decoded, message); + } + + #[test] + fn enable_thinking_is_none_when_no_kwargs_are_present() { + let request = ChatRequest::for_test(); + assert_eq!(request.enable_thinking().unwrap(), None); + } + + #[test] + fn enable_thinking_accepts_matching_duplicate_kwargs() { + let mut request = ChatRequest::for_test(); + request.chat_options.template_kwargs.insert("thinking".to_string(), json!(true)); + request + .chat_options + .template_kwargs + .insert("enable_thinking".to_string(), json!(true)); + + assert_eq!(request.enable_thinking().unwrap(), Some(true)); + } + + #[test] + fn enable_thinking_rejects_non_boolean_kwargs() { + let mut request = ChatRequest::for_test(); + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), json!("yes")); + + assert!(matches!( + request.enable_thinking(), + Err(Error::ChatTemplate(message)) + if message.contains("`thinking` must be a boolean") + )); + } + + #[test] + fn enable_thinking_rejects_conflicting_duplicate_kwargs() { + let mut request = ChatRequest::for_test(); + request + .chat_options + .template_kwargs + .insert("thinking".to_string(), json!(false)); + request + .chat_options + .template_kwargs + .insert("enable_thinking".to_string(), json!(true)); + + assert!(matches!( + request.enable_thinking(), + Err(Error::ChatTemplate(message)) + if message.contains("`thinking` and `enable_thinking` must match") + )); + } +} diff --git a/rust/src/chat/src/stream.rs b/rust/src/chat/src/stream.rs new file mode 100644 index 00000000000..8a8dea46e6c --- /dev/null +++ b/rust/src/chat/src/stream.rs @@ -0,0 +1,237 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::Stream; +use trait_set::trait_set; +use vllm_text::{DecodedLogprobs, DecodedPositionLogprobs, DecodedPromptLogprobs}; + +use crate::FinishReason; +use crate::error::{Error, Result}; +use crate::event::{AssistantContentBlock, AssistantMessage, ChatEvent}; + +/// Final structured assistant message plus terminal stream metadata. +#[derive(Debug, Clone, PartialEq)] +pub struct CollectedAssistantMessage { + pub message: AssistantMessage, + pub prompt_token_count: usize, + pub prompt_token_ids: Arc<[u32]>, + pub prompt_logprobs: Option, + pub logprobs: Option, + pub token_ids: Vec, + pub output_token_count: usize, + pub finish_reason: FinishReason, + /// Connector-specific KV transfer parameters for disaggregated serving. + pub kv_transfer_params: Option, +} + +/// Per-request stream of chat events. +pub struct ChatEventStream { + request_id: String, + inner: Pin> + Send>>, +} + +impl ChatEventStream { + pub(crate) fn new(request_id: String, inner: impl crate::output::ChatEventStream) -> Self { + Self { + request_id, + inner: Box::pin(inner), + } + } + + /// Return the request ID associated with this stream. + pub fn request_id(&self) -> &str { + &self.request_id + } + + /// Collect the stream to completion and return the final assembled + /// assistant message. + pub async fn collect_message(mut self) -> Result { + use futures::StreamExt as _; + + let mut message = AssistantMessage::default(); + let mut prompt_logprobs = None; + let mut prompt_token_ids: Arc<[u32]> = Arc::from([]); + let mut logprob_positions: Vec = Vec::new(); + let mut token_ids: Vec = Vec::new(); + while let Some(event) = self.next().await.transpose()? { + match event { + ChatEvent::Start { + prompt_logprobs: start_prompt_logprobs, + prompt_token_ids: start_prompt_token_ids, + } => { + prompt_logprobs = start_prompt_logprobs; + prompt_token_ids = start_prompt_token_ids; + } + ChatEvent::BlockEnd { block, .. } => message.push_block(block), + ChatEvent::LogprobsDelta { + logprobs, + token_ids: delta_ids, + } => { + if let Some(logprobs) = logprobs { + logprob_positions.extend(logprobs.positions); + } + token_ids.extend(delta_ids); + } + ChatEvent::Done { + message: done, + prompt_token_count, + output_token_count, + finish_reason, + kv_transfer_params, + } => { + return Ok(CollectedAssistantMessage { + message: done, + prompt_token_count, + prompt_token_ids, + prompt_logprobs, + logprobs: (!logprob_positions.is_empty()).then_some(DecodedLogprobs { + positions: logprob_positions, + }), + token_ids, + output_token_count, + finish_reason, + kv_transfer_params, + }); + } + ChatEvent::ToolCallEnd { call, .. } => { + message.push_block(AssistantContentBlock::ToolCall(call)); + } + ChatEvent::BlockStart { .. } + | ChatEvent::BlockDelta { .. } + | ChatEvent::ToolCallStart { .. } + | ChatEvent::ToolCallArgumentsDelta { .. } => {} + } + } + + // Note: this is actually unreachable, as the underlying stream always emit an + // error on unexpected close. + Err(Error::StreamClosedBeforeTerminalOutput { + request_id: self.request_id, + }) + } +} + +impl Stream for ChatEventStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } +} + +trait_set! { + pub trait ChatEventStreamTrait = Stream> + Send + 'static; +} + +#[cfg(test)] +mod tests { + + use futures::stream; + use vllm_llm::FinishReason; + use vllm_text::{ + DecodedLogprobs, DecodedPositionLogprobs, DecodedPromptLogprobs, DecodedTokenLogprob, + }; + + use super::{ChatEventStream, CollectedAssistantMessage}; + use crate::error::Error; + use crate::event::ChatEvent; + + #[tokio::test] + async fn collect_message_requires_terminal_done_event() { + let stream = ChatEventStream::new( + "chat-missing-done".to_string(), + stream::iter([Ok(ChatEvent::Start { + prompt_token_ids: vec![].into(), + prompt_logprobs: None, + })]), + ); + + let error = stream.collect_message().await.expect_err("missing done"); + assert!(matches!( + error, + Error::StreamClosedBeforeTerminalOutput { request_id } + if request_id == "chat-missing-done" + )); + } + + #[tokio::test] + async fn collect_message_retains_prompt_and_sample_logprobs() { + let stream = ChatEventStream::new( + "chat-logprobs".to_string(), + stream::iter(vec![ + Ok(ChatEvent::Start { + prompt_token_ids: vec![10, 11].into(), + prompt_logprobs: Some(DecodedPromptLogprobs { + first_token_id: 0, + first_token: "o".to_string(), + scored_positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "p".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }), + }), + Ok(ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "a".to_string(), + logprob: -0.2, + rank: 1, + }], + }], + }), + token_ids: vec![], + }), + Ok(ChatEvent::Done { + message: Default::default(), + prompt_token_count: 2, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]), + ); + + let collected = stream.collect_message().await.unwrap(); + assert_eq!( + collected, + CollectedAssistantMessage { + message: Default::default(), + prompt_token_count: 2, + prompt_token_ids: vec![10, 11].into(), + prompt_logprobs: Some(DecodedPromptLogprobs { + first_token_id: 0, + first_token: "o".to_string(), + scored_positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "p".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }), + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "a".to_string(), + logprob: -0.2, + rank: 1, + }], + }], + }), + token_ids: vec![], + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + } + ); + } +} diff --git a/rust/src/chat/tests/chat.rs b/rust/src/chat/tests/chat.rs new file mode 100644 index 00000000000..7c423561c85 --- /dev/null +++ b/rust/src/chat/tests/chat.rs @@ -0,0 +1,1671 @@ +use std::collections::BTreeSet; +use std::fmt; +use std::sync::Arc; +use std::time::Duration; + +use futures::StreamExt as _; +use tokio::time::timeout; +use vllm_chat::{ + AssistantBlockKind, AssistantContentBlock, AssistantMessageExt as _, ChatBackend, ChatEvent, + ChatLlm, ChatMessage, ChatRenderer, ChatRequest, ChatRole, ChatTextBackend, ChatTool, + ChatToolChoice, DefaultChatOutputProcessor, DynChatOutputProcessor, DynChatRenderer, + FinishReason, GenerationPromptMode, NewChatOutputProcessorOptions, ParserSelection, + RenderedPrompt, SamplingParams, +}; +use vllm_engine_core_client::protocol::logprobs::{ + Logprobs, MaybeWireLogprobs, PositionLogprobs, TokenLogprob, +}; +use vllm_engine_core_client::protocol::{ + EngineCoreFinishReason, EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, StopReason, +}; +use vllm_engine_core_client::test_utils::{IpcNamespace, spawn_mock_engine_task}; +use vllm_engine_core_client::{EngineCoreClient, EngineCoreClientConfig}; +use vllm_llm::Llm; +use vllm_text::tokenizer::{DynTokenizer, Tokenizer}; +use vllm_text::{ + DecodedLogprobs, DecodedPositionLogprobs, DecodedPromptLogprobs, DecodedTokenLogprob, Prompt, + TextBackend, +}; +use zeromq::prelude::{SocketRecv, SocketSend}; +use zeromq::{DealerSocket, PushSocket, ZmqMessage}; + +const SPECIAL_STOP_TOKEN_ID: u32 = 256; + +fn request_output( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, + stop_reason: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason, + stop_reason, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn request_output_with_logprobs( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, + stop_reason: Option, + new_logprobs: Option, + new_prompt_logprobs_tensors: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: new_logprobs.map(MaybeWireLogprobs::Direct), + new_prompt_logprobs_tensors: new_prompt_logprobs_tensors.map(MaybeWireLogprobs::Direct), + pooling_output: None, + finish_reason, + stop_reason, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn sample_logprobs_for_token(token_id: u32, alternate_token_id: u32) -> Logprobs { + Logprobs { + positions: vec![PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id, + logprob: -0.1, + rank: 1, + }, + TokenLogprob { + token_id: alternate_token_id, + logprob: -0.2, + rank: 2, + }, + ], + }], + } +} + +fn prompt_logprobs_for_hi() -> Logprobs { + Logprobs { + positions: vec![PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: b'i' as u32, + logprob: -0.3, + rank: 1, + }, + TokenLogprob { + token_id: b'!' as u32, + logprob: -0.4, + rank: 2, + }, + ], + }], + } +} + +fn bytes_to_token_ids(bytes: &[u8]) -> Vec { + bytes.iter().map(|byte| u32::from(*byte)).collect() +} + +fn bytes_with_special_stop_token(bytes: &[u8]) -> Vec { + let mut token_ids = bytes_to_token_ids(bytes); + token_ids.push(SPECIAL_STOP_TOKEN_ID); + token_ids +} + +async fn send_outputs(push: &mut PushSocket, outputs: EngineCoreOutputs) { + push.send(ZmqMessage::from(rmp_serde::to_vec_named(&outputs).unwrap())) + .await + .unwrap(); +} + +async fn recv_engine_message(dealer: &mut DealerSocket) -> Vec { + dealer.recv().await.unwrap().into_vec() +} + +async fn connect_chat_llm_with_ipc( + config: EngineCoreClientConfig, + ipc: &IpcNamespace, + backend: Arc, +) -> ChatLlm { + let client = EngineCoreClient::connect(config.with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + )) + .await + .unwrap(); + ChatLlm::from_shared_backend( + Llm::new(client).with_request_id_randomization(false), + backend, + ) +} + +#[derive(Clone)] +struct FakeChatBackend { + has_template: bool, + model_id: String, +} + +#[derive(Debug)] +struct FakeChatTokenizer; + +impl Tokenizer for FakeChatTokenizer { + fn encode(&self, text: &str, _add_special_tokens: bool) -> vllm_tokenizer::Result> { + Ok(text.bytes().map(u32::from).collect()) + } + + fn decode( + &self, + token_ids: &[u32], + skip_special_tokens: bool, + ) -> vllm_tokenizer::Result { + let bytes = token_ids + .iter() + .filter_map(|id| { + if skip_special_tokens && *id == SPECIAL_STOP_TOKEN_ID { + None + } else { + Some(*id as u8) + } + }) + .collect::>(); + Ok(String::from_utf8_lossy(&bytes).into_owned()) + } + + fn token_to_id(&self, token: &str) -> Option { + match token { + "" => Some(0xF001), + "" => Some(0xF002), + "<|START_THINKING|>" => Some(0xF003), + "<|END_THINKING|>" => Some(0xF004), + "◁think▷" => Some(0xF005), + "◁/think▷" => Some(0xF006), + _ => None, + } + } +} + +impl fmt::Debug for FakeChatBackend { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FakeChatBackend").finish_non_exhaustive() + } +} + +impl FakeChatBackend { + fn new() -> Self { + Self { + has_template: true, + model_id: "test-model".to_string(), + } + } + + fn without_template() -> Self { + Self { + has_template: false, + model_id: "test-model".to_string(), + } + } + + fn with_model_id(model_id: impl Into) -> Self { + Self { + has_template: true, + model_id: model_id.into(), + } + } +} + +impl TextBackend for FakeChatBackend { + fn tokenizer(&self) -> DynTokenizer { + Arc::new(FakeChatTokenizer) + } + + fn model_id(&self) -> &str { + &self.model_id + } +} + +impl ChatBackend for FakeChatBackend { + fn chat_renderer(&self) -> DynChatRenderer { + Arc::new(self.clone()) + } + + fn new_chat_output_processor( + &self, + request: &mut ChatRequest, + options: NewChatOutputProcessorOptions<'_>, + ) -> vllm_chat::Result { + Ok(Box::new(DefaultChatOutputProcessor::new( + request, + &self.model_id, + self.tokenizer(), + options.tool_call_parser, + options.reasoning_parser, + )?)) + } +} + +impl ChatRenderer for FakeChatBackend { + fn render(&self, request: &ChatRequest) -> vllm_chat::Result { + if !self.has_template { + return Err(vllm_chat::Error::MissingChatTemplate); + } + + let mut prompt = String::new(); + for message in &request.messages { + prompt.push_str(message.role().as_str()); + prompt.push_str(": "); + prompt.push_str(&message.text_content()?); + prompt.push('\n'); + } + if request.chat_options.add_generation_prompt() { + prompt.push_str("assistant:"); + } + + Ok(RenderedPrompt { + prompt: Prompt::Text(prompt), + }) + } +} + +#[derive(Clone, Debug)] +struct FailingDecodeBackend { + inner: FakeChatBackend, +} + +#[derive(Debug)] +struct FailingDecodeTokenizer; + +impl Tokenizer for FailingDecodeTokenizer { + fn encode(&self, text: &str, add_special_tokens: bool) -> vllm_tokenizer::Result> { + FakeChatTokenizer.encode(text, add_special_tokens) + } + + fn decode( + &self, + token_ids: &[u32], + skip_special_tokens: bool, + ) -> vllm_tokenizer::Result { + if token_ids.contains(&(b'i' as u32)) { + return Err(vllm_tokenizer::TokenizerError("decode failed".to_string())); + } + FakeChatTokenizer.decode(token_ids, skip_special_tokens) + } + + fn token_to_id(&self, token: &str) -> Option { + FakeChatTokenizer.token_to_id(token) + } +} + +impl TextBackend for FailingDecodeBackend { + fn tokenizer(&self) -> DynTokenizer { + Arc::new(FailingDecodeTokenizer) + } + + fn model_id(&self) -> &str { + self.inner.model_id() + } +} + +impl ChatBackend for FailingDecodeBackend { + fn chat_renderer(&self) -> DynChatRenderer { + Arc::new(self.clone()) + } + + fn new_chat_output_processor( + &self, + _request: &mut ChatRequest, + _options: NewChatOutputProcessorOptions<'_>, + ) -> vllm_chat::Result { + Ok(Box::new(DefaultChatOutputProcessor::plain_text_only())) + } +} + +impl ChatRenderer for FailingDecodeBackend { + fn render(&self, request: &ChatRequest) -> vllm_chat::Result { + self.inner.render(request) + } +} + +/// Skip `LogprobsDelta` events that carry only token_ids (no logprobs), +/// returning the next semantically interesting event. +async fn next_semantic(stream: &mut S) -> Option> +where + S: futures::Stream> + Unpin, +{ + loop { + match stream.next().await { + Some(Ok(ChatEvent::LogprobsDelta { logprobs: None, .. })) => continue, + other => return other, + } + } +} + +fn sample_request(request_id: &str) -> ChatRequest { + ChatRequest { + messages: vec![ + ChatMessage::text(ChatRole::System, "You are terse."), + ChatMessage::text(ChatRole::User, "Say hi"), + ], + sampling_params: SamplingParams { + max_tokens: Some(8), + ..Default::default() + }, + request_id: request_id.to_string(), + ..ChatRequest::for_test() + } +} + +fn sample_tool_request(request_id: &str) -> ChatRequest { + let mut request = sample_request(request_id); + request.tools = vec![ChatTool { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: serde_json::json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }), + strict: None, + }]; + request.tool_choice = ChatToolChoice::Auto; + request +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_streams_text_events() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.request_id, "chat-1"); + // more fields here in the future + assert_eq!( + String::from_utf8( + request + .prompt_token_ids + .clone() + .unwrap() + .into_iter() + .map(|id| id as u8) + .collect() + ) + .unwrap(), + "system: You are terse.\nuser: Say hi\nassistant:" + ); + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![ + request_output("chat-1", vec![b'H' as u32], None, None), + request_output( + "chat-1", + vec![b'i' as u32, b'!' as u32], + Some(EngineCoreFinishReason::Stop), + Some(StopReason::TokenId(b'!' as u32)), + ), + ], + finished_requests: Some(BTreeSet::from(["chat-1".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = Arc::new(FakeChatBackend::new()); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address).with_model_name("test-model"), + &ipc, + backend, + ) + .await; + + let mut stream = chat.chat(sample_request("chat-1")).await.unwrap(); + + match next_semantic(&mut stream).await.unwrap().unwrap() { + ChatEvent::Start { + prompt_token_ids, + prompt_logprobs: None, + } => { + assert_eq!( + prompt_token_ids.len(), + "system: You are terse.\nuser: Say hi\nassistant:".len() + ); + assert!(!prompt_token_ids.is_empty()); + } + other => panic!("expected Start, got {other:?}"), + } + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Text, + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Text, + delta: "H".to_string(), + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Text, + delta: "i".to_string(), + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockEnd { + index: 0, + block: AssistantContentBlock::Text { + text: "Hi".to_string(), + }, + } + ); + + match next_semantic(&mut stream).await { + Some(Ok(ChatEvent::Done { + message, + output_token_count, + finish_reason, + .. + })) => { + assert_eq!(message.text(), "Hi"); + assert_eq!(output_token_count, 3); + assert_eq!( + finish_reason, + FinishReason::Stop(Some(StopReason::TokenId(b'!' as u32))) + ); + } + other => panic!("unexpected final event: {other:?}"), + } + assert!(next_semantic(&mut stream).await.is_none()); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_stream_waits_for_complete_utf8_before_emitting() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-utf8".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![ + request_output("chat-utf8", bytes_to_token_ids(&[0xe4]), None, None), + request_output( + "chat-utf8", + bytes_to_token_ids(&[0xbd, 0xa0, b'!']), + Some(EngineCoreFinishReason::Stop), + Some(StopReason::TokenId(b'!' as u32)), + ), + ], + finished_requests: Some(BTreeSet::from(["chat-utf8".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = Arc::new(FakeChatBackend::new()); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await; + + let mut stream = chat.chat(sample_request("chat-utf8")).await.unwrap(); + + assert!(matches!( + next_semantic(&mut stream).await, + Some(Ok(ChatEvent::Start { + prompt_logprobs: None, + .. + })) + )); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Text, + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Text, + delta: "你".to_string(), + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockEnd { + index: 0, + block: AssistantContentBlock::Text { + text: "你".to_string(), + }, + } + ); + + match next_semantic(&mut stream).await { + Some(Ok(ChatEvent::Done { + message, + output_token_count, + .. + })) => { + assert_eq!(message.text(), "你"); + assert_eq!(output_token_count, 4); + } + other => panic!("unexpected final event: {other:?}"), + } + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_stream_flushes_held_text_on_finish() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-final-flush".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output( + "chat-final-flush", + bytes_to_token_ids(b"ok st"), + Some(EngineCoreFinishReason::Length), + None, + )], + finished_requests: Some(BTreeSet::from(["chat-final-flush".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = Arc::new(FakeChatBackend::new()); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await; + + let mut stream = chat.chat(sample_request("chat-final-flush")).await.unwrap(); + + assert!(matches!( + next_semantic(&mut stream).await, + Some(Ok(ChatEvent::Start { + prompt_logprobs: None, + .. + })) + )); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Text, + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Text, + delta: "ok st".to_string(), + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockEnd { + index: 0, + block: AssistantContentBlock::Text { + text: "ok st".to_string(), + }, + } + ); + + match next_semantic(&mut stream).await { + Some(Ok(ChatEvent::Done { + message, + output_token_count, + finish_reason, + .. + })) => { + assert_eq!(message.text(), "ok st"); + assert_eq!(output_token_count, 5); + assert_eq!(finish_reason, FinishReason::Length); + } + other => panic!("unexpected final event: {other:?}"), + } + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[test] +fn chat_request_rejects_conflicting_generation_modes() { + let mut request = sample_request("chat-2"); + request.chat_options.generation_prompt_mode = GenerationPromptMode::ContinueFinalAssistant; + let error = request.validate().unwrap_err(); + + assert!(matches!( + error, + vllm_chat::Error::ContinueFinalAssistantWithoutFinalAssistant + )); +} + +#[test] +fn chat_request_accepts_continue_final_assistant_mode_with_final_assistant() { + let mut request = sample_request("chat-2b"); + request.messages = vec![ChatMessage::assistant_text("hello")]; + request.chat_options.generation_prompt_mode = GenerationPromptMode::ContinueFinalAssistant; + + request.validate().unwrap(); +} + +#[test] +fn backend_requires_a_template() { + let request = sample_request("chat-3"); + let backend = FakeChatBackend::without_template(); + let error = backend.chat_renderer().render(&request).unwrap_err(); + assert!(matches!(error, vllm_chat::Error::MissingChatTemplate)); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_stream_reports_decode_failure_as_error_event() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-decode-fail".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output("chat-4", vec![b'i' as u32], None, None)], + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = Arc::new(FailingDecodeBackend { + inner: FakeChatBackend::new(), + }); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await; + + let mut stream = chat.chat(sample_request("chat-4")).await.unwrap(); + assert_eq!(stream.request_id(), "chat-4"); + assert!(matches!( + next_semantic(&mut stream).await, + Some(Ok(ChatEvent::Start { + prompt_logprobs: None, + .. + })) + )); + + match timeout(Duration::from_secs(2), stream.next()).await.unwrap() { + Some(Err(vllm_chat::Error::Text(vllm_text::Error::Tokenizer(message)))) => { + assert_eq!(message, "decode failed"); + } + other => panic!("unexpected event after close: {other:?}"), + } + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_stream_preserves_terminal_stop_token_when_requested() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-include-stop".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output( + "chat-include-stop", + vec![b'H' as u32, b'i' as u32, b'!' as u32], + Some(EngineCoreFinishReason::Stop), + Some(StopReason::TokenId(b'!' as u32)), + )], + finished_requests: Some(BTreeSet::from(["chat-include-stop".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = Arc::new(FakeChatBackend::new()); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await; + + let mut request = sample_request("chat-include-stop"); + request.decode_options.include_stop_str_in_output = true; + let mut stream = chat.chat(request).await.unwrap(); + + assert!(matches!( + next_semantic(&mut stream).await, + Some(Ok(ChatEvent::Start { + prompt_logprobs: None, + .. + })) + )); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Text, + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Text, + delta: "Hi!".to_string(), + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockEnd { + index: 0, + block: AssistantContentBlock::Text { + text: "Hi!".to_string(), + }, + } + ); + + match next_semantic(&mut stream).await { + Some(Ok(ChatEvent::Done { + message, + output_token_count, + .. + })) => { + assert_eq!(message.text(), "Hi!"); + assert_eq!(output_token_count, 3); + } + other => panic!("unexpected final event: {other:?}"), + } + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_stream_separates_reasoning_blocks_automatically() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-reasoning".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![ + request_output( + "chat-reasoning", + bytes_to_token_ids(b""), + None, + None, + ), + request_output( + "chat-reasoning", + bytes_to_token_ids(b"reason "), + None, + None, + ), + request_output( + "chat-reasoning", + bytes_to_token_ids(b"more"), + None, + None, + ), + request_output( + "chat-reasoning", + bytes_to_token_ids(b"answer"), + Some(EngineCoreFinishReason::Length), + None, + ), + ], + finished_requests: Some(BTreeSet::from(["chat-reasoning".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await; + + let mut stream = chat.chat(sample_request("chat-reasoning")).await.unwrap(); + + assert!(matches!( + next_semantic(&mut stream).await, + Some(Ok(ChatEvent::Start { + prompt_logprobs: None, + .. + })) + )); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Reasoning, + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Reasoning, + delta: "reason ".to_string(), + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Reasoning, + delta: "more".to_string(), + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockEnd { + index: 0, + block: AssistantContentBlock::Reasoning { + text: "reason more".to_string(), + }, + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockStart { + index: 1, + kind: AssistantBlockKind::Text, + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockDelta { + index: 1, + kind: AssistantBlockKind::Text, + delta: "answer".to_string(), + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockEnd { + index: 1, + block: AssistantContentBlock::Text { + text: "answer".to_string(), + }, + } + ); + + match next_semantic(&mut stream).await { + Some(Ok(ChatEvent::Done { + message, + finish_reason, + .. + })) => { + assert_eq!(message.reasoning().unwrap(), "reason more"); + assert_eq!(message.text(), "answer"); + assert_eq!(finish_reason, FinishReason::Length); + } + other => panic!("unexpected final event: {other:?}"), + } + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_collectors_return_structured_message_and_visible_text() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-collect".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output( + "chat-collect", + bytes_to_token_ids(b"innerouter"), + Some(EngineCoreFinishReason::Length), + None, + )], + finished_requests: Some(BTreeSet::from(["chat-collect".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address.clone()), + &ipc, + backend.clone(), + ) + .await; + + let message = chat + .chat(sample_request("chat-collect")) + .await + .unwrap() + .collect_message() + .await + .unwrap(); + assert_eq!(message.message.reasoning().unwrap(), "inner"); + assert_eq!(message.message.text(), "outer"); + assert_eq!(message.finish_reason, FinishReason::Length); + assert_eq!( + message.prompt_token_count, + "system: You are terse.\nuser: Say hi\nassistant:".len() + ); + assert_eq!( + message.output_token_count, + "innerouter".len() + ); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_explicitly_disables_reasoning_parser() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-reasoning-disabled".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![ + request_output( + "chat-reasoning-disabled", + bytes_to_token_ids(b""), + None, + None, + ), + request_output( + "chat-reasoning-disabled", + bytes_to_token_ids(b"reason "), + None, + None, + ), + request_output( + "chat-reasoning-disabled", + bytes_to_token_ids(b"more"), + None, + None, + ), + request_output( + "chat-reasoning-disabled", + bytes_to_token_ids(b"answer"), + Some(EngineCoreFinishReason::Length), + None, + ), + ], + finished_requests: Some(BTreeSet::from([ + "chat-reasoning-disabled".to_string() + ])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await + .with_reasoning_parser(ParserSelection::None); + + let message = chat + .chat(sample_request("chat-reasoning-disabled")) + .await + .unwrap() + .collect_message() + .await + .unwrap(); + assert_eq!(message.message.reasoning(), None); + assert_eq!(message.message.text(), "reason moreanswer"); + assert_eq!(message.finish_reason, FinishReason::Length); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_stream_parses_tool_calls_automatically() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-tool".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![ + request_output( + "chat-tool", + bytes_to_token_ids(b"Need tool."), + None, + None, + ), + request_output( + "chat-tool", + bytes_to_token_ids(b"\n{\"name\":\"get_weather\", "), + None, + None, + ), + request_output( + "chat-tool", + bytes_to_token_ids( + b"\"arguments\":{\"city\":\"Paris\"}}\n", + ), + Some(EngineCoreFinishReason::Stop), + Some(StopReason::TokenId(SPECIAL_STOP_TOKEN_ID)), + ), + ], + finished_requests: Some(BTreeSet::from(["chat-tool".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await; + let mut stream = chat.chat(sample_tool_request("chat-tool")).await.unwrap(); + + let mut saw_tool_start = false; + let mut saw_tool_args = false; + let mut saw_tool_end = false; + + while let Some(event) = stream.next().await { + match event.unwrap() { + ChatEvent::Start { .. } => {} + ChatEvent::LogprobsDelta { .. } => {} + ChatEvent::ToolCallStart { name, .. } => { + saw_tool_start = true; + assert_eq!(name, "get_weather"); + } + ChatEvent::ToolCallArgumentsDelta { delta, .. } => { + saw_tool_args = true; + assert!(delta.contains("Paris"), "{delta}"); + } + ChatEvent::ToolCallEnd { call, .. } => { + saw_tool_end = true; + assert_eq!(call.name, "get_weather"); + assert_eq!(call.arguments, r#"{"city":"Paris"}"#); + } + ChatEvent::Done { + message, + finish_reason, + .. + } => { + assert_eq!( + finish_reason, + FinishReason::Stop(Some(StopReason::TokenId(SPECIAL_STOP_TOKEN_ID))) + ); + assert_eq!(message.text(), ""); + let tool_calls = message.tool_calls().collect::>(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].name, "get_weather"); + assert_eq!(tool_calls[0].arguments, r#"{"city":"Paris"}"#); + break; + } + ChatEvent::BlockStart { .. } + | ChatEvent::BlockDelta { .. } + | ChatEvent::BlockEnd { .. } => {} + } + } + + assert!(saw_tool_start); + assert!(saw_tool_args); + assert!(saw_tool_end); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_collect_message_preserves_tool_call_arguments_in_final_only_mode() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-final-only-tool".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![ + request_output( + "chat-final-only-tool", + bytes_to_token_ids(b"Need tool."), + None, + None, + ), + request_output( + "chat-final-only-tool", + bytes_to_token_ids(b"\n{\"name\":\"get_weather\", "), + None, + None, + ), + request_output( + "chat-final-only-tool", + bytes_with_special_stop_token( + b"\"arguments\":{\"city\":\"Paris\"}}\n", + ), + Some(EngineCoreFinishReason::Stop), + Some(StopReason::TokenId(SPECIAL_STOP_TOKEN_ID)), + ), + ], + finished_requests: Some(BTreeSet::from([ + "chat-final-only-tool".to_string() + ])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let backend: Arc = + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await; + let mut request = sample_tool_request("chat-final-only-tool"); + request.intermediate = false; + + let message = chat.chat(request).await.unwrap().collect_message().await.unwrap(); + + assert_eq!( + message.finish_reason, + FinishReason::Stop(Some(StopReason::TokenId(SPECIAL_STOP_TOKEN_ID))) + ); + assert_eq!(message.message.tool_calls().count(), 1); + assert_eq!( + message.message.tool_calls().next().unwrap().arguments, + r#"{"city":"Paris"}"# + ); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_stream_and_collect_preserve_prompt_and_sample_logprobs() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-logprobs".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + for _ in 0..2 { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![ + request_output_with_logprobs( + &request.request_id, + vec![b'H' as u32], + None, + None, + Some(sample_logprobs_for_token(b'H' as u32, b'h' as u32)), + Some(prompt_logprobs_for_hi()), + ), + request_output_with_logprobs( + &request.request_id, + vec![b'i' as u32], + Some(EngineCoreFinishReason::Length), + None, + Some(sample_logprobs_for_token(b'i' as u32, b'I' as u32)), + None, + ), + ], + finished_requests: Some(BTreeSet::from([request.request_id])), + ..Default::default() + }, + ) + .await; + } + }) + }, + ); + + let backend: Arc = Arc::new(FakeChatBackend::new()); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address.clone()), + &ipc, + backend, + ) + .await; + + let mut request = sample_request("chat-logprobs"); + request.sampling_params.logprobs = Some(1); + request.sampling_params.prompt_logprobs = Some(1); + + let mut stream = chat.chat(request.clone()).await.unwrap(); + match next_semantic(&mut stream).await.unwrap().unwrap() { + ChatEvent::Start { + prompt_token_ids, + prompt_logprobs, + } => { + assert_eq!( + prompt_token_ids.len(), + "system: You are terse.\nuser: Say hi\nassistant:".len() + ); + assert!(!prompt_token_ids.is_empty()); + assert_eq!( + prompt_logprobs, + Some(DecodedPromptLogprobs { + first_token_id: b's' as u32, + first_token: "s".to_string(), + scored_positions: vec![DecodedPositionLogprobs { + entries: vec![ + DecodedTokenLogprob { + token_id: b'i' as u32, + token: "i".to_string(), + logprob: -0.3, + rank: 1, + }, + DecodedTokenLogprob { + token_id: b'!' as u32, + token: "!".to_string(), + logprob: -0.4, + rank: 1, + }, + ], + }], + }) + ); + } + other => panic!("expected Start, got {other:?}"), + } + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Text, + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Text, + delta: "H".to_string(), + } + ); + assert_eq!( + next_semantic(&mut stream).await.unwrap().unwrap(), + ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![ + DecodedTokenLogprob { + token_id: b'H' as u32, + token: "H".to_string(), + logprob: -0.1, + rank: 1, + }, + DecodedTokenLogprob { + token_id: b'h' as u32, + token: "h".to_string(), + logprob: -0.2, + rank: 1, + }, + ], + }], + }), + token_ids: vec![b'H' as u32], + } + ); + while !matches!( + next_semantic(&mut stream).await, + Some(Ok(ChatEvent::Done { .. })) + ) {} + + request.request_id = "chat-logprobs-collect".to_string(); + let collected = chat.chat(request).await.unwrap().collect_message().await.unwrap(); + assert_eq!(collected.message.text(), "Hi"); + assert_eq!( + collected.prompt_logprobs, + Some(DecodedPromptLogprobs { + first_token_id: b's' as u32, + first_token: "s".to_string(), + scored_positions: vec![DecodedPositionLogprobs { + entries: vec![ + DecodedTokenLogprob { + token_id: b'i' as u32, + token: "i".to_string(), + logprob: -0.3, + rank: 1, + }, + DecodedTokenLogprob { + token_id: b'!' as u32, + token: "!".to_string(), + logprob: -0.4, + rank: 1, + }, + ], + }], + }) + ); + assert_eq!( + collected.logprobs, + Some(DecodedLogprobs { + positions: vec![ + DecodedPositionLogprobs { + entries: vec![ + DecodedTokenLogprob { + token_id: b'H' as u32, + token: "H".to_string(), + logprob: -0.1, + rank: 1, + }, + DecodedTokenLogprob { + token_id: b'h' as u32, + token: "h".to_string(), + logprob: -0.2, + rank: 1, + }, + ], + }, + DecodedPositionLogprobs { + entries: vec![ + DecodedTokenLogprob { + token_id: b'i' as u32, + token: "i".to_string(), + logprob: -0.1, + rank: 1, + }, + DecodedTokenLogprob { + token_id: b'I' as u32, + token: "I".to_string(), + logprob: -0.2, + rank: 1, + }, + ], + }, + ], + }) + ); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_rejects_unknown_tool_parser_before_engine_request() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-tool-no-model".to_vec(); + let (shutdown_tx, engine_task) = + spawn_mock_engine_task(handshake_address.clone(), engine_id, |dealer, _| { + Box::pin(async move { + assert!( + timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err(), + "chat request should fail before any engine request is sent" + ); + }) + }); + + let backend: Arc = Arc::new(FakeChatBackend::new()); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await + .with_tool_call_parser(ParserSelection::Explicit( + "definitely_missing_tool_parser".into(), + )); + let error = match chat.chat(sample_tool_request("chat-tool-no-model")).await { + Ok(_) => panic!("unknown explicit tool parser should fail"), + Err(error) => error, + }; + + assert!(matches!( + error, + vllm_chat::Error::ParserUnavailableByName { name, .. } + if name == "definitely_missing_tool_parser" + )); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_rejects_unknown_reasoning_parser_before_engine_request() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-reasoning-no-model".to_vec(); + let (shutdown_tx, engine_task) = + spawn_mock_engine_task(handshake_address.clone(), engine_id, |dealer, _| { + Box::pin(async move { + assert!( + timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err(), + "chat request should fail before any engine request is sent" + ); + }) + }); + + let backend: Arc = Arc::new(FakeChatBackend::new()); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await + .with_reasoning_parser(ParserSelection::Explicit( + "definitely_missing_reasoning_parser".into(), + )); + let error = match chat.chat(sample_request("chat-reasoning-no-model")).await { + Ok(_) => panic!("unknown explicit reasoning parser should fail"), + Err(error) => error, + }; + + assert!(matches!( + error, + vllm_chat::Error::ParserUnavailableByName { name, .. } + if name == "definitely_missing_reasoning_parser" + )); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chat_rejects_tool_requests_when_tool_parser_is_disabled() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-tool-parser-disabled".to_vec(); + let (shutdown_tx, engine_task) = + spawn_mock_engine_task(handshake_address.clone(), engine_id, |dealer, _| { + Box::pin(async move { + assert!( + timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err(), + "chat request should fail before any engine request is sent" + ); + }) + }); + + let backend: Arc = Arc::new(FakeChatBackend::new()); + let chat = connect_chat_llm_with_ipc( + EngineCoreClientConfig::new_single(handshake_address), + &ipc, + backend, + ) + .await + .with_tool_call_parser(ParserSelection::None); + let error = match chat.chat(sample_tool_request("chat-tool-parser-disabled")).await { + Ok(_) => panic!("tool requests should fail when tool parsing is disabled"), + Err(error) => error, + }; + + assert!(matches!( + error, + vllm_chat::Error::ParserDisabled { kind: "tool" } + )); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + chat.shutdown().await.unwrap(); +} diff --git a/rust/src/chat/tests/templates/qwen3.jinja b/rust/src/chat/tests/templates/qwen3.jinja new file mode 100644 index 00000000000..9769e5cfde9 --- /dev/null +++ b/rust/src/chat/tests/templates/qwen3.jinja @@ -0,0 +1,89 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} diff --git a/rust/src/chat/tests/templates/qwen35.jinja b/rust/src/chat/tests/templates/qwen35.jinja new file mode 100644 index 00000000000..ae8ae364921 --- /dev/null +++ b/rust/src/chat/tests/templates/qwen35.jinja @@ -0,0 +1,154 @@ +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {%- if loop.index0 > ns.last_query_index %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '\n' }} + {%- else %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/README.md b/rust/src/chat/tests/templates/vllm_examples/README.md new file mode 100644 index 00000000000..7d8a9150f10 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/README.md @@ -0,0 +1,6 @@ +# vLLM Example Templates + +These fixtures are copied from `vllm/examples/`. + +They are currently used by `src/chat/src/renderers/hf/format.rs` tests to keep +our chat-template content format detection aligned with Python vLLM behavior. diff --git a/rust/src/chat/tests/templates/vllm_examples/template_alpaca.jinja b/rust/src/chat/tests/templates/vllm_examples/template_alpaca.jinja new file mode 100644 index 00000000000..60667acc3ef --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/template_alpaca.jinja @@ -0,0 +1,29 @@ +{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} + +{% for message in messages %} +{% if message['role'] == 'user' %} +### Instruction: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% elif message['role'] == 'assistant' %} +### Response: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% elif message['role'] == 'user_context' %} +### Input: +{{ message['content']|trim -}} +{% if not loop.last %} + + +{% endif %} +{% endif %} +{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} +### Response: +{% endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/template_baichuan.jinja b/rust/src/chat/tests/templates/vllm_examples/template_baichuan.jinja new file mode 100644 index 00000000000..42a8d9270a4 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/template_baichuan.jinja @@ -0,0 +1,13 @@ +{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} + +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- '' + message['content'] -}} + {%- elif message['role'] == 'assistant' -%} + {{- '' + message['content'] -}} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- '' -}} +{% endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/template_chatglm.jinja b/rust/src/chat/tests/templates/vllm_examples/template_chatglm.jinja new file mode 100644 index 00000000000..bf26f27274e --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/template_chatglm.jinja @@ -0,0 +1,18 @@ +{%- set counter = namespace(index=0) -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- '[Round ' + counter.index|string + ']\n问:' + message['content'] -}} + {%- set counter.index = counter.index + 1 -%} + {%- endif -%} + {%- if message['role'] == 'assistant' -%} + {{- '\n答:' + message['content'] -}} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '\n' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- '\n答:' -}} +{%- endif -%} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/template_chatglm2.jinja b/rust/src/chat/tests/templates/vllm_examples/template_chatglm2.jinja new file mode 100644 index 00000000000..c155b7c23f6 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/template_chatglm2.jinja @@ -0,0 +1,18 @@ +{%- set counter = namespace(index=1) -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- '[Round ' + counter.index|string + ']\n\n问:' + message['content'] -}} + {%- set counter.index = counter.index + 1 -%} + {%- endif -%} + {%- if message['role'] == 'assistant' -%} + {{- '\n\n答:' + message['content'] -}} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '\n\n' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- '\n\n答:' -}} +{%- endif -%} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/template_chatml.jinja b/rust/src/chat/tests/templates/vllm_examples/template_chatml.jinja new file mode 100644 index 00000000000..4844e681e1b --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/template_chatml.jinja @@ -0,0 +1,2 @@ +{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/template_falcon.jinja b/rust/src/chat/tests/templates/vllm_examples/template_falcon.jinja new file mode 100644 index 00000000000..01cf0e2670d --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/template_falcon.jinja @@ -0,0 +1,15 @@ +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- 'User: ' + message['content'] -}} + {%- elif message['role'] == 'assistant' -%} + {{- 'Assistant: ' + message['content'] -}} + {%- endif -%} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '\n' -}} + {%- endif -%} +{%- endfor -%} + + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- 'Assistant:' -}} +{% endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/template_falcon_180b.jinja b/rust/src/chat/tests/templates/vllm_examples/template_falcon_180b.jinja new file mode 100644 index 00000000000..f08f7395b7f --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/template_falcon_180b.jinja @@ -0,0 +1,17 @@ +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {{- 'System: ' + message['content'] -}} + {%- elif message['role'] == 'user' -%} + {{- 'User: ' + message['content'] -}} + {%- elif message['role'] == 'assistant' -%} + {{- 'Falcon: ' + message['content'] -}} + {%- endif -%} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '\n' -}} + {%- endif -%} +{%- endfor -%} + + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- 'Falcon:' -}} +{% endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/template_inkbot.jinja b/rust/src/chat/tests/templates/vllm_examples/template_inkbot.jinja new file mode 100644 index 00000000000..33a817454df --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/template_inkbot.jinja @@ -0,0 +1,30 @@ +<#meta#> +- Date: {{ (messages|selectattr('role', 'equalto', 'meta-current_date')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-current_date')|list) else '' }} +- Task: {{ (messages|selectattr('role', 'equalto', 'meta-task_name')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-task_name')|list) else '' }} +<#system#> +{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} +<#chat#> +{% for message in messages %} +{% if message['role'] == 'user' %} +<#user#> +{{ message['content']|trim -}} +{% if not loop.last %} + +{% endif %} +{% elif message['role'] == 'assistant' %} +<#bot#> +{{ message['content']|trim -}} +{% if not loop.last %} + +{% endif %} +{% elif message['role'] == 'user_context' %} +<#user_context#> +{{ message['content']|trim -}} +{% if not loop.last %} + +{% endif %} +{% endif %} +{% endfor %} +{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} +<#bot#> +{% endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/template_teleflm.jinja b/rust/src/chat/tests/templates/vllm_examples/template_teleflm.jinja new file mode 100644 index 00000000000..0cb29ccbb84 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/template_teleflm.jinja @@ -0,0 +1,12 @@ +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {{- '<_user>' + message['content']|trim }} + {%- elif message['role'] == 'system' %} + {{- '<_system>' + message['content']|trim }} + {%- elif message['role'] == 'assistant' %} + {{- '<_bot>' + message['content'] }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<_bot>' }} +{%- endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekr1.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekr1.jinja new file mode 100644 index 00000000000..908574be9df --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekr1.jinja @@ -0,0 +1,92 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} +{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %} +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor -%} + +{#- Adapted from https://github.com/sgl-project/sglang/blob/main/examples/chat_template/tool_chat_template_deepseekr1.jinja #} +{% if tools is defined and tools is not none %} + {% set tool_ns = namespace(text='You are a helpful assistant with tool calling capabilities. ' + 'When a tool call is needed, you MUST use the following format to issue the call:\n' + '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>FUNCTION_NAME\n' + '```json\n{"param1": "value1", "param2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n\n' + 'Make sure the JSON is valid.' + '## Tools\n\n### Function\n\nYou have the following functions available:\n\n') %} + {% for tool in tools %} + {% set tool_ns.text = tool_ns.text + '\n```json\n' + (tool | tojson) + '\n```\n' %} + {% endfor %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} +{% endif %} + +{{- bos_token }} +{{- ns.system_prompt }} +{%- for message in messages %} + {% set content = message['content'] %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<|User|>' + content + '<|Assistant|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' %} + {% if '' in content %} + {% set content = content.split('')[-1] %} + {% endif %} + {% endif %} + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{- '<|tool▁outputs▁end|>'}} + {%- endif %} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- set ns.is_output_first = true %} + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if content is none %} + {{- '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- else %} + {{- content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{- '\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{- '<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{- '<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {{- content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {%- if ns.is_output_first %} + {{- '<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}} + {%- set ns.is_output_first = false %} + {%- else %} + {{- '\n<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}} + {%- endif %} + {%- endif %} +{%- endfor -%} +{% if ns.is_tool %} + {{- '<|tool▁outputs▁end|>'}} +{%- endif %} +{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %} + {{- '<|Assistant|>'}} +{%- endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekv3.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekv3.jinja new file mode 100644 index 00000000000..36f3781439e --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekv3.jinja @@ -0,0 +1,96 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} + +{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %} + +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor %} + +{{ bos_token }} +{{ ns.system_prompt }} +{%- if tools %} + {{"\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{"\n\n\n"}} + + {{"For function call returns, you should first print <|tool▁calls▁begin|>"}} + + {{"For each function call, you should return object like:\n" }} + {{"<|tool▁call▁begin|>function<|tool▁sep|>\n```json\n\n```<|tool▁call▁end|>"}} + + {{"At the end of function call returns, you should print <|tool▁calls▁end|><|end▁of▁sentence|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<|User|>' + message['content'] + '<|Assistant|>'}} + {%- endif %} + + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{'<|tool▁outputs▁end|>'}} + {%- endif %} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- set ns.is_output_first = true %} + + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- else %} + {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {% set content = message['content'] %} + {{content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {%- if ns.is_output_first %} + {{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- set ns.is_output_first = false %} + {%- else %} + {{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} + {%- endif %} +{%- endfor -%} + +{% if ns.is_tool %} + {{'<|tool▁outputs▁end|>'}} +{% endif %} + +{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %} + {{'<|Assistant|>'}} +{% endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekv31.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekv31.jinja new file mode 100644 index 00000000000..863be69d60b --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_deepseekv31.jinja @@ -0,0 +1,91 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} +{% if not thinking is defined %} + {% set thinking = false %} +{% endif %} +{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %} +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor %} + +{% if tools is defined and tools is not none %} + {% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %} + {% for tool in tools %} + {% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %} + {% endfor %} + {% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} +{% endif %} + +{{ bos_token }}{{ ns.system_prompt }} +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<|User|>' + message['content']}} + {%- endif %} + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- if ns.is_last_user %} + {{'<|Assistant|>'}} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- else %} + {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} + {%- if ns.is_last_user %} + {{'<|Assistant|>'}} + {%- if message['prefix'] is defined and message['prefix'] and thinking %} + {{''}} + {%- else %} + {{''}} + {%- endif %} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {%- set content = message['content'] -%} + {%- if '' in content %} + {%- set content = content.split('', 1)[1] -%} + {%- endif %} + {{content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} +{%- endfor -%} +{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %} + {{'<|Assistant|>'}} + {%- if not thinking %} + {{''}} + {%- else %} + {{''}} + {%- endif %} +{% endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_functiongemma.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_functiongemma.jinja new file mode 100644 index 00000000000..63b5d336a76 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_functiongemma.jinja @@ -0,0 +1,54 @@ +{%- set ns = namespace(developer_content='', has_tools=false) -%} + +{%- if tools is defined and tools | length > 0 -%} + {%- set ns.has_tools = true -%} +{%- endif -%} + +{%- for message in messages -%} + {%- if message.role == 'developer' or message.role == 'system' -%} +user +{{ message.content }} +{%- if ns.has_tools %} + +Available functions: +{%- for tool in tools %} +{%- if tool.type == 'function' %} + +Function: {{ tool.function.name }} +Description: {{ tool.function.description | default('No description provided') }} +Parameters: {{ tool.function.parameters | tojson }} +{%- endif %} +{%- endfor %} +{%- endif %} + + {%- elif message.role == 'user' -%} +user +{{ message.content }} + {%- elif message.role == 'assistant' -%} + {%- if message.tool_calls is defined and message.tool_calls | length > 0 -%} +model +{%- for tool_call in message.tool_calls %} +call:{{ tool_call.function.name }}{ +{%- set args = tool_call.function.arguments -%} +{%- if args is string -%} +{%- set args = args | fromjson -%} +{%- endif -%} +{%- for key, value in args.items() -%} +{{ key }}:{{ value }}{% if not loop.last %},{% endif %} +{%- endfor -%} +} +{%- endfor %} + + {%- else -%} +model +{{ message.content }} + {%- endif -%} + {%- elif message.role == 'tool' -%} +user +Function result for {{ message.name | default('function') }}: {{ message.content }} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} +model +{%- endif -%} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_gemma3_pythonic.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_gemma3_pythonic.jinja new file mode 100644 index 00000000000..5a20b019112 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_gemma3_pythonic.jinja @@ -0,0 +1,123 @@ +{#- Begin-of-sequence token to start the model prompt -#} +{{ bos_token }} +{#- Extracts the system message. Gemma does not support system messages so it will be prepended to first user message. -#} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{#- Set tools to none if not defined for this ChatCompletion request (helps avoid errors later) -#} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{#- Validate alternating user/assistant messages (excluding 'tool' messages and ones with tool_calls) -#} +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | selectattr("tool_calls", "undefined") -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} +{%- endfor -%} + +{#- Main loop over all messages in the conversation history -#} +{%- for message in loop_messages -%} + {#- Normalize roles for model prompt formatting -#} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- elif (message['role'] == 'tool') -%} + {%- set role = "user" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {#- Mark the start of a message block with the appropriate role -#} + {{ '' + role + '\n' -}} + + {#- Insert system message content (if present) at the beginning of the first message. -#} + {%- if loop.first -%} + {{ first_user_prefix }} + {#- Append system message with tool information if using tools in message request. -#} + {%- if tools is not none -%} + {{- "Tools (functions) are available. If you decide to invoke one or more of the tools, you must respond with a python list of the function calls.\n" -}} + {{- "Example Format: [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] \n" -}} + {{- "Do not use variables. DO NOT USE MARKDOWN SYNTAX. You SHOULD NOT include any other text in the response if you call a function. If none of the functions can be used, point it out. If you lack the parameters required by the function, also point it out.\n" -}} + {{- "Here is a list of functions in JSON format that you can invoke.\n" -}} + {{- tools | tojson(indent=4) -}} + {{- "\n\n" -}} + {%- endif -%} + {%- endif -%} + + {#- Format model tool calls (turns where model indicates they want to call a tool) -#} + {%- if 'tool_calls' in message -%} + {#- Opening bracket for tool call list. -#} + {{- '[' -}} + {#- For each tool call -#} + {%- for tool_call in message.tool_calls -%} + {#- Get tool call function. -#} + {%- if tool_call.function is defined -%} + {%- set tool_call = tool_call.function -%} + {%- endif -%} + {#- Function name & opening parenthesis. -#} + {{- tool_call.name + '(' -}} + + {#-- Handle arguments as list (positional) or dict (named) --#} + {#-- Named arguments (dict) --#} + {%- if tool_call.arguments is iterable and tool_call.arguments is mapping -%} + {%- set first = true -%} + {%- for key, val in tool_call.arguments.items() -%} + {%- if not first %}, {% endif -%} + {{ key }}={{ val | tojson }} + {%- set first = false -%} + {%- endfor -%} + {#-- Positional arguments (list) --#} + {%- elif tool_call.arguments is iterable -%} + {{- tool_call.arguments | map('tojson') | join(', ') -}} + {#-- Fallback: single positional value --#} + {%- else -%} + {{- tool_call.arguments | tojson -}} + {#-- Closing parenthesis. --#} + {%- endif -%} + {{- ')' -}} + {#-- If more than one tool call, place comma and move to formatting next tool call --#} + {%- if not loop.last -%}, {% endif -%} + {%- endfor -%} + {#- Closing bracket for tool call list. -#} + {{- ']' -}} + {%- endif -%} + + {#- Tool response start tag (for messages from a tool) -#} + {%- if (message['role'] == 'tool') -%} + {{ '\n' -}} + {%- endif -%} + + {#- Render the message content: handle plain string or multimodal content like image/text -#} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + + {#- Tool response end tag -#} + {%- if (message['role'] == 'tool') -%} + {{ '' -}} + {%- endif -%} + + {#- Mark end of a single turn -#} + {{ '\n' }} +{%- endfor -%} + +{#- If generation is to be triggered, add model prompt prefix -#} +{%- if add_generation_prompt -%} + {{'model\n'}} +{%- endif -%} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_gemma4.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_gemma4.jinja new file mode 100644 index 00000000000..15c5238ac33 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_gemma4.jinja @@ -0,0 +1,331 @@ +{%- macro format_parameters(properties, required) -%} + {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in properties | dictsort -%} + {%- set add_comma = false -%} + {%- if key not in standard_keys -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {{ key }}:{ + {%- if value['description'] -%} + description:<|"|>{{ value['description'] }}<|"|> + {%- set add_comma = true -%} + {%- endif -%} + {%- if value['nullable'] %} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + nullable:true + {%- endif -%} + {%- if value['type'] | upper == 'STRING' -%} + {%- if value['enum'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + enum:{{ format_argument(value['enum']) }} + {%- endif -%} + {%- elif value['type'] | upper == 'OBJECT' -%} + ,properties:{ + {%- if value['properties'] is defined and value['properties'] is mapping -%} + {{- format_parameters(value['properties'], value['required'] | default([])) -}} + {%- elif value is mapping -%} + {{- format_parameters(value, value['required'] | default([])) -}} + {%- endif -%} + } + {%- if value['required'] -%} + ,required:[ + {%- for item in value['required'] | default([]) -%} + <|"|>{{- item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- endif -%} + {%- elif value['type'] | upper == 'ARRAY' -%} + {%- if value['items'] is mapping and value['items'] -%} + ,items:{ + {%- set ns_items = namespace(found_first=false) -%} + {%- for item_key, item_value in value['items'] | dictsort -%} + {%- if item_value is not none -%} + {%- if ns_items.found_first %},{% endif -%} + {%- set ns_items.found_first = true -%} + {%- if item_key == 'properties' -%} + properties:{ + {%- if item_value is mapping -%} + {{- format_parameters(item_value, value['items']['required'] | default([])) -}} + {%- endif -%} + } + {%- elif item_key == 'required' -%} + required:[ + {%- for req_item in item_value -%} + <|"|>{{- req_item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- elif item_key == 'type' -%} + {%- if item_value is string -%} + type:{{ format_argument(item_value | upper) }} + {%- else -%} + type:{{ format_argument(item_value | map('upper') | list) }} + {%- endif -%} + {%- else -%} + {{ item_key }}:{{ format_argument(item_value) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + } + {%- endif -%} + {%- endif -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + type:<|"|>{{ value['type'] | upper }}<|"|>} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} +{%- macro format_function_declaration(tool_data) -%} + declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|> + {%- set params = tool_data['function']['parameters'] -%} + {%- if params -%} + ,parameters:{ + {%- if params['properties'] -%} + properties:{ {{- format_parameters(params['properties'], params['required']) -}} }, + {%- endif -%} + {%- if params['required'] -%} + required:[ + {%- for item in params['required'] -%} + <|"|>{{- item -}}<|"|> + {{- ',' if not loop.last -}} + {%- endfor -%} + ], + {%- endif -%} + {%- if params['type'] -%} + type:<|"|>{{- params['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + {%- if 'response' in tool_data['function'] -%} + {%- set response_declaration = tool_data['function']['response'] -%} + ,response:{ + {%- if response_declaration['description'] -%} + description:<|"|>{{- response_declaration['description'] -}}<|"|>, + {%- endif -%} + {%- if response_declaration['type'] | upper == 'OBJECT' -%} + type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + } +{%- endmacro -%} +{%- macro format_argument(argument, escape_keys=True) -%} + {%- if argument is string -%} + {{- '<|"|>' + argument + '<|"|>' -}} + {%- elif argument is boolean -%} + {{- 'true' if argument else 'false' -}} + {%- elif argument is mapping -%} + {{- '{' -}} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in argument | dictsort -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {%- if escape_keys -%} + {{- '<|"|>' + key + '<|"|>' -}} + {%- else -%} + {{- key -}} + {%- endif -%} + :{{- format_argument(value, escape_keys=escape_keys) -}} + {%- endfor -%} + {{- '}' -}} + {%- elif argument is sequence -%} + {{- '[' -}} + {%- for item in argument -%} + {{- format_argument(item, escape_keys=escape_keys) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- ']' -}} + {%- else -%} + {{- argument -}} + {%- endif -%} +{%- endmacro -%} +{%- macro strip_thinking(text) -%} + {%- set ns = namespace(result='') -%} + {%- for part in text.split('') -%} + {%- if '<|channel>' in part -%} + {%- set ns.result = ns.result + part.split('<|channel>')[0] -%} + {%- else -%} + {%- set ns.result = ns.result + part -%} + {%- endif -%} + {%- endfor -%} + {{- ns.result | trim -}} +{%- endmacro -%} + +{%- macro format_tool_response_block(tool_name, response) -%} + {{- '<|tool_response>' -}} + {%- if response is mapping -%} + {{- 'response:' + tool_name + '{' -}} + {%- for key, value in response | dictsort -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- '}' -}} + {%- else -%} + {{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}} + {%- endif -%} + {{- '' -}} +{%- endmacro -%} + +{%- set ns = namespace(prev_message_type=None) -%} +{%- set loop_messages = messages -%} +{{ bos_token }} +{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%} + {{- '<|turn>system\n' -}} + + {%- if enable_thinking is defined and enable_thinking -%} + {{- '<|think|>' -}} + {%- set ns.prev_message_type = 'think' -%} + {%- endif -%} + + {%- if messages[0]['role'] in ['system', 'developer'] -%} + {{- messages[0]['content'] | trim -}} + {%- set loop_messages = messages[1:] -%} + {%- endif -%} + + {%- if tools -%} + {%- for tool in tools %} + {{- '<|tool>' -}} + {{- format_function_declaration(tool) | trim -}} + {{- '' -}} + {%- endfor %} + {%- set ns.prev_message_type = 'tool' -%} + {%- endif -%} + + {{- '\n' -}} +{%- endif %} + +{%- set ns_turn = namespace(last_user_idx=-1) -%} +{%- for i in range(loop_messages | length) -%} + {%- if loop_messages[i]['role'] == 'user' -%} + {%- set ns_turn.last_user_idx = i -%} + {%- endif -%} +{%- endfor -%} + +{%- for message in loop_messages -%} + {%- if message['role'] != 'tool' -%} + {%- set ns.prev_message_type = None -%} + {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%} + {#- OpenAI may emit multiple assistant messages in one tool loop (user → asst → tool → asst → tool). + Only the first of those should open <|turn>model; later ones continue the same model turn. -#} + {%- set prev_nt = namespace(role=None, found=false) -%} + {%- if loop.index0 > 0 -%} + {%- for j in range(loop.index0 - 1, -1, -1) -%} + {%- if not prev_nt.found -%} + {%- if loop_messages[j]['role'] != 'tool' -%} + {%- set prev_nt.role = loop_messages[j]['role'] -%} + {%- set prev_nt.found = true -%} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%} + {%- if not continue_same_model_turn -%} + {{- '<|turn>' + role + '\n' }} + {%- endif -%} + + {%- if message.get('reasoning') and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%} + {{- '<|channel>thought\n' + message['reasoning'] + '\n'}} + {%- endif -%} + + {%- if message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {%- set function = tool_call['function'] -%} + {{- '<|tool_call>call:' + function['name'] + '{' -}} + {%- if function['arguments'] is mapping -%} + {%- set ns_args = namespace(found_first=false) -%} + {%- for key, value in function['arguments'] | dictsort -%} + {%- if ns_args.found_first %},{% endif -%} + {%- set ns_args.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {%- elif function['arguments'] is string -%} + {{- function['arguments'] -}} + {%- endif -%} + {{- '}' -}} + {%- endfor -%} + {%- set ns.prev_message_type = 'tool_call' -%} + {%- endif -%} + + {%- set ns_tr_out = namespace(flag=false) -%} + {%- if message.get('tool_responses') -%} + {#- Legacy: tool_responses embedded on the assistant message -#} + {%- for tool_response in message['tool_responses'] -%} + {{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}} + {%- set ns_tr_out.flag = true -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endfor -%} + {%- elif message.get('tool_calls') -%} + {#- OpenAI Chat Completions: consecutive following messages with role "tool" (no break/continue; range scan) -#} + {%- set ns_tool_scan = namespace(stopped=false) -%} + {%- for k in range(loop.index0 + 1, loop_messages | length) -%} + {%- if ns_tool_scan.stopped -%} + {%- elif loop_messages[k]['role'] != 'tool' -%} + {%- set ns_tool_scan.stopped = true -%} + {%- else -%} + {%- set follow = loop_messages[k] -%} + {%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%} + {%- for tc in message['tool_calls'] -%} + {%- if tc.get('id') == follow.get('tool_call_id') -%} + {%- set ns_tname.name = tc['function']['name'] -%} + {%- endif -%} + {%- endfor -%} + {%- set tool_body = follow.get('content') -%} + {%- if tool_body is string -%} + {{- format_tool_response_block(ns_tname.name, tool_body) -}} + {%- elif tool_body is sequence and tool_body is not string -%} + {%- set ns_txt = namespace(s='') -%} + {%- for part in tool_body -%} + {%- if part.get('type') == 'text' -%} + {%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%} + {%- endif -%} + {%- endfor -%} + {{- format_tool_response_block(ns_tname.name, ns_txt.s) -}} + {%- else -%} + {{- format_tool_response_block(ns_tname.name, tool_body) -}} + {%- endif -%} + {%- set ns_tr_out.flag = true -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + + {%- if message['content'] is string -%} + {%- if role == 'model' -%} + {{- strip_thinking(message['content']) -}} + {%- else -%} + {{- message['content'] | trim -}} + {%- endif -%} + {%- elif message['content'] is sequence -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'text' -%} + {%- if role == 'model' -%} + {{- strip_thinking(item['text']) -}} + {%- else -%} + {{- item['text'] | trim -}} + {%- endif -%} + {%- elif item['type'] == 'image' -%} + {{- '\n\n<|image|>\n\n' -}} + {%- set ns.prev_message_type = 'image' -%} + {%- elif item['type'] == 'audio' -%} + {{- '<|audio|>' -}} + {%- set ns.prev_message_type = 'audio' -%} + {%- elif item['type'] == 'video' -%} + {{- '\n\n<|video|>\n\n' -}} + {%- set ns.prev_message_type = 'video' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + + {%- if not (ns_tr_out.flag and not message.get('content')) -%} + {{- '\n' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {%- if ns.prev_message_type != 'tool_response' -%} + {{- '<|turn>model\n' -}} + {%- endif -%} + {%- if not enable_thinking | default(false) -%} + {{- '<|channel>thought\n' -}} + {%- endif -%} +{%- endif -%} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_glm4.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_glm4.jinja new file mode 100644 index 00000000000..11f76b4d4af --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_glm4.jinja @@ -0,0 +1,54 @@ +{%- set counter = namespace(index=0) -%} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{%- if messages and messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant." %} +{%- endif %} + +{%- if tools is not none %} + {%- set tool_instruction %} +You have access to the following tools. When you need to call a tool, you MUST use the following format: + +function_name +parameter_name +parameter_value + + +Important rules: +- Always wrap tool calls with ... tags +- Put the function name on the first line after +- Use and tags for each parameter +- If a parameter value is a string, keep it as-is. If it's a number or boolean, convert it appropriately +- You can make multiple tool calls if needed +- If no tool is suitable, respond with regular text + +Available tools: +{% endset %} + {{- tool_instruction + "\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} + +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- '[Round ' + counter.index|string + ']\n问:' + message['content'] -}} + {%- set counter.index = counter.index + 1 -%} + {%- endif -%} + {%- if message['role'] == 'assistant' -%} + {{- '\n答:' + message['content'] -}} + {%- if (loop.last and add_generation_prompt) or not loop.last -%} + {{- '\n' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%} + {{- '\n答:' -}} +{%- endif -%} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_granite.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_granite.jinja new file mode 100644 index 00000000000..467dcb2d102 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_granite.jinja @@ -0,0 +1,36 @@ +{%- if tools %} + {{- '<|start_of_role|>available_tools<|end_of_role|> +' }} + {%- for tool in tools %} + {{- tool | tojson(indent=4) }} + {%- if not loop.last %} + {{- ' + +' }} + {%- endif %} + {%- endfor %} + {{- '<|end_of_text|> +' }} +{%- endif %} + +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'user' %} + {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'assistant_tool_call' or (message['role'] == 'assistant' and message.tool_calls is defined) %} + {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message.tool_calls|map(attribute='function')|list|tojson(indent=4) + '<|end_of_text|> +' }} + {%- elif message['role'] == 'assistant' %} + {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'tool_response' or message['role'] == 'tool' %} + {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- endif %} + {%- if loop.last and add_generation_prompt %} + {{- '<|start_of_role|>assistant<|end_of_role|>' }} + {%- endif %} +{%- endfor %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_granite_20b_fc.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_granite_20b_fc.jinja new file mode 100644 index 00000000000..cb52188ec72 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_granite_20b_fc.jinja @@ -0,0 +1,130 @@ +{%- macro json_to_python_type(json_spec) %} + {%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + + {%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} + {%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]" }} + {%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }} + {%- else %} + {{- "dict" }} + {%- endif %} + {%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} + {%- else %} + {{- "Any" }} + {%- endif %} +{%- endmacro %} + +{%- if not full_function_description is defined %} + {%- set full_function_description = false %} +{%- endif %} + +{%- macro full_description(tool) %} + {{- tool.name + '(' }} + {%- if tool.parameters is defined %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + "\n\n" }} + {%- if tool.parameters is defined %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args:\n" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- endif %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- "\n Returns:\n " + tool.return.description }} + {%- endif %} + {{- '"' }} +{%- endmacro %} + +{%- macro simple_description(tool) %} + {{- tool.description }} +{%- endmacro %} + +{%- macro function_description(tool) %} + {%- if full_function_description %} + {{- full_description(tool) }} + {%- else %} + {{- simple_description(tool) }} + {%- endif %} +{%- endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set sys_prompt = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} + {% set sys_prompt = 'You are a helpful assistant with access to the following function calls. Your task is to understand the given conversation with function calls and responses and generate natural language response as the ASSISTANT to continue the conversation. You may use the following function calls to understand how to respond to the user query.' %} +{%- endif %} + +{{ 'SYSTEM: ' + sys_prompt }} +{% if tools is iterable and tools | length > 0 %} +<|function_call_library|> + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + function_description(tool) }} + {{- ', "parameters": ' }} + {%- if not tool.parameters is defined or tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- "\n" }} + {%- endif %} + {%- endfor %} +If none of the functions are relevant or the given question lacks the parameters required by the function, please output \" {\"name\": \"no_function\", \"arguments\": {}}\". +{%- endif %} + + + +{% for message in messages %} + {% if message['role'] == 'user' %} + {{- '\nUSER: ' + message['content'] }} + {% elif message['role'] == 'assistant' and message.tool_calls is defined %} + {{- '\nASSISTANT:' }} + {% for tc in message.tool_calls %} + {{- ' ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }} + {% endfor %} + {{- '<|endoftext|>' }} + {% elif message['role'] == 'assistant' %} + {{- '\nASSISTANT: ' + message['content'] + ' <|endoftext|>' }} + {% elif message['role'] == 'tool' %} + {{- ' ' + message['content'] }} + {%- else %} + {{- raise_exception("Unexpected combination of role and message content") }} + {% endif %} + {% if loop.last and add_generation_prompt %} + {{- '\nASSISTANT: ' }} + {% endif %} +{% endfor %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_hermes.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_hermes.jinja new file mode 100644 index 00000000000..0b0902c8e74 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_hermes.jinja @@ -0,0 +1,130 @@ +{%- macro json_to_python_type(json_spec) %} + {%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + + {%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} + {%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]" }} + {%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }} + {%- else %} + {{- "dict" }} + {%- endif %} + {%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} + {%- else %} + {{- "Any" }} + {%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- if tools is iterable and tools | length > 0 %} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + "\n\n" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args:\n" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- "\n Returns:\n " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- "\n" }} + {%- endif %} + {%- endfor %} +{%- endif %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|>' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" and message.tool_calls is defined %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- '\n\n' }} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {%- if tool_call.arguments is defined %} + {{- ', ' }} + {{- '"arguments": ' }} + {{- tool_call.arguments|tojson }} + {%- endif %} + {{- '}' }} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {%- if not loop.last %} + {{- '\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_hunyuan_a13b.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_hunyuan_a13b.jinja new file mode 100644 index 00000000000..a0808e44858 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_hunyuan_a13b.jinja @@ -0,0 +1,113 @@ +{% set loop_messages = messages %} +{% if tools %} + {% set weekday_map = {'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三', 'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'} %} + {% set weekday_cn = weekday_map[strftime_now('%A')] %} + {% set datetime_str = strftime_now('%Y-%m-%d %H:%M:%S') %} + {% set datetime_str = datetime_str + ' ' + weekday_cn %} + {% for message in loop_messages %} + {% if 'content' in message %} + {% set content = message['content'] %} + {% else %} + {% set content = '' %} + {% endif %} + {% if loop.index0 == 0 %} + {% set content_tmp = '你是一位函数组合专家。你会得到一个问题和一组可能的函数。根据问题,你需要进行一个或多个函数/工具调用以实现目的。 +如果没有一个函数可以使用,请直接使用自然语言回复用户,以助手:开头。 +如果给定的问题缺少函数所需的参数,请使用自然语言进行提问,向用户询问必要信息,以助手:开头。 +如果调用结果已经足够回答用户问题,请对历史结果进行总结,使用自然语言回复用户,以助手:开头。 +你应该只在工具调用部分返回函数调用。如果你决定调用任何函数,你必须将其格式化为[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]。你不应该在回复中包含任何其他文本。以下是你可以调用的函数列表,格式为JSON。 +' %} + {% set content_tmp = content_tmp + ' +' + tools | tojson + ' +' %} + {% if message['role'] == 'system' %} + {% set content_tmp = content_tmp + ' +额外要求: +' + content + ' + +如果你决定返回函数调用,请将其格式化为[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...],不得包含其他文本。如果额外要求里有格式要求,请忽略,以此处为准。 +否则,请参考开头说的三种情况,以助手:开头进行回复。 + +如果额外要求里有时间信息,就以额外要求里的时间为准,否则,参考当前时间:' + datetime_str %} + {% set content = '<|startoftext|>' + content_tmp + '<|extra_4|>' %} + {% elif message['role'] == 'user' %} + {% set content_tmp = content_tmp + ' +如果你决定返回函数调用,请将其格式化为[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...],不得包含其他文本。 +否则,请参考开头说的三种情况,以助手:开头进行回复。 + +当前时间:' + datetime_str %} + {% set content_tmp = '<|startoftext|>' + content_tmp + '<|extra_4|>'%} + {% set content = content_tmp + '用户:' + content + '<|extra_0|>' %} + {% endif %} + {% else %} + {% if message['role'] == 'user' %} + {% set content = '用户:' + content + '<|extra_0|>' %} + {% elif message['role'] == 'assistant' %} + {% if 'tool_calls' in message %} + {% set tool_calls = message['tool_calls'] %} + {% set ns = namespace(tool_calls="[") %} + {% for tool_call in tool_calls %} + {% set function = tool_call['function'] %} + {% set name = function['name'] %} + {% set ns.tool_calls = ns.tool_calls + '{"name": "' + name + '", '%} + {% set arguments = function['arguments'] %} + {% if arguments is not string %} + {% set arguments = arguments | tojson %} + {% endif %} + {% set ns.tool_calls = ns.tool_calls + '"arguments": ' + arguments + '}' %} + {% if not loop.last %} + {% set ns.tool_calls = ns.tool_calls + ', '%} + {% endif %} + {% endfor %} + {% set ns.tool_calls = ns.tool_calls + ']' %} + {% set content = content + '' + ns.tool_calls + '' %} + {% else %} + {% set content = '助手:' + content %} + {% endif %} + {% set content = content + '<|eos|>' %} + {% elif message['role'] == 'tool' %} + {% if content is not string %} + {set content = content | tojson } + {% endif %} + {% set content = '' + content + '' %} + {% set content = content + '<|extra_0|>' %} + {% endif %} + {% endif %} + {{- content -}} + {% endfor %} +{% else %} + {% set context = {'has_head': true} %} + {% for message in loop_messages %} + {% if 'content' in message %} + {% set content = message['content'] %} + {% else %} + {% set content = '' %} + {% endif %} + {% if loop.index0 == 0 %} + {% if content == '' %} + {% set _ = context.update({'has_head': false}) %} + {% elif message['role'] == 'system' %} + {% set content = '<|startoftext|>' + content + '<|extra_4|>' %} + {% endif %} + {% endif %} + {% if message['role'] == 'user' %} + {% if loop.index0 == 1 and not context.has_head %} + {% set content = '<|startoftext|>' + content %} + {% endif %} + {% if loop.index0 == 1 and context.has_head %} + {% set content = content + '<|extra_0|>' %} + {% else %} + {% set content = '<|startoftext|>' + content + '<|extra_0|>' %} + {% endif %} + {% elif message['role'] == 'assistant' %} + {% set content = content + '<|eos|>' %} + {% elif message['role'] == 'tool' %} + {% set content = content + '<|extra_0|>' %} + {% endif %} + {{- content -}} + {% endfor %} +{% endif %} +{%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n' }} +{%- endif %} + diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_internlm2_tool.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_internlm2_tool.jinja new file mode 100644 index 00000000000..ac99666e93b --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_internlm2_tool.jinja @@ -0,0 +1,60 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{{- bos_token }} +{%- if system_message is defined %} +{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }} +{%- endif %} + +{%- if tools is not none %} + {{- "<|im_start|>system name=<|plugin|>\n[" }} + {%- for tool in tools %} + {{- tool.function|tojson }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "<|im_end|>\n" }} +{%- endif %} + +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}} + {%- elif message.tool_calls is defined and message.tool_calls is not none %} + {%- set content = message["content"] if message["content"] else "" %} + {{- "<|im_start|>assistant\n" + content }} + {%- for tool_call in message.tool_calls %} + {%- set function=tool_call.function %} + {{- "<|action_start|><|plugin|>\n" }} + {{- '{"name": "' + function.name + '", '}} + {{- '"arguments": ' + function.arguments|tojson + '}' }} + {{- "<|action_end|>" }} + {%- endfor %} + {{- "<|im_end|>\n" }} + {%- elif message["role"] == "assistant" %} + {{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }} + {%- else %} + {{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} +{{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.1_json.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.1_json.jinja new file mode 100644 index 00000000000..033830936a5 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.1_json.jinja @@ -0,0 +1,120 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {#- Llama 3.1 doesn't pass all tests if the tools are in the system prompt #} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content']|trim %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text']|trim %} + {%- endif %} + {%- set messages = messages[1:] %} +{%- else %} + {%- if tools is not none %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} +{%- endif %} + +{#- System message #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- if messages[0]['content'] is string %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- else %} + {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} + {%- endif %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] | trim}} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- content['text'] | trim }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {{- "<|eot_id|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is string %} + {{- { "output": message.content } | tojson }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- { "output": content['text'] } | tojson }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.2_json.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.2_json.jinja new file mode 100644 index 00000000000..2b290c0eede --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.2_json.jinja @@ -0,0 +1,133 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = false %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- Find out if there are any images #} +{% set image_ns = namespace(has_images=false) %} +{%- for message in messages %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {%- set image_ns.has_images = true %} + {%- endif %} + {%- endfor %} +{%- endfor %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content']|trim %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text']|trim %} + {%- endif %} + {%- set messages = messages[1:] %} +{%- else %} + {%- if tools is not none %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} +{%- endif %} + +{#- System message if there are no images, if the user supplied one, or if tools are used (default tool system message) #} +{%- if system_message or not image_ns.has_images %} + {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} + {%- if tools is not none %} + {{- "Environment: ipython\n" }} + {%- endif %} + {{- "Cutting Knowledge Date: December 2023\n" }} + {{- "Today Date: " + date_string + "\n\n" }} + {%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {%- endif %} + {{- system_message }} + {{- "<|eot_id|>" }} +{%- endif %} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- if messages[0]['content'] is string %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- else %} + {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} + {%- endif %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] | trim}} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] | trim }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {{- "<|eot_id|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is string %} + {{- { "output": message.content } | tojson }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- { "output": content['text'] } | tojson }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.2_pythonic.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.2_pythonic.jinja new file mode 100644 index 00000000000..e4ec2353b35 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama3.2_pythonic.jinja @@ -0,0 +1,98 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = false %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} +{%- endif %} + +{#- System message #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call functions, please respond with a python list of the calls. " }} + {{- 'Respond in the format [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a python list for function calls " }} + {{- "with their proper arguments to best answer the given prompt.\n\n" }} + {{- 'Respond in the format [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n[' -}} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- tool_call.name + '(' -}} + {%- for param in tool_call.arguments %} + {{- param + '=' -}} + {{- "%s" | format(tool_call.arguments[param]) -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ')' -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ']<|eot_id|>' -}} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping %} + {{- message.content | tojson }} + {%- else %} + {{- { "output": message.content } | tojson }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama4_json.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama4_json.jinja new file mode 100644 index 00000000000..759f1655443 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama4_json.jinja @@ -0,0 +1,116 @@ +{%- macro is_array_of_type_objects(var) -%} + {%- if var is iterable and var is not string -%} + {%- set valid = true -%} + {%- for item in var -%} + {%- if 'type' not in item -%} + {%- set valid = false -%} + {%- break -%} + {%- endif -%} + {%- endfor -%} + {{ valid }} + {%- else -%} + {{ false }} + {%- endif -%} +{%- endmacro %} + +{%- macro render_message(message) %} + {%- if message['content'] is string %} + {{- message['content']|trim }} + {%- elif is_array_of_type_objects(data) == 'True' %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text']|trim }} + {%- endif %} + {%- endfor %} + {%- else %} + {{- message['content']|tojson }} + {%- endif %} +{%- endmacro %} + +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0] %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = ({ "content": "You are a helpful assistant with tool calling " + "capabilities. Only reply with a tool call if the function exists in the " + "library provided by the user. If it doesn't exist, just reply directly in " + "natural language. When you receive a tool call response, use the output to " + "format an answer to the original user question."}) %} +{%- endif %} + +{%- set tool_lib_preamble = 'Tools: You have access to the following tools. You might need to use one ' + 'or more function/tool calls to fulfill the task. \n' + 'If none are needed, then proceed to the response.\n\n' + 'Tool Call Syntax: You can call tools using the following syntax:\n' + '{"name": function name, "parameters": dictionary of argument name and its value}.\n' + 'Separate multiple function calls by "; ". Do not use variables.\n' + 'Do not include anything else when calling the tools with the syntax above.\n\n' + 'Here is a list of functions in JSON format that you can invoke.\n' %} + +{{- "<|header_start|>system<|header_end|>\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- tool_lib_preamble }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- render_message(system_message) }} +{{ "<|eot|>\n" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0] %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|header_start|>user<|header_end|>\n\n' }} + {{- tool_lib_preamble }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- render_message(first_user_message) + "\n<|eot|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {{- render_message(message) }} + {{- "\n<|eot|>" }} + {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %} + {{- '\n<|header_start|>assistant<|header_end|>\n\n' -}} + {{- render_message(message) }} + {%- for tool_call in message.tool_calls %} + {{- '{"name": "' + tool_call.function.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.function.arguments | tojson }} + {{- "}" }} + {%- endfor %} + {{- "\n<|eot|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "\n<|header_start|>ipython<|header_end|>\n\n" }} + {{- render_message(message) }} + {{- "\n<|eom|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '\n<|header_start|>assistant<|header_end|>\n\n' }} +{%- endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama4_pythonic.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama4_pythonic.jinja new file mode 100644 index 00000000000..bbed3d8205e --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_llama4_pythonic.jinja @@ -0,0 +1,111 @@ +{{- bos_token }} +{%- if custom_tools is defined and custom_tools%} + {%- set tools = custom_tools %} +{%- endif %} +{%- if tools is defined and tools %} + {%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %} +{%- else %} + {%- set tools = none %} +{%- endif %} + + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set user_provided_system_message = true %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content']|trim %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text']|trim %} + {%- endif %} + {%- set messages = messages[1:] %} +{%- else %} + {%- if tools is not none %} + {#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #} + {#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/ #} + {%- set system_message = "You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If a function is not in the list, respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location=\"Vancouver\"), calculate_route(start=\"Boston\", end=\"New York\")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location=\"New York\")\nINCORRECT: Let me check the weather: [get_weather(location=\"New York\")]\nINCORRECT: [get_events(location=\"Singapore\")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don't repeat tool response verbatim\n- Don't add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n" %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} +{%- endif %} +{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #} +{%- if system_message %} + {#- always use user provided system message to override default tool system message #} + {{- "<|header_start|>system<|header_end|>\n\n" }} + {{- system_message }} + {%- if user_provided_system_message and tools %} + {{- "\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n" }} + {{- tool_definition -}} + {%- elif tool_definition %} + {{- tool_definition -}} + {%- endif %} + {{- "<|eot|>" }} +{%- endif %} + +{#- Now deal with all other messages #} +{%- for message in messages %} + {#- Base case: messages that are not from tool role and has empty tool_call list #} + {%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] | trim }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- "<|eot|>" }} + {#- Tool case: messages has non-empty tool_call list, must from assistant #} + {%- elif 'tool_calls' in message %} + {#- assume tool_calls are always coming from assistant #} + {%- if message.role == 'assistant' %} + {{- '<|header_start|>assistant<|header_end|>\n\n' -}} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- "[" }} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- tool_call.name + '(' -}} + {%- for param in tool_call.arguments %} + {{- param + '="' -}} + {{- "%s" | format(tool_call.arguments[param]) -}} + {{- '"' -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ')' -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- "]<|eot|>" }} +{%- endif %} +{#- Tool_response case: messages are from tool_response #} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|header_start|>ipython<|header_end|>\n\n" }} + {%- if message.content is string %} + {{- message.content | tojson }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- content['text'] | tojson }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- "<|eot|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|header_start|>assistant<|header_end|>\n\n' }} +{%- endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_minimax_m1.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_minimax_m1.jinja new file mode 100644 index 00000000000..2d5bbf4de56 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_minimax_m1.jinja @@ -0,0 +1,91 @@ +{{ '' -}} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- Extract system message #} +{% set ns = namespace(system_prompt='') -%} +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set ns.system_prompt = messages[0]['content']|trim %} + {%- else %} + {%- set ns.system_prompt = messages[0]['content'][0]['text']|trim %} + {%- endif %} + {%- set messages = messages[1:] %} +{%- else %} + {%- if tools is not none %} + {%- set ns.system_prompt = "You are a helpful assistant created by Minimax based on MiniMax-M1 model." %} + {%- else %} + {%- set ns.system_prompt = "You are a helpful assistant created by Minimax based on MiniMax-M1 model." %} + {%- endif %} +{%- endif %} + +{#- System message #} +{%- if ns.system_prompt != '' %} +{{ 'system ai_setting=assistant\n' + ns.system_prompt + '\n' -}} +{%- endif %} + +{#- Tools configuration #} +{%- if tools is not none %} +{{ 'system tool_setting=tools\nYou are provided with these tools:\n\n' -}} +{%- for tool in tools %} +{{ tool | tojson ~ '\n' -}} +{%- endfor %} +{{ '\n\nIf you need to call tools, please respond with XML tags, and provide tool-name and json-object of arguments, following the format below:\n\n{"name": , "arguments": }\n...\n\n' -}} +{%- endif %} + +{#- Process messages #} +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {%- if message['role'] == 'user' %} +{{ 'user name=user\n' -}} +{%- if message['content'] is string %} +{{ message['content']|trim -}} +{%- else %} +{%- for content in message['content'] %} +{%- if content['type'] == 'text' %} +{{ content['text']|trim -}} +{%- endif %} +{%- endfor %} +{%- endif %} +{{ '\n' -}} + {%- elif message['role'] == 'assistant' %} +{{ 'ai name=assistant\n' -}} +{%- if message['content'] is string %} +{{ message['content']|trim -}} +{%- else %} +{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %} +{{ content['text']|trim -}} +{%- endfor %} +{%- endif %} +{{ '\n' -}} + {%- endif %} + {%- elif 'tool_calls' in message %} +{{ 'ai name=assistant\n\n' -}} +{%- for tool_call in message.tool_calls %} +{{ '{"name": "' + tool_call.function.name + '", "arguments": ' + tool_call.function.arguments | tojson + '}\n' -}} +{%- endfor %} +{{ '\n' -}} + {%- elif message.role == "tool" or message.role == "ipython" %} +{{ 'tool name=tools\n' -}} +{%- if message.content is string %} +{{ 'tool result: ' + message.content + '\n\n' -}} +{%- else %} +{%- for content in message['content'] %} +{%- if content['type'] == 'text' %} +{{ 'tool result: ' + content['text'] + '\n\n' -}} +{%- elif content.get('name') %} +{{ 'tool name: ' + content['name'] + '\ntool result: ' + content['text'] + '\n\n' -}} +{%- endif %} +{%- endfor %} +{%- endif %} +{{ '\n' -}} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} +{{ 'ai name=assistant\n' -}} +{%- endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral.jinja new file mode 100644 index 00000000000..49691f59c2f --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral.jinja @@ -0,0 +1,86 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral3.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral3.jinja new file mode 100644 index 00000000000..7c4249ec44c --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral3.jinja @@ -0,0 +1,126 @@ +{%- set today = strftime_now("%Y-%m-%d") %} +{%- set default_system_message = "You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYour knowledge base was last updated on 2023-10-01. The current date is " + today + ".\n\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\")" %} + +{{- bos_token }} + +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content'] %} + {%- set loop_messages = messages[1:] %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text'] %} + {%- set loop_messages = messages[1:] %} + {%- endif %} +{%- else %} + {%- set system_message = default_system_message %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- elif tools is not none %} + {%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %} + {%- if system_message is defined %} + {%- set system_message = parallel_tool_prompt + "\n\n" + system_message %} + {%- else %} + {%- set system_message = parallel_tool_prompt %} + {%- endif %} +{%- endif %} +{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }} + +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- set filtered_messages = [] %} +{%- for message in loop_messages %} + {%- if message["role"] not in ["tool", "tool_results"] and not message.get("tool_calls") %} + {%- set filtered_messages = filtered_messages + [message] %} + {%- endif %} +{%- endfor %} + +{%- for message in filtered_messages %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if message['content'] is string %} + {{- '[INST]' + message['content'] + '[/INST]' }} + {%- else %} + {{- '[INST]' }} + {%- for block in message['content'] %} + {%- if block['type'] == 'text' %} + {{- block['text'] }} + {%- elif block['type'] == 'image' or block['type'] == 'image_url' %} + {{- '[IMG]' }} + {%- else %} + {{- raise_exception('Only text and image blocks are supported in message content!') }} + {%- endif %} + {%- endfor %} + {{- '[/INST]' }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message['role'] == 'assistant' %} + {%- if message['content'] is string %} + {{- message['content'] + eos_token }} + {%- else %} + {{- message['content'][0]['text'] + eos_token }} + {%- endif %} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral_parallel.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral_parallel.jinja new file mode 100644 index 00000000000..2ef4bedf862 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_mistral_parallel.jinja @@ -0,0 +1,93 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- elif tools is not none %} + {%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %} + {%- if system_message is defined %} + {%- set system_message = parallel_tool_prompt + "\n\n" + system_message %} + {%- else %} + {%- set system_message = parallel_tool_prompt %} + {%- endif %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_phi4_mini.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_phi4_mini.jinja new file mode 100644 index 00000000000..6f40c38c206 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_phi4_mini.jinja @@ -0,0 +1,62 @@ +{%- if messages and messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant." %} +{%- endif %} + +{%- if messages %} +<|system|> +{{ system_message }} +{%- if tools %} +In addition to plain text responses, you can choose to call one or more of the provided functions. + +Use the following rule to decide when to call a function: + * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so + * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls + +If you decide to call functions: + * prefix function calls with functools marker (no closing marker required) + * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] + * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples + * respect the argument type formatting. E.g., if the type is number and format is float, write value 7 as 7.0 + * make sure you pick the right functions that match the user intent + + + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %}<|end|> + + {%- for message in messages %} + {%- if message.role != "system" %} +<|{{ message.role }}|> + {%- if message.content and message.role == "tools" %} +{"result": {{ message.content }}} + {%- elif message.content %} +{{ message.content }} + {%- elif message.tool_calls %} + {%- for call in message.tool_calls %} +{"name": "{{ call.function.name }}", "arguments": {{ call.function.arguments }}} + {%- if not loop.last %},{% endif %} + {%- endfor %} + {%- endif %}<|end|> + {%- endif %} + {%- endfor %}<|assistant|> + +{%- else %} + {%- if system_message %} +<|system|> + +{{ system_message }}<|end|> + {%- endif %} + {%- if prompt %} +<|user|> + +{{ prompt }}<|end|> + {%- endif %}<|assistant|> + +{%- endif %} +{{ response }} +{%- if response %}<|user|>{% endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_qwen3coder.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_qwen3coder.jinja new file mode 100644 index 00000000000..49b0e8d0ee7 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_qwen3coder.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {{- '\n\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_toolace.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_toolace.jinja new file mode 100644 index 00000000000..da0f25cdcb3 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_toolace.jinja @@ -0,0 +1,65 @@ +{{- bos_token }} + +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language." %} +{%- endif %} + +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You are an expert in composing functions. You are given a question and a set of possible functions. Based on the question, you will need to make one or more function/tool calls to achieve the purpose.\n" }} + {{- "If none of the function can be used, point it out. If the given question lacks the parameters required by the function, also point it out.\n" }} + {{- "You should only return the function call in tools call sections.\n\n" }} + {{- "If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]\n" }} + {{- "You SHOULD NOT include any other text in the response.\n" }} + {{- "Here is a list of functions in JSON format that you can invoke.\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- "\n" }} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n[' -}} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- tool_call.name + '(' -}} + {%- for param in tool_call.arguments %} + {{- param + '=' -}} + {{- "%s" | format(tool_call.arguments[param]) -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ')' -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ']<|eot_id|>' -}} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping %} + {{- message.content | tojson }} + {%- else %} + {{- { "output": message.content } | tojson }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} + +{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_xlam_llama.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_xlam_llama.jinja new file mode 100644 index 00000000000..f97de4004f1 --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_xlam_llama.jinja @@ -0,0 +1,77 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- Extract system message #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content'] | trim %} + {%- set messages = messages[1:] %} + {{- system_message + "\n" }} +{%- else %} + {%- set system_message = "You are a helpful assistant. You are developed by Salesforce xLAM team." %} + {% set format_instruction %}You have access to a set of tools. When using tools, make calls in a single JSON array: + +[{"name": "tool_call_name", "arguments": {"arg1": "value1", "arg2": "value2"}}, ... (additional parallel tool calls as needed)] + +If no tool is suitable, state that explicitly. If the user's input lacks required parameters, ask for clarification. Do not interpret or respond until tool results are returned. Once they are available, process them or make additional calls if needed. For tasks that don't require tools, such as casual conversation or general advice, respond directly in plain text. The available tools are:{% endset %} + {{- system_message + "\n" }} + {%- if tools is not none %} + {{- format_instruction + "\n\n" }} + {%- endif %} +{%- endif %} + + +{%- if tools is not none %} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- "<|eot_id|>" }} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {%- if message['tool_calls'] %} + {{- "[" }} + {%- for tool_call_function in message.tool_calls %} + {%- set tool_call = tool_call_function.function %} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "]" }} + {{- "<|eot_id|>" }} + {%- elif message['content'] %} + {{- message['content'] | trim + '<|eot_id|>' }} + {%- else %} + {{- "[]\n" + '<|eot_id|>' }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>" + "ipython" + "<|end_header_id|>\n\n" }} + {%- set content = message["content"] %} + {%- if content is mapping or (content is iterable and content is not string) %} + {{- content | tojson }} + {%- else %} + {{- content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} \ No newline at end of file diff --git a/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_xlam_qwen.jinja b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_xlam_qwen.jinja new file mode 100644 index 00000000000..acf57cc4b2c --- /dev/null +++ b/rust/src/chat/tests/templates/vllm_examples/tool_chat_template_xlam_qwen.jinja @@ -0,0 +1,66 @@ +{# System message #} +{{- "<|im_start|>system\n" }} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content'] | trim %} + {%- set messages = messages[1:] %} + {{- system_message + "\n" }} +{%- else %} + {%- set system_message = "You are a helpful assistant. You are developed by Salesforce xLAM team." %} + {% set format_instruction %}You have access to a set of tools. When using tools, make calls in a single JSON array: + +[{"name": "tool_call_name", "arguments": {"arg1": "value1", "arg2": "value2"}}, ... (additional parallel tool calls as needed)] + +If no tool is suitable, state that explicitly. If the user's input lacks required parameters, ask for clarification. Do not interpret or respond until tool results are returned. Once they are available, process them or make additional calls if needed. For tasks that don't require tools, such as casual conversation or general advice, respond directly in plain text. The available tools are:{% endset %} + {{- system_message + "\n" }} + {%- if tools is not none %} + {{- format_instruction + "\n\n" }} + {%- endif %} +{%- endif %} + +{%- if tools is not none %} + {%- for func in tools %} + {{- func | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- "<|im_end|>\n" }} +{%- for message in messages %} + {%- if message['role'] == 'tool' %} + {{- "<|im_start|>tool\n" }} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {%- if content is mapping or content is iterable and content is not string %} + {{- content | tojson }} + {%- else %} + {{- content }} + {%- endif %} + {{- "<|im_end|>\n" }} + {%- elif 'tool_calls' in message %} + {{- "<|im_start|>assistant\n" }} + {%- if message['tool_calls'] %} + {{- "[" }} + {%- for tool_call in message.tool_calls %} + {%- set out = tool_call.function | tojson %} + {{- out }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "]"}} + {%- elif message['content'] %} + {{- message['content'] | trim }} + {%- else %} + {{- "[]\n" }} + {%- endif %} + {{- "<|im_end|>\n" }} + {%- else %} + {{- "<|im_start|>" + message['role'] + "\n" + message['content'] | trim + "<|im_end|>\n" }} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- "<|im_start|>assistant\n" }} +{%- endif %} diff --git a/rust/src/cmd/Cargo.toml b/rust/src/cmd/Cargo.toml new file mode 100644 index 00000000000..b684d072202 --- /dev/null +++ b/rust/src/cmd/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "vllm-cmd" +version.workspace = true +edition.workspace = true +license.workspace = true + +[[bin]] +name = "vllm-rs" +path = "src/main.rs" + +[features] +default = [] +native-tls-vendored = ["dep:native-tls-vendored"] + +[dependencies] +anyhow.workspace = true +clap.workspace = true +educe.workspace = true +itertools.workspace = true +native-tls-vendored = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true +serde_with.workspace = true +thiserror-ext.workspace = true +time.workspace = true +tokio = { workspace = true, features = ["signal"] } +tokio-util.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true +uuid.workspace = true +vllm-engine-core-client.workspace = true +vllm-managed-engine.workspace = true +vllm-server.workspace = true + +[dev-dependencies] +expect-test.workspace = true + +[lints] +workspace = true diff --git a/rust/src/cmd/examples/README.md b/rust/src/cmd/examples/README.md new file mode 100644 index 00000000000..ad328a73183 --- /dev/null +++ b/rust/src/cmd/examples/README.md @@ -0,0 +1,44 @@ +# `vllm-rs` CLI Quick Start + +Start Qwen3 with one managed `vllm-rs serve` command from the repo root: + +```bash +HF_HUB_OFFLINE=1 \ +VLLM_CPU_KVCACHE_SPACE=2 \ +VLLM_HOST_IP=127.0.0.1 \ +VLLM_LOOPBACK_IP=127.0.0.1 \ +cargo run --bin vllm-rs -- serve \ + Qwen/Qwen3-0.6B \ + --python ../vllm/.venv/bin/python \ + --max-model-len 512 \ + -- \ + --dtype float16 +``` + +This launches: + +- a managed headless Python `vllm` engine +- the Rust OpenAI-compatible frontend on `127.0.0.1:8000` + +All Python engine arguments must be placed after `--`. Arguments before `--` are parsed by the Rust +frontend itself. + +You can then send OpenAI-style requests to the Rust frontend: + +```bash +curl http://127.0.0.1:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-0.6B", + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "stream": true + }' +``` + +If you already started headless `vllm` yourself, use `frontend` instead: + +```bash +cargo run --bin vllm-rs -- frontend \ + --handshake-address tcp://127.0.0.1:62100 \ + Qwen/Qwen3-0.6B +``` diff --git a/rust/src/cmd/src/cli.rs b/rust/src/cmd/src/cli.rs new file mode 100644 index 00000000000..70ac8440453 --- /dev/null +++ b/rust/src/cmd/src/cli.rs @@ -0,0 +1,434 @@ +//! CLI argument definitions for the `vllm-rs` binary. +//! +//! Python vLLM references: +//! - Engine args: +//! - Environment variables: + +mod unsupported; + +use std::collections::HashMap; +use std::ffi::{OsStr, OsString}; +use std::path::PathBuf; +use std::time::Duration; + +use clap::{Args, Parser, Subcommand}; +use educe::Educe; +use serde::Deserialize; +use serde::de::DeserializeOwned; +use serde_json::Value; +use thiserror_ext::AsReport as _; +use uuid::Uuid; +use vllm_engine_core_client::TransportMode; +use vllm_managed_engine::ManagedEngineConfig; +use vllm_managed_engine::cli::{ManagedEngineArgs, repartition_managed_engine_args}; +use vllm_server::{ + ChatTemplateContentFormatOption, Config, CoordinatorMode, HttpListenerMode, ParserSelection, + RendererSelection, +}; + +use crate::cli::unsupported::UnsupportedArgs; + +/// Top-level parser for the `vllm-rs` binary. +#[derive(Debug, Parser)] +#[command( + name = "vllm-rs", + about = "Rust frontend and managed-engine CLI for vLLM." +)] +pub struct Cli { + #[command(subcommand)] + pub command: Command, +} + +impl Cli { + pub fn parse() -> Self { + Self::try_parse_from(std::env::args_os()).unwrap_or_else(|error| error.exit()) + } + + pub fn try_parse_from(itr: I) -> Result + where + I: IntoIterator, + T: Into, + { + let args: Vec = itr.into_iter().map(Into::into).collect(); + let repartitioned_args = repartition_managed_engine_args::(&args, Some("serve"))?; + ::try_parse_from(&repartitioned_args).inspect(|cli| { + if let Command::Serve(serve) = &cli.command + && serve.debug_cli + { + println!( + "Original CLI args: {}\n", + args.join(OsStr::new(" ")).display() + ); + println!( + "Repartitioned CLI args: {}\n", + repartitioned_args.join(OsStr::new(" ")).display() + ); + println!( + "Passthrough Python args: {}", + serve.managed_engine.python_args.join(" ") + ); + std::process::exit(0); + } + }) + } +} + +/// Supported top-level CLI commands. +#[derive(Debug, Subcommand, PartialEq, Eq)] +pub enum Command { + /// Run the Rust OpenAI frontend as a Python-supervised worker. + Frontend(FrontendArgs), + /// Launch a managed Python headless engine, then run the Rust OpenAI + /// frontend. + Serve(ServeArgs), +} + +/// Runtime arguments shared by the external-engine and managed-engine paths. +#[derive(Educe, Clone, Args, PartialEq, Eq, Deserialize)] +#[educe(Debug)] +pub struct SharedRuntimeArgs { + #[serde(rename = "model_tag")] + /// Model identifier or local model directory used for backend loading and + /// public model ID. + pub model: String, + + /// Maximum time to wait for the expected engines to register on the + /// frontend transport. + #[arg( + long = "engine-ready-timeout-secs", + env = "VLLM_ENGINE_READY_TIMEOUT_S", + default_value_t = default_engine_ready_timeout_secs() + )] + #[serde(default = "default_engine_ready_timeout_secs")] + pub engine_ready_timeout_secs: u64, + + /// Select the tool call parser depending on the model that you're using. + /// Use `auto` to infer from the model or `none` to disable parsing. + #[arg(long, default_value_t)] + #[serde(default)] + pub tool_call_parser: ParserSelection, + /// Select the reasoning parser depending on the model that you're using. + /// Use `auto` to infer from the model or `none` to disable parsing. + #[arg(long, default_value_t)] + #[serde(default)] + pub reasoning_parser: ParserSelection, + /// Select the chat renderer implementation. + #[arg(long = "tokenizer-mode", default_value_t)] + #[serde(default, rename = "tokenizer_mode")] + pub renderer: RendererSelection, + /// Override the maximum model context length. When set, the frontend uses + /// this value instead of the model's `max_position_embeddings` from + /// `config.json`. + #[arg(long)] + pub max_model_len: Option, + /// TCP port for the gRPC Generate service. When not set, no gRPC server is + /// started. + #[arg(long)] + #[serde(default)] + pub grpc_port: Option, + /// Maximum time to wait for active requests to drain during shutdown. + #[arg(long, default_value_t = 0)] + #[serde(default)] + pub shutdown_timeout: u64, + + /// The file path to the chat template, or the template in single-line form + /// for the specified model. + #[arg(long)] + #[serde(default)] + pub chat_template: Option, + + /// Default keyword arguments to pass to the chat template renderer. + /// + /// These will be merged with request-level chat_template_kwargs, with + /// request values taking precedence. Useful for setting default + /// behavior for reasoning models. + /// + /// Example: `{"enable_thinking": false}` to disable thinking mode by + /// default for Qwen3/DeepSeek models. + #[arg(long, value_parser = parse_json::>, value_name = "JSON")] + #[serde(default)] + pub default_chat_template_kwargs: Option>, + + /// The format to render message content within a chat template. + /// + /// * "auto" detects the format from the template + /// * "string" renders content as a string. Example: `"Hello World"` + /// * "openai" renders content as a list of dictionaries, similar to OpenAI schema. Example: + /// `[{"type": "text", "text": "Hello world!"}]` + #[arg(long, default_value_t)] + #[serde(default)] + pub chat_template_content_format: ChatTemplateContentFormatOption, + + /// Log a summary line for each completed request, including prompt/output + /// token counts and finish reason. + #[arg(long)] + #[serde(default)] + pub enable_log_requests: bool, + + /// Disable periodic logging of engine statistics (throughput, queue depth, + /// cache usage). + #[arg(long)] + #[serde(default)] + pub disable_log_stats: bool, + + /// The model name(s) used in the API. If multiple names are provided, the + /// server will respond to any of the provided names. The model name in the + /// model field of a response will be the first name in this list. If not + /// specified, the model name will be the same as the `--model` argument. + /// Noted that this name(s) will also be used in `model_name` tag + /// content of prometheus metrics, if multiple names provided, metrics + /// tag will take the first one. + #[arg(long, num_args = 0..)] + #[serde(default)] + pub served_model_name: Vec, + + /// Unsupported Python vLLM frontend arguments recognized but not yet + /// implemented in Rust. + #[educe(Debug(ignore))] + #[command(flatten)] + #[serde(default, flatten)] + pub unsupported: UnsupportedArgs, +} + +impl SharedRuntimeArgs { + /// Maximum time to wait for the expected engines to register on the + /// frontend transport. + pub fn ready_timeout(&self) -> Duration { + Duration::from_secs(self.engine_ready_timeout_secs) + } + + /// Maximum time to wait for active requests to drain during shutdown. + pub fn shutdown_timeout(&self) -> Duration { + Duration::from_secs(self.shutdown_timeout) + } + + /// Build the OpenAI-server config for the Python-bootstrap worker contract. + /// + /// The resulting config binds the Python-supplied transport addresses and + /// inherits an already open HTTP listener from the supervisor process. + fn into_bootstrapped_config( + self, + listen_fd: i32, + input_address: String, + output_address: String, + coordinator_address: Option, + engine_count: usize, + ) -> Config { + let ready_timeout = self.ready_timeout(); + let shutdown_timeout = self.shutdown_timeout(); + + Config { + transport_mode: TransportMode::Bootstrapped { + input_address, + output_address, + engine_count, + ready_timeout, + }, + coordinator_mode: match coordinator_address { + Some(address) => CoordinatorMode::External { address }, + None => CoordinatorMode::None, + }, + model: self.model, + served_model_name: self.served_model_name, + listener_mode: HttpListenerMode::InheritedFd { fd: listen_fd }, + tool_call_parser: self.tool_call_parser, + reasoning_parser: self.reasoning_parser, + renderer: self.renderer, + chat_template: self.chat_template, + default_chat_template_kwargs: self.default_chat_template_kwargs, + chat_template_content_format: self.chat_template_content_format, + enable_log_requests: self.enable_log_requests, + disable_log_stats: self.disable_log_stats, + grpc_port: self.grpc_port, + shutdown_timeout, + } + } + + /// Build the OpenAI-server config for the managed `serve` path that still + /// owns the startup handshake and binds its own HTTP listener. + fn into_managed_config( + self, + listener_mode: HttpListenerMode, + handshake_address: String, + advertised_host: String, + engine_count: usize, + local_input_address: Option, + local_output_address: Option, + ) -> Config { + let ready_timeout = self.ready_timeout(); + let shutdown_timeout = self.shutdown_timeout(); + + Config { + transport_mode: TransportMode::HandshakeOwner { + handshake_address, + advertised_host, + engine_count, + ready_timeout, + local_input_address, + local_output_address, + }, + coordinator_mode: CoordinatorMode::MaybeInProc, + model: self.model, + served_model_name: self.served_model_name, + listener_mode, + tool_call_parser: self.tool_call_parser, + reasoning_parser: self.reasoning_parser, + renderer: self.renderer, + chat_template: self.chat_template, + default_chat_template_kwargs: self.default_chat_template_kwargs, + chat_template_content_format: self.chat_template_content_format, + enable_log_requests: self.enable_log_requests, + disable_log_stats: self.disable_log_stats, + grpc_port: self.grpc_port, + shutdown_timeout, + } + } +} + +fn default_engine_ready_timeout_secs() -> u64 { + 600 +} + +fn parse_json(value: &str) -> Result { + serde_json::from_str(value).map_err(|e| format!("invalid JSON object: {}", e.as_report())) +} + +fn parse_runtime_args_json(value: &str) -> Result { + let args: SharedRuntimeArgs = serde_json::from_str(value) + .map_err(|e| format!("invalid JSON arguments: {}", e.as_report()))?; + args.unsupported.check()?; + Ok(args) +} + +/// Arguments for running the Rust frontend as a Python-bootstrapped worker. +#[derive(Educe, Clone, Args, PartialEq, Eq)] +#[educe(Debug)] +pub struct FrontendArgs { + /// Inherited listening socket file descriptor passed by the Python + /// supervisor. + #[arg(long)] + pub listen_fd: i32, + /// Frontend input ROUTER socket address that the Python engines will + /// connect to. + #[arg(long)] + pub input_address: String, + /// Frontend output PULL socket address that the Python engines will push + /// responses to. + #[arg(long)] + pub output_address: String, + /// Optional Python-owned frontend-side DP coordinator socket address for + /// external coordinator mode in the bootstrapped frontend path, i.e., + /// `stats_update_address`. + #[arg(long)] + pub coordinator_address: Option, + /// Total number of data-parallel engines expected for this frontend. + #[arg(long, default_value_t = 1)] + pub engine_count: usize, + + /// Shared frontend arguments as one JSON object. + #[arg(long = "args-json", value_parser = parse_runtime_args_json, value_name = "JSON")] + pub runtime: SharedRuntimeArgs, +} + +impl FrontendArgs { + /// Convert the CLI arguments into the OpenAI server's runtime config. + pub fn into_config(self) -> Config { + self.runtime.into_bootstrapped_config( + self.listen_fd, + self.input_address, + self.output_address, + self.coordinator_address, + self.engine_count, + ) + } +} + +/// Arguments for the managed-engine mode that spawns Python on behalf of the +/// user. +#[derive(Educe, Clone, Args, PartialEq, Eq)] +#[educe(Debug)] +#[command(override_usage = "vllm-rs serve [OPTIONS] [-- ...]")] +pub struct ServeArgs { + /// Only launch the managed Python headless engine and do not start the Rust + /// frontend. + #[arg(long)] + pub headless: bool, + /// HTTP bind host for the OpenAI-compatible server. + #[arg(long, default_value = "127.0.0.1")] + pub host: String, + /// HTTP bind port for the OpenAI-compatible server. + #[arg(long, default_value_t = 8000)] + pub port: u16, + /// Unix domain socket path. If set, host and port arguments are ignored. + #[arg(long)] + pub uds: Option, + + /// Flag to print debug information about CLI argument parsing and exit. + #[educe(Debug(ignore))] + #[arg(long, hide = true, env = "VLLM_RS_DEBUG_CLI")] + pub debug_cli: bool, + + /// Shared frontend arguments. + #[command(flatten)] + pub runtime: SharedRuntimeArgs, + + /// Managed Python headless-engine arguments. + #[command(flatten)] + pub managed_engine: ManagedEngineArgs, +} + +impl ServeArgs { + /// Build the OpenAI-server runtime config used after the managed Python + /// engine starts. + pub fn to_frontend_config(&self, handshake_address: String) -> Config { + // Prefer IPC sockets for local engine input/output. + let (local_input_address, local_output_address) = + self.managed_engine.frontend_local_only().then(frontend_ipc_addresses).unzip(); + let listener_mode = match &self.uds { + Some(path) => HttpListenerMode::BindUnix { path: path.clone() }, + None => HttpListenerMode::BindTcp { + host: self.host.clone(), + port: self.port, + }, + }; + + self.runtime.clone().into_managed_config( + listener_mode, + handshake_address, + self.managed_engine.handshake_host.clone(), + self.managed_engine.data_parallel_size, + local_input_address, + local_output_address, + ) + } + + /// Build the managed Python-engine spawn configuration with the given + /// handshake port. + pub fn to_managed_engine_config(&self, handshake_port: u16) -> ManagedEngineConfig { + self.managed_engine.clone().into_config( + self.runtime.model.clone(), + self.runtime.max_model_len, + handshake_port, + ) + } +} + +/// Allocate fresh IPC endpoints for one managed frontend instance. +fn frontend_ipc_addresses() -> (String, String) { + let preferred_base_path = std::env::var_os("VLLM_RPC_BASE_PATH") + .map(PathBuf::from) + .unwrap_or_else(std::env::temp_dir); + let input_name = format!("vllm-rs-i-{}", Uuid::new_v4().simple()); + let output_name = format!("vllm-rs-o-{}", Uuid::new_v4().simple()); + + let input = preferred_base_path.join(input_name); + let output = preferred_base_path.join(output_name); + + ( + format!("ipc://{}", input.to_string_lossy()), + format!("ipc://{}", output.to_string_lossy()), + ) +} + +#[cfg(test)] +mod tests; diff --git a/rust/src/cmd/src/cli/tests.rs b/rust/src/cmd/src/cli/tests.rs new file mode 100644 index 00000000000..0762468456e --- /dev/null +++ b/rust/src/cmd/src/cli/tests.rs @@ -0,0 +1,905 @@ +use expect_test::expect; +use vllm_engine_core_client::TransportMode; +use vllm_server::{Config, HttpListenerMode, ParserSelection, RendererSelection}; + +use super::{Cli, Command}; + +#[test] +fn serve_args_forward_python_flags_with_separator() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--python", + "../vllm/.venv/bin/python", + "--max-model-len", + "512", + "--", + "--dtype", + "float16", + ]) + .unwrap(); + + expect![[r#" + Cli { + command: Serve( + ServeArgs { + headless: false, + host: "127.0.0.1", + port: 8000, + uds: None, + runtime: SharedRuntimeArgs { + model: "Qwen/Qwen3-0.6B", + engine_ready_timeout_secs: 600, + tool_call_parser: Auto, + reasoning_parser: Auto, + renderer: Auto, + max_model_len: Some( + 512, + ), + grpc_port: None, + shutdown_timeout: 0, + chat_template: None, + default_chat_template_kwargs: None, + chat_template_content_format: Auto, + enable_log_requests: false, + disable_log_stats: false, + served_model_name: [], + }, + managed_engine: ManagedEngineArgs { + python: "../vllm/.venv/bin/python", + handshake_host: "127.0.0.1", + handshake_port: None, + data_parallel_size: 1, + data_parallel_size_local: None, + python_args: [ + "--dtype", + "float16", + ], + }, + }, + ), + } + "#]] + .assert_debug_eq(&cli); +} + +#[test] +fn serve_args_auto_forward_python_flags_without_separator() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--python", + "python3", + "--quantization", + "awq", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!( + args.managed_engine.python_args, + vec!["--quantization", "awq"] + ); +} + +#[test] +fn serve_args_auto_forward_python_multi_char_alias_without_separator() { + let cli = Cli::try_parse_from(["vllm-rs", "serve", "Qwen/Qwen3-0.6B", "-tp", "2"]).unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!( + args.managed_engine.python_args, + vec!["--tensor-parallel-size", "2"] + ); +} + +#[test] +fn serve_args_accept_explicit_deepseek_v32_renderer() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--tokenizer-mode", + "deepseek_v32", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!(args.runtime.renderer, RendererSelection::DeepSeekV32); +} + +#[test] +fn serve_args_reject_unknown_renderer_value() { + let error = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--tokenizer-mode", + "definitely_missing", + ]) + .unwrap_err(); + + expect![[r#" + error: invalid value 'definitely_missing' for '--tokenizer-mode ': unknown renderer `definitely_missing` (expected one of: auto, hf, deepseek_v32, deepseek_v4) + + For more information, try '--help'. + "#]] + .assert_eq(&error.to_string()); +} + +#[test] +fn serve_args_reject_unsupported_flag_arg() { + let error = Cli::try_parse_from(["vllm-rs", "serve", "Qwen/Qwen3-0.6B", "--allow-credentials"]) + .unwrap_err(); + + expect![[r#" + error: invalid value 'true' for '--allow-credentials []': argument is not implemented in Rust frontend yet + + Remove this unsupported argument to continue. + + Alternatively, if you intend to pass it only to the Python engine, put it after `--` (e.g., `-- `). + This may lead to unexpected behavior as the Rust frontend will completely ignore that argument. + + For more information, try '--help'. + "#]] + .assert_eq(&error.to_string()); +} + +#[test] +fn serve_args_reject_unsupported_no_flag_alias() { + let error = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--no-enable-log-deltas", + ]) + .unwrap_err(); + + expect![[r#" + error: invalid value 'true' for '--enable-log-deltas []': argument is not implemented in Rust frontend yet + + Remove this unsupported argument to continue. + + Alternatively, if you intend to pass it only to the Python engine, put it after `--` (e.g., `-- `). + This may lead to unexpected behavior as the Rust frontend will completely ignore that argument. + + For more information, try '--help'. + "#]] + .assert_eq(&error.to_string()); +} + +#[test] +fn frontend_args_accept_json() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--coordinator-address", + "tcp://127.0.0.1:7000", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_count":2}"#, + ]) + .unwrap(); + + expect![[r#" + Cli { + command: Frontend( + FrontendArgs { + listen_fd: 3, + input_address: "ipc:///tmp/input.sock", + output_address: "ipc:///tmp/output.sock", + coordinator_address: Some( + "tcp://127.0.0.1:7000", + ), + engine_count: 1, + runtime: SharedRuntimeArgs { + model: "Qwen/Qwen3-0.6B", + engine_ready_timeout_secs: 600, + tool_call_parser: Auto, + reasoning_parser: Auto, + renderer: Auto, + max_model_len: None, + grpc_port: None, + shutdown_timeout: 0, + chat_template: None, + default_chat_template_kwargs: None, + chat_template_content_format: Auto, + enable_log_requests: false, + disable_log_stats: false, + served_model_name: [], + }, + }, + ), + } + "#]] + .assert_debug_eq(&cli); +} + +#[test] +fn frontend_args_json_applies_defaults() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B"}"#, + ]) + .unwrap(); + + let Command::Frontend(args) = cli.command else { + panic!("expected frontend args"); + }; + assert_eq!(args.runtime.model, "Qwen/Qwen3-0.6B"); + assert_eq!(args.runtime.engine_ready_timeout_secs, 600); + assert_eq!(args.runtime.tool_call_parser, ParserSelection::Auto); + assert_eq!(args.runtime.reasoning_parser, ParserSelection::Auto); + assert_eq!(args.runtime.renderer, RendererSelection::Auto); + assert_eq!(args.runtime.max_model_len, None); + assert_eq!(args.runtime.shutdown_timeout, 0); +} + +#[test] +fn frontend_args_json_accepts_supported_non_default_fields() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_ready_timeout_secs":42,"tool_call_parser":"hermes","reasoning_parser":"qwen3_thinking","tokenizer_mode":"deepseek_v32","max_model_len":8192,"shutdown_timeout":3}"#, + ]) + .unwrap(); + + let Command::Frontend(args) = cli.command else { + panic!("expected frontend args"); + }; + assert_eq!(args.runtime.engine_ready_timeout_secs, 42); + assert_eq!( + args.runtime.tool_call_parser, + ParserSelection::Explicit("hermes".to_string()) + ); + assert_eq!( + args.runtime.reasoning_parser, + ParserSelection::Explicit("qwen3_thinking".to_string()) + ); + assert_eq!(args.runtime.renderer, RendererSelection::DeepSeekV32); + assert_eq!(args.runtime.max_model_len, Some(8192)); + assert_eq!(args.runtime.shutdown_timeout, 3); +} + +#[test] +fn serve_args_accept_none_reasoning_parser() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--reasoning-parser", + "none", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!(args.runtime.reasoning_parser, ParserSelection::None); + assert_eq!(args.runtime.tool_call_parser, ParserSelection::Auto); +} + +#[test] +fn frontend_args_json_ignores_unknown_fields() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B","uds":"/tmp/vllm.sock","nested_unknown":{"x":1}}"#, + ]) + .unwrap(); + + let Command::Frontend(args) = cli.command else { + panic!("expected frontend args"); + }; + assert_eq!(args.runtime.model, "Qwen/Qwen3-0.6B"); +} + +#[test] +fn frontend_args_json_accepts_noop_fields() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B","api_server_count":2}"#, + ]) + .unwrap(); + + let Command::Frontend(args) = cli.command else { + panic!("expected frontend args"); + }; + assert_eq!(args.runtime.model, "Qwen/Qwen3-0.6B"); +} + +#[test] +fn frontend_args_json_rejects_unsupported_fields() { + let error = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B","allow_credentials":true}"#, + ]) + .unwrap_err(); + + expect![[r#" + error: invalid value '{"model_tag":"Qwen/Qwen3-0.6B","allow_credentials":true}' for '--args-json ': + The following arguments are not implemented in Rust frontend yet: + - allow_credentials + + Remove these arguments to continue. + + For more information, try '--help'. + "#]].assert_eq(&error.to_string()); +} + +#[test] +fn frontend_args_json_aggregates_multiple_unsupported_fields() { + let error = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B","allow_credentials":true,"api_key":"secret"}"#, + ]) + .unwrap_err(); + + expect![[r#" + error: invalid value '{"model_tag":"Qwen/Qwen3-0.6B","allow_credentials":true,"api_key":"secret"}' for '--args-json ': + The following arguments are not implemented in Rust frontend yet: + - allow_credentials + - api_key + + Remove these arguments to continue. + + For more information, try '--help'. + "#]].assert_eq(&error.to_string()); +} + +#[test] +fn frontend_args_json_rejects_malformed_json() { + let error = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B""#, + ]) + .unwrap_err(); + + expect![[r#" + error: invalid value '{"model_tag":"Qwen/Qwen3-0.6B"' for '--args-json ': invalid JSON arguments: EOF while parsing an object at line 1 column 30 + + For more information, try '--help'. + "#]].assert_eq(&error.to_string()); +} + +#[test] +fn serve_args_reject_flags_before_model() { + let error = Cli::try_parse_from(["vllm-rs", "serve", "--python", "python3", "Qwen/Qwen3-0.6B"]) + .unwrap_err(); + + expect![[r#" + error: the model must appear immediately after the command + + Usage: vllm-rs serve [OPTIONS] [-- ...] + + For more information, try '--help'. + "#]] + .assert_eq(&error.to_string()); +} + +#[test] +fn serve_args_accept_headless_mode() { + let cli = Cli::try_parse_from(["vllm-rs", "serve", "Qwen/Qwen3-0.6B", "--headless"]).unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert!(args.headless); +} + +#[test] +fn serve_args_keep_python_passthrough_flags_after_separator() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--python", + "python3", + "--", + "--tensor-parallel-size", + "2", + "--dtype", + "float16", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!( + args.managed_engine.python_args, + vec!["--tensor-parallel-size", "2", "--dtype", "float16"] + ); +} + +#[test] +fn serve_args_keep_python_multi_char_alias_after_separator() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--python", + "python3", + "--", + "-tp", + "2", + "--dtype", + "float16", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!( + args.managed_engine.python_args, + vec!["-tp", "2", "--dtype", "float16"] + ); +} + +#[test] +fn serve_args_keep_frontend_arg_after_separator() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--", + "--uds", + "/tmp/vllm.sock", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!( + args.managed_engine.python_args, + vec!["--uds", "/tmp/vllm.sock"] + ); +} + +#[test] +fn serve_args_keep_python_multi_char_engine_aliases_after_separator() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--", + "-dpr", + "1", + "-dpl", + "2", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!( + args.managed_engine.python_args, + vec!["-dpr", "1", "-dpl", "2"] + ); +} + +#[test] +fn serve_args_auto_forward_unknown_flags_without_separator() { + let cli = Cli::try_parse_from(["vllm-rs", "serve", "Qwen/Qwen3-0.6B", "--foo", "bar"]).unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!(args.managed_engine.python_args, vec!["--foo", "bar"]); +} + +#[test] +fn serve_args_auto_forward_negative_value_without_separator() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--num-gpu-blocks-override", + "-1", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert_eq!( + args.managed_engine.python_args, + vec!["--num-gpu-blocks-override", "-1"] + ); +} + +#[test] +fn serve_args_accept_handshake_aliases() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--python", + "python3", + "--handshake-host", + "10.99.48.128", + "--handshake-port", + "13345", + "--data-parallel-size", + "4", + ]) + .unwrap(); + + expect![[r#" + Cli { + command: Serve( + ServeArgs { + headless: false, + host: "127.0.0.1", + port: 8000, + uds: None, + runtime: SharedRuntimeArgs { + model: "Qwen/Qwen3-0.6B", + engine_ready_timeout_secs: 600, + tool_call_parser: Auto, + reasoning_parser: Auto, + renderer: Auto, + max_model_len: None, + grpc_port: None, + shutdown_timeout: 0, + chat_template: None, + default_chat_template_kwargs: None, + chat_template_content_format: Auto, + enable_log_requests: false, + disable_log_stats: false, + served_model_name: [], + }, + managed_engine: ManagedEngineArgs { + python: "python3", + handshake_host: "10.99.48.128", + handshake_port: Some( + 13345, + ), + data_parallel_size: 4, + data_parallel_size_local: None, + python_args: [], + }, + }, + ), + } + "#]] + .assert_debug_eq(&cli); +} + +#[test] +fn serve_args_accept_data_parallel_primary_flags() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--data-parallel-address", + "10.99.48.128", + "--data-parallel-rpc-port", + "13345", + "--data-parallel-size", + "4", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + assert!(!args.headless); + assert_eq!(args.managed_engine.handshake_host, "10.99.48.128"); + assert_eq!(args.managed_engine.handshake_port, Some(13345)); + assert_eq!(args.managed_engine.data_parallel_size, 4); +} + +#[test] +fn serve_frontend_config_uses_dp_address_as_advertised_host() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--handshake-host", + "10.99.48.128", + "--data-parallel-size", + "4", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + let config = args.to_frontend_config("tcp://10.99.48.128:29550".to_string()); + + let TransportMode::HandshakeOwner { + handshake_address, + advertised_host, + engine_count, + ready_timeout, + local_input_address, + local_output_address, + } = &config.transport_mode + else { + panic!("expected handshake-owned transport"); + }; + + assert_eq!(handshake_address, "tcp://10.99.48.128:29550"); + assert_eq!(advertised_host, "10.99.48.128"); + assert_eq!(*engine_count, 4); + assert!( + local_input_address + .as_deref() + .is_some_and(|address| address.starts_with("ipc://")) + ); + assert!( + local_output_address + .as_deref() + .is_some_and(|address| address.starts_with("ipc://")) + ); + assert_ne!(local_input_address, local_output_address); + + expect![[r#" + Config { + transport_mode: HandshakeOwner { + handshake_address: "tcp://10.99.48.128:29550", + advertised_host: "10.99.48.128", + engine_count: 4, + ready_timeout: 600s, + local_input_address: Some( + "", + ), + local_output_address: Some( + "", + ), + }, + coordinator_mode: MaybeInProc, + model: "Qwen/Qwen3-0.6B", + served_model_name: [], + listener_mode: BindTcp { + host: "127.0.0.1", + port: 8000, + }, + tool_call_parser: Auto, + reasoning_parser: Auto, + renderer: Auto, + chat_template: None, + default_chat_template_kwargs: None, + chat_template_content_format: Auto, + enable_log_requests: false, + disable_log_stats: false, + grpc_port: None, + shutdown_timeout: 0ns, + } + "#]] + .assert_debug_eq(&Config { + transport_mode: TransportMode::HandshakeOwner { + handshake_address: handshake_address.clone(), + advertised_host: advertised_host.clone(), + engine_count: *engine_count, + ready_timeout: *ready_timeout, + local_input_address: Some("".to_string()), + local_output_address: Some("".to_string()), + }, + ..config.clone() + }); +} + +#[test] +fn serve_frontend_config_keeps_tcp_transport_for_non_local_only_topology() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--data-parallel-address", + "10.99.48.128", + "--data-parallel-size", + "4", + "--data-parallel-size-local", + "2", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + let config = args.to_frontend_config("tcp://10.99.48.128:29550".to_string()); + + expect![[r#" + Config { + transport_mode: HandshakeOwner { + handshake_address: "tcp://10.99.48.128:29550", + advertised_host: "10.99.48.128", + engine_count: 4, + ready_timeout: 600s, + local_input_address: None, + local_output_address: None, + }, + coordinator_mode: MaybeInProc, + model: "Qwen/Qwen3-0.6B", + served_model_name: [], + listener_mode: BindTcp { + host: "127.0.0.1", + port: 8000, + }, + tool_call_parser: Auto, + reasoning_parser: Auto, + renderer: Auto, + chat_template: None, + default_chat_template_kwargs: None, + chat_template_content_format: Auto, + enable_log_requests: false, + disable_log_stats: false, + grpc_port: None, + shutdown_timeout: 0ns, + } + "#]] + .assert_debug_eq(&config); +} + +#[test] +fn frontend_args_reject_legacy_handshake_flags() { + let error = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B"}"#, + "--handshake-address", + "tcp://127.0.0.1:62100", + ]) + .unwrap_err(); + + assert!(error.to_string().contains("--handshake-address")); +} + +#[test] +fn frontend_config_uses_external_coordinator_when_coordinator_address_is_present() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "frontend", + "--listen-fd", + "3", + "--input-address", + "ipc:///tmp/input.sock", + "--output-address", + "ipc:///tmp/output.sock", + "--coordinator-address", + "tcp://127.0.0.1:7000", + "--engine-count", + "2", + "--args-json", + r#"{"model_tag":"Qwen/Qwen3-0.6B"}"#, + ]) + .unwrap(); + + let Command::Frontend(args) = cli.command else { + panic!("expected frontend args"); + }; + let config = args.into_config(); + + expect![[r#" + Config { + transport_mode: Bootstrapped { + input_address: "ipc:///tmp/input.sock", + output_address: "ipc:///tmp/output.sock", + engine_count: 2, + ready_timeout: 600s, + }, + coordinator_mode: External { + address: "tcp://127.0.0.1:7000", + }, + model: "Qwen/Qwen3-0.6B", + served_model_name: [], + listener_mode: InheritedFd { + fd: 3, + }, + tool_call_parser: Auto, + reasoning_parser: Auto, + renderer: Auto, + chat_template: None, + default_chat_template_kwargs: None, + chat_template_content_format: Auto, + enable_log_requests: false, + disable_log_stats: false, + grpc_port: None, + shutdown_timeout: 0ns, + } + "#]] + .assert_debug_eq(&config); +} + +#[test] +fn serve_frontend_config_uses_unix_listener_when_uds_is_present() { + let cli = Cli::try_parse_from([ + "vllm-rs", + "serve", + "Qwen/Qwen3-0.6B", + "--uds", + "/tmp/vllm.sock", + ]) + .unwrap(); + + let Command::Serve(args) = cli.command else { + panic!("expected serve args"); + }; + let config = args.to_frontend_config("tcp://127.0.0.1:29550".to_string()); + + assert_eq!( + config.listener_mode, + HttpListenerMode::BindUnix { + path: "/tmp/vllm.sock".to_string(), + } + ); +} diff --git a/rust/src/cmd/src/cli/unsupported.rs b/rust/src/cmd/src/cli/unsupported.rs new file mode 100644 index 00000000000..eeaa0832888 --- /dev/null +++ b/rust/src/cmd/src/cli/unsupported.rs @@ -0,0 +1,660 @@ +#![allow(clippy::doc_lazy_continuation)] + +use std::fmt::Display; +use std::str::FromStr; + +use clap::Args; +use clap::builder::{TypedValueParser, ValueParserFactory}; +use itertools::Itertools; +use serde::{Deserialize, Deserializer, Serialize}; + +/// Marker type for frontend-owned `serve` arguments that `vllm-rs` recognizes +/// but does not support yet. +/// +/// When passed as JSON args, it can be deserialized from any value, and +/// serializes back to the original value. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Unsupported(pub serde_json::Value); + +impl FromStr for Unsupported { + type Err = String; + + fn from_str(_s: &str) -> Result { + Err("argument is not implemented in Rust frontend yet + +Remove this unsupported argument to continue. + +Alternatively, if you intend to pass it only to the Python engine, put it after `--` (e.g., `-- `). +This may lead to unexpected behavior as the Rust frontend will completely ignore that argument." + .to_string()) + } +} + +/// Marker type for no-op arguments that are accepted by the Rust frontend but +/// have no effect. +/// +/// When passed as JSON args, it can be deserialized from any value, but always +/// serializes back to `null`. +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] +pub struct Noop; + +impl<'de> Deserialize<'de> for Noop { + fn deserialize(_deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Ok(Noop) + } +} + +impl ValueParserFactory for Noop { + type Parser = NoopValueParser; + + fn value_parser() -> Self::Parser { + NoopValueParser + } +} + +#[derive(Copy, Clone, Debug)] +pub struct NoopValueParser; + +#[track_caller] +fn noop_warn(arg: impl Display) { + tracing::warn!("argument '{arg}' currently has no effect in Rust frontend, ignoring"); +} + +impl TypedValueParser for NoopValueParser { + type Value = Noop; + + fn parse_ref( + &self, + _cmd: &clap::Command, + arg: Option<&clap::Arg>, + _value: &std::ffi::OsStr, + ) -> Result { + if let Some(arg) = arg { + noop_warn(arg); + } + Ok(Noop) + } +} + +/// Frontend-owned Python `serve` arguments that `vllm-rs` recognizes but does +/// not support yet. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, PartialEq, Eq, Default, Args, Serialize, Deserialize)] +#[command(next_help_heading = "Options not implemented in Rust frontend yet")] +pub struct UnsupportedArgs { + #[command(flatten)] + #[serde(default, flatten)] + top_level: TopLevelUnsupportedArgs, + #[command(flatten)] + #[serde(default, flatten)] + engine: EngineUnsupportedArgs, + #[command(flatten)] + #[serde(default, flatten)] + server: ServerUnsupportedArgs, +} + +impl UnsupportedArgs { + /// Check whether any unsupported arguments are set, and if so, return an + /// error listing them. Also warn about any no-op arguments that are set + /// but will be ignored. + pub(crate) fn check(&self) -> Result<(), String> { + let value = serde_json::to_value(self).unwrap(); + let map = value.as_object().unwrap(); + let mut unsupported = Vec::new(); + + for (key, value) in map { + if value.is_null() { + noop_warn(key); + } else { + unsupported.push(key.as_str()); + } + } + + if !unsupported.is_empty() { + unsupported.sort_unstable(); + let bullets = unsupported.into_iter().map(|key| format!("- {key}")).join("\n"); + return Err(format!( + " +The following arguments are not implemented in Rust frontend yet: +{bullets} + +Remove these arguments to continue." + )); + } + + Ok(()) + } +} + +/// Frontend-owned Python `vllm serve` top-level arguments that `vllm-rs` +/// recognizes but does not support yet. +/// +/// Source of truth in Python vLLM: +/// - `vllm.entrypoints.openai.cli_args.make_arg_parser(...)` +/// - `vllm.entrypoints.cli.serve.ServeSubcommand.subparser_init(...)` +/// +/// These are not part of `EngineArgs`, `AsyncEngineArgs`, `BaseFrontendArgs`, +/// or `FrontendArgs`. They live on the `serve` command itself and control +/// managed-engine / multi-process orchestration rather than the shared frontend +/// runtime config. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, PartialEq, Eq, Default, Args, Serialize, Deserialize)] +pub struct TopLevelUnsupportedArgs { + /// How many API server processes to run. Defaults to data_parallel_size if + /// not specified. + #[arg(long, hide = true)] + pub api_server_count: Option, + + /// Read CLI options from a config file. Must be a YAML with the following + /// options: https://docs.vllm.ai/en/latest/configuration/serve_args.html + #[arg(long)] + pub config: Option, + + /// Launch a gRPC server instead of the HTTP OpenAI-compatible server. + /// Requires: pip install vllm[grpc]. + #[arg(long, default_missing_value = "true", num_args = 0..=1)] + pub grpc: Option, +} + +/// Frontend-owned Python engine arguments that `vllm-rs` recognizes but does +/// not support yet. +/// +/// Source of truth in Python vLLM: +/// - `vllm.engine.arg_utils.EngineArgs.add_cli_args(...)` +/// - `vllm.engine.arg_utils.AsyncEngineArgs.add_cli_args(...)` +/// +/// These arguments are declared through the Python engine-args surface, but +/// they are still frontend-owned: the API server / AsyncLLM layer reads them +/// for tokenizer setup, request validation, routing, logging, and other +/// frontend behavior, so Rust must recognize them rather than treating them as +/// pure engine passthrough. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, PartialEq, Eq, Default, Args, Serialize, Deserialize)] +pub struct EngineUnsupportedArgs { + /// Name or path of the Hugging Face tokenizer to use. If unspecified, model + /// name or path will be used. + #[arg(long)] + pub tokenizer: Option, + + /// Name or path of the Hugging Face config to use. If unspecified, model + /// name or path will be used. + #[arg(long)] + pub hf_config_path: Option, + + /// Allowing API requests to read local images or videos from directories + /// specified by the server file system. This is a security risk. Should + /// only be enabled in trusted environments. + #[arg(long)] + pub allowed_local_media_path: Option, + + /// If set, only media URLs that belong to this domain can be used for + /// multi-modal inputs. + #[arg(long)] + pub allowed_media_domains: Option, + + /// The specific revision to use for the tokenizer on the Hugging Face Hub. + /// It can be a branch name, a tag name, or a commit id. If unspecified, + /// will use the default version. + #[arg(long)] + pub tokenizer_revision: Option, + + /// Maximum number of log probabilities to return when `logprobs` is + /// specified in `SamplingParams`. The default value comes the default for + /// the OpenAI Chat Completions API. -1 means no cap, i.e. all + /// (output_length * vocab_size) logprobs are allowed to be returned and + /// it may cause OOM. + #[arg(long)] + pub max_logprobs: Option, + + /// Skip initialization of tokenizer and detokenizer. Expects valid + /// `prompt_token_ids` and `None` for prompt from the input. The generated + /// output will contain token ids. + #[arg( + long, + visible_alias = "no-skip-tokenizer-init", + default_missing_value = "true", + num_args = 0..=1 + )] + pub skip_tokenizer_init: Option, + + /// If `True`, enables passing text embeddings as inputs via the + /// `prompt_embeds` key. + /// + /// WARNING: The vLLM engine may crash if incorrect shape of embeddings is + /// passed. Only enable this flag for trusted users! + #[arg( + long, + visible_alias = "no-enable-prompt-embeds", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_prompt_embeds: Option, + + /// The token to use as HTTP bearer authorization for remote files. If + /// `True`, will use the token generated when running `hf auth login` + /// (stored in `~/.cache/huggingface/token`). + #[arg(long, default_missing_value = "true", num_args = 0..=1)] + pub hf_token: Option, + + /// If a dictionary, contains arguments to be forwarded to the Hugging Face + /// config. If a callable, it is called to update the HuggingFace config. + #[arg(long)] + pub hf_overrides: Option, + + /// The folder path to the generation config. Defaults to `"auto"`, the + /// generation config will be loaded from model path. If set to `"vllm"`, no + /// generation config is loaded, vLLM defaults will be used. If set to a + /// folder path, the generation config will be loaded from the specified + /// folder path. If `max_new_tokens` is specified in generation config, + /// then it sets a server-wide limit on the number of output tokens for + /// all requests. + #[arg(long)] + pub generation_config: Option, + + /// IOProcessor plugin name to load at model startup + #[arg(long)] + pub io_processor_plugin: Option, + + /// Path to a dynamically reasoning parser plugin that can be dynamically + /// loaded and registered. + #[arg(long)] + pub reasoning_parser_plugin: Option, + + /// Rank of the data parallel group. + #[arg(long, env = "VLLM_DP_RANK")] + pub data_parallel_rank: Option, + + /// Whether to use "hybrid" DP LB mode. Applies only to online serving + /// and when data_parallel_size > 0. Enables running an AsyncLLM + /// and API server on a "per-node" basis where vLLM load balances + /// between local data parallel ranks, but an external LB balances + /// between vLLM nodes/replicas. Set explicitly in conjunction with + /// --data-parallel-start-rank. + #[arg( + long, + visible_alias = "no-data-parallel-hybrid-lb", + default_missing_value = "true", + num_args = 0..=1 + )] + pub data_parallel_hybrid_lb: Option, + + /// Whether to use "external" DP LB mode. Applies only to online serving + /// and when data_parallel_size > 0. This is useful for a "one-pod-per-rank" + /// wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank + /// is provided explicitly to vllm serve. + #[arg( + long, + visible_alias = "no-data-parallel-external-lb", + default_missing_value = "true", + num_args = 0..=1 + )] + pub data_parallel_external_lb: Option, + + /// This feature is work in progress and no prefill optimization takes place + /// with this flag enabled currently. + #[arg( + long, + visible_alias = "no-kv-sharing-fast-prefill", + default_missing_value = "true", + num_args = 0..=1 + )] + pub kv_sharing_fast_prefill: Option, + + /// The maximum number of input items and options allowed per + /// prompt for each modality. + #[arg(long)] + pub limit_mm_per_prompt: Option, + + /// Additional args passed to process media inputs, keyed by modalities. + #[arg(long)] + pub media_io_kwargs: Option, + + /// Arguments to be forwarded to the model's processor for multi-modal data, + /// e.g., image processor. + #[arg(long)] + pub mm_processor_kwargs: Option, + + /// The size (in GiB) of the multi-modal processor cache. + #[arg(long)] + pub mm_processor_cache_gb: Option, + + /// Type of cache to use for the multi-modal preprocessor/mapper. + #[arg(long)] + pub mm_processor_cache_type: Option, + + /// If True, enable handling of LoRA adapters. + #[arg( + long, + visible_alias = "no-enable-lora", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_lora: Option, + + /// Dictionary mapping specific modalities to LoRA model paths. + #[arg(long)] + pub default_mm_loras: Option, + + /// Target URL to which OpenTelemetry traces will be sent. + #[arg(long)] + pub otlp_traces_endpoint: Option, + + /// It makes sense to set this only if `--otlp-traces-endpoint` is set. + #[arg(long)] + pub collect_detailed_traces: Option, + + /// The interval (or buffer size) for streaming in terms of token length. + #[arg(long)] + pub stream_interval: Option, + + /// Structured outputs configuration. + #[arg(long)] + pub structured_outputs_config: Option, + + /// Log aggregate rather than per-engine statistics when using data + /// parallelism. + #[arg(long, default_missing_value = "true", num_args = 0..=1)] + pub aggregate_engine_logging: Option, +} + +/// Frontend-owned Python OpenAI server arguments that `vllm-rs` recognizes but +/// does not support yet. +/// +/// Source of truth in Python vLLM: +/// - `vllm.entrypoints.openai.cli_args.BaseFrontendArgs` +/// - `vllm.entrypoints.openai.cli_args.FrontendArgs` +/// +/// These are not engine args. They belong to the Python OpenAI-compatible +/// frontend / API-server layer itself, for example chat-template configuration, +/// tool/frontend behavior, TLS / CORS / HTTP server settings, and other +/// northbound server knobs. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, PartialEq, Eq, Default, Args, Serialize, Deserialize)] +pub struct ServerUnsupportedArgs { + /// LoRA modules configurations in either 'name=path' format or JSON format + /// or JSON list format. Example (old format): `'name=path'` Example (new + /// format): `{"name": "name", "path": "lora_path", + /// "base_model_name": "id"}` + #[arg(long)] + pub lora_modules: Option, + + /// Whether to trust the chat template provided in the request. If False, + /// the server will always use the chat template specified by + /// `--chat-template` or the ones from tokenizer. + #[arg( + long, + visible_alias = "no-trust-request-chat-template", + default_missing_value = "true", + num_args = 0..=1 + )] + pub trust_request_chat_template: Option, + + /// The role name to return if `request.add_generation_prompt=true`. + #[arg(long)] + pub response_role: Option, + + /// When `--max-logprobs` is specified, represents single tokens as + /// strings of the form 'token_id:{token_id}' so that tokens that are not + /// JSON-encodable can be identified. + #[arg( + long, + visible_alias = "no-return-tokens-as-token-ids", + default_missing_value = "true", + num_args = 0..=1 + )] + pub return_tokens_as_token_ids: Option, + + /// Enable auto tool choice for supported models. Use `--tool-call-parser` + /// to specify which parser to use. + #[arg( + long, + visible_alias = "no-enable-auto-tool-choice", + default_missing_value = "true", + num_args = 0..=1, + hide = true + )] + pub enable_auto_tool_choice: Option, + + /// If specified, exclude tool definitions in prompts when + /// tool_choice='none'. + #[arg( + long, + visible_alias = "no-exclude-tools-when-tool-choice-none", + default_missing_value = "true", + num_args = 0..=1 + )] + pub exclude_tools_when_tool_choice_none: Option, + + /// Special the tool parser plugin write to parse the model-generated tool + /// into OpenAI API format, the name register in this plugin can be used in + /// `--tool-call-parser`. + #[arg(long)] + pub tool_parser_plugin: Option, + + /// Comma-separated list of host:port pairs (IPv4, IPv6, or hostname). + /// Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for + /// built-in demo tools (browser and Python code interpreter). WARNING: + /// The `demo` Python tool executes model-generated code in Docker without + /// network isolation by default. See the security guide for more + /// information. + #[arg(long)] + pub tool_server: Option, + + /// Path to logging config JSON file for both vllm and uvicorn + #[arg(long, /* env = "VLLM_LOGGING_CONFIG_PATH" */)] + pub log_config_file: Option, + + /// Max number of prompt characters or prompt ID numbers being printed in + /// log. The default of None means unlimited. + #[arg(long)] + pub max_log_len: Option, + + /// If set to True, enable prompt_tokens_details in usage. + #[arg( + long, + visible_alias = "no-enable-prompt-tokens-details", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_prompt_tokens_details: Option, + + /// If set to True, enable tracking server_load_metrics in the app state. + #[arg( + long, + visible_alias = "no-enable-server-load-tracking", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_server_load_tracking: Option, + + /// If set to True, including usage on every request. + #[arg( + long, + visible_alias = "no-enable-force-include-usage", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_force_include_usage: Option, + + /// Enable the `/tokenizer_info` endpoint. May expose chat + /// templates and other tokenizer configuration. + #[arg( + long, + visible_alias = "no-enable-tokenizer-info-endpoint", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_tokenizer_info_endpoint: Option, + + /// If set to True, log model outputs (generations). + /// Requires `--enable-log-requests`. As with `--enable-log-requests`, + /// information is only logged at INFO level at maximum. + #[arg( + long, + visible_alias = "no-enable-log-outputs", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_log_outputs: Option, + + /// If set to False, output deltas will not be logged. Relevant only if + /// --enable-log-outputs is set. + #[arg( + long, + visible_alias = "no-enable-log-deltas", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_log_deltas: Option, + + /// If set to True, log the stack trace of error responses + #[arg( + long, + // env = "VLLM_SERVER_DEV_MODE", + visible_alias = "no-log-error-stack", + default_missing_value = "true", + num_args = 0..=1 + )] + pub log_error_stack: Option, + + /// If set to True, only enable the Tokens In<>Out endpoint. + /// This is intended for use in a Disaggregated Everything setup. + #[arg( + long, + visible_alias = "no-tokens-only", + default_missing_value = "true", + num_args = 0..=1 + )] + pub tokens_only: Option, + + /// Log level for uvicorn. + #[arg(long)] + pub uvicorn_log_level: Option, + + /// Disable uvicorn access log. + #[arg( + long, + visible_alias = "no-disable-uvicorn-access-log", + default_missing_value = "true", + num_args = 0..=1 + )] + pub disable_uvicorn_access_log: Option, + + /// Comma-separated list of endpoint paths to exclude from uvicorn access + /// logs. This is useful to reduce log noise from high-frequency endpoints + /// like health checks. Example: "/health,/metrics,/ping". + /// When set, access logs for requests to these paths will be suppressed + /// while keeping logs for other endpoints. + #[arg(long)] + pub disable_access_log_for_endpoints: Option, + + /// Allow credentials. + #[arg( + long, + visible_alias = "no-allow-credentials", + default_missing_value = "true", + num_args = 0..=1 + )] + pub allow_credentials: Option, + + /// Allowed origins. + #[arg(long)] + pub allowed_origins: Option, + + /// Allowed methods. + #[arg(long)] + pub allowed_methods: Option, + + /// Allowed headers. + #[arg(long)] + pub allowed_headers: Option, + + /// If provided, the server will require one of these keys to be presented + /// in the header. + #[arg(long)] + pub api_key: Option, + + /// The file path to the SSL key file. + #[arg(long)] + pub ssl_keyfile: Option, + + /// The file path to the SSL cert file. + #[arg(long)] + pub ssl_certfile: Option, + + /// The CA certificates file. + #[arg(long)] + pub ssl_ca_certs: Option, + + /// Refresh SSL Context when SSL certificate files change + #[arg( + long, + visible_alias = "no-enable-ssl-refresh", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_ssl_refresh: Option, + + /// Whether client certificate is required (see stdlib ssl module's). + #[arg(long)] + pub ssl_cert_reqs: Option, + + /// SSL cipher suites for HTTPS (TLS 1.2 and below only). + /// Example: 'ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305' + #[arg(long)] + pub ssl_ciphers: Option, + + /// FastAPI root_path when app is behind a path based routing proxy. + #[arg(long)] + pub root_path: Option, + + /// Additional ASGI middleware to apply to the app. We accept multiple + /// --middleware arguments. The value should be an import path. If a + /// function is provided, vLLM will add it to the server using + /// `@app.middleware('http')`. If a class is provided, vLLM will + /// add it to the server using `app.add_middleware()`. + #[arg(long)] + pub middleware: Option, + + /// If specified, API server will add X-Request-Id header to responses. + #[arg( + long, + visible_alias = "no-enable-request-id-headers", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_request_id_headers: Option, + + /// Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint. + #[arg( + long, + visible_alias = "no-disable-fastapi-docs", + default_missing_value = "true", + num_args = 0..=1 + )] + pub disable_fastapi_docs: Option, + + /// Maximum size (bytes) of an incomplete HTTP event (header or body) for + /// h11 parser. Helps mitigate header abuse. Default: 4194304 (4 MB). + #[arg(long)] + pub h11_max_incomplete_event_size: Option, + + /// Maximum number of HTTP headers allowed in a request for h11 parser. + /// Helps mitigate header abuse. Default: 256. + #[arg(long)] + pub h11_max_header_count: Option, + + /// Enable offline FastAPI documentation for air-gapped environments. + /// Uses vendored static assets bundled with vLLM. + #[arg( + long, + visible_alias = "no-enable-offline-docs", + default_missing_value = "true", + num_args = 0..=1 + )] + pub enable_offline_docs: Option, +} diff --git a/rust/src/cmd/src/logging.rs b/rust/src/cmd/src/logging.rs new file mode 100644 index 00000000000..aab080c9398 --- /dev/null +++ b/rust/src/cmd/src/logging.rs @@ -0,0 +1,328 @@ +use std::{env, fmt, process}; + +use time::UtcOffset; +use time::macros::format_description; +use tracing::level_filters::LevelFilter; +use tracing::{Event, Level, Subscriber}; +use tracing_subscriber::Layer as _; +use tracing_subscriber::filter::Targets; +use tracing_subscriber::fmt::format::{FormatEvent, FormatFields, Writer}; +use tracing_subscriber::fmt::time::FormatTime; +use tracing_subscriber::fmt::{FmtContext, FormattedFields}; +use tracing_subscriber::layer::SubscriberExt as _; +use tracing_subscriber::registry::LookupSpan; +use tracing_subscriber::util::SubscriberInitExt as _; + +const CYAN: &str = "\x1b[0;36m"; +const GREY: &str = "\x1b[90m"; +const GREEN: &str = "\x1b[32m"; +const YELLOW: &str = "\x1b[33m"; +const RED: &str = "\x1b[31m"; +const WHITE: &str = "\x1b[37m"; +const RESET: &str = "\x1b[0m"; +const VLLM_TIME_FORMAT: &[time::format_description::FormatItem<'static>] = + format_description!("[month]-[day] [hour]:[minute]:[second]"); + +const PROCESS_LABEL: &str = "RustFrontend"; + +/// Install the process-wide vLLM-style tracing subscriber for the CLI binary. +pub(crate) fn init_tracing() { + let filter = build_targets_filter( + env::var("VLLM_LOGGING_LEVEL").ok().as_deref(), + env::var("RUST_LOG").ok().as_deref(), + ); + let formatter = VllmEventFormatter::new(); + + let _ = tracing_subscriber::registry() + .with(tracing_subscriber::fmt::layer().event_format(formatter).with_filter(filter)) + .try_init(); +} + +/// Build the CLI log filter by merging the vLLM-style default level with +/// Rust-style target overrides. +/// +/// Precedence: +/// - Start from `VLLM_LOGGING_LEVEL` as the default level for all targets. +/// - If `RUST_LOG` contains a global default level such as `warn`, it overrides +/// `VLLM_LOGGING_LEVEL`. +/// - Any explicit target directives in `RUST_LOG`, such as `hyper=info`, override whichever default +/// level is active for those targets only. +fn build_targets_filter(vllm_logging_level: Option<&str>, rust_log: Option<&str>) -> Targets { + let mut filter = + Targets::new().with_default(map_python_log_level(vllm_logging_level.unwrap_or("INFO"))); + + if let Some(rust_log) = rust_log + && !rust_log.is_empty() + { + let rust_log_targets: Targets = rust_log.parse().expect("failed to parse `RUST_LOG`"); + if let Some(default_level) = rust_log_targets.default_level() { + filter = filter.with_default(default_level); + } + filter = filter.with_targets(rust_log_targets); + } + + filter +} + +#[derive(Debug, Clone, Copy)] +struct VllmLocalTimer { + local_offset: UtcOffset, +} + +impl Default for VllmLocalTimer { + fn default() -> Self { + let local_offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + Self { local_offset } + } +} + +impl FormatTime for VllmLocalTimer { + fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result { + let now = time::OffsetDateTime::now_utc().to_offset(self.local_offset); + let formatted = now.format(VLLM_TIME_FORMAT).map_err(|_| fmt::Error)?; + w.write_str(&formatted) + } +} + +#[derive(Debug, Clone)] +struct VllmEventFormatter { + prefix: String, + timer: VllmLocalTimer, +} + +impl VllmEventFormatter { + fn new() -> Self { + Self { + prefix: format!("({} pid={})", PROCESS_LABEL, process::id()), + timer: VllmLocalTimer::default(), + } + } + + fn write_process_prefix(&self, writer: &mut Writer<'_>, ansi: bool) -> fmt::Result { + write_colored(writer, ansi, Some(CYAN), &self.prefix)?; + writer.write_char(' ') + } + + fn write_level(&self, writer: &mut Writer<'_>, level: &Level, ansi: bool) -> fmt::Result { + let (text, color) = match *level { + Level::TRACE => ("TRACE", WHITE), + Level::DEBUG => ("DEBUG", WHITE), + Level::INFO => ("INFO", GREEN), + Level::WARN => ("WARNING", YELLOW), + Level::ERROR => ("ERROR", RED), + }; + write_colored(writer, ansi, Some(color), text) + } + + fn write_timestamp(&self, writer: &mut Writer<'_>, ansi: bool) -> fmt::Result { + if ansi { + writer.write_str(GREY)?; + } + if self.timer.format_time(writer).is_err() { + writer.write_str("")?; + } + if ansi { + writer.write_str(RESET)?; + } + Ok(()) + } + + fn write_location( + &self, + writer: &mut Writer<'_>, + file: Option<&str>, + line: Option, + full_path: bool, + ansi: bool, + ) -> fmt::Result { + let Some(file) = file else { + return Ok(()); + }; + let file = if full_path { + file + } else { + shorten_file_path(file) + }; + if ansi { + writer.write_str(GREY)?; + } + match line { + Some(line) => write!(writer, "[{file}:{line}]")?, + None => write!(writer, "[{file}]")?, + } + if ansi { + writer.write_str(RESET)?; + } + Ok(()) + } + + fn write_scope(&self, ctx: &FmtContext<'_, S, N>, writer: &mut Writer<'_>) -> fmt::Result + where + S: Subscriber + for<'lookup> LookupSpan<'lookup>, + N: for<'writer> FormatFields<'writer> + 'static, + { + let Some(scope) = ctx.event_scope() else { + return Ok(()); + }; + + let mut seen = false; + for span in scope.from_root() { + if seen { + writer.write_str(":")?; + } + seen = true; + writer.write_str(span.metadata().name())?; + + let ext = span.extensions(); + if let Some(fields) = ext.get::>() + && !fields.is_empty() + { + write!(writer, "{{{fields}}}")?; + } + } + + if seen { + writer.write_str(": ")?; + } + + Ok(()) + } +} + +impl FormatEvent for VllmEventFormatter +where + S: Subscriber + for<'lookup> LookupSpan<'lookup>, + N: for<'writer> FormatFields<'writer> + 'static, +{ + fn format_event( + &self, + ctx: &FmtContext<'_, S, N>, + mut writer: Writer<'_>, + event: &Event<'_>, + ) -> fmt::Result { + let meta = event.metadata(); + let ansi = writer.has_ansi_escapes(); + + self.write_process_prefix(&mut writer, ansi)?; + self.write_level(&mut writer, meta.level(), ansi)?; + writer.write_char(' ')?; + self.write_timestamp(&mut writer, ansi)?; + writer.write_char(' ')?; + // Use the full file path only when DEBUG (or more verbose) is enabled anywhere, + // independent of the level of this particular event. Filenames alone are often + // ambiguous, but full paths are too noisy for normal INFO-level operation. + let full_path = LevelFilter::current() >= LevelFilter::DEBUG; + self.write_location(&mut writer, meta.file(), meta.line(), full_path, ansi)?; + writer.write_char(' ')?; + self.write_scope(ctx, &mut writer)?; + ctx.format_fields(writer.by_ref(), event)?; + writer.write_char('\n') + } +} + +/// Shorten a source file path for log output while preserving enough context +/// for common Rust entrypoint and module filenames. +/// +/// - For `mod.rs`, keep the parent directory as `parent/mod.rs`. +/// - For `src/lib.rs` and `src/main.rs`, keep one additional component as `crate/src/lib.rs` or +/// `crate/src/main.rs` when available. +/// - Other files are displayed as just the basename. +fn shorten_file_path(file: &str) -> &str { + let mut parts = file.rsplit('/'); + let name = parts.next().unwrap_or(file); + let parent = parts.next(); + let grandparent = parts.next(); + + let Some(parent) = parent else { + return file; + }; + + if name == "mod.rs" { + return &file[file.len() - parent.len() - 1 - name.len()..]; + } + + if !matches!(name, "lib.rs" | "main.rs") || parent != "src" { + return name; + } + let Some(grandparent) = grandparent else { + return file; + }; + + &file[file.len() - grandparent.len() - 1 - parent.len() - 1 - name.len()..] +} + +fn write_colored( + writer: &mut Writer<'_>, + ansi: bool, + color: Option<&str>, + text: &str, +) -> fmt::Result { + if ansi { + if let Some(color) = color { + writer.write_str(color)?; + } + writer.write_str(text)?; + if color.is_some() { + writer.write_str(RESET)?; + } + return Ok(()); + } + + writer.write_str(text) +} + +/// Map a Python logging level name to the corresponding Rust tracing level. +fn map_python_log_level(level: &str) -> LevelFilter { + match level.to_ascii_uppercase().as_str() { + "CRITICAL" | "FATAL" => LevelFilter::ERROR, + "ERROR" => LevelFilter::ERROR, + "WARNING" | "WARN" => LevelFilter::WARN, + "INFO" => LevelFilter::INFO, + "DEBUG" => LevelFilter::DEBUG, + "NOTSET" => LevelFilter::TRACE, + _ => LevelFilter::INFO, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rust_log_target_overrides_are_merged_with_vllm_default_level() { + let filter = build_targets_filter(Some("DEBUG"), Some("hyper=warn,tower=error")); + + assert_eq!(filter.to_string(), "tower=error,hyper=warn,debug"); + } + + #[test] + fn rust_log_default_level_overrides_vllm_default_level() { + let filter = build_targets_filter(Some("DEBUG"), Some("warn,hyper=info")); + + assert_eq!(filter.to_string(), "hyper=info,warn"); + } + + #[test] + fn invalid_vllm_level_falls_back_to_info() { + let filter = build_targets_filter(Some("bogus"), None); + + assert_eq!(filter.to_string(), "info"); + } + + #[test] + fn location_path_uses_filename_for_non_ambiguous_files() { + assert_eq!(shorten_file_path("src/cmd/src/logging.rs"), "logging.rs"); + assert_eq!(shorten_file_path("src/chat/lib.rs"), "lib.rs"); + assert_eq!(shorten_file_path("src/chat/main.rs"), "main.rs"); + assert_eq!(shorten_file_path("src/chat/src/xmod.rs"), "xmod.rs"); + } + + #[test] + fn location_path_keeps_more_context_for_common_entrypoint_filenames() { + assert_eq!(shorten_file_path("src/lib.rs"), "src/lib.rs"); + assert_eq!(shorten_file_path("src/chat/src/lib.rs"), "chat/src/lib.rs"); + assert_eq!(shorten_file_path("src/cmd/src/main.rs"), "cmd/src/main.rs"); + assert_eq!(shorten_file_path("mod.rs"), "mod.rs"); + assert_eq!(shorten_file_path("src/chat/src/tool/mod.rs"), "tool/mod.rs"); + } +} diff --git a/rust/src/cmd/src/main.rs b/rust/src/cmd/src/main.rs new file mode 100644 index 00000000000..ce4e37e09bd --- /dev/null +++ b/rust/src/cmd/src/main.rs @@ -0,0 +1,189 @@ +mod cli; +mod logging; + +use std::env; +use std::process::ExitStatus; + +use anyhow::{Context, Result, anyhow, bail}; +use tokio_util::sync::CancellationToken; +use tracing::{info, warn}; +use vllm_managed_engine::ManagedEngineHandle; + +use crate::cli::{Cli, Command}; + +const TOKIO_WORKER_THREADS_ENV: &str = "TOKIO_WORKER_THREADS"; +const DEFAULT_MAX_TOKIO_WORKER_THREADS: usize = 32; + +/// Cap the default number of Tokio worker threads if the user did not +/// explicitly set `TOKIO_WORKER_THREADS` to avoid spawning too many threads on +/// machines with a large number of CPUs, which may lead to excessive context +/// switching and degraded performance. +fn tokio_worker_threads() -> Option { + if env::var_os(TOKIO_WORKER_THREADS_ENV).is_some() { + return None; + } + + std::thread::available_parallelism() + .map(|parallelism| { + let available = parallelism.get(); + let worker_threads = available.min(DEFAULT_MAX_TOKIO_WORKER_THREADS); + if worker_threads < available { + info!( + available_parallelism = available, + capped_worker_threads = worker_threads, + "capping tokio worker threads, set {TOKIO_WORKER_THREADS_ENV} to override" + ); + } + worker_threads + }) + .ok() +} + +/// Reason that caused a managed `serve` session to stop. +#[derive(Debug)] +enum ShutdownReason { + Signal, + Server(anyhow::Error), + EngineExited(ExitStatus), +} + +/// Cancellation token tripped by Ctrl-C or SIGTERM. +fn shutdown_signal() -> CancellationToken { + let token = CancellationToken::new(); + let shutdown = token.clone(); + + tokio::spawn(async move { + let ctrl_c = async { + tokio::signal::ctrl_c().await.expect("failed to install Ctrl-C signal handler"); + }; + + let sigterm = async { + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to install SIGTERM signal handler") + .recv() + .await; + }; + + tokio::select! { + _ = ctrl_c => info!("received shutdown signal (Ctrl-C), shutting down..."), + _ = sigterm => info!("received shutdown signal (SIGTERM), shutting down..."), + } + + shutdown.cancel(); + }); + + token +} + +fn main() -> Result<()> { + logging::init_tracing(); + let cli = Cli::parse(); + + let mut runtime = tokio::runtime::Builder::new_multi_thread(); + runtime.enable_all(); + if let Some(worker_threads) = tokio_worker_threads() { + runtime.worker_threads(worker_threads); + } + + runtime + .build() + .context("failed to build Tokio runtime")? + .block_on(async_main(cli)) +} + +async fn async_main(cli: Cli) -> Result<()> { + match cli.command { + Command::Frontend(args) => vllm_server::serve(args.into_config(), shutdown_signal()).await, + Command::Serve(args) => { + let handshake_port = args.managed_engine.resolve_handshake_port()?; + + if args.managed_engine.data_parallel_size_local == Some(0) { + if args.headless { + bail!("cannot combine `--headless` with `--data-parallel-size-local 0`"); + } + + let handshake_address = args.managed_engine.handshake_address(handshake_port); + info!( + %handshake_address, + engine_count = args.managed_engine.data_parallel_size, + "running Rust frontend without a managed local Python engine" + ); + let config = args.to_frontend_config(handshake_address); + return vllm_server::serve(config, shutdown_signal()).await; + } + + let shutdown_timeout = args.runtime.shutdown_timeout(); + let engine_config = args.to_managed_engine_config(handshake_port); + let handshake_address = engine_config.handshake_address(); + + let engine = ManagedEngineHandle::spawn(engine_config) + .await + .context("failed to start managed Python headless engine")?; + + let shutdown = shutdown_signal(); + + let mut serve_task = if args.headless { + info!("running managed Python headless engine without Rust frontend"); + let shutdown = shutdown.clone(); + tokio::spawn(async move { + shutdown.cancelled().await; + Ok(()) + }) + } else { + let config = args.to_frontend_config(handshake_address); + let shutdown = shutdown.clone(); + tokio::spawn(async move { + let result = vllm_server::serve(config, shutdown).await; + if result.is_ok() { + info!("OpenAI server shut down gracefully"); + } + result + }) + }; + + let shutdown_reason = tokio::select! { + biased; + + // Received shutdown signal via Ctrl-C or SIGTERM. + _ = shutdown.cancelled() => ShutdownReason::Signal, + + // Engine process exited unexpectedly. + status = engine.wait_for_exit() => { + warn!(%status, "managed Python headless engine exited, shutting down..."); + ShutdownReason::EngineExited(status) + } + + // Serve task exited unexpectedly. + serve_result = &mut serve_task => { + let serve_result = serve_result.context("serve task join failed")?; + match serve_result { + Ok(()) => ShutdownReason::Server(anyhow!("OpenAI server shut down unexpectedly without error")), + Err(error) => ShutdownReason::Server(error), + } + } + }; + // Regardless of the shutdown reason, broadcast shutdown signal here to ensure + // that all serving tasks are notified. + shutdown.cancel(); + + // Shutdown begins. Terminate the managed engine first. + engine.shutdown(shutdown_timeout).await?; + info!("managed engine shut down gracefully"); + // Wait for the API server to shut down gracefully by draining in-flight + // requests. + if !matches!(shutdown_reason, ShutdownReason::Server(_)) { + serve_task.await.context("serve task join failed")??; + } + + match shutdown_reason { + ShutdownReason::Signal => Ok(()), + ShutdownReason::Server(error) => { + Err(error.context("OpenAI server shut down unexpectedly")) + } + ShutdownReason::EngineExited(status) => Err(anyhow!( + "managed Python headless engine exited unexpectedly with status {status}" + )), + } + } + } +} diff --git a/rust/src/engine-core-client/Cargo.toml b/rust/src/engine-core-client/Cargo.toml new file mode 100644 index 00000000000..14b9fe14234 --- /dev/null +++ b/rust/src/engine-core-client/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "vllm-engine-core-client" +version.workspace = true +edition.workspace = true +license.workspace = true + +[features] +test-util = ["dep:tempfile"] + +[dependencies] +arc-swap.workspace = true +bytemuck.workspace = true +byteorder.workspace = true +bytes.workspace = true +easy-ext.workspace = true +enum-as-inner.workspace = true +futures.workspace = true +half.workspace = true +hex.workspace = true +itertools.workspace = true +parking_lot.workspace = true +rmp-serde.workspace = true +rmpv.workspace = true +serde.workspace = true +serde_default.workspace = true +serde_json.workspace = true +serde_repr.workspace = true +serde_tuple.workspace = true +serde_with.workspace = true +task-local.workspace = true +tempfile = { workspace = true, optional = true } +thiserror.workspace = true +thiserror-ext.workspace = true +tokio.workspace = true +tokio-util.workspace = true +tracing.workspace = true +vllm-metrics.workspace = true +zeromq.workspace = true + +[dev-dependencies] +anyhow.workspace = true +clap.workspace = true +expect-test.workspace = true +hex.workspace = true +tempfile.workspace = true +tracing-subscriber.workspace = true + +[lints] +workspace = true diff --git a/rust/src/engine-core-client/examples/README.md b/rust/src/engine-core-client/examples/README.md new file mode 100644 index 00000000000..4fa13133d47 --- /dev/null +++ b/rust/src/engine-core-client/examples/README.md @@ -0,0 +1,53 @@ +# Engine-Core Smoke Tests + +Start headless `vllm`: + +```bash +source ../vllm/.venv/bin/activate +HF_HUB_OFFLINE=1 \ +VLLM_LOGGING_LEVEL=DEBUG \ +VLLM_CPU_KVCACHE_SPACE=2 \ +VLLM_HOST_IP=127.0.0.1 \ +VLLM_LOOPBACK_IP=127.0.0.1 \ +python3 -m vllm.entrypoints.cli.main serve Qwen/Qwen3-0.6B \ + --headless \ + --enable-sleep-mode \ + --data-parallel-address 127.0.0.1 \ + --data-parallel-rpc-port 62100 \ + --data-parallel-size-local 1 \ + --max-model-len 512 \ + --dtype float16 +``` + +Run the Rust smoke test through the `vllm-engine-core-client` utility interface: + +```bash +cargo run -p vllm-engine-core-client --example external_engine_utility_call -- \ + --handshake-address tcp://127.0.0.1:62100 \ + --host 127.0.0.1 +``` + +If your current engine setup does not support sleep mode, skip the `sleep` / `wake_up` part of the +smoke: + +```bash +cargo run -p vllm-engine-core-client --example external_engine_utility_call -- \ + --handshake-address tcp://127.0.0.1:62100 \ + --host 127.0.0.1 \ + --skip-sleep-wake +``` + +Run the Rust smoke test for sample logprobs decoding through the raw engine-core request path: + +```bash +cargo run -p vllm-engine-core-client --example external_engine_logprobs -- \ + --handshake-address tcp://127.0.0.1:62100 \ + --host 127.0.0.1 +``` + +This smoke requests a small generated-token `logprobs` payload plus prompt logprobs over a much +longer prompt, so it exercises both the inline and aux-frame decode paths against a real engine. +The Rust client decodes those payloads into semantic per-position records rather than exposing the +raw ndarray/tensor wire shape. + +IMPORTANT: You must restart `vllm` each time you run the smoke test, as the vLLM engine cannot manage frontend closures and subsequent reconnects. In other words, do not reuse existing `vllm` instances, if any. diff --git a/rust/src/engine-core-client/examples/external_engine_logprobs.rs b/rust/src/engine-core-client/examples/external_engine_logprobs.rs new file mode 100644 index 00000000000..08290c69bf5 --- /dev/null +++ b/rust/src/engine-core-client/examples/external_engine_logprobs.rs @@ -0,0 +1,190 @@ +use std::time::Duration; + +use anyhow::{Context, Result, bail}; +use clap::Parser; +use futures::StreamExt as _; +use tokio::time::timeout; +use tracing_subscriber::EnvFilter; +use vllm_engine_core_client::protocol::{ + EngineCoreFinishReason, EngineCoreRequest, EngineCoreSamplingParams, +}; +use vllm_engine_core_client::{ + EngineCoreClient, EngineCoreClientConfig, EngineCoreStreamOutput, TransportMode, +}; + +const BASE_PROMPT_TOKEN_IDS: &[u32] = &[20841, 448, 6896, 25, 23811]; + +#[derive(Debug, Parser)] +#[command(about = "Smoke-test engine-core sample logprobs against an external vLLM engine.")] +struct Args { + #[arg(long)] + handshake_address: String, + #[arg(long, default_value_t = 1)] + engine_count: usize, + #[arg(long, default_value = "Qwen/Qwen3-0.6B")] + model: String, + #[arg(long, default_value = "127.0.0.1")] + host: String, + #[arg(long, default_value_t = 0)] + client_index: u32, + #[arg(long, default_value_t = 30)] + ready_timeout_secs: u64, + #[arg(long, default_value_t = 120)] + output_timeout_secs: u64, + #[arg(long, default_value_t = 1)] + max_tokens: u32, + #[arg(long, default_value_t = 2)] + logprobs: i32, + #[arg(long, default_value_t = 1)] + prompt_logprobs: i32, + #[arg(long, default_value_t = 96)] + prompt_repeats: usize, +} + +fn init_tracing() { + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("vllm_engine_core_client=debug")); + let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init(); +} + +fn unique_request_id() -> String { + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system clock should be after unix epoch") + .as_nanos(); + format!("rust-engine-core-logprobs-{nanos}") +} + +fn build_prompt_token_ids(prompt_repeats: usize) -> Vec { + let repeats = prompt_repeats.max(1); + BASE_PROMPT_TOKEN_IDS.repeat(repeats) +} + +fn build_request( + request_id: String, + prompt_token_ids: Vec, + max_tokens: u32, + logprobs: i32, + prompt_logprobs: i32, + client_index: u32, +) -> EngineCoreRequest { + EngineCoreRequest { + request_id, + prompt_token_ids: Some(prompt_token_ids), + sampling_params: Some(EngineCoreSamplingParams { + max_tokens, + logprobs: Some(logprobs), + prompt_logprobs: Some(prompt_logprobs), + ..EngineCoreSamplingParams::for_test() + }), + arrival_time: 0.0, + client_index, + ..EngineCoreRequest::default() + } +} + +async fn wait_for_final_output( + mut stream: vllm_engine_core_client::EngineCoreOutputStream, +) -> Result { + while let Some(output) = stream.next().await { + let output = output.context("failed to receive engine-core output")?; + if output.finished() { + return Ok(output); + } + } + bail!("request stream ended without a final output") +} + +#[tokio::main(flavor = "multi_thread")] +async fn main() -> Result<()> { + init_tracing(); + let args = Args::parse(); + let ready_timeout = Duration::from_secs(args.ready_timeout_secs); + let output_timeout = Duration::from_secs(args.output_timeout_secs); + let request_id = unique_request_id(); + let prompt_token_ids = build_prompt_token_ids(args.prompt_repeats); + let client = EngineCoreClient::connect(EngineCoreClientConfig { + transport_mode: TransportMode::HandshakeOwner { + handshake_address: args.handshake_address.clone(), + advertised_host: args.host.clone(), + engine_count: args.engine_count, + ready_timeout, + local_input_address: None, + local_output_address: None, + }, + coordinator_mode: None, + model_name: args.model.clone(), + client_index: args.client_index, + }) + .await + .context("failed to connect to external vLLM engine")?; + + println!("model={}", args.model); + println!("handshake_address={}", args.handshake_address); + println!("engine_count={}", args.engine_count); + println!("input_address={}", client.input_address()); + println!("output_address={}", client.output_address()); + println!("engine_identities={:x?}", client.engine_identities()); + + let request = build_request( + request_id.clone(), + prompt_token_ids.clone(), + args.max_tokens, + args.logprobs, + args.prompt_logprobs, + args.client_index, + ); + println!("request_id={request_id}"); + println!("prompt_len={}", prompt_token_ids.len()); + println!("base_prompt_len={}", BASE_PROMPT_TOKEN_IDS.len()); + println!("prompt_repeats={}", args.prompt_repeats); + println!("requested_logprobs={}", args.logprobs); + println!("requested_prompt_logprobs={}", args.prompt_logprobs); + + let stream = client.call(request).await.context("failed to submit engine-core request")?; + let output = timeout(output_timeout, wait_for_final_output(stream)) + .await + .context("timed out waiting for final output")??; + + let finish_reason = output.finish_reason; + let token_ids = output.new_token_ids.clone(); + let logprobs = output + .new_logprobs + .as_ref() + .and_then(|value| value.as_direct()) + .context("engine output did not include decoded sample logprobs")?; + let prompt_logprobs = output + .new_prompt_logprobs_tensors + .as_ref() + .and_then(|value| value.as_direct()) + .context("engine output did not include decoded prompt logprobs")?; + + println!("token_ids={token_ids:?}"); + println!("finish_reason={finish_reason:?}"); + println!("new_logprobs={logprobs:#?}"); + println!("new_prompt_logprobs_tensors={prompt_logprobs:#?}"); + + client.shutdown().await.context("failed to shut down engine-core client")?; + + if finish_reason != Some(EngineCoreFinishReason::Length) { + bail!("unexpected finish_reason: expected Length, got {finish_reason:?}"); + } + if token_ids.is_empty() { + bail!("engine returned no generated token ids"); + } + if logprobs.is_empty() { + bail!("decoded logprobs payload is unexpectedly empty"); + } + if prompt_logprobs.is_empty() { + bail!("decoded prompt logprobs payload is unexpectedly empty"); + } + if prompt_logprobs.len() + 1 < prompt_token_ids.len() { + bail!( + "prompt logprobs rows look too short: prompt_len={}, rows={}", + prompt_token_ids.len(), + prompt_logprobs.len() + ); + } + + Ok(()) +} diff --git a/rust/src/engine-core-client/examples/external_engine_utility_call.rs b/rust/src/engine-core-client/examples/external_engine_utility_call.rs new file mode 100644 index 00000000000..ee2a4e57b7a --- /dev/null +++ b/rust/src/engine-core-client/examples/external_engine_utility_call.rs @@ -0,0 +1,143 @@ +use std::time::Duration; + +use anyhow::{Context, Result, bail}; +use clap::Parser; +use tracing_subscriber::EnvFilter; +use vllm_engine_core_client::{EngineCoreClient, EngineCoreClientConfig, TransportMode}; + +#[derive(Debug, Parser)] +#[command(about = "Smoke-test EngineCoreClient utility calls against an external vLLM engine.")] +struct Args { + #[arg(long)] + handshake_address: String, + #[arg(long, default_value_t = 1)] + engine_count: usize, + #[arg(long, default_value = "Qwen/Qwen3-0.6B")] + model: String, + #[arg(long, default_value = "127.0.0.1")] + host: String, + #[arg(long, default_value_t = 0)] + client_index: u32, + #[arg(long, default_value_t = 30)] + ready_timeout_secs: u64, + #[arg( + long, + default_value_t = false, + help = "Expected initial result of is_sleeping() before running the smoke steps." + )] + expected_is_sleeping: bool, + #[arg(long, default_value_t = false)] + reset_running_requests: bool, + #[arg(long, default_value_t = false)] + reset_external: bool, + #[arg(long, default_value_t = 1)] + sleep_level: u32, + #[arg(long, default_value = "abort")] + sleep_mode: String, + #[arg( + long, + default_value_t = false, + help = "Skip sleep/wake_up calls when the engine was not started with sleep-mode support." + )] + skip_sleep_wake: bool, +} + +fn init_tracing() { + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("vllm_engine_core_client=debug")); + let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init(); +} + +#[tokio::main(flavor = "multi_thread")] +async fn main() -> Result<()> { + init_tracing(); + let args = Args::parse(); + let client = EngineCoreClient::connect(EngineCoreClientConfig { + transport_mode: TransportMode::HandshakeOwner { + handshake_address: args.handshake_address.clone(), + advertised_host: args.host.clone(), + engine_count: args.engine_count, + ready_timeout: Duration::from_secs(args.ready_timeout_secs), + local_input_address: None, + local_output_address: None, + }, + coordinator_mode: None, + model_name: args.model.clone(), + client_index: args.client_index, + }) + .await + .context("failed to connect to external vLLM engine")?; + + println!("model={}", args.model); + println!("handshake_address={}", args.handshake_address); + println!("engine_count={}", args.engine_count); + println!("input_address={}", client.input_address()); + println!("output_address={}", client.output_address()); + println!("engine_identities={:x?}", client.engine_identities()); + + let initial_is_sleeping = + client.is_sleeping().await.context("failed to call is_sleeping utility")?; + + println!("is_sleeping(initial)={initial_is_sleeping}"); + + if initial_is_sleeping != args.expected_is_sleeping { + bail!( + "unexpected initial is_sleeping state: expected {}, got {}", + args.expected_is_sleeping, + initial_is_sleeping + ); + } + + let reset_prefix_cache = client + .reset_prefix_cache(args.reset_running_requests, args.reset_external) + .await + .context("failed to call reset_prefix_cache utility")?; + println!("reset_prefix_cache={reset_prefix_cache}"); + + client.reset_mm_cache().await.context("failed to call reset_mm_cache utility")?; + println!("reset_mm_cache=ok"); + + client + .reset_encoder_cache() + .await + .context("failed to call reset_encoder_cache utility")?; + println!("reset_encoder_cache=ok"); + + if args.skip_sleep_wake { + println!("sleep_wake=skipped"); + } else { + client.sleep(args.sleep_level, &args.sleep_mode).await.with_context(|| { + format!( + "failed to call sleep utility with level={} mode={}", + args.sleep_level, args.sleep_mode + ) + })?; + println!( + "sleep=ok level={} mode={}", + args.sleep_level, args.sleep_mode + ); + + let sleeping_after_sleep = + client.is_sleeping().await.context("failed to call is_sleeping after sleep")?; + println!("is_sleeping(after_sleep)={sleeping_after_sleep}"); + + if !sleeping_after_sleep { + bail!("engine should report sleeping=true after sleep()"); + } + + client.wake_up(None).await.context("failed to call wake_up utility")?; + println!("wake_up=ok"); + + let sleeping_after_wake = + client.is_sleeping().await.context("failed to call is_sleeping after wake_up")?; + println!("is_sleeping(after_wake)={sleeping_after_wake}"); + + if sleeping_after_wake { + bail!("engine should report sleeping=false after wake_up()"); + } + } + + client.shutdown().await.context("failed to shut down engine-core client")?; + + Ok(()) +} diff --git a/rust/src/engine-core-client/src/client.rs b/rust/src/engine-core-client/src/client.rs new file mode 100644 index 00000000000..a224e560167 --- /dev/null +++ b/rust/src/engine-core-client/src/client.rs @@ -0,0 +1,673 @@ +use std::sync::Arc; +use std::time::Duration; + +use futures::future::{join_all, try_join_all}; +use tokio::sync::mpsc; +use tokio_util::task::AbortOnDropHandle; +use tracing::{debug, info, trace}; + +use crate::client::imp::{ClientInner, run_abort_loop, run_output_dispatcher_loop}; +use crate::coordinator::CoordinatorHandle; +use crate::error::{Error, Result}; +use crate::protocol::handshake::EngineCoreReadyResponse; +use crate::protocol::{ + EngineCoreRequest, EngineCoreRequestType, EngineCoreUtilityRequest, ModelDtype, +}; +use crate::transport::{self, ConnectedEngine}; + +pub(crate) mod imp; +mod state; +mod stream; + +pub use stream::{EngineCoreOutputStream, EngineCoreStreamOutput}; + +/// How the frontend acquires its request/response transport with Python +/// `EngineCoreProc`s. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TransportMode { + /// The Rust process owns the startup handshake and allocates or binds the + /// frontend transport addresses itself before replying to engine + /// `HELLO` messages. + HandshakeOwner { + /// Shared handshake endpoint that engines dial during startup. + handshake_address: String, + /// Host/IP that engines should use to connect back to the frontend + /// transport sockets. + advertised_host: String, + /// Total number of engines expected to join this transport. + engine_count: usize, + /// Maximum time to wait for each startup phase to complete. + ready_timeout: Duration, + /// Optional explicit bind address for the input ROUTER socket. + local_input_address: Option, + /// Optional explicit bind address for the output PULL socket. + local_output_address: Option, + }, + + /// The Python supervisor has already chosen the frontend transport + /// addresses, and the Rust process only needs to bind them and wait for + /// engine registration frames. + Bootstrapped { + /// Input ROUTER socket address that engines will connect to for + /// requests. + input_address: String, + /// Output PULL socket address that engines will connect to for + /// responses. + output_address: String, + /// Total number of engines expected to register on this transport. + engine_count: usize, + /// Maximum time to wait for all expected engines to register. + ready_timeout: Duration, + }, +} + +/// Which coordinator implementation should be active when one is present for a +/// frontend client. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CoordinatorMode { + /// Run the Rust in-process coordinator for managed `serve` deployments. + InProc, + /// Connect to an external coordinator owned by another process. + External { address: String }, +} + +/// Configuration for connecting a Rust frontend client to an already running +/// Python `EngineCoreProc`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EngineCoreClientConfig { + /// Frontend-to-engine transport setup. + pub transport_mode: TransportMode, + /// Frontend-side coordinator behavior, or `None` when requests should flow + /// directly to engines without any coordinator involvement. + pub coordinator_mode: Option, + /// Model name used for frontend-side metrics labels. + pub model_name: String, + /// Frontend client index stamped onto every request. + pub client_index: u32, +} + +impl EngineCoreClientConfig { + /// Create a new client config with the given handshake address, expecting a + /// single engine, and default values for all other fields. + pub fn new_single(handshake_address: impl Into) -> Self { + Self { + transport_mode: TransportMode::HandshakeOwner { + handshake_address: handshake_address.into(), + advertised_host: "127.0.0.1".to_string(), + engine_count: 1, + ready_timeout: Duration::from_secs(30), + local_input_address: None, + local_output_address: None, + }, + coordinator_mode: None, + model_name: String::new(), + client_index: 0, + } + } + + /// Set the model name used by frontend-side metrics and diagnostics. + pub fn with_model_name(mut self, model_name: impl Into) -> Self { + self.model_name = model_name.into(); + self + } + + /// Override the client index stamped onto every outgoing request. + pub fn with_client_index(mut self, client_index: u32) -> Self { + self.client_index = client_index; + self + } + + /// Override the optional coordinator mode for this client config. + pub fn with_coordinator_mode(mut self, coordinator_mode: Option) -> Self { + self.coordinator_mode = coordinator_mode; + self + } + + /// Override the locally bound input/output addresses for handshake-owned + /// transport mode. + /// + /// This is primarily used by tests that want deterministic IPC endpoints + /// while still exercising the handshake-owned startup path. + pub fn with_local_input_output_addresses( + mut self, + local_input_address: Option, + local_output_address: Option, + ) -> Self { + let TransportMode::HandshakeOwner { + local_input_address: current_input, + local_output_address: current_output, + .. + } = &mut self.transport_mode + else { + panic!("local input/output overrides are only valid in handshake-owned mode"); + }; + *current_input = local_input_address; + *current_output = local_output_address; + self + } +} + +/// The reason a request stream is being aborted when its output stream is +/// dropped. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum AbortCause { + /// The consumer dropped the stream before the request reached a terminal + /// engine output. + #[default] + DroppedStream, + /// The frontend matched a stop string locally and intentionally stopped + /// consuming the stream. + StopStringMatched, +} + +task_local::task_local! { + static ABORT_CAUSE: AbortCause; +} + +impl AbortCause { + /// Return the abort cause currently associated with this task, or + /// [`AbortCause::DroppedStream`] by default. + pub fn current() -> Self { + ABORT_CAUSE.try_get().unwrap_or_default() + } + + /// Drop one value while marking the drop as happening for this abort cause. + pub fn drop_as(self, value: T) { + ABORT_CAUSE.sync_scope(self, move || drop(value)); + } +} + +/// Internal auto-abort work item sent from stream `Drop` handlers to the abort +/// worker. +#[derive(Debug, Clone)] +pub(crate) struct AbortRequest { + request_id: String, + cause: AbortCause, +} + +/// Default ZMQ-based implementation that talks directly to a Python +/// `EngineCoreProc`. +pub struct EngineCoreClient { + config: EngineCoreClientConfig, + input_address: String, + output_address: String, + engines: Vec, + inner: Arc, + coordinator: Option, + abort_tx: mpsc::UnboundedSender, + + // Background tasks + output_task: AbortOnDropHandle<()>, + dispatcher_task: AbortOnDropHandle<()>, + abort_task: AbortOnDropHandle<()>, + coordinator_output_task: Option>, + coordinator_task: Option>, +} + +impl EngineCoreClient { + /// Connect to Python `EngineCoreProc`s using the configured + /// transport/coordinator modes. + /// + /// In handshake-owned mode this method drives the full engine startup + /// handshake. In bootstrapped mode it binds the provided frontend + /// sockets and waits for the expected engine registration frames. + pub async fn connect(config: EngineCoreClientConfig) -> Result { + let connected = match &config.transport_mode { + TransportMode::HandshakeOwner { + handshake_address, + advertised_host, + engine_count, + ready_timeout, + local_input_address, + local_output_address, + } => { + let enable_inproc_coordinator = match config.coordinator_mode { + None => false, + Some(CoordinatorMode::InProc) => true, + Some(CoordinatorMode::External { .. }) => { + return Err(Error::UnsupportedExternalCoordinator); + } + }; + + transport::connect_handshake( + handshake_address, + *engine_count, + advertised_host, + local_input_address.as_deref(), + local_output_address.as_deref(), + enable_inproc_coordinator, + *ready_timeout, + ) + .await? + } + + TransportMode::Bootstrapped { + input_address, + output_address, + engine_count, + ready_timeout, + } => { + if let Some(CoordinatorMode::InProc) = config.coordinator_mode { + panic!("cannot use in-process coordinator with bootstrapped transport mode") + } + + transport::connect_bootstrapped( + input_address, + output_address, + *engine_count, + *ready_timeout, + ) + .await? + } + }; + + Self::from_connected(config, connected).await + } + + /// Create a new client instance from the connected transport state after + /// the startup handshake completes. + async fn from_connected( + config: EngineCoreClientConfig, + connected: transport::ConnectedTransport, + ) -> Result { + let (output_tx, output_rx) = mpsc::channel(64); + let (abort_tx, abort_rx) = mpsc::unbounded_channel(); + let engines = connected.engines; + let inner = Arc::new(ClientInner::new( + connected.input_send, + config.model_name.clone(), + &engines, + )); + let output_task = AbortOnDropHandle::new(tokio::spawn(transport::run_output_loop( + connected.output_socket, + output_tx, + ))); + let dispatcher_task = AbortOnDropHandle::new(tokio::spawn(run_output_dispatcher_loop( + inner.clone(), + output_rx, + ))); + let abort_task = + AbortOnDropHandle::new(tokio::spawn(run_abort_loop(inner.clone(), abort_rx))); + + // If any engine reported a dp_stats_address in its ready response, use it + // as the external coordinator address. + let dp_stats_address: Option = engines + .iter() + .filter_map(|e| e.ready_response.as_ref()) + .find_map(|r| r.dp_stats_address.clone()); + + let (coordinator, coordinator_output_task, coordinator_task) = + if let Some(coordinator_transport) = connected.coordinator { + let (handle, runner) = + CoordinatorHandle::new_inproc(coordinator_transport.input_socket); + let (coordinator_output_tx, coordinator_output_rx) = mpsc::channel(64); + let coordinator_output_task = + AbortOnDropHandle::new(tokio::spawn(transport::run_output_loop( + coordinator_transport.output_socket, + coordinator_output_tx, + ))); + let coordinator_task = AbortOnDropHandle::new(tokio::spawn( + runner.run(coordinator_output_rx, inner.clone()), + )); + ( + Some(handle), + Some(coordinator_output_task), + Some(coordinator_task), + ) + } else if let Some(address) = + dp_stats_address.as_deref().or(match config.coordinator_mode.as_ref() { + Some(CoordinatorMode::External { address }) => Some(address.as_str()), + _ => None, + }) + { + let (handle, service) = CoordinatorHandle::connect_external(address).await?; + let coordinator_task = + AbortOnDropHandle::new(tokio::spawn(service.run(inner.clone()))); + (Some(handle), None, Some(coordinator_task)) + } else { + (None, None, None) + }; + + Ok(Self { + config, + input_address: connected.input_address, + output_address: connected.output_address, + engines, + inner, + coordinator, + abort_tx, + output_task, + dispatcher_task, + abort_task, + coordinator_output_task, + coordinator_task, + }) + } + + /// Return the address of the input socket that the client uses to send + /// requests to the engine. + pub fn input_address(&self) -> &str { + &self.input_address + } + + /// Return the address of the output socket that the client listens on for + /// engine responses. + pub fn output_address(&self) -> &str { + &self.output_address + } + + /// Return the number of engines connected to this client. + pub fn engine_count(&self) -> usize { + self.engines.len() + } + + /// Return the engine identities of all engines connected to this client. + pub fn engine_identities(&self) -> Vec<&[u8]> { + self.engines.iter().map(|engine| &*engine.engine_id).collect() + } + + /// Return the ready responses received from all engines on the input + /// socket. + pub fn ready_responses(&self) -> Vec<&EngineCoreReadyResponse> { + self.engines + .iter() + .filter_map(|engine| engine.ready_response.as_ref()) + .collect() + } + + /// Return the engine-reported effective model dtype, when available. + pub fn model_dtype(&self) -> Option { + self.engines + .iter() + .filter_map(|engine| engine.ready_response.as_ref()) + .find_map(|response| response.dtype) + } + + /// Return the total number of GPU blocks summed across all connected + /// engines. + pub fn total_num_gpu_blocks(&self) -> u64 { + self.engines + .iter() + .filter_map(|engine| engine.ready_response.as_ref()) + .map(|r| r.num_gpu_blocks) + .sum() + } + + /// Return the minimum engine-reported `max_model_len` across all engines. + /// + /// This is the auto-fitted value after KV cache profiling and may differ + /// from the originally configured value. + pub fn max_model_len(&self) -> Option { + self.engines + .iter() + .filter_map(|e| e.ready_response.as_ref()) + .map(|r| r.max_model_len as u32) + .min() + } + + /// Get the model name associated with this client used for metrics + /// labeling. + pub fn model_name(&self) -> &str { + self.inner.model_name() + } + + /// Return whether the client still considers the engine healthy. + pub fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + /// Return the first persistent health error observed by the client, if any. + pub fn health_error(&self) -> Option> { + self.inner.health_error() + } +} + +// Client API implementation. +impl EngineCoreClient { + /// Add a new request to the engine and return a per-request raw output + /// stream. + pub async fn call(&self, mut req: EngineCoreRequest) -> Result { + req.client_index = self.config.client_index; + req.validate()?; + trace!( + request_id = %req.request_id, + client_index = req.client_index, + current_wave = req.current_wave, + request = ?req, + "sending add request" + ); + + let request_id = req.request_id.clone(); + let data_parallel_rank = req.data_parallel_rank; + let (engine_id, rx) = + self.inner.register_request(request_id.clone(), data_parallel_rank)?; + + let result: Result<()> = async { + if let Some(coordinator) = self.coordinator.as_ref() { + let snapshot = coordinator.snapshot(); + req.current_wave = snapshot.current_wave; + if !snapshot.engines_running { + coordinator.notify_first_request(engine_id.clone())?; + } + } + + debug!( + request_id = req.request_id, + ?engine_id, + "registered request to engine" + ); + + self.inner.send_to_engine(&engine_id, EngineCoreRequestType::Add, &req).await?; + Ok(()) + } + .await; + + // Failed to send the request to the engine, roll back the registration. + if let Err(error) = result { + self.inner.rollback_request(&request_id); + return Err(error); + } + + Ok(EngineCoreOutputStream::new( + request_id, + self.abort_tx.clone(), + rx, + )) + } + + /// Abort currently in-flight requests by request ID. + pub async fn abort(&self, ids: &[String]) -> Result<()> { + let abortable = self.inner.abortable_request_ids(ids)?; + + trace!(request_ids = ?ids, abortable_request_ids = ?abortable, "sending abort request ids"); + + if abortable.is_empty() { + return Ok(()); + } + + for (engine_id, request_ids) in abortable { + self.inner.do_abort_requests(&engine_id, &request_ids).await?; + } + Ok(()) + } + + /// Call a typed utility method on all connected engines, returning one + /// decoded result per connected engine if all calls succeed or an error + /// if any call fails. + /// + /// Callers should pass utility arguments using Rust tuple semantics so the + /// encoded payload matches Python's `(client_index, call_id, + /// method_name, args)` contract: `()`, `(arg,)`, `(arg1, arg2)`, etc. + pub async fn call_utility(&self, method: &str, args: A) -> Result> + where + T: serde::de::DeserializeOwned, + A: serde::Serialize + std::fmt::Debug, + { + trace!( + method, + client_index = self.config.client_index, + engine_count = self.engines.len(), + "sending utility request" + ); + + // Phase 1: allocate one call id per engine and build the per-engine + // request payloads up-front. Any failure here (registry closed, encode + // error) must roll back the call ids already allocated so they do not + // leak in the utility registry until shutdown. + let mut pending_calls = Vec::with_capacity(self.engines.len()); + let mut prepared_sends = Vec::with_capacity(self.engines.len()); + for engine in &self.engines { + let (call_id, rx) = match self.inner.allocate_and_register_utility_call() { + Ok(pair) => pair, + Err(err) => { + self.inner.unregister_utility_calls(pending_calls.iter().map(|(id, _)| *id)); + return Err(err); + } + }; + let request = match EngineCoreUtilityRequest::new( + self.config.client_index, + call_id, + method, + &args, + ) { + Ok(request) => request, + Err(err) => { + self.inner.unregister_utility_calls( + pending_calls.iter().map(|(id, _)| *id).chain(std::iter::once(call_id)), + ); + return Err(err); + } + }; + pending_calls.push((call_id, rx)); + prepared_sends.push((&engine.engine_id, request)); + } + + // Phase 2: dispatch every utility request concurrently. `try_join_all` + // fails fast on the first transport error and drops the remaining send + // futures; any engines that already received the request will reply, + // but those replies are simply dropped because we roll back the call + // ids below. + let send_futures = prepared_sends.iter().map(|(engine_id, request)| { + self.inner.send_to_engine(engine_id, EngineCoreRequestType::Utility, request) + }); + if let Err(err) = try_join_all(send_futures).await { + self.inner.unregister_utility_calls(pending_calls.iter().map(|(id, _)| *id)); + return Err(err); + } + + // Phase 3: wait for all engines to respond and preserve the per-engine + // result list. + let futures = pending_calls.into_iter().map(|(call_id, rx)| async move { + rx.await + .map_err(|_| Error::UtilityCallClosed { + method: method.to_string(), + call_id, + })?? + .into_typed_result(method) + }); + try_join_all(futures).await + } + + /// Execute `collective_rpc` on all engines and flatten all engine results + /// into one list. + pub async fn collective_rpc( + &self, + method: &str, + timeout: Option, + args: A, + kwargs: K, + ) -> Result> + where + A: serde::Serialize + std::fmt::Debug, + K: serde::Serialize + std::fmt::Debug, + { + let results = self + .call_utility::("collective_rpc", (method, timeout, args, kwargs)) + .await?; + + Ok(results + .into_iter() + .flat_map(|result| match result { + // Each engine's `collective_rpc` result is itself the worker-level result list. + rmpv::Value::Array(results) => results, + other => vec![other], + }) + .collect()) + } + + /// Return whether the engine is currently sleeping at any level. + pub async fn is_sleeping(&self) -> Result { + // TODO: we only return the result of the first engine here. + Ok(self.call_utility("is_sleeping", ()).await?[0]) + } + + /// Reset the multi-modal cache. + pub async fn reset_mm_cache(&self) -> Result<()> { + self.call_utility::<(), _>("reset_mm_cache", ()).await?; + Ok(()) + } + + /// Reset the encoder cache. + pub async fn reset_encoder_cache(&self) -> Result<()> { + self.call_utility::<(), _>("reset_encoder_cache", ()).await?; + Ok(()) + } + + /// Reset the prefix cache and optionally the external connector cache. + pub async fn reset_prefix_cache( + &self, + reset_running_requests: bool, + reset_connector: bool, + ) -> Result { + // TODO: we only return the result of the first engine here. + Ok(self + .call_utility( + "reset_prefix_cache", + (reset_running_requests, reset_connector), + ) + .await?[0]) + } + + /// Put the engine to sleep. + pub async fn sleep(&self, level: u32, mode: &str) -> Result<()> { + self.call_utility::<(), _>("sleep", (level, mode)).await?; + Ok(()) + } + + /// Wake the engine from sleep, optionally limiting the wake-up to specific + /// tags. + pub async fn wake_up(&self, tags: Option>) -> Result<()> { + self.call_utility::<(), _>("wake_up", (tags,)).await?; + Ok(()) + } + + /// Shut down local client tasks and close transport state. + pub async fn shutdown(self) -> Result<()> { + let Self { + inner, + abort_tx, + output_task, + dispatcher_task, + abort_task, + coordinator_output_task, + coordinator_task, + .. + } = self; + + info!("shutting down engine-core client"); + inner.shutdown(); + drop(abort_tx); + + // Abort all client tasks first, then await them. + // Note the aborting orders here. + let mut tasks = vec![abort_task, dispatcher_task, output_task]; + tasks.extend(coordinator_task); + tasks.extend(coordinator_output_task); + + tasks.iter().for_each(|t| t.abort()); + join_all(tasks).await; + + info!("engine-core client shut down"); + Ok(()) + } +} diff --git a/rust/src/engine-core-client/src/client/imp.rs b/rust/src/engine-core-client/src/client/imp.rs new file mode 100644 index 00000000000..daf1a5f7d71 --- /dev/null +++ b/rust/src/engine-core-client/src/client/imp.rs @@ -0,0 +1,410 @@ +use std::collections::BTreeMap; +use std::slice; +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use parking_lot::Mutex; +use thiserror_ext::AsReport as _; +use tokio::sync::mpsc; +use tracing::{debug, info, trace, warn}; +use vllm_metrics::METRICS; +use zeromq::RouterSendHalf; + +use crate::client::state::{OutputReceiver, RequestRegistry, UtilityReceiver, UtilityRegistry}; +use crate::client::stream::EngineCoreStreamOutput; +use crate::client::{AbortCause, AbortRequest}; +use crate::error::{client_closed, dispatcher_closed, unexpected_dispatcher_output}; +use crate::metrics::record_scheduler_stats; +use crate::protocol::stats::SchedulerStats; +use crate::protocol::{ + ClassifiedEngineCoreOutputs, EngineCoreOutput, EngineCoreOutputs, EngineCoreRequestType, + UtilityOutput, encode_msgpack, +}; +use crate::transport::{ConnectedEngine, EngineId}; +use crate::{Error, Result, transport}; + +pub(crate) struct ClientInner { + input_send: RouterSendHalf, + model_name: String, + request_reg: Mutex, + utility_reg: Mutex, + health_error: ArcSwapOption, +} + +impl ClientInner { + /// Create a new instance with the given input send half after the startup + /// handshake completes. + pub fn new( + input_send: RouterSendHalf, + model_name: String, + engines: &[ConnectedEngine], + ) -> Self { + Self { + input_send, + model_name, + request_reg: Mutex::new(RequestRegistry::new(engines)), + utility_reg: Mutex::new(UtilityRegistry::default()), + health_error: ArcSwapOption::empty(), + } + } + + /// Get the model name associated with this client used for metrics + /// labeling. + pub fn model_name(&self) -> &str { + &self.model_name + } + + /// Register a newly added request. Return the selected engine id and the + /// per-request output channel bound to its `request_id`. + /// + /// When `data_parallel_rank` is provided, the request is routed to that + /// specific engine rank, bypassing load balancing. + pub fn register_request( + &self, + request_id: String, + data_parallel_rank: Option, + ) -> Result<(EngineId, OutputReceiver)> { + let mut registry = self.request_reg.lock(); + if registry.is_closed() { + return Err(self.closed_error()); + } + registry.register(request_id, data_parallel_rank) + } + + /// Allocate the next utility `call_id` and register its waiting receiver. + pub fn allocate_and_register_utility_call(&self) -> Result<(i64, UtilityReceiver)> { + let mut registry = self.utility_reg.lock(); + if registry.is_closed() { + return Err(self.closed_error()); + } + Ok(registry.allocate_and_register()) + } + + /// Undo a batch of utility call allocations when the fan-out send fails + /// partway through. Silently ignores unknown call ids so callers can pass + /// the full set without first filtering successful sends. + pub fn unregister_utility_calls(&self, call_ids: impl IntoIterator) { + self.utility_reg.lock().unregister_many(call_ids); + } + + /// Undo a request registration when `add_request()` fails. + pub fn rollback_request(&self, request_id: &str) { + let _ = self.request_reg.lock().remove(request_id); + } + + /// Filter the given request IDs to the subset that are still tracked as + /// active and can be aborted, grouped by the engine that originally + /// accepted them. + pub fn abortable_request_ids( + &self, + request_ids: &[String], + ) -> Result>> { + let registry = self.request_reg.lock(); + if registry.is_closed() { + return Err(self.closed_error()); + } + Ok(registry.abortable_request_ids(request_ids)) + } + + /// Obtain the stream sender for one output. If it indicates the request is + /// finished, it will be removed from the registry. + pub fn take_sender_for_output( + &self, + output: &EngineCoreOutput, + ) -> Option>> { + self.request_reg.lock().sender_for_output(output) + } + + /// Remove a batch of requests that have finished or aborted, returning + /// their stream senders. + pub fn finish_requests<'a>( + &self, + request_ids: impl IntoIterator, + ) -> Vec>> { + self.request_reg.lock().finish_many(request_ids) + } + + /// Apply one scheduler stats update for the given engine to the local + /// routing state. Returns `false` if the engine is unknown to the + /// client. + pub fn apply_scheduler_stats(&self, engine_index: u32, stats: &SchedulerStats) -> bool { + self.request_reg.lock().apply_scheduler_stats(engine_index, stats) + } + + /// Close all active request streams and utility calls with the first + /// persistent health error. + pub fn close_registries(&self, error: Arc) { + let persistent_error = self.record_health_error(error); + let request_senders = self.request_reg.lock().close(); + let utility_senders = self.utility_reg.lock().close(); + + // Notify all ongoing requests that the client is closed. + for sender in request_senders { + let _ = sender.send(Err(Error::Shared(persistent_error.clone()))); + } + for sender in utility_senders { + let _ = sender.send(Err(Error::Shared(persistent_error.clone()))); + } + } + + /// Return the first persistent health error observed by the client, if any. + pub fn health_error(&self) -> Option> { + self.health_error.load_full() + } + + /// Return whether the client still considers the engine healthy. + pub fn is_healthy(&self) -> bool { + self.health_error.load().is_none() + } + + /// Resolve one utility output to the waiting caller. Returns `true` if a + /// waiting caller existed. + pub fn resolve_utility_output(&self, output: UtilityOutput) -> bool { + match self.utility_reg.lock().resolve(&output.call_id) { + Some(sender) => { + sender.send(Ok(output)).unwrap_or_default(); + true + } + None => false, + } + } + + /// Send the given message to the engine. The request should be first + /// registered via `register_request()` to ensure the request stream is + /// tracked. + pub async fn send_to_engine( + &self, + engine_id: &EngineId, + request_type: EngineCoreRequestType, + payload: &T, + ) -> Result<()> + where + T: serde::Serialize + std::fmt::Debug, + { + // TODO: for `EngineCoreRequest`, split outbound tensor raw views into aux + // frames instead of always producing a single msgpack frame. + let payload = encode_msgpack(payload)?; + let mut input_send = self.input_send.clone(); + transport::send_message(&mut input_send, engine_id, request_type.to_frame(), payload) + .await?; + Ok(()) + } + + /// Handle an abort request by sending the abort message to the engine. + pub async fn do_abort_requests( + &self, + engine_id: &EngineId, + request_ids: &[String], + ) -> Result<()> { + self.send_to_engine(engine_id, EngineCoreRequestType::Abort, &request_ids).await + } + + /// Shut down by closing all active request streams and utility calls with a + /// sticky client closed error. + pub fn shutdown(&self) { + self.close_registries(Arc::new(client_closed!("engine-core client shut down"))); + } + + /// Remove the request from the active registry for auto-abort and return + /// the engine that the request was originally routed to, if it is still + /// active. + pub fn take_auto_abort_target(&self, request_id: &str) -> Option { + let mut registry = self.request_reg.lock(); + let (_, engine_id) = registry.remove(request_id)?; + if registry.is_closed() { + return None; + } + Some(engine_id) + } + + /// Publish the first persistent health error and return the sticky error + /// recorded for this client. Later failures do not overwrite the first + /// one so `/health` and post-close callers observe a stable cause. + fn record_health_error(&self, error: Arc) -> Arc { + if let Some(existing) = self.health_error.load_full() { + return existing; + } + self.health_error + .rcu(|current| current.clone().unwrap_or_else(|| error.clone())); + self.health_error + .load_full() + .expect("health error must be recorded before registries close") + } + + /// Assert there is a recorded health error and return a `Shared` variant + /// wrapping it for error returns when the client is already closed. + fn closed_error(&self) -> Error { + Error::Shared(self.health_error.load_full().expect( + "closed registry must have a recorded health error before rejecting new operations", + )) + } +} + +/// Background loop that listens for request IDs to abort and sends abort +/// messages to the engine. This is used to implement the auto-abort behavior +/// when a request stream is dropped without being properly terminated. +pub(crate) async fn run_abort_loop( + inner: Arc, + mut abort_rx: mpsc::UnboundedReceiver, +) { + // TODO: receive and abort requests in batch + while let Some(AbortRequest { request_id, cause }) = abort_rx.recv().await { + let Some(engine_id) = inner.take_auto_abort_target(&request_id) else { + debug!(request_id, "skip auto-abort for inactive request"); + continue; + }; + + match cause { + AbortCause::DroppedStream => { + info!(request_id, "auto-aborting request due to dropped stream") + } + AbortCause::StopStringMatched => { + debug!( + request_id, + "auto-aborting request due to stop string matched" + ) + } + } + + if let Err(error) = inner.do_abort_requests(&engine_id, slice::from_ref(&request_id)).await + { + warn!( + request_id, + ?engine_id, + error = %error.as_report(), + "failed to auto-abort dropped request stream" + ); + } + } +} + +/// Background loop that listens for engine-core outputs and dispatches them to +/// the corresponding request streams based on their `request_id`. +pub(crate) async fn run_output_dispatcher_loop( + inner: Arc, + mut output_rx: mpsc::Receiver>, +) { + let result: Result<()> = async { + loop { + let outputs = match output_rx.recv().await { + Some(outputs) => outputs, + None => Err(dispatcher_closed!( + "engine-core output dispatcher channel closed" + )), + }?; + + match outputs.classify() { + ClassifiedEngineCoreOutputs::RequestBatch(batch) => { + for output in batch.outputs { + let request_id = output.request_id.clone(); + let Some(sender) = inner.take_sender_for_output(&output) else { + debug!(request_id, "dropping output for inactive request"); + continue; + }; + + let wrapped_output = EngineCoreStreamOutput { + engine_index: batch.engine_index, + timestamp: batch.timestamp, + output, + }; + if sender.send(Ok(wrapped_output)).is_err() { + debug!(request_id, "request output stream receiver dropped"); + } + } + + // The sender for normally-finished requests should have already been removed + // from the registry when their final output was dispatched + // above. This serves as a safety net to capture any + // requests marked as finished by the engine. + if let Some(finished_requests) = batch.finished_requests.as_ref() { + for request_id in finished_requests { + trace!(request_id, "request completed via finished_requests"); + } + drop(inner.finish_requests(finished_requests)); + } + + if let Some(scheduler_stats) = batch.scheduler_stats.as_ref() { + if !inner.apply_scheduler_stats(batch.engine_index, scheduler_stats) { + debug!( + engine_index = batch.engine_index, + "dropping scheduler stats for unknown engine" + ); + } + record_scheduler_stats( + &METRICS.scheduler, + inner.model_name(), + batch.engine_index, + scheduler_stats, + ); + } + } + ClassifiedEngineCoreOutputs::Utility(utility) => { + let call_id = utility.output.call_id; + if inner.resolve_utility_output(utility.output) { + trace!( + call_id, + engine_index = utility.engine_index, + "resolved utility output" + ); + } else { + warn!( + call_id, + engine_index = utility.engine_index, + "dropping output for unexpected utility call" + ); + } + } + other @ (ClassifiedEngineCoreOutputs::DpControl { .. } + | ClassifiedEngineCoreOutputs::Other(_)) => { + Err::<(), _>(unexpected_dispatcher_output!( + "received unexpected output on main dispatcher path: {other:?}" + ))?; + } + } + } + } + .await; + let Err(error) = result else { return }; + + warn!(error = %error.as_report(), "output dispatcher exiting with error"); + inner.close_registries(Arc::new(error)); +} + +#[cfg(test)] +mod tests { + use zeromq::{RouterSocket, Socket}; + + use super::*; + + async fn test_inner() -> ClientInner { + let mut socket = RouterSocket::new(); + socket.bind("tcp://127.0.0.1:0").await.unwrap(); + let (send, _) = socket.split(); + ClientInner::new( + send, + "test-model".to_string(), + &[ConnectedEngine { + engine_id: EngineId::from(b"engine-0"), + ready_response: None, + }], + ) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn close_registries_records_first_health_error_only() { + let inner = test_inner().await; + + inner.close_registries(Arc::new(Error::EngineCoreDead)); + assert!(!inner.is_healthy()); + assert!(matches!( + inner.health_error().as_deref(), + Some(Error::EngineCoreDead) + )); + + inner.close_registries(Arc::new(client_closed!("shutdown"))); + assert!(matches!( + inner.health_error().as_deref(), + Some(Error::EngineCoreDead) + )); + } +} diff --git a/rust/src/engine-core-client/src/client/state.rs b/rust/src/engine-core-client/src/client/state.rs new file mode 100644 index 00000000000..87e0263f317 --- /dev/null +++ b/rust/src/engine-core-client/src/client/state.rs @@ -0,0 +1,670 @@ +use std::collections::BTreeMap; +use std::sync::atomic::{AtomicI64, Ordering}; + +use tokio::sync::{mpsc, oneshot}; +use tracing::trace; + +use crate::EngineId; +use crate::client::stream::EngineCoreStreamOutput; +use crate::error::{Error, Result}; +use crate::protocol::stats::SchedulerStats; +use crate::protocol::{EngineCoreOutput, UtilityOutput}; +use crate::transport::ConnectedEngine; + +pub type OutputSender = mpsc::UnboundedSender>; +pub type OutputReceiver = mpsc::UnboundedReceiver>; +pub type UtilitySender = oneshot::Sender>; +pub type UtilityReceiver = oneshot::Receiver>; + +#[derive(Debug)] +struct TrackedRequest { + sender: OutputSender, + engine_id: EngineId, +} + +/// The latest real scheduler-side load snapshot observed from one engine. +/// +/// These counters come from `scheduler_stats` on the normal engine output path +/// and are the preferred routing signal once available. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct EngineLoadSnapshot { + /// Requests still counted on the scheduler's waiting side. + waiting: usize, + /// Requests currently counted on the scheduler's running side. + running: usize, +} + +#[derive(Debug, Default)] +struct EngineRoutingState { + /// Requests admitted by this frontend that have not finished yet. + /// + /// This is used both as the bootstrap fallback before real scheduler stats + /// exist and as a lower bound afterwards so asynchronous scheduler + /// snapshots cannot erase frontend admission history. + inflight: usize, + /// The latest real scheduler snapshot received from this engine, if any. + last_scheduler_stats: Option, +} + +impl EngineRoutingState { + /// Compute the routing score used to pick the least-loaded engine. + /// + /// Scheduler stats can raise the load estimate above the frontend-local + /// view, but they should not lower it below requests this frontend has + /// already admitted. Waiting requests still get the same extra penalty + /// as the original `waiting * 4 + running` score. + fn routing_score(&self) -> usize { + const WAITING_WEIGHT: usize = 4; + + let Some(stats) = self.last_scheduler_stats else { + return self.inflight; + }; + + let scheduler_total = stats.running + stats.waiting; + self.inflight.max(scheduler_total) + stats.waiting * (WAITING_WEIGHT - 1) + } + + /// Replace the local routing view with a fresh real scheduler snapshot. + fn apply_scheduler_counts(&mut self, next: EngineLoadSnapshot) { + self.last_scheduler_stats = Some(next); + } +} + +/// Internal registry for tracking active requests and their output stream +/// senders. +/// +/// This is used to route incoming outputs to the correct request stream, and to +/// ensure proper cleanup of senders when requests finish or the client shuts +/// down. +#[derive(Debug)] +pub struct RequestRegistry { + closed: bool, + requests: BTreeMap, + routing_per_engine: BTreeMap, +} + +impl RequestRegistry { + pub fn new(engines: &[ConnectedEngine]) -> Self { + Self { + closed: false, + requests: BTreeMap::default(), + routing_per_engine: engines + .iter() + .map(|engine| (engine.engine_id.clone(), EngineRoutingState::default())) + .collect(), + } + } + + /// Register a newly added request. Create the per-request output channel + /// bound to its `request_id` and return the selected engine id. + /// + /// When `data_parallel_rank` is provided, the request is routed directly to + /// the engine at that rank index, bypassing load balancing. Otherwise + /// the engine with the fewest in-flight requests is chosen. + pub fn register( + &mut self, + request_id: String, + data_parallel_rank: Option, + ) -> Result<(EngineId, OutputReceiver)> { + if self.requests.contains_key(&request_id) { + return Err(Error::DuplicateRequestId { request_id }); + } + + let engine_id = self.choose_engine_for_request(data_parallel_rank)?; + let (tx, rx) = mpsc::unbounded_channel(); + self.requests.insert( + request_id, + TrackedRequest { + sender: tx, + engine_id: engine_id.clone(), + }, + ); + + let state = self + .routing_per_engine + .get_mut(&engine_id) + .expect("request registry must track all known engines"); + state.inflight += 1; + + Ok((engine_id, rx)) + } + + fn choose_engine_for_request(&mut self, data_parallel_rank: Option) -> Result { + if let Some(rank) = data_parallel_rank { + // Route to the engine at the specified rank index. + let engine_id = EngineId::from_engine_index(rank); + return self + .routing_per_engine + .contains_key(&engine_id) + .then_some(engine_id) + .ok_or_else(|| Error::InvalidDataParallelRank { + rank, + num_engines: self.routing_per_engine.len() as u32, + }); + } + + Ok(self + .routing_per_engine + .iter() + .min_by_key(|(_, state)| state.routing_score()) + .map(|(engine_id, _)| engine_id.clone()) + .expect("request registry must contain at least one engine")) + } + + /// Filter the given request IDs to the subset that are still tracked as + /// active and can be aborted, grouped by engine. + pub fn abortable_request_ids(&self, request_ids: &[String]) -> BTreeMap> { + let mut by_engine = BTreeMap::new(); + for request_id in request_ids { + let Some(tracked) = self.requests.get(request_id.as_str()) else { + continue; + }; + by_engine + .entry(tracked.engine_id.clone()) + .or_insert_with(Vec::new) + .push(request_id.clone()); + } + by_engine + } + + /// Obtain the stream sender for one output. If it indicates the request is + /// finished, it will be removed from the registry. + pub fn sender_for_output(&mut self, output: &EngineCoreOutput) -> Option { + if output.finished() { + self.remove(output.request_id.as_str()).map(|tracked| tracked.0) + } else { + self.requests + .get(output.request_id.as_str()) + .map(|tracked| tracked.sender.clone()) + } + } + + /// Remove a batch of requests that have finished or aborted, returning + /// their stream senders. + pub fn finish_many<'a>( + &mut self, + request_ids: impl IntoIterator, + ) -> Vec { + request_ids + .into_iter() + .filter_map(|request_id| self.remove(request_id.as_str()).map(|tracked| tracked.0)) + .collect() + } + + /// Apply one scheduler stats update for the given engine to the local + /// routing state. Returns `false` if the engine is unknown to the + /// client. + pub fn apply_scheduler_stats(&mut self, engine_index: u32, stats: &SchedulerStats) -> bool { + self.apply_scheduler_counts( + engine_index, + EngineLoadSnapshot { + waiting: stats.num_waiting_reqs as usize, + running: stats.num_running_reqs as usize, + }, + ) + } + + /// Mark the registry as closed, detach and return all tracked senders. + pub fn close(&mut self) -> Vec { + if self.closed { + return Vec::new(); + } + + self.closed = true; + std::mem::take(&mut self.requests) + .into_values() + .map(|tracked| tracked.sender) + .collect() + } + + /// Remove one request from the local registry. Returns the tracked entry if + /// it exists. + #[must_use] + pub fn remove(&mut self, request_id: &str) -> Option<(OutputSender, EngineId)> { + let tracked = self.requests.remove(request_id)?; + self.routing_per_engine + .get_mut(&tracked.engine_id) + .expect("request registry must track all known engines") + .inflight -= 1; + Some((tracked.sender, tracked.engine_id)) + } + + fn apply_scheduler_counts(&mut self, engine_index: u32, next: EngineLoadSnapshot) -> bool { + let engine_id = EngineId::from_engine_index(engine_index); + let Some(state) = self.routing_per_engine.get_mut(&engine_id) else { + return false; + }; + + let previous = state.last_scheduler_stats; + if previous != Some(next) { + trace!( + ?engine_id, + previous_waiting = previous.map(|stats| stats.waiting), + previous_running = previous.map(|stats| stats.running), + waiting = next.waiting, + running = next.running, + "updated scheduler routing counts", + ); + } + + state.apply_scheduler_counts(next); + true + } + + #[cfg(test)] + pub fn contains(&self, request_id: &str) -> bool { + self.requests.contains_key(request_id) + } + + pub fn is_closed(&self) -> bool { + self.closed + } +} + +/// Internal registry for tracking active utility calls and their waiting +/// receivers. +#[derive(Debug)] +pub struct UtilityRegistry { + closed: bool, + next_call_id: AtomicI64, + utility_calls: BTreeMap, +} + +impl Default for UtilityRegistry { + fn default() -> Self { + Self { + closed: false, + next_call_id: AtomicI64::new(1), + utility_calls: BTreeMap::default(), + } + } +} + +impl UtilityRegistry { + /// Allocate the next utility `call_id` and register a newly added utility + /// call. + pub fn allocate_and_register(&mut self) -> (i64, UtilityReceiver) { + let call_id = self.next_call_id.fetch_add(1, Ordering::Relaxed); + let (tx, rx) = oneshot::channel(); + self.utility_calls.insert(call_id, tx); + (call_id, rx) + } + + /// Resolve a utility output to its waiting receiver. + pub fn resolve(&mut self, call_id: &i64) -> Option { + self.utility_calls.remove(call_id) + } + + /// Drop a batch of registered utility calls without delivering a result. + /// Used to roll back allocations when the dispatch fan-out fails before + /// every engine could accept the request. + pub fn unregister_many(&mut self, call_ids: impl IntoIterator) { + for call_id in call_ids { + self.utility_calls.remove(&call_id); + } + } + + /// Mark the registry as closed, detach and return all tracked senders. + pub fn close(&mut self) -> Vec { + if self.closed { + return Vec::new(); + } + + self.closed = true; + std::mem::take(&mut self.utility_calls).into_values().collect() + } + + #[cfg(test)] + pub fn contains(&self, call_id: i64) -> bool { + self.utility_calls.contains_key(&call_id) + } + + pub fn is_closed(&self) -> bool { + self.closed + } +} + +#[cfg(test)] +mod tests { + use super::{EngineRoutingState, RequestRegistry, UtilityRegistry}; + use crate::EngineId; + use crate::client::state::EngineLoadSnapshot; + use crate::protocol::{EngineCoreFinishReason, EngineCoreOutput}; + use crate::transport::ConnectedEngine; + + #[test] + fn registry_rejects_duplicate_request_ids() { + let mut registry = RequestRegistry::new(&[ConnectedEngine { + engine_id: EngineId::from(b"engine-0"), + ready_response: None, + }]); + registry.register("req-1".to_string(), None).unwrap(); + let error = registry.register("req-1".to_string(), None).unwrap_err(); + assert!(matches!( + error, + crate::error::Error::DuplicateRequestId { request_id } if request_id == "req-1" + )); + } + + #[test] + fn registry_removes_finished_request_on_output() { + let mut registry = RequestRegistry::new(&[ConnectedEngine { + engine_id: EngineId::from(b"engine-0"), + ready_response: None, + }]); + registry.register("req-1".to_string(), None).unwrap(); + + let sender = registry.sender_for_output(&EngineCoreOutput { + request_id: "req-1".to_string(), + finish_reason: Some(EngineCoreFinishReason::Length), + ..Default::default() + }); + + assert!(sender.is_some()); + assert!(!registry.contains("req-1")); + } + + #[test] + fn registry_closes_all_requests_on_failure() { + let mut registry = RequestRegistry::new(&[ConnectedEngine { + engine_id: EngineId::from(b"engine-0"), + ready_response: None, + }]); + registry.register("req-1".to_string(), None).unwrap(); + registry.register("req-2".to_string(), None).unwrap(); + + let senders = registry.close(); + + assert_eq!(senders.len(), 2); + assert!(registry.is_closed()); + } + + #[test] + fn registry_tracks_engine_id_per_request() { + let engine_0 = EngineId::from_engine_index(0); + let engine_1 = EngineId::from_engine_index(1); + let mut registry = RequestRegistry::new(&[ + ConnectedEngine { + engine_id: engine_0.clone(), + ready_response: None, + }, + ConnectedEngine { + engine_id: engine_1.clone(), + ready_response: None, + }, + ]); + let (chosen_0, _) = registry.register("req-1".to_string(), None).unwrap(); + let (chosen_1, _) = registry.register("req-2".to_string(), None).unwrap(); + let (chosen_0_again, _) = registry.register("req-3".to_string(), None).unwrap(); + + assert_eq!(chosen_0, engine_0); + assert_eq!(chosen_1, engine_1); + assert_eq!(chosen_0_again, engine_0); + + let grouped = registry.abortable_request_ids(&[ + "req-1".to_string(), + "req-2".to_string(), + "req-3".to_string(), + ]); + assert_eq!( + grouped.get(&engine_0).unwrap(), + &vec!["req-1".to_string(), "req-3".to_string()] + ); + assert_eq!(grouped.get(&engine_1).unwrap(), &vec!["req-2".to_string()]); + } + + #[test] + fn registry_uses_inflight_as_waiting_fallback_before_stats_arrive() { + let engine_0 = EngineId::from_engine_index(0); + let engine_1 = EngineId::from_engine_index(1); + let mut registry = RequestRegistry::new(&[ + ConnectedEngine { + engine_id: engine_0.clone(), + ready_response: None, + }, + ConnectedEngine { + engine_id: engine_1.clone(), + ready_response: None, + }, + ]); + + let (chosen_0, _) = registry.register("req-1".to_string(), None).unwrap(); + let (chosen_1, _) = registry.register("req-2".to_string(), None).unwrap(); + let (chosen_0_again, _) = registry.register("req-3".to_string(), None).unwrap(); + + assert_eq!(chosen_0, engine_0); + assert_eq!(chosen_1, engine_1); + assert_eq!(chosen_0_again, engine_0); + } + + #[test] + fn routing_score_uses_inflight_before_stats_arrive() { + let state = EngineRoutingState { + inflight: 3, + last_scheduler_stats: None, + }; + + assert_eq!(state.routing_score(), 3); + } + + #[test] + fn routing_score_uses_inflight_as_scheduler_stats_lower_bound() { + let state = EngineRoutingState { + inflight: 7, + last_scheduler_stats: Some(EngineLoadSnapshot { + waiting: 0, + running: 2, + }), + }; + + assert_eq!(state.routing_score(), 7); + } + + #[test] + fn routing_score_keeps_extra_waiting_penalty() { + let state = EngineRoutingState { + inflight: 1, + last_scheduler_stats: Some(EngineLoadSnapshot { + waiting: 3, + running: 2, + }), + }; + + assert_eq!(state.routing_score(), 14); + } + + #[test] + fn registry_prefers_real_scheduler_stats_over_inflight() { + let engine_0 = EngineId::from_engine_index(0); + let engine_1 = EngineId::from_engine_index(1); + let mut registry = RequestRegistry::new(&[ + ConnectedEngine { + engine_id: engine_0.clone(), + ready_response: None, + }, + ConnectedEngine { + engine_id: engine_1.clone(), + ready_response: None, + }, + ]); + + assert!(registry.apply_scheduler_counts( + 0, + EngineLoadSnapshot { + waiting: 3, + running: 2 + } + )); + assert!(registry.apply_scheduler_counts( + 1, + EngineLoadSnapshot { + waiting: 0, + running: 1 + } + )); + + let (chosen, _) = registry.register("req-stats".to_string(), None).unwrap(); + assert_eq!(chosen, engine_1); + } + + #[test] + fn register_with_data_parallel_rank_routes_to_specified_engine() { + let engine_0 = EngineId::from_engine_index(0); + let engine_1 = EngineId::from_engine_index(1); + let engine_2 = EngineId::from_engine_index(2); + let mut registry = RequestRegistry::new(&[ + ConnectedEngine { + engine_id: engine_0.clone(), + ready_response: None, + }, + ConnectedEngine { + engine_id: engine_1.clone(), + ready_response: None, + }, + ConnectedEngine { + engine_id: engine_2.clone(), + ready_response: None, + }, + ]); + + // Explicitly target rank 2 (third engine). + let (chosen, _) = registry.register("req-1".to_string(), Some(2)).unwrap(); + assert_eq!(chosen, engine_2); + + // Explicitly target rank 0 (first engine). + let (chosen, _) = registry.register("req-2".to_string(), Some(0)).unwrap(); + assert_eq!(chosen, engine_0); + + // Explicitly target rank 1. + let (chosen, _) = registry.register("req-3".to_string(), Some(1)).unwrap(); + assert_eq!(chosen, engine_1); + } + + #[test] + fn register_with_data_parallel_rank_bypasses_load_balancing() { + let engine_0 = EngineId::from_engine_index(0); + let engine_1 = EngineId::from_engine_index(1); + let mut registry = RequestRegistry::new(&[ + ConnectedEngine { + engine_id: engine_0.clone(), + ready_response: None, + }, + ConnectedEngine { + engine_id: engine_1.clone(), + ready_response: None, + }, + ]); + + // Load-balance: first two go to engine_0 and engine_1. + registry.register("req-lb-0".to_string(), None).unwrap(); + + // Now engine_0 has 1 in-flight. Without dp_rank, next would go to engine_1. + // But with dp_rank=0, it should still go to engine_0. + let (chosen, _) = registry.register("req-dp".to_string(), Some(0)).unwrap(); + assert_eq!(chosen, engine_0); + } + + #[test] + fn register_with_out_of_range_rank_returns_error() { + let mut registry = RequestRegistry::new(&[ + ConnectedEngine { + engine_id: EngineId::from_engine_index(0), + ready_response: None, + }, + ConnectedEngine { + engine_id: EngineId::from_engine_index(1), + ready_response: None, + }, + ]); + + let error = registry.register("req-1".to_string(), Some(2)).unwrap_err(); + assert!(matches!( + error, + crate::error::Error::InvalidDataParallelRank { + rank: 2, + num_engines: 2, + } + )); + } + + #[test] + fn register_with_rank_on_single_engine_only_accepts_zero() { + let engine_0 = EngineId::from_engine_index(0); + let mut registry = RequestRegistry::new(&[ConnectedEngine { + engine_id: engine_0.clone(), + ready_response: None, + }]); + + let (chosen, _) = registry.register("req-ok".to_string(), Some(0)).unwrap(); + assert_eq!(chosen, engine_0); + + let error = registry.register("req-bad".to_string(), Some(1)).unwrap_err(); + assert!(matches!( + error, + crate::error::Error::InvalidDataParallelRank { + rank: 1, + num_engines: 1, + } + )); + } + + #[test] + fn utility_registry_tracks_and_removes_call_ids() { + let mut registry = UtilityRegistry::default(); + let (call_id_1, _) = registry.allocate_and_register(); + let (call_id_2, _) = registry.allocate_and_register(); + + assert_eq!(call_id_1, 1); + assert_eq!(call_id_2, 2); + assert!(registry.contains(1)); + assert!(registry.contains(2)); + assert!(registry.resolve(&1).is_some()); + assert!(!registry.contains(1)); + assert!(registry.contains(2)); + } + + #[test] + fn utility_registry_closes_all_waiters_on_failure() { + let mut registry = UtilityRegistry::default(); + registry.allocate_and_register(); + registry.allocate_and_register(); + + let senders = registry.close(); + + assert_eq!(senders.len(), 2); + assert!(!registry.contains(1)); + assert!(!registry.contains(2)); + assert!(registry.is_closed()); + } + + #[test] + fn utility_registry_unregister_many_drops_pending_calls() { + use tokio::sync::oneshot::error::TryRecvError; + + let mut registry = UtilityRegistry::default(); + let (call_id_1, mut rx_1) = registry.allocate_and_register(); + let (call_id_2, mut rx_2) = registry.allocate_and_register(); + let (call_id_3, _rx_3) = registry.allocate_and_register(); + + // Drop two of the three allocated calls; the third stays pending. + registry.unregister_many([call_id_1, call_id_2]); + + assert!(!registry.contains(call_id_1)); + assert!(!registry.contains(call_id_2)); + assert!(registry.contains(call_id_3)); + // The receivers must observe the sender being dropped (channel closed). + assert!(matches!(rx_1.try_recv(), Err(TryRecvError::Closed))); + assert!(matches!(rx_2.try_recv(), Err(TryRecvError::Closed))); + } + + #[test] + fn utility_registry_unregister_many_ignores_unknown_call_ids() { + let mut registry = UtilityRegistry::default(); + let (call_id, _rx) = registry.allocate_and_register(); + + // Unknown call ids are silently ignored — caller doesn't care which were live. + registry.unregister_many([call_id, 42, 9999]); + + assert!(!registry.contains(call_id)); + } +} diff --git a/rust/src/engine-core-client/src/client/stream.rs b/rust/src/engine-core-client/src/client/stream.rs new file mode 100644 index 00000000000..3cbb215b0ef --- /dev/null +++ b/rust/src/engine-core-client/src/client/stream.rs @@ -0,0 +1,153 @@ +use std::ops::Deref; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::Stream; +use futures::stream::FusedStream; +use thiserror_ext::AsReport as _; +use tokio::sync::mpsc; +use tracing::{debug, error, warn}; + +use crate::client::AbortRequest; +use crate::client::state::OutputReceiver; +use crate::protocol::{EngineCoreFinishReason, EngineCoreOutput}; +use crate::{AbortCause, Error, Result}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum State { + Running, + Finished, + ClosedWithError, + UnexpectedClose, +} + +/// One request-scoped engine-core output plus the enclosing batch metadata. +#[derive(Debug, Clone, PartialEq)] +pub struct EngineCoreStreamOutput { + pub engine_index: u32, + pub timestamp: f64, + pub output: EngineCoreOutput, +} + +impl Deref for EngineCoreStreamOutput { + type Target = EngineCoreOutput; + + fn deref(&self) -> &Self::Target { + &self.output + } +} + +/// Stream of raw engine-core outputs for one request. +/// +/// The stream yields only [`EngineCoreStreamOutput`] values whose embedded +/// output `request_id` matches the originating `add_request()` call. Normal +/// request completion is expected to include a final output object whose +/// `finish_reason` is non-`None`. +pub struct EngineCoreOutputStream { + request_id: String, + abort_tx: mpsc::UnboundedSender, + state: State, + rx: OutputReceiver, +} + +impl EngineCoreOutputStream { + pub(crate) fn new( + request_id: String, + abort_tx: mpsc::UnboundedSender, + rx: OutputReceiver, + ) -> Self { + Self { + request_id, + abort_tx, + state: State::Running, + rx, + } + } + + /// Return the engine-core `request_id` bound to this stream. + pub fn request_id(&self) -> &str { + &self.request_id + } +} + +impl Stream for EngineCoreOutputStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_terminated() { + return Poll::Ready(None); + } + + match Pin::new(&mut self.rx).poll_recv(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(item)) => { + match &item { + Ok(output) => { + // If the output indicates the request is finished, mark the stream as + // terminated with cleanly-finished state and expect no more outputs to + // come. + if output.finished() { + if output.finish_reason == Some(EngineCoreFinishReason::Error) { + error!( + self.request_id, + "request failed with an internal error during generation" + ); + } + debug!(self.request_id, "request completed via final output"); + self.state = State::Finished; + } + } + Err(error) => { + // If we get an error from the output stream, mark the stream as terminated + // with an error. + warn!(self.request_id, error = %error.as_report(), "request encountered an error"); + self.state = State::ClosedWithError; + } + } + Poll::Ready(Some(item)) + } + Poll::Ready(None) => { + // If we get a `None` without seeing a finished output, this is an unexpected + // close from the engine side. Mark the stream as terminated + // with an unexpected close state and send an error down the + // stream to notify the caller. + warn!(self.request_id, "request stream closed unexpectedly"); + self.state = State::UnexpectedClose; + + Poll::Ready(Some(Err(Error::RequestStreamClosed { + request_id: self.request_id.clone(), + }))) + } + } + } +} + +impl FusedStream for EngineCoreOutputStream { + fn is_terminated(&self) -> bool { + !matches!(self.state, State::Running) + } +} + +impl Drop for EngineCoreOutputStream { + fn drop(&mut self) { + if self.is_terminated() { + // If it's terminated, it means that the request either finished cleanly, or + // encountered an error or unexpected close from the engine. In any + // case, the request stream is already considered inactive and + // there's no need to abort it on the engine side. + return; + } + + let abort_req = AbortRequest { + request_id: self.request_id.clone(), + cause: AbortCause::current(), + }; + + if self.abort_tx.send(abort_req).is_err() { + warn!( + request_id = self.request_id, + "auto-abort worker already shut down; skip auto-abort" + ); + } + } +} diff --git a/rust/src/engine-core-client/src/coordinator/bootstrap.rs b/rust/src/engine-core-client/src/coordinator/bootstrap.rs new file mode 100644 index 00000000000..8c6855bfe77 --- /dev/null +++ b/rust/src/engine-core-client/src/coordinator/bootstrap.rs @@ -0,0 +1,91 @@ +use std::time::Duration; + +use bytes::Bytes; +use zeromq::prelude::{Socket, SocketRecv, SocketSend}; +use zeromq::{PullSocket, XPubSocket, ZmqMessage}; + +use crate::error::{Error, Result, bail_unexpected_handshake_message}; + +/// Engine-facing sockets owned by the in-process coordinator. +pub(crate) struct CoordinatorBootstrap { + pub input_address: String, + pub output_address: String, + pub input_socket: XPubSocket, + pub output_socket: PullSocket, +} + +impl CoordinatorBootstrap { + /// Bind the engine-facing coordinator sockets on the given host. + pub(crate) async fn bind(local_host: &str) -> Result { + let mut input_socket = XPubSocket::new(); + let input_address = input_socket.bind(&format!("tcp://{local_host}:0")).await?.to_string(); + + let mut output_socket = PullSocket::new(); + let output_address = + output_socket.bind(&format!("tcp://{local_host}:0")).await?.to_string(); + + Ok(Self { + input_address, + output_address, + input_socket, + output_socket, + }) + } + + /// Complete the engine-facing startup gate before engines are allowed to + /// send handshake READY. + pub(crate) async fn wait_for_startup_gate( + &mut self, + engine_count: usize, + ready_timeout: Duration, + ) -> Result<()> { + wait_for_engine_subscriptions(&mut self.input_socket, engine_count, ready_timeout).await?; + send_ready_to_engines(&mut self.input_socket).await?; + Ok(()) + } +} + +/// Wait until all engines subscribe to the coordinator broadcast socket. +async fn wait_for_engine_subscriptions( + input_socket: &mut XPubSocket, + engine_count: usize, + ready_timeout: Duration, +) -> Result<()> { + let mut received = 0; + while received < engine_count { + let message = + tokio::time::timeout(ready_timeout, input_socket.recv()).await.map_err(|_| { + Error::HandshakeTimeout { + stage: "coordinator engine subscriptions", + timeout: ready_timeout, + } + })??; + if message.len() != 1 { + bail_unexpected_handshake_message!( + "expected 1 frame for coordinator subscription, got {}", + message.len() + ); + } + + let frame = message + .into_vec() + .into_iter() + .next() + .expect("single-frame coordinator subscription message"); + if frame.as_ref() != [0x01] { + bail_unexpected_handshake_message!( + "expected coordinator subscription frame [0x01], got {:?}", + frame.as_ref() + ); + } + received += 1; + } + + Ok(()) +} + +/// Send the coordinator READY marker to all subscribed engines. +async fn send_ready_to_engines(input_socket: &mut XPubSocket) -> Result<()> { + input_socket.send(ZmqMessage::from(Bytes::from_static(b"READY"))).await?; + Ok(()) +} diff --git a/rust/src/engine-core-client/src/coordinator/external.rs b/rust/src/engine-core-client/src/coordinator/external.rs new file mode 100644 index 00000000000..eb447c4809e --- /dev/null +++ b/rust/src/engine-core-client/src/coordinator/external.rs @@ -0,0 +1,167 @@ +use std::sync::Arc; + +use serde_tuple::{Deserialize_tuple, Serialize_tuple}; +use thiserror_ext::AsReport; +use tokio::sync::mpsc; +use tracing::{debug, warn}; +use zeromq::prelude::{SocketRecv, SocketSend}; +use zeromq::{XSubSocket, ZmqMessage}; + +use crate::client::imp::ClientInner; +use crate::coordinator::handle::{CoordinatorCommand, CoordinatorState}; +use crate::error::{Error, Result, bail_unexpected_coordinator_output}; +use crate::protocol::{OpaqueValue, decode_msgpack, encode_msgpack}; + +/// Frontend-to-coordinator wakeup message sent when the first request arrives +/// while all engines are paused. +/// +/// This matches the frontend-side msgpack tuple sent by Python +/// `DPAsyncMPClient._ensure_stats_update_task` to the coordinator front socket. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize_tuple, Deserialize_tuple)] +struct CoordinatorWakeupMessage { + /// Engine index that already has the triggering request and should be + /// excluded from the coordinator's `START_DP_WAVE` rebroadcast. + exclude_engine_index: u32, + /// DP wave number observed by the frontend when the request was admitted. + wave: u32, +} + +/// Coordinator-to-frontend state publish received on the front-side coordinator +/// socket. +/// +/// This matches the msgpack tuple periodically published by Python +/// `DPCoordinatorProc.run_coordinator` to all connected frontends. +/// +/// Original Python definitions: +/// +/// +#[derive(Debug, Clone, PartialEq, Deserialize_tuple)] +struct CoordinatorStateUpdate { + /// Global per-engine request counts published by the coordinator. + /// + /// The Rust bootstrapped external-coordinator path preserves this field for + /// wire compatibility but intentionally ignores it for routing decisions. + counts: OpaqueValue, + /// Current global DP wave number stamped onto newly admitted requests. + wave: u32, + /// Whether engines are currently running (`true`) or paused (`false`). + engines_running: bool, +} + +/// Background half of an external Python-owned coordinator connection. +/// +/// This owns the command receiver and one frontend-facing XSUB socket. It +/// mirrors the subset of Python's coordinator protocol needed by the Rust +/// bootstrapped frontend: receive `(counts, wave, running)` publishes, ignore +/// `counts`, and send `(exclude_engine_index, wave)` wakeup messages when the +/// first request arrives while engines are paused. +pub(crate) struct ExternalCoordinatorService { + state: Arc, + command_rx: mpsc::UnboundedReceiver, + socket: XSubSocket, +} + +impl ExternalCoordinatorService { + pub(super) fn new( + state: Arc, + command_rx: mpsc::UnboundedReceiver, + socket: XSubSocket, + ) -> Self { + Self { + state, + command_rx, + socket, + } + } + + /// Apply one frontend-originated command to the external coordinator state + /// machine. + async fn handle_command(&mut self, command: CoordinatorCommand) -> Result<()> { + match command { + CoordinatorCommand::FirstRequest { + target_engine_id, + wave, + } => { + let target_engine_index = target_engine_id.engine_index().ok_or_else(|| { + Error::UnsupportedCoordinatorEngineId { + engine_id: target_engine_id.to_vec(), + } + })?; + debug!( + wave, + exclude_engine_index = target_engine_index, + "notifying external coordinator about first request while engines were paused" + ); + let payload = encode_msgpack(&CoordinatorWakeupMessage { + exclude_engine_index: target_engine_index, + wave, + })?; + self.socket.send(ZmqMessage::from(payload)).await?; + } + } + Ok(()) + } + + /// Apply one publish received from the xsub socket containing a coordinator + /// state update. + async fn handle_publish(&mut self, message: ZmqMessage) -> Result<()> { + let frames = message.into_vec(); + if frames.len() != 1 { + bail_unexpected_coordinator_output!( + "received malformed external coordinator publish with {} frame(s)", + frames.len() + ); + } + + let update: CoordinatorStateUpdate = decode_msgpack(&frames[0])?; + + let mut state = self.state.lock(); + let previous_wave = state.current_wave; + let previous_engines_running = state.engines_running; + state.current_wave = update.wave; + state.engines_running = update.engines_running; + debug!( + previous_wave, + wave = update.wave, + previous_engines_running, + engines_running = update.engines_running, + "applied external coordinator state update" + ); + Ok(()) + } + + /// Drive the coordinator event loop until either side of the control plane + /// is closed or a fatal error is observed. + pub(crate) async fn run(mut self, inner: Arc) { + let result: Result<()> = async { + loop { + tokio::select! { + // Received frontend-originated command from the handle. + command = self.command_rx.recv() => { + let Some(command) = command else { + warn!("external coordinator command channel closed, shutting down service"); + return Ok(()); + }; + self.handle_command(command).await?; + } + // Received publish from the external coordinator socket. + publish = self.socket.recv() => { + let publish = publish.map_err(Error::from)?; + self.handle_publish(publish).await?; + } + } + } + } + .await; + let Err(error) = result else { return }; + + warn!( + error = %error.as_report(), + "external coordinator service exiting with error" + ); + inner.close_registries(Arc::new(error)); + } +} diff --git a/rust/src/engine-core-client/src/coordinator/handle.rs b/rust/src/engine-core-client/src/coordinator/handle.rs new file mode 100644 index 00000000000..dca6f70de1f --- /dev/null +++ b/rust/src/engine-core-client/src/coordinator/handle.rs @@ -0,0 +1,123 @@ +use std::sync::Arc; + +use parking_lot::Mutex; +use tokio::sync::mpsc; +use zeromq::prelude::Socket; +use zeromq::{XPubSocket, XSubSocket}; + +use crate::coordinator::external::ExternalCoordinatorService; +use crate::coordinator::inproc::InProcCoordinatorRunner; +use crate::error::{Error, Result, bail_control_closed}; +use crate::transport::EngineId; + +/// Snapshot to the coordinator state for request routing and stamping. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct CoordinatorStateSnapshot { + /// The current DP wave, which will be stamped on outgoing requests. + pub current_wave: u32, + /// Whether the engines are currently running or paused, which determines if + /// the frontend must trigger a new wave on the next request. + pub engines_running: bool, +} + +/// Shared in-process coordinator state. +pub(crate) type CoordinatorState = Mutex; + +/// Commands sent from the frontend request path into the background runner. +#[derive(Debug)] +pub(crate) enum CoordinatorCommand { + /// The first request arrived while all engines were paused. + /// + /// The coordinator should broadcast `START_DP_WAVE` with the current wave + /// and the target engine index as the excluded engine. + FirstRequest { + target_engine_id: EngineId, + wave: u32, + }, +} + +/// Frontend-facing coordinator handle used by `EngineCoreClient::call()`. +/// +/// This side stays intentionally small: it can read the latest wave snapshot +/// and enqueue a `FirstRequest` transition when the request path observes the +/// system in the paused state. +#[derive(Clone)] +pub(crate) struct CoordinatorHandle { + state: Arc, + command_tx: mpsc::UnboundedSender, +} + +impl CoordinatorHandle { + fn new_parts() -> ( + Self, + Arc, + mpsc::UnboundedReceiver, + ) { + let state = Arc::new(Mutex::new(CoordinatorStateSnapshot { + current_wave: 0, + engines_running: false, + })); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + ( + Self { + state: state.clone(), + command_tx, + }, + state, + command_rx, + ) + } + + /// Build the paired frontend handle and background runner around one + /// engine-facing coordinator broadcast socket. + pub(crate) fn new_inproc(coordinator_input: XPubSocket) -> (Self, InProcCoordinatorRunner) { + let (handle, state, command_rx) = Self::new_parts(); + ( + handle, + InProcCoordinatorRunner::new(state, command_rx, coordinator_input), + ) + } + + /// Build the paired frontend handle and background service around an + /// external Python-owned frontend-side coordinator socket. + pub(crate) async fn connect_external( + coordinator_address: &str, + ) -> Result<(Self, ExternalCoordinatorService)> { + let (handle, state, command_rx) = Self::new_parts(); + let mut socket = XSubSocket::new(); + socket.connect(coordinator_address).await?; + socket.subscribe("").await?; + Ok(( + handle, + ExternalCoordinatorService::new(state, command_rx, socket), + )) + } + + /// Snapshot the coordinator state for request routing and stamping. + pub(crate) fn snapshot(&self) -> CoordinatorStateSnapshot { + *self.state.lock() + } + + /// Notify the runner that a new request arrived while engines were paused. + /// + /// The handle flips `engines_running` optimistically so concurrent request + /// submissions coalesce behind one `START_DP_WAVE` broadcast instead of all + /// trying to trigger the wave independently. + pub(crate) fn notify_first_request(&self, target_engine_id: EngineId) -> Result<()> { + let mut state = self.state.lock(); + if state.engines_running { + return Ok(()); + } + + let command = CoordinatorCommand::FirstRequest { + target_engine_id, + wave: state.current_wave, + }; + if self.command_tx.send(command).is_err() { + bail_control_closed!("in-process coordinator command channel already shut down"); + } + + state.engines_running = true; + Ok(()) + } +} diff --git a/rust/src/engine-core-client/src/coordinator/inproc.rs b/rust/src/engine-core-client/src/coordinator/inproc.rs new file mode 100644 index 00000000000..26ab0c5a0e7 --- /dev/null +++ b/rust/src/engine-core-client/src/coordinator/inproc.rs @@ -0,0 +1,204 @@ +use std::sync::Arc; + +use serde_tuple::{Deserialize_tuple, Serialize_tuple}; +use thiserror_ext::AsReport; +use tokio::sync::mpsc; +use tracing::{debug, warn}; +use zeromq::prelude::SocketSend; +use zeromq::{XPubSocket, ZmqMessage}; + +use crate::client::imp::ClientInner; +use crate::coordinator::handle::{CoordinatorCommand, CoordinatorState}; +use crate::error::{Error, Result, bail_unexpected_coordinator_output}; +use crate::protocol::{ + ClassifiedEngineCoreOutputs, DpControlMessage, EngineCoreOutputs, EngineCoreRequestType, + encode_msgpack, +}; + +/// Coordinator-to-engine `START_DP_WAVE` control payload encoded on the +/// engine-facing coordinator socket. +/// +/// This matches the msgpack tuple broadcast by Python +/// `DPCoordinatorProc._send_start_wave`. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize_tuple, Deserialize_tuple)] +struct StartDpWaveMessage { + /// DP wave number that all engines should start processing. + wave: u32, + /// Engine index that already received the triggering request and should not + /// receive an extra wakeup notification. + exclude_engine_index: u32, +} + +/// Background half of the in-process coordinator. +/// +/// This owns the command receiver and the engine-facing coordinator input +/// socket. It is the single place where wave transitions are serialized and +/// where `START_DP_WAVE` broadcasts are emitted. +pub(crate) struct InProcCoordinatorRunner { + state: Arc, + command_rx: mpsc::UnboundedReceiver, + coordinator_input: XPubSocket, +} + +impl InProcCoordinatorRunner { + pub(super) fn new( + state: Arc, + command_rx: mpsc::UnboundedReceiver, + coordinator_input: XPubSocket, + ) -> Self { + Self { + state, + command_rx, + coordinator_input, + } + } + + /// Broadcast Python-compatible `START_DP_WAVE` to all connected engines. + async fn broadcast_start_wave(&mut self, wave: u32, exclude_engine_index: u32) -> Result<()> { + let payload = encode_msgpack(&StartDpWaveMessage { + wave, + exclude_engine_index, + })?; + self.coordinator_input + .send( + ZmqMessage::try_from(vec![ + EngineCoreRequestType::StartDpWave.to_frame(), + payload.into(), + ]) + .expect("coordinator START_DP_WAVE message must contain two frames"), + ) + .await?; + Ok(()) + } + + /// Apply one frontend-originated command to the coordinator state machine. + async fn handle_command(&mut self, command: CoordinatorCommand) -> Result<()> { + match command { + CoordinatorCommand::FirstRequest { + target_engine_id, + wave, + } => { + let target_engine_index = target_engine_id.engine_index().ok_or_else(|| { + Error::UnsupportedCoordinatorEngineId { + engine_id: target_engine_id.to_vec(), + } + })?; + self.state.lock().current_wave = wave; + debug!( + wave, + exclude_engine_index = target_engine_index, + "starting DP wave after first request while engines were paused" + ); + self.broadcast_start_wave(wave, target_engine_index).await?; + } + } + Ok(()) + } + + /// Apply one engine-originated control output to the coordinator state + /// machine. + async fn handle_outputs(&mut self, outputs: EngineCoreOutputs) -> Result<()> { + match outputs.classify() { + ClassifiedEngineCoreOutputs::RequestBatch(batch) + if batch.outputs.is_empty() && batch.finished_requests.is_none() => + { + // Stats-only output for coordinator. + // Ignore since the Rust coordinator doesn't track stats for + // routing decisions. + } + ClassifiedEngineCoreOutputs::DpControl { + engine_index, + control, + .. + } => match control { + // The engines signals they completed the current wave and are now paused. + // Advance the current wave and mark the state as paused. + DpControlMessage::WaveComplete(wave) => { + let mut state = self.state.lock(); + if wave >= state.current_wave { + let next_wave = wave + 1; + debug!( + wave, + next_wave, + "DP wave finished; pausing engines and advancing coordinator state" + ); + state.current_wave = wave + 1; + state.engines_running = false; + } + } + // An engine requests to start the wave. + // Rebroadcast the wave to all engines except for the originated one. + DpControlMessage::StartWave(wave) => { + let should_broadcast = { + let mut state = self.state.lock(); + if wave > state.current_wave + || (wave == state.current_wave && !state.engines_running) + { + state.current_wave = wave; + state.engines_running = true; + true + } else { + false + } + }; + if should_broadcast { + debug!( + wave, + exclude_engine_index = engine_index, + "starting DP wave after stale-wave notification from engine" + ); + self.broadcast_start_wave(wave, engine_index).await?; + } + } + }, + other => { + bail_unexpected_coordinator_output!( + "received non-control output on coordinator path: {other:?}" + ); + } + } + Ok(()) + } + + /// Drive the coordinator event loop until either side of the control plane + /// is closed or a fatal error is observed. + /// + /// Any fatal error closes the main client registries so request streams and + /// future calls observe a stable shutdown cause. + pub(crate) async fn run( + mut self, + mut output_rx: mpsc::Receiver>, + inner: Arc, + ) { + let result: Result<()> = async { + loop { + tokio::select! { + // Received frontend-originated command from the handle. + command = self.command_rx.recv() => { + let Some(command) = command else { + warn!("coordinator command channel closed, shutting down coordinator runner"); + return Ok(()); + }; + self.handle_command(command).await?; + } + // Received engine-originated control output from the coordinator socket. + outputs = output_rx.recv() => { + let Some(outputs) = outputs else { + warn!("coordinator output channel closed, shutting down coordinator runner"); + return Ok(()); + }; + self.handle_outputs(outputs?).await?; + } + } + } + } + .await; + let Err(error) = result else { return }; + + warn!(error = %error.as_report(), "coordinator runner exiting with error"); + inner.close_registries(Arc::new(error)); + } +} diff --git a/rust/src/engine-core-client/src/coordinator/mod.rs b/rust/src/engine-core-client/src/coordinator/mod.rs new file mode 100644 index 00000000000..e4740a3739f --- /dev/null +++ b/rust/src/engine-core-client/src/coordinator/mod.rs @@ -0,0 +1,7 @@ +mod bootstrap; +mod external; +mod handle; +mod inproc; + +pub(crate) use bootstrap::CoordinatorBootstrap; +pub(crate) use handle::CoordinatorHandle; diff --git a/rust/src/engine-core-client/src/error.rs b/rust/src/engine-core-client/src/error.rs new file mode 100644 index 00000000000..d9022bbb089 --- /dev/null +++ b/rust/src/engine-core-client/src/error.rs @@ -0,0 +1,88 @@ +use std::sync::Arc; +use std::time::Duration; + +use thiserror::Error; +use thiserror_ext::Macro; + +pub type Result = std::result::Result; + +/// Public error type for the Rust engine-core client. +#[derive(Debug, Error, Macro)] +pub enum Error { + #[error("messagepack encode failed for {target_type}: {message}")] + Encode { + target_type: &'static str, + message: String, + }, + #[error("messagepack decode failed for {target_type}: {message}")] + Decode { + target_type: &'static str, + message: String, + }, + #[error("messagepack value decode failed")] + ValueDecode(#[from] rmpv::decode::Error), + #[error("messagepack ext value decode failed: {message}")] + ExtValueDecode { message: String }, + #[error("io error")] + Io(#[from] std::io::Error), + #[error("transport error")] + Transport(#[from] zeromq::ZmqError), + #[error("engine core reported fatal failure")] + EngineCoreDead, + #[error("startup handshake timed out while waiting for {stage} after {timeout:?}")] + HandshakeTimeout { + stage: &'static str, + timeout: Duration, + }, + #[error("engine input registration timed out after {timeout:?}")] + InputRegistrationTimeout { timeout: Duration }, + #[error("unexpected engine id in startup handshake: expected {expected:?}, got {actual:?}")] + UnexpectedHandshakeIdentity { expected: Vec, actual: Vec }, + #[error("unexpected startup handshake message: {message}")] + UnexpectedHandshakeMessage { message: String }, + #[error("unexpected non-control output on coordinator path: {message}")] + UnexpectedCoordinatorOutput { message: String }, + #[error("unexpected output on main dispatcher path: {message}")] + UnexpectedDispatcherOutput { message: String }, + #[error("coordinator requires a Python-compatible two-byte engine id, got {engine_id:?}")] + UnsupportedCoordinatorEngineId { engine_id: Vec }, + #[error("unsupported auxiliary frame(s): expected 1 frame, got {frame_count}")] + UnsupportedAuxFrames { frame_count: usize }, + #[error("external coordinator mode is not implemented yet")] + UnsupportedExternalCoordinator, + #[error("unsupported field `{field}` in {context}")] + UnsupportedField { + context: &'static str, + field: &'static str, + }, + #[error("engine control channel closed unexpectedly: {message}")] + ControlClosed { message: String }, + #[error("request `{request_id}` is already in flight")] + DuplicateRequestId { request_id: String }, + #[error("data parallel rank {rank} is out of range for {num_engines} engine(s)")] + InvalidDataParallelRank { rank: u32, num_engines: u32 }, + #[error("engine-core output dispatcher closed: {message}")] + DispatcherClosed { message: String }, + #[error("engine-core client is closed: {message}")] + ClientClosed { message: String }, + #[error("request output stream for `{request_id}` closed unexpectedly")] + RequestStreamClosed { request_id: String }, + #[error("utility call `{method}` failed (call_id={call_id}): {message}")] + UtilityCallFailed { + method: String, + call_id: i64, + message: String, + }, + #[error("utility call `{method}` returned an invalid result (call_id={call_id}): {message}")] + UtilityResultDecode { + method: String, + call_id: i64, + message: String, + }, + #[error("utility call `{method}` closed unexpectedly (call_id={call_id})")] + UtilityCallClosed { method: String, call_id: i64 }, + + /// A special variant to allow cloning the same error. + #[error(transparent)] + Shared(Arc), +} diff --git a/rust/src/engine-core-client/src/lib.rs b/rust/src/engine-core-client/src/lib.rs new file mode 100644 index 00000000000..914ce874e9e --- /dev/null +++ b/rust/src/engine-core-client/src/lib.rs @@ -0,0 +1,18 @@ +mod client; +mod coordinator; +mod error; +mod metrics; +pub mod protocol; +#[cfg(any(test, feature = "test-util"))] +pub mod test_utils; +mod transport; + +pub use client::{ + AbortCause, CoordinatorMode, EngineCoreClient, EngineCoreClientConfig, EngineCoreOutputStream, + EngineCoreStreamOutput, TransportMode, +}; +pub use error::{Error, Result}; +pub use transport::{ENGINE_CORE_DEAD_SENTINEL, EngineId}; + +#[cfg(test)] +mod tests; diff --git a/rust/src/engine-core-client/src/metrics.rs b/rust/src/engine-core-client/src/metrics.rs new file mode 100644 index 00000000000..8f459396198 --- /dev/null +++ b/rust/src/engine-core-client/src/metrics.rs @@ -0,0 +1,131 @@ +use vllm_metrics::{EngineLabels, EnginePositionLabels, SchedulerMetrics, WaitingReasonLabels}; + +use crate::protocol::stats::SchedulerStats; + +const WAITING_REASON_CAPACITY: &str = "capacity"; +const WAITING_REASON_DEFERRED: &str = "deferred"; + +/// Record the scheduler-stats-backed metrics for one engine at one point in +/// time. +pub(crate) fn record_scheduler_stats( + metrics: &SchedulerMetrics, + model_name: impl Into, + engine: u32, + stats: &SchedulerStats, +) { + let model_name = model_name.into(); + let labels = EngineLabels { + model_name: model_name.clone(), + engine, + }; + + // Scheduler state gauges. + metrics.scheduler_running.get_or_create(&labels).set(stats.num_running_reqs); + metrics + .scheduler_waiting + .get_or_create(&labels) + .set(stats.num_waiting_reqs + stats.num_skipped_waiting_reqs); + metrics + .scheduler_waiting_by_reason + .get_or_create(&WaitingReasonLabels { + model_name: model_name.clone(), + engine, + reason: WAITING_REASON_CAPACITY, + }) + .set(stats.num_waiting_reqs); + metrics + .scheduler_waiting_by_reason + .get_or_create(&WaitingReasonLabels { + model_name: model_name.clone(), + engine, + reason: WAITING_REASON_DEFERRED, + }) + .set(stats.num_skipped_waiting_reqs); + metrics.kv_cache_usage.get_or_create(&labels).set(stats.kv_cache_usage); + + // Prefix-cache counters, including the connector-backed external cache path. + metrics + .prefix_cache_queries + .get_or_create(&labels) + .inc_by(stats.prefix_cache_stats.base.queries); + metrics + .prefix_cache_hits + .get_or_create(&labels) + .inc_by(stats.prefix_cache_stats.base.hits); + + if let Some(connector_prefix_cache_stats) = &stats.connector_prefix_cache_stats { + metrics + .external_prefix_cache_queries + .get_or_create(&labels) + .inc_by(connector_prefix_cache_stats.base.queries); + metrics + .external_prefix_cache_hits + .get_or_create(&labels) + .inc_by(connector_prefix_cache_stats.base.hits); + } + + // Speculative decoding counters. + if let Some(spec_decoding_stats) = &stats.spec_decoding_stats { + metrics + .spec_decode_num_drafts + .get_or_create(&labels) + .inc_by(spec_decoding_stats.num_drafts); + metrics + .spec_decode_num_draft_tokens + .get_or_create(&labels) + .inc_by(spec_decoding_stats.num_draft_tokens); + metrics + .spec_decode_num_accepted_tokens + .get_or_create(&labels) + .inc_by(spec_decoding_stats.num_accepted_tokens); + + for (position, accepted_tokens) in + spec_decoding_stats.num_accepted_tokens_per_pos.iter().copied().enumerate() + { + metrics + .spec_decode_num_accepted_tokens_per_pos + .get_or_create(&EnginePositionLabels { + model_name: model_name.clone(), + engine, + position: position as u32, + }) + .inc_by(accepted_tokens); + } + } + + // Per-engine performance / MFU counters. + if let Some(perf_stats) = &stats.perf_stats + && (perf_stats.num_flops_per_gpu != 0 + || perf_stats.num_read_bytes_per_gpu != 0 + || perf_stats.num_write_bytes_per_gpu != 0) + { + metrics + .estimated_flops_per_gpu + .get_or_create(&labels) + .inc_by(perf_stats.num_flops_per_gpu); + metrics + .estimated_read_bytes_per_gpu + .get_or_create(&labels) + .inc_by(perf_stats.num_read_bytes_per_gpu); + metrics + .estimated_write_bytes_per_gpu + .get_or_create(&labels) + .inc_by(perf_stats.num_write_bytes_per_gpu); + } + + // Sampled KV-cache residency histograms. + if !stats.kv_cache_eviction_events.is_empty() { + let kv_block_lifetime_seconds = metrics.kv_block_lifetime_seconds.get_or_create(&labels); + let kv_block_idle_before_evict_seconds = + metrics.kv_block_idle_before_evict_seconds.get_or_create(&labels); + let kv_block_reuse_gap_seconds = metrics.kv_block_reuse_gap_seconds.get_or_create(&labels); + + for event in &stats.kv_cache_eviction_events { + kv_block_lifetime_seconds.observe(event.lifetime_seconds); + kv_block_idle_before_evict_seconds.observe(event.idle_seconds); + for reuse_gap_seconds in &event.reuse_gaps_seconds { + kv_block_reuse_gap_seconds.observe(*reuse_gap_seconds); + } + } + } +} diff --git a/rust/src/engine-core-client/src/protocol/classified_outputs.rs b/rust/src/engine-core-client/src/protocol/classified_outputs.rs new file mode 100644 index 00000000000..4e7b88d3691 --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/classified_outputs.rs @@ -0,0 +1,251 @@ +use std::collections::BTreeSet; + +use enum_as_inner::EnumAsInner; + +use super::{EngineCoreOutput, EngineCoreOutputs, UtilityOutput}; +use crate::protocol::stats::SchedulerStats; + +/// Data-parallel control notifications multiplexed through `EngineCoreOutputs`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DpControlMessage { + WaveComplete(u32), + StartWave(u32), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct RequestBatchOutputs { + pub engine_index: u32, + pub outputs: Vec, + pub scheduler_stats: Option>, + pub timestamp: f64, + pub finished_requests: Option>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct UtilityCallOutput { + pub engine_index: u32, + pub timestamp: f64, + pub output: UtilityOutput, +} + +/// Semantic classification of a raw `EngineCoreOutputs` message. +/// +/// Python currently uses one product-shaped wire struct for several distinct +/// output families. This enum exposes those families more explicitly without +/// changing the wire format. +#[derive(Debug, Clone, PartialEq, EnumAsInner)] +pub enum ClassifiedEngineCoreOutputs { + RequestBatch(RequestBatchOutputs), + Utility(UtilityCallOutput), + DpControl { + engine_index: u32, + timestamp: f64, + control: DpControlMessage, + }, + /// Fallback for wire-shape combinations that do not map cleanly onto the + /// current semantic families. + Other(EngineCoreOutputs), +} + +impl EngineCoreOutputs { + /// Classify the raw wire message into a more semantic Rust enum. + pub fn classify(self) -> ClassifiedEngineCoreOutputs { + let has_request_payload = !self.outputs.is_empty() + || self.scheduler_stats.is_some() + || self.finished_requests.is_some(); + + match ( + has_request_payload, + &self.utility_output, + &self.wave_complete, + &self.start_wave, + ) { + (true, None, None, None) => { + ClassifiedEngineCoreOutputs::RequestBatch(RequestBatchOutputs { + engine_index: self.engine_index, + outputs: self.outputs, + scheduler_stats: self.scheduler_stats, + timestamp: self.timestamp, + finished_requests: self.finished_requests, + }) + } + (false, Some(_), None, None) => { + ClassifiedEngineCoreOutputs::Utility(UtilityCallOutput { + engine_index: self.engine_index, + timestamp: self.timestamp, + output: self.utility_output.unwrap(), + }) + } + (false, None, Some(_), None) => ClassifiedEngineCoreOutputs::DpControl { + engine_index: self.engine_index, + timestamp: self.timestamp, + control: DpControlMessage::WaveComplete(self.wave_complete.unwrap()), + }, + (false, None, None, Some(_)) => ClassifiedEngineCoreOutputs::DpControl { + engine_index: self.engine_index, + timestamp: self.timestamp, + control: DpControlMessage::StartWave(self.start_wave.unwrap()), + }, + _ => ClassifiedEngineCoreOutputs::Other(self), + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use super::*; + use crate::protocol::EngineCoreOutput; + + #[test] + fn engine_core_outputs_classify_request_batch() { + let outputs = EngineCoreOutputs { + outputs: vec![EngineCoreOutput { + request_id: "req-1".to_string(), + new_token_ids: vec![7], + ..Default::default() + }], + finished_requests: Some(BTreeSet::from(["req-1".to_string()])), + ..Default::default() + }; + + expect_test::expect![[r#" + RequestBatch( + RequestBatchOutputs { + engine_index: 0, + outputs: [ + EngineCoreOutput { + request_id: "req-1", + new_token_ids: [ + 7, + ], + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason: None, + stop_reason: None, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + }, + ], + scheduler_stats: None, + timestamp: 0.0, + finished_requests: Some( + { + "req-1", + }, + ), + }, + ) + "#]] + .assert_debug_eq(&outputs.classify()); + } + + #[test] + fn engine_core_outputs_classify_utility() { + let outputs = EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id: 42, + failure_message: None, + result: None, + }), + ..Default::default() + }; + + expect_test::expect![[r#" + Utility( + UtilityCallOutput { + engine_index: 0, + timestamp: 0.0, + output: UtilityOutput { + call_id: 42, + failure_message: None, + result: None, + }, + }, + ) + "#]] + .assert_debug_eq(&outputs.classify()); + } + + #[test] + fn engine_core_outputs_classify_control() { + let outputs = EngineCoreOutputs { + start_wave: Some(3), + ..Default::default() + }; + + expect_test::expect![[r#" + DpControl { + engine_index: 0, + timestamp: 0.0, + control: StartWave( + 3, + ), + } + "#]] + .assert_debug_eq(&outputs.classify()); + } + + #[test] + fn engine_core_outputs_classify_mixed_shape_as_raw() { + let outputs = EngineCoreOutputs { + outputs: vec![EngineCoreOutput { + request_id: "req-1".to_string(), + new_token_ids: vec![7], + ..Default::default() + }], + utility_output: Some(UtilityOutput { + call_id: 1, + failure_message: None, + result: None, + }), + ..Default::default() + }; + + expect_test::expect![[r#" + Other( + EngineCoreOutputs { + engine_index: 0, + outputs: [ + EngineCoreOutput { + request_id: "req-1", + new_token_ids: [ + 7, + ], + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason: None, + stop_reason: None, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + }, + ], + scheduler_stats: None, + timestamp: 0.0, + utility_output: Some( + UtilityOutput { + call_id: 1, + failure_message: None, + result: None, + }, + ), + finished_requests: None, + wave_complete: None, + start_wave: None, + }, + ) + "#]] + .assert_debug_eq(&outputs.classify()); + } +} diff --git a/rust/src/engine-core-client/src/protocol/dtype.rs b/rust/src/engine-core-client/src/protocol/dtype.rs new file mode 100644 index 00000000000..081d1ce921c --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/dtype.rs @@ -0,0 +1,40 @@ +use serde::{Deserialize, Serialize}; + +/// Effective model dtype reported by the engine after config resolution. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum ModelDtype { + #[serde(rename = "float16")] + Float16, + #[serde(rename = "bfloat16")] + BFloat16, + #[serde(rename = "float32")] + Float32, +} + +impl ModelDtype { + pub fn as_str(self) -> &'static str { + match self { + Self::Float16 => "float16", + Self::BFloat16 => "bfloat16", + Self::Float32 => "float32", + } + } +} + +#[cfg(test)] +mod tests { + use super::ModelDtype; + + #[test] + fn serde_uses_protocol_dtype_strings() { + assert_eq!( + serde_json::to_value(ModelDtype::Float16).unwrap(), + serde_json::json!("float16") + ); + assert_eq!( + serde_json::from_value::(serde_json::json!("bfloat16")).unwrap(), + ModelDtype::BFloat16 + ); + assert_eq!(ModelDtype::Float32.as_str(), "float32"); + } +} diff --git a/rust/src/engine-core-client/src/protocol/handshake.rs b/rust/src/engine-core-client/src/protocol/handshake.rs new file mode 100644 index 00000000000..622a032772e --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/handshake.rs @@ -0,0 +1,90 @@ +use std::collections::BTreeMap; + +use serde::{Deserialize, Serialize}; + +use crate::protocol::{ModelDtype, OpaqueValue}; + +/// Decoded engine startup-handshake payload sent on the handshake socket. +/// +/// Original Python payload construction: +/// +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ReadyMessage { + #[serde(default)] + pub status: Option, + #[serde(default)] + pub local: Option, + #[serde(default)] + pub headless: Option, + #[serde(default)] + pub parallel_config_hash: Option, +} + +/// Post-initialization configuration sent from each engine on the input socket +/// registration message, after the handshake completes. +/// +/// Contains values that may differ from the original config (e.g. +/// `max_model_len` after KV cache auto-fitting, `num_gpu_blocks` after +/// profiling). +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EngineCoreReadyResponse { + /// Engine-reported maximum model context length (auto-fitted after + /// KV cache profiling and may differ from the original config value). + pub max_model_len: u64, + /// Number of GPU blocks available for KV cache on this engine. + pub num_gpu_blocks: u64, + /// DP coordinator stats publish address, if applicable. + pub dp_stats_address: Option, + /// Effective model dtype after Python vLLM resolves `--dtype`. + // TODO: This is currently not wired up on the engine side. After it's added, remove `Option` + // and `serde(default)`. + #[serde(default)] + pub dtype: Option, +} + +/// Frontend-owned ZMQ addresses that are sent to the engine during startup +/// handshake initialization. +/// +/// Original Python definition (`EngineZmqAddresses`): +/// +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HandshakeAddresses { + pub inputs: Vec, + pub outputs: Vec, + pub coordinator_input: Option, + pub coordinator_output: Option, + pub frontend_stats_publish_address: Option, +} + +/// Startup handshake payload sent from the frontend to initialize an engine +/// after receiving `HELLO`. +/// +/// Original Python definition (`EngineHandshakeMetadata`): +/// +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HandshakeInitMessage { + pub addresses: HandshakeAddresses, + pub parallel_config: BTreeMap, +} + +#[cfg(test)] +mod tests { + use super::EngineCoreReadyResponse; + use crate::protocol::ModelDtype; + + #[test] + fn ready_response_accepts_effective_dtype() { + let response: EngineCoreReadyResponse = serde_json::from_value(serde_json::json!({ + "max_model_len": 4096, + "num_gpu_blocks": 2, + "dp_stats_address": null, + "dtype": "bfloat16" + })) + .unwrap(); + + assert_eq!(response.dtype, Some(ModelDtype::BFloat16)); + } +} diff --git a/rust/src/engine-core-client/src/protocol/logprobs.rs b/rust/src/engine-core-client/src/protocol/logprobs.rs new file mode 100644 index 00000000000..00c01df671c --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/logprobs.rs @@ -0,0 +1,330 @@ +mod array; +#[cfg(test)] +mod tests; +mod wire; + +use std::ops::{Deref, DerefMut}; + +use enum_as_inner::EnumAsInner; +use serde::{Deserialize, Deserializer, Serialize}; + +use self::wire::*; +use super::{EngineCoreOutput, EngineCoreOutputs, decode_msgpack}; +use crate::error::{Error, Result, bail_ext_value_decode, ext_value_decode}; +use crate::protocol::tensor::{WireArrayData, WireNdArray}; + +/// One token candidate and its logprob metadata for a single sequence position. +/// +/// The first entry in a [`PositionLogprobs`] is always the sampled/selected +/// token for that position. Any remaining entries follow the engine's returned +/// top-k candidate order. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct TokenLogprob { + pub token_id: u32, + pub logprob: f32, + /// The sampled/selected token uses its actual vocab rank. Remaining entries + /// use 1-based top-k ranks matching the engine's returned candidate + /// order. + pub rank: u32, +} + +/// Logprob payload for one sequence position. +/// +/// This is the semantic Rust representation used by the public client API after +/// the lower-level ndarray/tensor wire payload has been decoded. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PositionLogprobs { + pub entries: Vec, +} + +impl PositionLogprobs { + /// Convert one decoded logprobs row into this per-position form by grouping + /// each token/logprob pair together with the sampled/selected token's + /// actual vocab rank. + fn from_decoded_row(token_ids: &[u32], logprobs: &[f32], sampled_rank: u32) -> Result { + if token_ids.len() != logprobs.len() { + bail_ext_value_decode!( + "logprobs row length mismatch: token_ids={}, logprobs={}", + token_ids.len(), + logprobs.len() + ); + } + if sampled_rank == 0 { + bail_ext_value_decode!("token_ranks must be >= 1 for decoded engine-core logprobs"); + } + + let mut entries = Vec::with_capacity(token_ids.len()); + for (index, (&token_id, &logprob)) in token_ids.iter().zip(logprobs.iter()).enumerate() { + let rank = if index == 0 { + sampled_rank + } else { + index as u32 + }; + entries.push(TokenLogprob { + token_id, + logprob, + rank, + }); + } + Ok(Self { entries }) + } +} + +/// Decoded per-request logprobs payload for one engine-core output. +/// +/// Unlike the Python wire payload, this public Rust type is already fully +/// semantic: one [`PositionLogprobs`] per scored position, each containing the +/// sampled/selected token plus any returned top-k alternatives for that same +/// position. +/// +/// The Python engine still sends logprobs as ndarray/tensor-shaped wire tuples. +/// Rust resolves that lower-level representation during decode and exposes only +/// this per-position form to callers. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Logprobs { + /// One decoded logprobs record per scored position in this engine-core + /// output. + pub positions: Vec, +} + +impl Logprobs { + /// Returns the number of scored positions in this payload. + pub fn len(&self) -> usize { + self.positions.len() + } + + /// Returns whether the payload contains no scored positions. + pub fn is_empty(&self) -> bool { + self.positions.is_empty() + } +} + +/// Output field wrapper that is initially deserialized from the Python wire +/// shape, then resolved into [`Logprobs`] before the decoded message is +/// returned to callers. +#[derive(Clone, PartialEq, Debug, EnumAsInner)] +pub enum MaybeWireLogprobs { + /// The logprobs are still in the wire format and need to be resolved by + /// looking up aux frames and decoding raw views. Should only be used + /// internally during deserialization. + Wire(Box), + /// The actual decoded logprobs value, + Direct(Logprobs), +} + +impl Deref for MaybeWireLogprobs { + type Target = Logprobs; + + fn deref(&self) -> &Self::Target { + match self { + Self::Wire(_) => panic!("Logprobs is still in wire format"), + Self::Direct(value) => value, + } + } +} + +impl DerefMut for MaybeWireLogprobs { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::Wire(_) => panic!("Logprobs is still in wire format"), + Self::Direct(value) => value, + } + } +} + +impl<'de> Deserialize<'de> for MaybeWireLogprobs { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + // When deserializing, it's always in the wire form. + WireLogprobs::deserialize(deserializer).map(|v| Self::Wire(Box::new(v))) + } +} + +impl Serialize for MaybeWireLogprobs { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + // For testing purposes only. We don't actually serialize it into aux frames. + match self { + Self::Wire(value) => value.serialize(serializer), + Self::Direct(value) => WireLogprobs::from_direct(value) + .map_err(serde::ser::Error::custom)? + .serialize(serializer), + } + } +} + +impl MaybeWireLogprobs { + /// Resolve the wire representation into decoded logprobs by looking up aux + /// frames and decoding raw views as needed. + fn resolve(self, frames: &[Frame], field_prefix: &str) -> Result + where + Frame: AsRef<[u8]>, + { + match self { + Self::Direct(value) => Ok(Self::Direct(value)), + Self::Wire(value) => value.resolve(frames, field_prefix).map(Self::Direct), + } + } +} + +impl EngineCoreOutputs { + /// Resolve all wire-format fields in-place by looking up aux frames and + /// decoding raw-view payloads as needed. + fn resolve_in_place(&mut self, frames: &[Frame]) -> Result<()> + where + Frame: AsRef<[u8]>, + { + for output in &mut self.outputs { + output.resolve_in_place(frames)?; + } + Ok(()) + } +} + +impl EngineCoreOutput { + /// Resolve all wire-format fields in-place by looking up aux frames and + /// decoding raw-view payloads as needed. + fn resolve_in_place(&mut self, frames: &[Frame]) -> Result<()> + where + Frame: AsRef<[u8]>, + { + self.new_logprobs = (self.new_logprobs.take()) + .map(|value| value.resolve(frames, "new_logprobs")) + .transpose()?; + self.new_prompt_logprobs_tensors = (self.new_prompt_logprobs_tensors.take()) + .map(|value| value.resolve(frames, "new_prompt_logprobs_tensors")) + .transpose()?; + Ok(()) + } +} + +impl WireLogprobs { + /// Convert semantic per-position logprobs into the Python wire tuple shape. + /// + /// This exists mainly so Rust-side tests can inject semantic logprobs into + /// mocked engine-core outputs without manually building ndarray + /// raw-view tuples. + fn from_direct(value: &Logprobs) -> std::result::Result { + let rows = value.positions.len(); + let cols = value.positions.first().map(|position| position.entries.len()).unwrap_or(0); + + let mut token_ids = Vec::with_capacity(rows.saturating_mul(cols).saturating_mul(8)); + let mut logprobs = Vec::with_capacity(rows.saturating_mul(cols).saturating_mul(4)); + let mut token_ranks = Vec::with_capacity(rows.saturating_mul(8)); + + for (row_index, position) in value.positions.iter().enumerate() { + if position.entries.len() != cols { + return Err(format!( + "logprobs row {row_index} length mismatch: expected {cols}, got {}", + position.entries.len() + )); + } + let Some((sampled, _)) = position.entries.split_first() else { + return Err(format!("logprobs row {row_index} is empty")); + }; + + token_ranks.extend_from_slice(&(sampled.rank as i64).to_le_bytes()); + for entry in &position.entries { + token_ids.extend_from_slice(&(entry.token_id as i64).to_le_bytes()); + logprobs.extend_from_slice(&entry.logprob.to_le_bytes()); + } + } + + Ok(Self { + logprob_token_ids: WireNdArray { + dtype: "(self, frames: &[Frame], field_prefix: &str) -> Result + where + Frame: AsRef<[u8]>, + { + if let Some(indices) = self.cu_num_generated_tokens { + bail_ext_value_decode!( + "{field_prefix}.cu_num_generated_tokens: \ + expected None for per-request engine-core logprobs payload, got {indices:?}" + ); + } + + let token_ids = array::decode_array2_u32( + self.logprob_token_ids, + &format!("{field_prefix}.logprob_token_ids"), + frames, + )?; + let logprobs = + array::decode_array2_f32(self.logprobs, &format!("{field_prefix}.logprobs"), frames)?; + let token_ranks = array::decode_array1_u32( + self.token_ranks, + &format!("{field_prefix}.token_ranks"), + frames, + )?; + + if token_ids.rows != logprobs.rows || token_ids.cols != logprobs.cols { + bail_ext_value_decode!( + "{field_prefix}: row shape mismatch between token ids ({}, {}) and logprobs ({}, {})", + token_ids.rows, + token_ids.cols, + logprobs.rows, + logprobs.cols + ); + } + if token_ids.rows != token_ranks.len() { + bail_ext_value_decode!( + "{field_prefix}: token_ranks length {} does not match row count {}", + token_ranks.len(), + token_ids.rows + ); + } + + let mut positions = Vec::with_capacity(token_ids.rows); + for ((token_ids_row, logprobs_row), sampled_rank) in token_ids + .data + .chunks(token_ids.cols) + .zip(logprobs.data.chunks(logprobs.cols)) + .zip(token_ranks) + { + positions.push(PositionLogprobs::from_decoded_row( + token_ids_row, + logprobs_row, + sampled_rank, + )?); + } + + Ok(Logprobs { positions }) + } +} + +/// Decode one ordinary or multipart engine-core output message into the strong +/// typed public protocol shape. +pub fn decode_engine_core_outputs(frames: &[Frame]) -> Result +where + Frame: AsRef<[u8]>, +{ + let first_frame = frames.first().ok_or_else(|| ext_value_decode!("missing output frame"))?; + + let mut outputs: EngineCoreOutputs = decode_msgpack(first_frame.as_ref())?; + outputs.resolve_in_place(frames)?; + Ok(outputs) +} diff --git a/rust/src/engine-core-client/src/protocol/logprobs/array.rs b/rust/src/engine-core-client/src/protocol/logprobs/array.rs new file mode 100644 index 00000000000..132428ffcd8 --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/logprobs/array.rs @@ -0,0 +1,306 @@ +use std::io::Cursor; + +use byteorder::{BigEndian, LittleEndian, NativeEndian, ReadBytesExt}; +use itertools::Itertools as _; + +use crate::error::{Error, Result, ext_value_decode}; +use crate::protocol::tensor::{ShapeExt as _, WireArrayData, WireNdArray}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum ScalarType { + I32, + I64, + F32, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum Endianness { + Little, + Big, + Native, +} + +#[derive(Debug, Clone, PartialEq)] +pub(super) struct DecodedArray2 { + pub rows: usize, + pub cols: usize, + pub data: Vec, +} + +pub(super) fn decode_array2_u32( + value: WireNdArray, + field: &str, + frames: &[Frame], +) -> Result> +where + Frame: AsRef<[u8]>, +{ + let (shape, bytes, scalar, endianness) = + decode_array_metadata(value, field, frames, &[ScalarType::I32, ScalarType::I64])?; + if shape.len() != 2 { + return Err(decode_error( + field, + &format!("expected rank-2 array, got rank {}", shape.len()), + )); + } + + let data = match scalar { + ScalarType::I32 => decode_i32_vec(&bytes, endianness, field)? + .into_iter() + .map(|value| convert_to_u32(value, field)) + .try_collect()?, + ScalarType::I64 => decode_i64_vec(&bytes, endianness, field)? + .into_iter() + .map(|value| convert_to_u32(value, field)) + .try_collect()?, + ScalarType::F32 => unreachable!("scalar validation should reject f32"), + }; + Ok(DecodedArray2 { + rows: shape[0], + cols: shape[1], + data, + }) +} + +pub(super) fn decode_array1_u32( + value: WireNdArray, + field: &str, + frames: &[Frame], +) -> Result> +where + Frame: AsRef<[u8]>, +{ + let (shape, bytes, scalar, endianness) = + decode_array_metadata(value, field, frames, &[ScalarType::I32, ScalarType::I64])?; + if shape.len() != 1 { + return Err(decode_error( + field, + &format!("expected rank-1 array, got rank {}", shape.len()), + )); + } + + let data = match scalar { + ScalarType::I32 => decode_i32_vec(&bytes, endianness, field)? + .into_iter() + .map(|value| convert_to_u32(value, field)) + .try_collect()?, + ScalarType::I64 => decode_i64_vec(&bytes, endianness, field)? + .into_iter() + .map(|value| convert_to_u32(value, field)) + .try_collect()?, + ScalarType::F32 => unreachable!("scalar validation should reject f32"), + }; + Ok(data) +} + +pub(super) fn decode_array2_f32( + value: WireNdArray, + field: &str, + frames: &[Frame], +) -> Result> +where + Frame: AsRef<[u8]>, +{ + let (shape, bytes, _, endianness) = + decode_array_metadata(value, field, frames, &[ScalarType::F32])?; + if shape.len() != 2 { + return Err(decode_error( + field, + &format!("expected rank-2 array, got rank {}", shape.len()), + )); + } + + let data = decode_f32_vec(&bytes, endianness, field)?; + Ok(DecodedArray2 { + rows: shape[0], + cols: shape[1], + data, + }) +} + +pub(super) fn decode_array_metadata( + value: WireNdArray, + field: &str, + frames: &[Frame], + expected_scalars: &[ScalarType], +) -> Result<(Vec, Vec, ScalarType, Endianness)> +where + Frame: AsRef<[u8]>, +{ + let WireNdArray { dtype, shape, data } = value; + let (scalar, endianness) = parse_dtype(&dtype, field)?; + if !expected_scalars.contains(&scalar) { + return Err(decode_error( + field, + &format!("expected dtype in {:?}, got {}", expected_scalars, dtype), + )); + } + + let bytes = resolve_array_bytes(data, field, frames)?; + validate_byte_length(shape.as_slice(), bytes.len(), field, scalar)?; + Ok((shape, bytes, scalar, endianness)) +} + +pub(super) fn parse_dtype(dtype: &str, field: &str) -> Result<(ScalarType, Endianness)> { + let (endianness, body) = match dtype.as_bytes().first().copied() { + Some(b'<') => (Endianness::Little, &dtype[1..]), + Some(b'>') => (Endianness::Big, &dtype[1..]), + Some(b'=') => (Endianness::Native, &dtype[1..]), + Some(b'|') => (Endianness::Native, &dtype[1..]), + _ => (Endianness::Native, dtype), + }; + + let scalar = match body { + "i4" | "int32" => ScalarType::I32, + "i8" | "int64" => ScalarType::I64, + "f4" | "float32" => ScalarType::F32, + _ => { + return Err(decode_error( + field, + &format!("unsupported dtype string {dtype:?}"), + )); + } + }; + Ok((scalar, endianness)) +} + +pub(super) fn resolve_array_bytes( + value: WireArrayData, + field: &str, + frames: &[Frame], +) -> Result> +where + Frame: AsRef<[u8]>, +{ + match value { + WireArrayData::RawView(bytes) => Ok(bytes), + WireArrayData::AuxIndex(index) => { + let frame = frames.get(index).ok_or_else(|| { + decode_error( + field, + &format!( + "aux frame index {index} out of range for {} frames", + frames.len() + ), + ) + })?; + Ok(frame.as_ref().to_vec()) + } + } +} + +pub(super) fn validate_byte_length( + shape: &[usize], + byte_len: usize, + field: &str, + scalar: ScalarType, +) -> Result<()> { + let element_count = shape + .checked_numel() + .ok_or_else(|| decode_error(field, "shape element count overflowed usize"))?; + let element_size = match scalar { + ScalarType::I32 | ScalarType::F32 => 4, + ScalarType::I64 => 8, + }; + let expected = element_count + .checked_mul(element_size) + .ok_or_else(|| decode_error(field, "byte length overflowed usize"))?; + if expected != byte_len { + return Err(decode_error( + field, + &format!("byte length mismatch: expected {expected}, got {byte_len}"), + )); + } + Ok(()) +} + +pub(super) fn decode_i32_vec( + bytes: &[u8], + endianness: Endianness, + field: &str, +) -> Result> { + if !bytes.len().is_multiple_of(4) { + return Err(decode_error( + field, + &format!("byte length {} is not divisible by 4", bytes.len()), + )); + } + let mut cursor = Cursor::new(bytes); + let mut values = Vec::with_capacity(bytes.len() / 4); + while (cursor.position() as usize) < bytes.len() { + let value = match endianness { + Endianness::Little => cursor.read_i32::(), + Endianness::Big => cursor.read_i32::(), + Endianness::Native => cursor.read_i32::(), + } + .map_err(|error| decode_error(field, &format!("failed to read i32 payload: {error}")))?; + values.push(value); + } + Ok(values) +} + +pub(super) fn decode_f32_vec( + bytes: &[u8], + endianness: Endianness, + field: &str, +) -> Result> { + if !bytes.len().is_multiple_of(4) { + return Err(decode_error( + field, + &format!("byte length {} is not divisible by 4", bytes.len()), + )); + } + let mut cursor = Cursor::new(bytes); + let mut values = Vec::with_capacity(bytes.len() / 4); + while (cursor.position() as usize) < bytes.len() { + let value = match endianness { + Endianness::Little => cursor.read_f32::(), + Endianness::Big => cursor.read_f32::(), + Endianness::Native => cursor.read_f32::(), + } + .map_err(|error| decode_error(field, &format!("failed to read f32 payload: {error}")))?; + values.push(value); + } + Ok(values) +} + +pub(super) fn decode_i64_vec( + bytes: &[u8], + endianness: Endianness, + field: &str, +) -> Result> { + if !bytes.len().is_multiple_of(8) { + return Err(decode_error( + field, + &format!("byte length {} is not divisible by 8", bytes.len()), + )); + } + let mut cursor = Cursor::new(bytes); + let mut values = Vec::with_capacity(bytes.len() / 8); + while (cursor.position() as usize) < bytes.len() { + let value = match endianness { + Endianness::Little => cursor.read_i64::(), + Endianness::Big => cursor.read_i64::(), + Endianness::Native => cursor.read_i64::(), + } + .map_err(|error| decode_error(field, &format!("failed to read i64 payload: {error}")))?; + values.push(value); + } + Ok(values) +} + +fn convert_to_u32(value: I, field: &str) -> Result +where + I: TryInto + std::fmt::Display + Copy, +{ + value.try_into().map_err(|_| { + decode_error( + field, + &format!("expected non-negative token id/rank that fits in u32, got {value}"), + ) + }) +} + +pub(super) fn decode_error(field: &str, reason: &str) -> Error { + ext_value_decode!("{field}: {reason}") +} diff --git a/rust/src/engine-core-client/src/protocol/logprobs/tests.rs b/rust/src/engine-core-client/src/protocol/logprobs/tests.rs new file mode 100644 index 00000000000..7408b98f50c --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/logprobs/tests.rs @@ -0,0 +1,302 @@ +use std::collections::BTreeSet; + +use bytes::Bytes; +use rmpv::Value; + +use super::{Logprobs, PositionLogprobs, TokenLogprob, decode_engine_core_outputs}; +use crate::protocol::EngineCoreFinishReason; + +fn encode_value(value: &Value) -> Vec { + let mut out = Vec::new(); + rmpv::encode::write_value(&mut out, value).unwrap(); + out +} + +fn output_wire_with_custom_fields( + new_logprobs: Option, + prompt_logprobs: Option, +) -> Value { + Value::Array(vec![ + Value::from(0), + Value::Array(vec![Value::Array(vec![ + Value::from("req-1"), + Value::Array(vec![Value::from(7), Value::from(8)]), + new_logprobs.unwrap_or(Value::Nil), + prompt_logprobs.unwrap_or(Value::Nil), + Value::Nil, + Value::from(EngineCoreFinishReason::Length as u8), + ])]), + Value::Nil, + Value::from(0.0), + Value::Nil, + Value::Array(vec![Value::from("req-1")]), + ]) +} + +fn ndarray_value(dtype: &str, shape: &[usize], data: Value) -> Value { + Value::Array(vec![ + Value::from(dtype), + Value::Array(shape.iter().copied().map(Value::from).collect()), + data, + ]) +} + +fn inline_logprobs_value() -> Value { + let ids = Value::Ext( + 3, + vec![ + 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, + 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, + ], + ); + let probs = Value::Ext( + 3, + vec![ + 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64, + ], + ); + let ranks = Value::Ext(3, vec![1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0]); + Value::Array(vec![ + ndarray_value(" Value { + let ids = Value::Ext( + 3, + vec![ + 10, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, + 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, + ], + ); + let probs = Value::Ext( + 3, + vec![ + 0, 0, 32, 65, 0, 0, 48, 65, 0, 0, 64, 65, 0, 0, 80, 65, 0, 0, 96, 65, 0, 0, 112, 65, + ], + ); + let ranks = Value::Ext(3, vec![3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0]); + Value::Array(vec![ + ndarray_value("int64", &[2, 3], ids), + ndarray_value("float32", &[2, 3], probs), + ndarray_value("int64", &[2], ranks), + Value::Nil, + ]) +} + +fn expected_sample_logprobs() -> Logprobs { + Logprobs { + positions: vec![ + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: 1, + logprob: 1.0, + rank: 1, + }, + TokenLogprob { + token_id: 2, + logprob: 2.0, + rank: 1, + }, + TokenLogprob { + token_id: 3, + logprob: 3.0, + rank: 2, + }, + ], + }, + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: 4, + logprob: 4.0, + rank: 2, + }, + TokenLogprob { + token_id: 5, + logprob: 5.0, + rank: 1, + }, + TokenLogprob { + token_id: 6, + logprob: 6.0, + rank: 2, + }, + ], + }, + ], + } +} + +fn expected_prompt_logprobs() -> Logprobs { + Logprobs { + positions: vec![ + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: 10, + logprob: 10.0, + rank: 3, + }, + TokenLogprob { + token_id: 11, + logprob: 11.0, + rank: 1, + }, + TokenLogprob { + token_id: 12, + logprob: 12.0, + rank: 2, + }, + ], + }, + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: 13, + logprob: 13.0, + rank: 4, + }, + TokenLogprob { + token_id: 14, + logprob: 14.0, + rank: 1, + }, + TokenLogprob { + token_id: 15, + logprob: 15.0, + rank: 2, + }, + ], + }, + ], + } +} + +#[test] +fn decodes_inline_new_logprobs() { + let frames = vec![Bytes::from(encode_value(&output_wire_with_custom_fields( + Some(inline_logprobs_value()), + None, + )))]; + let decoded = decode_engine_core_outputs(&frames).unwrap(); + + let logprobs = decoded.outputs[0].new_logprobs.clone().unwrap().into_direct().unwrap(); + assert_eq!(logprobs, expected_sample_logprobs()); + assert_eq!( + decoded.finished_requests, + Some(BTreeSet::from(["req-1".to_string()])) + ); +} + +#[test] +fn decodes_multipart_new_logprobs() { + let frames = vec![ + Bytes::from(encode_value(&output_wire_with_custom_fields( + Some(Value::Array(vec![ + ndarray_value("i4", &[1, 2], Value::Ext(3, vec![0, 0, 0, 1, 0, 0, 0, 2])), + ndarray_value( + ">f4", + &[1, 2], + Value::Ext(3, vec![63, 128, 0, 0, 64, 0, 0, 0]), + ), + ndarray_value(">i4", &[1], Value::Ext(3, vec![0, 0, 0, 3])), + Value::Nil, + ])), + None, + )))]; + let decoded = decode_engine_core_outputs(&frames).unwrap(); + let logprobs = decoded.outputs[0].new_logprobs.clone().unwrap().into_direct().unwrap(); + assert_eq!( + logprobs, + Logprobs { + positions: vec![PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: 1, + logprob: 1.0, + rank: 3, + }, + TokenLogprob { + token_id: 2, + logprob: 2.0, + rank: 1, + }, + ], + }], + } + ); +} + +#[test] +fn rejects_non_none_cu_num_generated_tokens() { + let frames = vec![Bytes::from(encode_value(&output_wire_with_custom_fields( + Some(Value::Array(vec![ + ndarray_value(" +#[derive(Debug, Clone, PartialEq, Serialize_tuple, Deserialize_tuple)] +pub struct WireLogprobs { + /// Wire array with shape `[num_positions, max_num_logprobs + 1]`. + pub logprob_token_ids: WireNdArray, + /// Wire array with shape `[num_positions, max_num_logprobs + 1]`. + pub logprobs: WireNdArray, + /// Wire array with shape `[num_positions]`. + /// + /// Python uses the field name `sampled_token_ranks` for sample logprobs and + /// `selected_token_ranks` for prompt logprobs. Rust keeps one neutral field + /// because both payloads share the same wire representation. + pub token_ranks: WireNdArray, + /// Preserved only for wire compatibility with batch-level Python tensors. + /// Scheduler-sliced per-request outputs should emit `None` here, and + /// the semantic Rust decoder rejects any other value. + #[serde(default)] + pub cu_num_generated_tokens: Option>, +} diff --git a/rust/src/engine-core-client/src/protocol/mod.rs b/rust/src/engine-core-client/src/protocol/mod.rs new file mode 100644 index 00000000000..dfe9ecff7e5 --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/mod.rs @@ -0,0 +1,756 @@ +use std::any::type_name; +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::io::Cursor; + +use bytes::Bytes; +use rmpv::Value; +use serde::{Deserialize, Serialize}; +use serde_default::DefaultFromSerde; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use serde_tuple::{Deserialize_tuple, Serialize_tuple}; +use thiserror_ext::AsReport; + +use crate::error::{Error, Result}; +use crate::protocol::logprobs::MaybeWireLogprobs; +use crate::protocol::multimodal::MmFeatures; +use crate::protocol::stats::{PrefillStats, SchedulerStats}; + +// TODO: This module currently mixes reusable frontend-facing semantic types +// (for example `FinishReason`, `StopReason`, `RequestOutputKind`, and future +// cleaned-up frontend sampling types) with engine-core-specific wire DTOs and +// handshake/control messages. While the Rust frontend is still evolving +// quickly, keep them co-located here for iteration speed. Once the higher-level +// API boundary stabilizes, move the truly reusable semantic types into a +// lower-level common crate and keep the engine transport/wire messages here. + +/// Dynamic msgpack value used for schema positions that are preserved but not +/// yet strongly typed in the early-stage Rust client. +pub type OpaqueValue = Value; + +fn default_opaque_value_nil() -> OpaqueValue { + Value::Nil +} + +fn is_false(v: &bool) -> bool { + !v +} + +mod classified_outputs; +pub mod dtype; +pub mod handshake; +pub mod logprobs; +pub mod multimodal; +pub mod stats; +pub mod tensor; +pub use classified_outputs::{ + ClassifiedEngineCoreOutputs, DpControlMessage, RequestBatchOutputs, UtilityCallOutput, +}; +pub use dtype::ModelDtype; +pub use logprobs::decode_engine_core_outputs; + +/// Request types are encoded as single-byte protocol constants so they can be +/// sent over the ZMQ socket without an extra encoding step. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum EngineCoreRequestType { + Add = 0, + Abort = 1, + StartDpWave = 2, + Utility = 3, +} + +impl EngineCoreRequestType { + pub fn to_frame(self) -> Bytes { + Bytes::from_static(match self { + Self::Add => b"\x00", + Self::Abort => b"\x01", + Self::StartDpWave => b"\x02", + Self::Utility => b"\x03", + }) + } +} + +/// Reason a request finished: stop, length, abort, error, or repetition. +/// +/// This mirrors the Python enum and uses integer encoding for compact wire +/// representation. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum EngineCoreFinishReason { + /// A stop string was emitted. + Stop = 0, + /// `max_tokens` or `max_model_len` was reached. + Length = 1, + /// The request was aborted by the client. + Abort = 2, + /// A retryable request-level internal error occurred. + Error = 3, + /// A repetitive token pattern was detected. + Repetition = 4, +} + +/// Event types emitted by engine-core for one request. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum EngineCoreEventType { + Queued = 1, + Scheduled = 2, + Preempted = 3, +} + +/// A timestamped engine-core event associated with one request. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct EngineCoreEvent { + pub r#type: EngineCoreEventType, + pub timestamp: f64, +} + +/// Controls how intermediate outputs are returned to the frontend. +/// +/// `Cumulative = 0` is intentionally not supported in Rust frontend. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum RequestOutputKind { + /// Return only token deltas in each update. + #[default] + Delta = 1, + /// Suppress intermediate updates and return only the final output. + FinalOnly = 2, +} + +/// The stop reason associated with a finished output. +/// +/// Python models this as the union-typed `stop_reason: int | str | None` +/// field on `EngineCoreOutput`; the Rust client narrows it into a tagged enum. +/// +/// Original Python field: +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum StopReason { + TokenId(u32), + Text(String), +} + +/// Parameters for configuring structured outputs (guided decoding). +/// +/// Exactly one constraint field (`json`, `regex`, `choice`, `grammar`, +/// `json_object`, or `structural_tag`) should be set. The engine-core +/// backend selects the appropriate grammar compiler based on which field +/// is present. +/// +/// Original Python definition: +/// +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(default)] +pub struct StructuredOutputsParams { + /// JSON schema (as a dict/object or JSON string) constraining the output. + pub json: Option, + /// Regular expression the output must match. + pub regex: Option, + /// List of allowed output strings (the model must produce one of these). + pub choice: Option>, + /// Context-free grammar (in EBNF-like notation) the output must conform to. + pub grammar: Option, + /// When `true`, output must be valid JSON (free-form, no schema). + pub json_object: Option, + /// Disable any additional whitespace in guided JSON output. + #[serde(skip_serializing_if = "crate::protocol::is_false")] + pub disable_any_whitespace: bool, + /// Disable `additionalProperties` in JSON schema output. + #[serde(skip_serializing_if = "crate::protocol::is_false")] + pub disable_additional_properties: bool, + /// Custom whitespace pattern for guided JSON output. + pub whitespace_pattern: Option, + /// Structural tag configuration (JSON-encoded string). + pub structural_tag: Option, +} + +/// Engine-core-facing sampling parameters for text generation. +/// +/// This is the normalized southbound subset used by the Rust frontend when it +/// talks to Python engine-core over the wire. User-facing request semantics +/// such as `stop` strings, `n`, `ignore_eos`, and output aggregation mode are +/// intentionally handled by higher layers before values reach this DTO. +/// +/// Original Python definition: +/// +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct EngineCoreSamplingParams { + /// Controls randomness. Lower values are more deterministic; zero means + /// greedy sampling. + pub temperature: f32, + /// Cumulative probability threshold for nucleus sampling. + pub top_p: f32, + /// Maximum number of top tokens to consider. `0` means all tokens. + pub top_k: u32, + /// Random seed used by the sampler when present. + pub seed: Option, + /// Maximum number of tokens to generate per output sequence. + pub max_tokens: u32, + /// Minimum number of tokens to generate before EOS or stop-token handling. + pub min_tokens: u32, + /// Number of log probabilities to return per generated token. + /// + /// `None` disables sample logprobs. `-1` requests the full vocabulary. + pub logprobs: Option, + /// Number of log probabilities to return per prompt token. + /// + /// `None` disables prompt logprobs. `-1` requests the full vocabulary. + pub prompt_logprobs: Option, + /// Minimum probability threshold for token sampling. + pub min_p: f32, + /// Frequency penalty applied by the sampler. + pub frequency_penalty: f32, + /// Presence penalty applied by the sampler. + pub presence_penalty: f32, + /// Repetition penalty applied by the sampler. + pub repetition_penalty: f32, + /// Token IDs that stop generation. + pub stop_token_ids: Vec, + /// Primary EOS token ID used by engine-core's dedicated EOS stop path. + /// + /// This mirrors Python's internal `_eos_token_id` field and is derived by + /// the frontend from tokenizer/model metadata rather than supplied directly + /// by end users. + #[serde(rename = "_eos_token_id")] + pub eos_token_id: Option, + /// Complete stop-token set used by engine-core for `min_tokens` masking. + /// + /// This mirrors Python's internal `_all_stop_token_ids` field and should + /// contain explicit `stop_token_ids` plus any frontend-derived EOS token + /// IDs. + #[serde(rename = "_all_stop_token_ids")] + pub all_stop_token_ids: BTreeSet, + /// Logit biases to apply during sampling. + /// Keys are token IDs + #[serde(default)] + pub logit_bias: Option>, + /// Restrict output to these token IDs only. + #[serde(default)] + pub allowed_token_ids: Option>, + /// Tokenized bad words to avoid during generation. + #[serde(default, rename = "_bad_words_token_ids")] + pub bad_words_token_ids: Option>>, + /// Parameters for configuring structured outputs (guided decoding). + #[serde(default)] + pub structured_outputs: Option, + /// Specific token IDs for which log probabilities should be returned at + /// each position. + /// + /// When set, the engine returns logprobs for exactly these tokens in + /// addition to the sampled/scored token. Mutually exclusive with the + /// `logprobs` count field in practice. + #[serde(default)] + pub logprob_token_ids: Option>, + /// If `Some(true)`, the request will not attempt to read from the prefix + /// cache; newly computed blocks may still populate the cache. `None` + /// defers to engine-core defaults. + #[serde(default)] + pub skip_reading_prefix_cache: Option, + /// Additional request parameters for custom extensions (from `vllm_xargs`). + #[serde(default)] + pub extra_args: Option>, +} + +impl EngineCoreSamplingParams { + /// Constructs a default sampling params for testing purposes only. + pub fn for_test() -> Self { + Self { + temperature: 1.0, + top_p: 1.0, + top_k: 0, + seed: None, + max_tokens: 65536, + min_tokens: 0, + logprobs: None, + prompt_logprobs: None, + min_p: 0.0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + repetition_penalty: 1.0, + stop_token_ids: Vec::new(), + eos_token_id: None, + all_stop_token_ids: BTreeSet::new(), + logit_bias: None, + allowed_token_ids: None, + bad_words_token_ids: None, + structured_outputs: None, + logprob_token_ids: None, + skip_reading_prefix_cache: None, + extra_args: None, + } + } +} + +/// Engine-core add-request payload sent from frontend to engine. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Serialize_tuple, Deserialize_tuple, DefaultFromSerde)] +pub struct EngineCoreRequest { + pub request_id: String, + pub prompt_token_ids: Option>, + /// Multimodal features attached to the request. + pub mm_features: Option, + pub sampling_params: Option, + /// Pooling parameters are preserved in the schema but not yet strongly + /// typed. + pub pooling_params: Option, + pub arrival_time: f64, + #[serde(default)] + pub lora_request: Option, + #[serde(default)] + pub cache_salt: Option, + #[serde(default)] + pub data_parallel_rank: Option, + /// Unsupported in the first-stage Rust client because Python uses a custom + /// tensor/aux-frame encoding path for this field. + #[serde(default)] + pub prompt_embeds: Option, + /// Per-position mask for mixed-mode inputs (e.g. chat completion with + /// `prompt_embeds` content parts). `Some(true)` means real token id; + /// `Some(false)` means the position uses a pre-computed entry from + /// `prompt_embeds`. `None` for pure-tokens and pure-embeds requests. + #[serde(default)] + pub prompt_is_token_ids: Option>, + /// Index of the client, used to ensure outputs are sent back to the same + /// client when scaling out the frontend. + #[serde(default)] + pub client_index: u32, + /// In DP mode, indicates which wave this request is expected to belong to. + #[serde(default)] + pub current_wave: u32, + #[serde(default)] + pub priority: i32, + #[serde(default)] + pub trace_headers: Option>, + #[serde(default)] + pub resumable: bool, + /// Original user-provided request ID, used for output reporting and aborts. + #[serde(default)] + pub external_req_id: Option, + #[serde(default)] + pub reasoning_ended: Option, + /// Opaque reasoning-parser kwargs forwarded from the frontend to the + /// structured-output backend. + #[serde(default)] + pub reasoning_parser_kwargs: Option, + /// If `true`, the request should be added to the scheduler's waiting queue + /// and immediately aborted, so connector-side cleanup runs via the + /// standard `request_finished` hook. + #[serde(default)] + pub abort_immediately: bool, +} + +impl EngineCoreRequest { + /// Validate fields intentionally not supported in the first-stage client. + pub fn validate(&self) -> Result<()> { + if self.prompt_embeds.is_some() { + return Err(Error::UnsupportedField { + context: "EngineCoreRequest", + field: "prompt_embeds", + }); + } + Ok(()) + } +} + +/// Engine-core utility call payload sent from frontend to engine. +/// +/// Original Python payload shape: +/// `(client_index, call_id, method_name, args)` +#[derive(Debug, Clone, PartialEq, Serialize_tuple)] +pub struct EngineCoreUtilityRequest { + pub client_index: u32, + pub call_id: i64, + pub method_name: String, + pub args: OpaqueValue, +} + +impl EngineCoreUtilityRequest { + /// Create a new utility request with the given strongly typed arguments, + /// encoding them into the expected msgpack value format. + pub fn new( + client_index: u32, + call_id: i64, + method_name: impl Into, + args: T, + ) -> Result + where + T: Serialize + std::fmt::Debug, + { + let args = rmpv::ext::to_value(&args).map_err(|error| Error::Encode { + target_type: type_name::(), + message: format!( + "failed to encode utility args `{args:?}`: {}", + error.to_report_string() + ), + })?; + let args = match args { + Value::Nil => Value::Array(Vec::new()), + other => other, + }; + + Ok(Self { + client_index, + call_id, + method_name: method_name.into(), + args, + }) + } +} + +/// Engine-core output for a single request. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Serialize_tuple, Deserialize_tuple, DefaultFromSerde)] +pub struct EngineCoreOutput { + pub request_id: String, + pub new_token_ids: Vec, + /// Decoded sample logprobs for the newly generated positions in this + /// output. + #[serde(default)] + pub new_logprobs: Option, + /// Decoded prompt logprobs for the scored prompt positions emitted in this + /// output. + #[serde(default)] + pub new_prompt_logprobs_tensors: Option, + #[serde(default)] + pub pooling_output: Option, + #[serde(default)] + pub finish_reason: Option, + #[serde(default)] + pub stop_reason: Option, + #[serde(default)] + pub events: Option>, + #[serde(default)] + pub kv_transfer_params: Option, + #[serde(default)] + pub trace_headers: Option, + /// Breakdown of the scheduled prefill computation, set on the first output + /// of a newly scheduled prefill and elided for subsequent decode outputs. + #[serde(default)] + pub prefill_stats: Option, + #[serde(default)] + pub routed_experts: Option, + /// Number of NaNs seen in logits. Values above zero indicate corruption. + #[serde(default)] + pub num_nans_in_logits: u32, +} + +impl EngineCoreOutput { + /// Returns whether this output is terminal for the request. + pub fn finished(&self) -> bool { + self.finish_reason.is_some() + } +} + +/// Result of a utility call. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Serialize_tuple, Deserialize_tuple, DefaultFromSerde)] +pub struct UtilityOutput { + pub call_id: i64, + /// Non-`None` implies the call failed and `result` should be ignored. + #[serde(default)] + pub failure_message: Option, + #[serde(default)] + pub result: Option, +} + +/// Python `UtilityResult` wrapper carried inside `UtilityOutput.result`. +/// +/// Upstream reference: +/// +#[derive(Debug, Clone, PartialEq, Serialize_tuple, Deserialize_tuple)] +pub struct UtilityResultEnvelope { + /// Recursive type information encoded on Python side, serving as the hint + /// for deserialization. We don't care it here as in Rust frontend all + /// utility calls are strongly-typed. + #[serde(default)] + type_info: Option, + /// The actual utility result. + #[serde(default = "default_opaque_value_nil")] + result: OpaqueValue, +} + +impl UtilityResultEnvelope { + /// Create a utility result envelope without type information. + pub fn without_type_info(result: OpaqueValue) -> Self { + Self { + type_info: None, + result, + } + } +} + +impl UtilityOutput { + /// Decode the typed result of a utility call. + pub fn into_typed_result(self, method: &str) -> Result + where + T: serde::de::DeserializeOwned, + { + if let Some(message) = self.failure_message { + return Err(Error::UtilityCallFailed { + method: method.to_string(), + call_id: self.call_id, + message, + }); + } + + let result = self.result.map(|e| e.result).unwrap_or(Value::Nil); + + rmpv::ext::from_value(result).map_err(|error| Error::UtilityResultDecode { + method: method.to_string(), + call_id: self.call_id, + message: error.to_report_string(), + }) + } +} + +/// Batch of engine-core outputs returned to a frontend client. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Serialize_tuple, Deserialize_tuple, DefaultFromSerde)] +pub struct EngineCoreOutputs { + #[serde(default)] + pub engine_index: u32, + /// Outputs grouped for this client in the current engine tick. + #[serde(default)] + pub outputs: Vec, + #[serde(default)] + pub scheduler_stats: Option>, + #[serde(default)] + pub timestamp: f64, + #[serde(default)] + pub utility_output: Option, + #[serde(default)] + pub finished_requests: Option>, + /// In DP mode, signals that the current wave finished and engines are + /// paused. + #[serde(default)] + pub wave_complete: Option, + /// In DP mode, signals that a request arrived for an old wave and the next + /// wave needs to start in other engines. + #[serde(default)] + pub start_wave: Option, +} + +/// Encode a Rust value into msgpack using the protocol crate's serde model. +pub fn encode_msgpack(value: &T) -> Result> +where + T: Serialize + std::fmt::Debug, +{ + rmp_serde::to_vec_named(value).map_err(|error| Error::Encode { + target_type: type_name::(), + message: format!( + "failed to encode value `{:?}`: {}", + value, + error.to_report_string() + ), + }) +} + +/// Decode a msgpack payload into a strongly typed protocol value, with enhanced +/// error reporting. +pub fn decode_msgpack(bytes: &[u8]) -> Result +where + T: for<'de> Deserialize<'de>, +{ + fn decode_value_preview(bytes: &[u8]) -> String { + match decode_value(bytes) { + Ok(value) => format!("{value}"), + Err(error) => format!(""), + } + } + + rmp_serde::from_slice(bytes).map_err(|error| Error::Decode { + target_type: type_name::(), + message: format!("{error}; value fallback: {}", decode_value_preview(bytes)), + }) +} + +pub fn decode_value(bytes: &[u8]) -> Result { + Ok(rmpv::decode::read_value(&mut Cursor::new(bytes))?) +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use super::*; + + fn utility_result_value(value: T) -> UtilityResultEnvelope + where + T: Serialize, + { + UtilityResultEnvelope::without_type_info(rmpv::ext::to_value(value).unwrap()) + } + + #[test] + fn engine_core_request_serializes_as_full_array() { + let request = EngineCoreRequest { + request_id: "req-1".to_string(), + prompt_token_ids: Some(vec![1, 2, 3]), + sampling_params: Some(EngineCoreSamplingParams { + max_tokens: 8, + ..EngineCoreSamplingParams::for_test() + }), + arrival_time: 1234.5, + client_index: 7, + ..EngineCoreRequest::default() + }; + + let encoded = encode_msgpack(&request).unwrap(); + let value = decode_value(&encoded).unwrap(); + let array = match value { + Value::Array(array) => array, + other => panic!("expected array, got {other:?}"), + }; + + assert_eq!(array.len(), 20); + assert_eq!(array[0], Value::from("req-1")); + assert_eq!(array[2], Value::Nil); + assert_eq!(array[4], Value::Nil); + assert_eq!(array[10], Value::Nil); + assert_eq!(array[11], Value::from(7)); + } + + #[test] + fn engine_core_outputs_roundtrip_finished_fields() { + let outputs = EngineCoreOutputs { + outputs: vec![EngineCoreOutput { + request_id: "req-1".to_string(), + new_token_ids: vec![42], + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason: Some(EngineCoreFinishReason::Length), + stop_reason: Some(StopReason::Text("stop".to_string())), + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + }], + finished_requests: Some(BTreeSet::from(["req-1".to_string()])), + ..Default::default() + }; + + let encoded = encode_msgpack(&outputs).unwrap(); + let decoded: EngineCoreOutputs = decode_msgpack(&encoded).unwrap(); + + assert_eq!(decoded.outputs.len(), 1); + assert_eq!( + decoded.outputs[0].finish_reason, + Some(EngineCoreFinishReason::Length) + ); + assert_eq!( + decoded.finished_requests, + Some(BTreeSet::from(["req-1".to_string()])) + ); + } + + #[test] + fn utility_request_serializes_as_tuple_payload() { + let request = EngineCoreUtilityRequest::new(7, 42, "is_sleeping", ()).unwrap(); + + let encoded = encode_msgpack(&request).unwrap(); + let value = decode_value(&encoded).unwrap(); + let array = match value { + Value::Array(array) => array, + other => panic!("expected utility request array, got {other:?}"), + }; + + assert_eq!(array.len(), 4); + assert_eq!(array[0], Value::from(7)); + assert_eq!(array[1], Value::from(42)); + assert_eq!(array[2], Value::from("is_sleeping")); + assert_eq!(array[3], Value::Array(Vec::new())); + } + + #[test] + fn utility_output_decodes_typed_result() { + let output = UtilityOutput { + call_id: 9, + failure_message: None, + result: Some(utility_result_value(true)), + }; + + assert!(output.into_typed_result::("is_sleeping").unwrap()); + } + + #[test] + fn utility_output_reports_failure_message() { + let error = UtilityOutput { + call_id: 9, + failure_message: Some("boom".to_string()), + result: None, + } + .into_typed_result::("is_sleeping") + .unwrap_err(); + + assert!(matches!( + error, + Error::UtilityCallFailed { + method, + call_id, + message + } if method == "is_sleeping" && call_id == 9 && message == "boom" + )); + } + + #[test] + fn utility_output_decodes_missing_result_as_unit() { + UtilityOutput { + call_id: 3, + failure_message: None, + result: None, + } + .into_typed_result::<()>("reset_mm_cache") + .unwrap(); + } + + #[test] + fn utility_output_decodes_nil_result_as_unit() { + UtilityOutput { + call_id: 4, + failure_message: None, + result: Some(UtilityResultEnvelope::without_type_info(Value::Nil)), + } + .into_typed_result::<()>("sleep") + .unwrap(); + } + + #[test] + fn decode_msgpack_includes_type_name_and_value_fallback() { + let error = decode_msgpack::( + &rmp_serde::to_vec_named(&BTreeMap::from([("status", "READY")])).unwrap(), + ) + .unwrap_err(); + + expect_test::expect![[r#"messagepack decode failed for u64: wrong msgpack marker FixMap(1); value fallback: {"status": "READY"}"#]].assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/engine-core-client/src/protocol/multimodal.rs b/rust/src/engine-core-client/src/protocol/multimodal.rs new file mode 100644 index 00000000000..78d84740e2d --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/multimodal.rs @@ -0,0 +1,282 @@ +use std::collections::BTreeMap; + +use serde::{Deserialize, Serialize}; +use serde_tuple::{Deserialize_tuple, Serialize_tuple}; + +use super::tensor::WireTensor; + +/// Multimodal feature payload accepted from higher-level frontend code. +/// +/// Original Python definition: +/// +pub type MmFeatures = Vec; + +/// Represents a single multimodal input with its processed data and metadata. +/// +/// Used to track multimodal data through processing and caching. A request +/// containing multiple multimodal items will have one `MmFeatureSpec` +/// per item. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MmFeatureSpec { + /// Represents multimodal data for this feature. + /// + /// Can be `None` if the item is cached, to skip IPC between API server + /// and engine core processes. + pub data: Option, + + /// The input modality, e.g., `"image"`, `"audio"`, `"video"`. + pub modality: String, + + /// The hash for caching encoder outputs (with LoRA prefix if applicable). + pub identifier: String, + + /// The location of the `modality` tokens corresponding to this item + /// in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`. + pub mm_position: PlaceholderRange, + + /// The hash for caching processor outputs (without LoRA prefix). + #[serde(default)] + pub mm_hash: Option, +} + +/// Placeholder location information for multi-modal data. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PlaceholderRange { + /// The start index of the placeholder in the prompt. + pub offset: usize, + + /// The length of the placeholder. + pub length: usize, + + /// A boolean mask of shape `(length,)` indicating which positions + /// between `offset` and `offset + length` to assign embeddings to. + /// `None` means all positions. + #[serde(default)] + pub is_embed: Option, +} + +/// A dictionary of processed keyword arguments to pass to the model, +/// corresponding to a single item in `MultiModalDataItems`. +/// +/// Original Python definition: +/// +pub type MmKwargsItem = BTreeMap; + +/// Represents a processed keyword argument to pass to a model for a +/// `MmKwargsItem`. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MmFieldElem { + /// The processed value of this field in `MmKwargsItem`, i.e. the + /// keyword argument value to be passed to the model. + /// + /// It may be set to `None` if it is determined that the item is cached + /// in `EngineCore`. + pub data: Option, + + /// Defines how to combine this field's processed values with others in + /// order to batch multi-modal items together for model inference. + pub field: MmField, +} + +/// Processed multimodal keyword argument value. +/// +/// Original Python definition (`NestedTensors`) and wire encoding: +/// +/// +/// +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum MmKwargValue { + Tensor(WireTensor), + Int(i64), + Float(f64), + List(Vec), +} + +/// Defines how to interpret tensor data belonging to a keyword argument for +/// `MultiModalKwargsItems`, and vice versa. +/// +/// Original Python definitions and wire encoding: +/// +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(try_from = "MmFieldWire", into = "MmFieldWire")] +pub enum MmField { + Batched(MmBatchedField), + Flat(MmFlatField), + Shared(MmSharedField), +} + +/// Info: `MultiModalFieldConfig.batched`. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct MmBatchedField { + /// If `True`, then this field is excluded from being moved to the + /// accelerator when multimodal items are grouped and batched. + pub keep_on_cpu: bool, +} + +/// Info: `MultiModalFieldConfig.flat` and +/// `MultiModalFieldConfig.flat_from_sizes`. +/// +/// Original Python definition: +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct MmFlatField { + /// For each multi-modal item, a slice (`dim=0`) or a tuple of slices + /// (`dim>0`) that is used to extract the data corresponding to it. + pub slices: Vec, + + /// The dimension to extract data, default to 0. + pub dim: i32, + + /// If `True`, then this field is excluded from being moved to the + /// accelerator when multimodal items are grouped and batched. + pub keep_on_cpu: bool, +} + +/// Info: `MultiModalFieldConfig.shared`. +/// +/// Original Python definition: +/// +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct MmSharedField { + pub batch_size: usize, + + /// If `True`, then this field is excluded from being moved to the + /// accelerator when multimodal items are grouped and batched. + pub keep_on_cpu: bool, +} + +/// Python slice encoded as `(start, stop, step)`. +/// +/// Original Python wire encoding: +/// +#[derive(Debug, Clone, PartialEq, Eq, Serialize_tuple, Deserialize_tuple)] +pub struct SliceSpec { + pub start: Option, + pub stop: Option, + pub step: Option, +} + +/// A single slice or a tuple of slices used by `MmFlatField`. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum MmSlice { + Slice(SliceSpec), + Slices(Vec), +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize_tuple, Deserialize_tuple)] +struct MmFieldWire { + name: String, + inner: MmFieldWireInner, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +enum MmFieldWireInner { + Batched(MmBatchedField), + Flat(MmFlatField), + Shared(MmSharedField), +} + +impl TryFrom for MmField { + type Error = String; + + fn try_from(value: MmFieldWire) -> Result { + match (value.name.as_str(), value.inner) { + ("batched", MmFieldWireInner::Batched(kwargs)) => Ok(Self::Batched(kwargs)), + ("flat", MmFieldWireInner::Flat(kwargs)) => Ok(Self::Flat(kwargs)), + ("shared", MmFieldWireInner::Shared(kwargs)) => Ok(Self::Shared(kwargs)), + (name, _) => Err(format!( + "mismatched or unknown multimodal field factory {name:?}" + )), + } + } +} + +impl From for MmFieldWire { + fn from(value: MmField) -> Self { + match value { + MmField::Batched(kwargs) => Self { + name: "batched".to_string(), + inner: MmFieldWireInner::Batched(kwargs), + }, + MmField::Flat(kwargs) => Self { + name: "flat".to_string(), + inner: MmFieldWireInner::Flat(kwargs), + }, + MmField::Shared(kwargs) => Self { + name: "shared".to_string(), + inner: MmFieldWireInner::Shared(kwargs), + }, + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use rmpv::Value; + + use super::*; + + fn encode_value(value: &T) -> Value { + let bytes = rmp_serde::to_vec_named(value).expect("encode value"); + rmpv::decode::read_value(&mut Cursor::new(bytes)).expect("decode value") + } + + #[test] + fn multimodal_field_serializes_to_python_factory_tuple() { + let field = MmField::Flat(MmFlatField { + slices: vec![MmSlice::Slice(SliceSpec { + start: Some(0), + stop: Some(1200), + step: None, + })], + dim: 0, + keep_on_cpu: false, + }); + + let value = encode_value(&field); + let Value::Array(items) = value else { + panic!("field should encode as tuple array"); + }; + assert_eq!(items.len(), 2); + assert_eq!(items[0].as_str(), Some("flat")); + + let Value::Map(kwargs) = &items[1] else { + panic!("field kwargs should encode as map"); + }; + assert!(kwargs.iter().any(|(key, _)| key.as_str() == Some("slices"))); + assert!(kwargs.iter().any(|(key, _)| key.as_str() == Some("dim"))); + assert!(kwargs.iter().any(|(key, _)| key.as_str() == Some("keep_on_cpu"))); + } + + #[test] + fn multimodal_field_round_trips_python_factory_tuple() { + let field = MmField::Batched(MmBatchedField { keep_on_cpu: true }); + let encoded = rmp_serde::to_vec_named(&field).expect("encode field"); + let decoded: MmField = rmp_serde::from_slice(&encoded).expect("decode field"); + assert_eq!(decoded, field); + } +} diff --git a/rust/src/engine-core-client/src/protocol/stats.rs b/rust/src/engine-core-client/src/protocol/stats.rs new file mode 100644 index 00000000000..254efc31b24 --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/stats.rs @@ -0,0 +1,192 @@ +use std::collections::BTreeMap; + +use serde::{Deserialize, Serialize}; + +use crate::protocol::OpaqueValue; + +/// Stores cache hit statistics. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct BaseCacheStats { + /// Whether the cache was reset. + pub reset: bool, + /// The number of requests in this update. + pub requests: u64, + /// The number of queries in these requests. + pub queries: u64, + /// The number of hits in these requests. + pub hits: u64, +} + +/// Stores prefix cache hit statistics. +/// - `reset`: Whether `reset_prefix_cache` was invoked. +/// - `queries`: Refers to the number of tokens that were queried. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct PrefixCacheStats { + /// Embedded base cache counters and reset flag. + #[serde(flatten)] + pub base: BaseCacheStats, + /// The number of previously preempted requests in this update. + pub preempted_requests: u64, + /// The `queries` number for preempted requests. + pub preempted_queries: u64, + /// The `hits` number for preempted requests. + pub preempted_hits: u64, +} + +/// Single KV cache block eviction sample. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct KvCacheEvictionEvent { + /// Lifetime from allocation to eviction. + pub lifetime_seconds: f64, + /// Idle time observed before eviction. + pub idle_seconds: f64, + /// Time gaps between consecutive accesses before eviction. + pub reuse_gaps_seconds: Vec, +} + +/// Per-step iteration decoding stats from scheduler. +/// +/// Each scheduler step, statistics on spec decoding performance are aggregated +/// across requests by the scheduler and returned to the frontend in +/// `EngineCoreOutputs -> SchedulerStats`. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct SpecDecodingStats { + /// Configured speculative token count for this scheduler. + pub num_spec_tokens: u64, + /// Number of drafted speculative decoding attempts. + pub num_drafts: u64, + /// Number of drafted tokens. + pub num_draft_tokens: u64, + /// Number of accepted drafted tokens. + pub num_accepted_tokens: u64, + /// Accepted drafted tokens counted by draft position. + pub num_accepted_tokens_per_pos: Vec, +} + +/// Breakdown of a scheduled prefill computation. +/// +/// Python models this as a plain `@dataclass`, so it is serialized by msgspec +/// as a map (named fields) rather than in the array-like form used by +/// `EngineCoreOutput` itself. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct PrefillStats { + /// Total number of tokens to be prefilled. + #[serde(default)] + pub num_prompt_tokens: u32, + /// Tokens to be prefilled locally (actual compute work). + #[serde(default)] + pub num_computed_tokens: u32, + /// Tokens to be prefilled without actual compute work. + #[serde(default)] + pub num_cached_tokens: u32, + /// Tokens to be prefilled from local prefix cache. + #[serde(default)] + pub num_local_cached_tokens: u32, + /// Tokens to be prefilled from external KV transfer. + #[serde(default)] + pub num_external_cached_tokens: u32, +} + +/// Stats for debugging the metrics calculation. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct DebugPerfStats { + /// Time spent calculating these stats. + pub calc_duration: f64, + /// Number of prefill requests included in the sampled batch. + pub num_prefill_requests: u64, + /// Number of decode requests included in the sampled batch. + pub num_decode_requests: u64, + /// Optional execution-context breakdown used for debugging. + pub context_breakdown: Option>, + /// Optional per-component FLOPs breakdown. + pub num_flops_per_gpu_breakdown: Option>, + /// Optional per-component memory-read breakdown. + pub num_read_bytes_per_gpu_breakdown: Option>, + /// Optional per-component memory-write breakdown. + pub num_write_bytes_per_gpu_breakdown: Option>, +} + +/// Original Python definition: +/// +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct PerfStats { + /// Estimated floating point operations per GPU. + pub num_flops_per_gpu: u64, + /// Estimated bytes read from memory per GPU. + pub num_read_bytes_per_gpu: u64, + /// Estimated bytes written to memory per GPU. + pub num_write_bytes_per_gpu: u64, + /// Optional debug-only perf derivation details. + pub debug_stats: Option, +} + +/// Original Python definition: +/// +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct CudagraphStat { + /// Number of real tokens in the captured batch before padding. + pub num_unpadded_tokens: u64, + /// Number of padded tokens in the captured batch. + pub num_padded_tokens: u64, + /// Number of padding positions added for capture/runtime shape alignment. + pub num_paddings: u64, + /// Runtime mode string associated with this CUDA graph sample. + pub runtime_mode: String, +} + +/// Stats associated with the scheduler. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct SchedulerStats { + /// Number of requests in model execution batches. + pub num_running_reqs: u64, + /// Length of the "waiting" request queue. + pub num_waiting_reqs: u64, + /// Length of the "skipped waiting" queue. + #[serde(default)] + pub num_skipped_waiting_reqs: u64, + /// Internal DP load-balancing step counter. + pub step_counter: u64, + /// Internal DP load-balancing wave number. + pub current_wave: u64, + /// KV-cache usage. `1.0` means 100% usage. + pub kv_cache_usage: f64, + /// Local prefix cache statistics. + pub prefix_cache_stats: PrefixCacheStats, + /// External connector prefix cache statistics, when configured. + pub connector_prefix_cache_stats: Option, + /// Sampled KV cache eviction events for residency metrics. + pub kv_cache_eviction_events: Vec, + /// Speculative decoding scheduler stats, when enabled. + pub spec_decoding_stats: Option, + /// Connector-specific KV transfer stats, kept opaque for now. + pub kv_connector_stats: Option>, + /// Waiting request counts per LoRA adapter. + pub waiting_lora_adapters: BTreeMap, + /// Running request counts per LoRA adapter. + pub running_lora_adapters: BTreeMap, + /// CUDA graph runtime stats when graph metrics are enabled. + pub cudagraph_stats: Option, + /// Estimated MFU/performance stats, when enabled. + pub perf_stats: Option, +} diff --git a/rust/src/engine-core-client/src/protocol/tensor.rs b/rust/src/engine-core-client/src/protocol/tensor.rs new file mode 100644 index 00000000000..b6711215481 --- /dev/null +++ b/rust/src/engine-core-client/src/protocol/tensor.rs @@ -0,0 +1,255 @@ +use bytemuck::allocation::pod_collect_to_vec; +use enum_as_inner::EnumAsInner; +use half::{bf16, f16}; +use rmpv::Value; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_tuple::{Deserialize_tuple, Serialize_tuple}; + +/// Tensors and ndarrays are encoded with this extension type in Python. +/// +/// Original Python definition: +/// +const CUSTOM_TYPE_RAW_VIEW: i8 = 3; + +#[easy_ext::ext(ShapeExt)] +impl [usize] { + /// Returned the total number of elements implied by this shape, or `None` + /// if the product of the dimensions overflows `usize`. + pub fn checked_numel(&self) -> Option { + self.iter().try_fold(1usize, |acc, dim| acc.checked_mul(*dim)) + } +} + +/// Python ndarray/tensor wire tuple encoded as `(dtype, shape, data)`. +/// +/// This matches the custom msgpack representation built by Python +/// `serial_utils.encode_ndarray` / `encode_tensor`. +/// +/// Original Python wire encoders: +/// +/// +#[derive(Debug, Clone, PartialEq, Serialize_tuple, Deserialize_tuple)] +pub struct WireNdArray { + pub dtype: String, + pub shape: Vec, + pub data: WireArrayData, +} + +impl WireNdArray { + /// Build a float32 tensor/ndarray backed by native-endian raw-view bytes. + pub fn from_f32(shape: Vec, data: Vec) -> Result { + validate_element_count(&shape, data.len())?; + Ok(Self { + dtype: "float32".to_string(), + shape, + data: WireArrayData::RawView(pod_collect_to_vec::(&data)), + }) + } + + /// Build a float16 tensor/ndarray backed by native-endian raw-view bytes. + pub fn from_f16(shape: Vec, data: Vec) -> Result { + validate_element_count(&shape, data.len())?; + Ok(Self { + dtype: "float16".to_string(), + shape, + data: WireArrayData::RawView(pod_collect_to_vec::(&data)), + }) + } + + /// Build a bfloat16 tensor/ndarray backed by native-endian raw-view bytes. + pub fn from_bf16(shape: Vec, data: Vec) -> Result { + validate_element_count(&shape, data.len())?; + Ok(Self { + dtype: "bfloat16".to_string(), + shape, + data: WireArrayData::RawView(pod_collect_to_vec::(&data)), + }) + } + + /// Build an int64 tensor/ndarray backed by native-endian raw-view bytes. + pub fn from_i64(shape: Vec, data: Vec) -> Result { + validate_element_count(&shape, data.len())?; + Ok(Self { + dtype: "int64".to_string(), + shape, + data: WireArrayData::RawView(pod_collect_to_vec::(&data)), + }) + } + + /// Build a uint32 tensor/ndarray backed by native-endian raw-view bytes. + pub fn from_u32(shape: Vec, data: Vec) -> Result { + validate_element_count(&shape, data.len())?; + Ok(Self { + dtype: "uint32".to_string(), + shape, + data: WireArrayData::RawView(pod_collect_to_vec::(&data)), + }) + } + + /// Build a bool tensor/ndarray backed by raw-view bytes. + /// + /// This matches `torch.bool` storage: one byte per element, not a packed + /// bitmap. Values are canonicalized as `false -> 0` and `true -> 1`. + pub fn from_bool(shape: Vec, data: Vec) -> Result { + validate_element_count(&shape, data.len())?; + Ok(Self { + dtype: "bool".to_string(), + shape, + data: WireArrayData::RawView(data.iter().map(|value| u8::from(*value)).collect()), + }) + } + + /// Build a tensor/ndarray from already-encoded raw-view bytes. + /// + /// Use this as an escape hatch when the caller already owns bytes that + /// match the requested `dtype` and `shape`. + pub fn from_raw(dtype: impl Into, shape: Vec, data: Vec) -> Self { + Self { + dtype: dtype.into(), + shape, + data: WireArrayData::RawView(data), + } + } +} + +/// Validate that the number of elements implied by the shape matches the length +/// of the data. +fn validate_element_count(shape: &[usize], len: usize) -> Result<(), String> { + let expected = shape + .checked_numel() + .ok_or_else(|| format!("tensor shape product overflows usize: {shape:?}"))?; + if expected == len { + Ok(()) + } else { + Err(format!( + "tensor data length {len} does not match shape {shape:?} product {expected}" + )) + } +} + +/// Python tensor wire tuple encoded as `(dtype, shape, data)`. +/// +/// This is the same wire shape as [`WireNdArray`]; multimodal request payloads +/// use it for `torch.Tensor` values. +pub type WireTensor = WireNdArray; + +/// Python array/tensor payload reference inside [`WireNdArray`]. +/// +/// The data can be either an inline msgpack raw-view extension or an index into +/// the multipart aux-frame list carried alongside the primary msgpack frame. +/// +/// Original Python wire encoders: +/// +/// +#[derive(Debug, Clone, PartialEq, EnumAsInner)] +pub enum WireArrayData { + /// The index of the aux frame where the raw bytes of this array/tensor are + /// stored. + AuxIndex(usize), + /// The raw bytes of this array/tensor. + RawView(Vec), +} + +impl<'de> Deserialize<'de> for WireArrayData { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + match value { + Value::Ext(tag, bytes) if tag == CUSTOM_TYPE_RAW_VIEW => Ok(Self::RawView(bytes)), + Value::Ext(tag, _) => Err(serde::de::Error::custom(format!( + "unsupported extension type code {tag}" + ))), + Value::Integer(index) => { + index.as_u64().map(|index| Self::AuxIndex(index as usize)).ok_or_else(|| { + serde::de::Error::custom("aux frame index must be a non-negative integer") + }) + } + other => Err(serde::de::Error::custom(format!( + "expected raw-view ext or aux frame index, got {other:?}" + ))), + } + } +} + +impl Serialize for WireArrayData { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + // TODO: outbound request serialization currently only supports inline + // raw-view bytes. Emitting aux frames needs transport-level plumbing; + // serializing `AuxIndex` here only preserves an already-built reference. + match self { + Self::AuxIndex(index) => serializer.serialize_u64(*index as u64), + Self::RawView(bytes) => { + Value::Ext(CUSTOM_TYPE_RAW_VIEW, bytes.clone()).serialize(serializer) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn constructors_build_raw_view_tensors() { + let f32_tensor = WireNdArray::from_f32(vec![2], vec![1.0, 2.5]).unwrap(); + assert_eq!(f32_tensor.dtype, "float32"); + assert_eq!(f32_tensor.shape, vec![2]); + assert_eq!( + f32_tensor.data.into_raw_view().expect("raw view"), + [1.0_f32, 2.5].into_iter().flat_map(f32::to_ne_bytes).collect::>() + ); + + let f16_tensor = + WireNdArray::from_f16(vec![2], vec![f16::from_f32(1.0), f16::from_f32(2.5)]).unwrap(); + assert_eq!(f16_tensor.dtype, "float16"); + assert_eq!(f16_tensor.shape, vec![2]); + assert_eq!(f16_tensor.data.into_raw_view().expect("raw view").len(), 4); + + let bf16_tensor = + WireNdArray::from_bf16(vec![2], vec![bf16::from_f32(1.0), bf16::from_f32(2.5)]) + .unwrap(); + assert_eq!(bf16_tensor.dtype, "bfloat16"); + assert_eq!(bf16_tensor.shape, vec![2]); + assert_eq!(bf16_tensor.data.into_raw_view().expect("raw view").len(), 4); + + let i64_tensor = WireNdArray::from_i64(vec![1], vec![-7]).unwrap(); + assert_eq!(i64_tensor.dtype, "int64"); + assert_eq!( + i64_tensor.data.into_raw_view().expect("raw view"), + (-7_i64).to_ne_bytes() + ); + + let u32_tensor = WireNdArray::from_u32(vec![1], vec![42]).unwrap(); + assert_eq!(u32_tensor.dtype, "uint32"); + assert_eq!( + u32_tensor.data.into_raw_view().expect("raw view"), + 42_u32.to_ne_bytes() + ); + + let bool_tensor = WireNdArray::from_bool(vec![2], vec![false, true]).unwrap(); + assert_eq!(bool_tensor.dtype, "bool"); + assert_eq!( + bool_tensor.data.into_raw_view().expect("raw view"), + vec![0, 1] + ); + + let raw_tensor = WireNdArray::from_raw("custom", vec![3], vec![1, 2, 3]); + assert_eq!(raw_tensor.dtype, "custom"); + assert_eq!(raw_tensor.shape, vec![3]); + assert_eq!( + raw_tensor.data.into_raw_view().expect("raw view"), + vec![1, 2, 3] + ); + } + + #[test] + fn constructors_validate_shape_product() { + let err = WireNdArray::from_f32(vec![2, 2], vec![1.0, 2.0]).unwrap_err(); + assert!(err.contains("does not match shape")); + } +} diff --git a/rust/src/engine-core-client/src/test_utils.rs b/rust/src/engine-core-client/src/test_utils.rs new file mode 100644 index 00000000000..99c266dc626 --- /dev/null +++ b/rust/src/engine-core-client/src/test_utils.rs @@ -0,0 +1,295 @@ +use std::future::Future; +use std::path::Path; +use std::pin::Pin; +use std::time::Duration; + +use tempfile::TempDir; +use tokio::sync::oneshot; +use zeromq::prelude::{Socket, SocketRecv, SocketSend}; +use zeromq::util::PeerIdentity; +use zeromq::{DealerSocket, PushSocket, SocketOptions, SubSocket, ZmqMessage}; + +use crate::EngineId; +use crate::protocol::ModelDtype; +use crate::protocol::handshake::{EngineCoreReadyResponse, HandshakeInitMessage, ReadyMessage}; + +/// Per-test IPC endpoint namespace backed by a unique temporary directory. +/// +/// Using one directory per test avoids endpoint collisions without requiring +/// ad-hoc unique-name generation at each call site. +#[derive(Debug)] +pub struct IpcNamespace { + dir: TempDir, +} + +impl IpcNamespace { + /// Create a fresh namespace for one test case. + pub fn new() -> std::io::Result { + Ok(Self { + dir: TempDir::new()?, + }) + } + + /// Build one `ipc://...` endpoint under this namespace. + pub fn endpoint(&self, name: impl AsRef) -> String { + let path = self.dir.path().join(name); + format!("ipc://{}", path.to_string_lossy()) + } + + /// Endpoint used for the initial READY/HELLO handshake. + pub fn handshake_endpoint(&self) -> String { + self.endpoint("handshake.sock") + } + + /// Endpoint used for engine-core request traffic. + pub fn input_endpoint(&self) -> String { + self.endpoint("input.sock") + } + + /// Endpoint used for engine-core output traffic. + pub fn output_endpoint(&self) -> String { + self.endpoint("output.sock") + } +} + +/// Construct a standard local READY message used by mock engines in tests. +fn ready_message(status: &str) -> ReadyMessage { + ReadyMessage { + status: Some(status.to_string()), + local: Some(true), + headless: Some(true), + parallel_config_hash: None, + } +} + +/// Construct a default ready response payload for mock engine input +/// registration. +fn ready_response_payload() -> Vec { + rmp_serde::to_vec_named(&EngineCoreReadyResponse { + max_model_len: 4096, + num_gpu_blocks: 0, + dp_stats_address: None, + dtype: Some(ModelDtype::Float32), + }) + .expect("encode ready response payload") +} + +/// Coordinator-side sockets connected by one mock engine when coordinator mode +/// is enabled. +pub struct MockCoordinatorConnections { + /// Subscription socket that receives coordinator broadcasts such as + /// `START_DP_WAVE`. + pub input_sub: SubSocket, + /// Push socket used to send coordinator-only `EngineCoreOutputs` back to + /// the frontend. + pub output_push: PushSocket, +} + +/// Fully connected mock engine transport state used by tests. +pub struct MockEngineConnections { + /// Decoded INIT message sent by the frontend during handshake. + pub init: HandshakeInitMessage, + /// Socket used to receive frontend requests. + pub dealer: DealerSocket, + /// Socket used to publish normal request outputs back to the frontend. + pub push: PushSocket, + /// Optional coordinator sockets when the client enabled the in-process + /// coordinator. + pub coordinator: Option, +} + +/// Complete the engine-core handshake and connect mock input/output sockets +/// plus optional coordinator sockets. +pub async fn setup_mock_engine_connections( + engine_handshake: String, + engine_id: impl Into, +) -> MockEngineConnections { + // Wait for the client to bind the handshake socket before connecting. + // A fixed sleep is racy under CI load; instead poll for the socket file. + let socket_path = engine_handshake + .strip_prefix("ipc://") + .expect("handshake address must be ipc://"); + for _ in 0..100 { + if Path::new(socket_path).exists() { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + + let peer_identity = PeerIdentity::try_from(engine_id.into()).expect("peer id"); + + let mut options = SocketOptions::default(); + options.peer_identity(peer_identity.clone()); + let mut handshake = DealerSocket::with_options(options); + handshake + .connect(&engine_handshake) + .await + .expect("connect mock engine handshake socket"); + handshake + .send(ZmqMessage::from( + rmp_serde::to_vec_named(&ready_message("HELLO")).expect("encode HELLO ready message"), + )) + .await + .expect("send HELLO ready message"); + + let init_frames = handshake.recv().await.expect("receive handshake init message").into_vec(); + assert_eq!(init_frames.len(), 1); + let init: HandshakeInitMessage = + rmp_serde::from_slice(init_frames[0].as_ref()).expect("decode handshake init message"); + + let mut input_options = SocketOptions::default(); + input_options.peer_identity(peer_identity); + let mut dealer = DealerSocket::with_options(input_options); + dealer + .connect(&init.addresses.inputs[0]) + .await + .expect("connect mock engine input socket"); + dealer + .send(ZmqMessage::from(ready_response_payload())) + .await + .expect("send mock engine input ready frame"); + + let mut push = PushSocket::new(); + push.connect(&init.addresses.outputs[0]) + .await + .expect("connect mock engine output socket"); + + let coordinator = match ( + init.addresses.coordinator_input.as_deref(), + init.addresses.coordinator_output.as_deref(), + ) { + (Some(coordinator_input), Some(coordinator_output)) => { + let mut input_sub = SubSocket::new(); + input_sub + .connect(coordinator_input) + .await + .expect("connect mock engine coordinator input socket"); + input_sub + .subscribe("") + .await + .expect("subscribe mock engine coordinator input socket"); + + let mut output_push = PushSocket::new(); + output_push + .connect(coordinator_output) + .await + .expect("connect mock engine coordinator output socket"); + + let ready = + input_sub.recv().await.expect("receive coordinator READY marker").into_vec(); + assert_eq!(ready.len(), 1); + assert_eq!(ready[0].as_ref(), b"READY"); + + Some(MockCoordinatorConnections { + input_sub, + output_push, + }) + } + (None, None) => None, + _ => panic!("coordinator handshake addresses must be both present or both absent"), + }; + + handshake + .send(ZmqMessage::from( + rmp_serde::to_vec_named(&ready_message("READY")).expect("encode READY ready message"), + )) + .await + .expect("send READY ready message"); + + MockEngineConnections { + init, + dealer, + push, + coordinator, + } +} + +/// Connect one mock engine directly to already-bootstrapped frontend +/// input/output sockets. +pub async fn setup_bootstrapped_mock_engine( + input_address: String, + output_address: String, + engine_id: impl Into, +) -> (DealerSocket, PushSocket) { + for endpoint in [&input_address, &output_address] { + if let Some(socket_path) = endpoint.strip_prefix("ipc://") { + for _ in 0..100 { + if Path::new(socket_path).exists() { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + } + } + + let peer_identity = PeerIdentity::try_from(engine_id.into()).expect("peer id"); + let mut input_options = SocketOptions::default(); + input_options.peer_identity(peer_identity); + let mut dealer = DealerSocket::with_options(input_options); + dealer.connect(&input_address).await.expect("connect mock engine input socket"); + dealer + .send(ZmqMessage::from(ready_response_payload())) + .await + .expect("send mock engine input ready frame"); + + let mut push = PushSocket::new(); + push.connect(&output_address).await.expect("connect mock engine output socket"); + + (dealer, push) +} + +/// Complete the engine-core handshake and connect mock input/output sockets. +/// +/// This returns the decoded handshake init message plus the `DealerSocket` used +/// to receive client requests and the `PushSocket` used to send engine outputs +/// back to the client. +pub async fn setup_mock_engine_with_init( + engine_handshake: String, + engine_id: impl Into, +) -> (HandshakeInitMessage, DealerSocket, PushSocket) { + let MockEngineConnections { + init, dealer, push, .. + } = setup_mock_engine_connections(engine_handshake, engine_id).await; + (init, dealer, push) +} + +/// Complete the engine-core handshake and connect mock input/output sockets. +/// +/// This returns the `DealerSocket` used to receive client requests and the +/// `PushSocket` used to send engine outputs back to the client. +pub async fn setup_mock_engine( + engine_handshake: String, + engine_id: impl Into, +) -> (DealerSocket, PushSocket) { + let (_, dealer, push) = setup_mock_engine_with_init(engine_handshake, engine_id).await; + (dealer, push) +} + +/// Spawn a mock engine task and keep its sockets alive until the returned +/// shutdown sender is triggered by the test. +/// +/// The script borrows the connected sockets mutably while it runs. After the +/// script completes, this helper keeps the sockets alive until the test +/// explicitly signals shutdown. +pub fn spawn_mock_engine_task( + engine_handshake: String, + engine_id: impl Into, + run: F, +) -> (oneshot::Sender<()>, tokio::task::JoinHandle<()>) +where + F: for<'a> FnOnce( + &'a mut DealerSocket, + &'a mut PushSocket, + ) -> Pin + Send + 'a>> + + Send + + 'static, +{ + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let engine_id = engine_id.into(); + let engine_task = tokio::spawn(async move { + let (mut dealer, mut push) = setup_mock_engine(engine_handshake, engine_id).await; + run(&mut dealer, &mut push).await; + let _ = shutdown_rx.await; + }); + (shutdown_tx, engine_task) +} diff --git a/rust/src/engine-core-client/src/tests/client.rs b/rust/src/engine-core-client/src/tests/client.rs new file mode 100644 index 00000000000..5ad307e51de --- /dev/null +++ b/rust/src/engine-core-client/src/tests/client.rs @@ -0,0 +1,2559 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::convert::TryFrom; +use std::io::Cursor; +use std::path::PathBuf; +use std::process::Command; +use std::sync::Once; +use std::time::Duration; + +use futures::StreamExt; +use rmpv::Value; +use thiserror_ext::AsReport as _; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::timeout; +use tracing_subscriber::EnvFilter; +use zeromq::prelude::{Socket, SocketRecv, SocketSend}; +use zeromq::util::PeerIdentity; +use zeromq::{DealerSocket, PushSocket, SocketOptions, SubSocket, XPubSocket, ZmqMessage}; + +use crate::protocol::handshake::{HandshakeInitMessage, ReadyMessage}; +use crate::protocol::logprobs::MaybeWireLogprobs; +use crate::protocol::multimodal::{ + MmFeatureSpec, MmField, MmFieldElem, MmFlatField, MmKwargValue, MmSlice, PlaceholderRange, + SliceSpec, +}; +use crate::protocol::stats::SchedulerStats; +use crate::protocol::tensor::WireTensor; +use crate::protocol::{ + EngineCoreFinishReason, EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, + EngineCoreRequestType, EngineCoreSamplingParams, UtilityOutput, UtilityResultEnvelope, + decode_engine_core_outputs, +}; +use crate::test_utils::{ + IpcNamespace, setup_bootstrapped_mock_engine, setup_mock_engine_connections, + setup_mock_engine_with_init, spawn_mock_engine_task, +}; +use crate::{ + CoordinatorMode, ENGINE_CORE_DEAD_SENTINEL, EngineCoreClient, EngineCoreClientConfig, EngineId, + Error, TransportMode, +}; + +static TRACING: Once = Once::new(); + +fn expect_sample_logprobs(actual: &MaybeWireLogprobs) { + expect_test::expect![[r#" + Logprobs { + positions: [ + PositionLogprobs { + entries: [ + TokenLogprob { + token_id: 1, + logprob: 1.0, + rank: 1, + }, + TokenLogprob { + token_id: 2, + logprob: 2.0, + rank: 1, + }, + TokenLogprob { + token_id: 3, + logprob: 3.0, + rank: 2, + }, + ], + }, + PositionLogprobs { + entries: [ + TokenLogprob { + token_id: 4, + logprob: 4.0, + rank: 2, + }, + TokenLogprob { + token_id: 5, + logprob: 5.0, + rank: 1, + }, + TokenLogprob { + token_id: 6, + logprob: 6.0, + rank: 2, + }, + ], + }, + ], + } + "#]] + .assert_debug_eq(actual.as_direct().expect("logprobs resolved")); +} + +fn expect_prompt_logprobs(actual: &MaybeWireLogprobs) { + expect_test::expect![[r#" + Logprobs { + positions: [ + PositionLogprobs { + entries: [ + TokenLogprob { + token_id: 10, + logprob: 10.0, + rank: 3, + }, + TokenLogprob { + token_id: 11, + logprob: 11.0, + rank: 1, + }, + TokenLogprob { + token_id: 12, + logprob: 12.0, + rank: 2, + }, + ], + }, + PositionLogprobs { + entries: [ + TokenLogprob { + token_id: 13, + logprob: 13.0, + rank: 4, + }, + TokenLogprob { + token_id: 14, + logprob: 14.0, + rank: 1, + }, + TokenLogprob { + token_id: 15, + logprob: 15.0, + rank: 2, + }, + ], + }, + ], + } + "#]] + .assert_debug_eq(actual.as_direct().expect("prompt logprobs resolved")); +} + +fn sample_request() -> EngineCoreRequest { + sample_request_with_id("req-1") +} + +fn sample_request_with_id(request_id: &str) -> EngineCoreRequest { + EngineCoreRequest { + request_id: request_id.to_string(), + prompt_token_ids: Some(vec![11, 22]), + sampling_params: Some(EngineCoreSamplingParams { + temperature: 0.8, + top_p: 0.9, + top_k: 8, + max_tokens: 32, + min_tokens: 1, + stop_token_ids: vec![151643], + eos_token_id: Some(151645), + all_stop_token_ids: BTreeSet::from([151643, 151645]), + ..EngineCoreSamplingParams::for_test() + }), + arrival_time: 42.5, + ..EngineCoreRequest::default() + } +} + +fn sample_multimodal_request() -> EngineCoreRequest { + EngineCoreRequest { + request_id: "req-mm".to_string(), + prompt_token_ids: Some(vec![101, 102, 103, 104]), + mm_features: Some(vec![MmFeatureSpec { + data: Some(BTreeMap::from([( + "pixel_values".to_string(), + MmFieldElem { + data: Some(MmKwargValue::Tensor( + WireTensor::from_f32(vec![2, 2], vec![1.0, 2.0, 3.5, 4.25]) + .expect("valid tensor shape"), + )), + field: MmField::Flat(MmFlatField { + slices: vec![MmSlice::Slice(SliceSpec { + start: Some(0), + stop: Some(2), + step: None, + })], + dim: 0, + keep_on_cpu: false, + }), + }, + )])), + modality: "image".to_string(), + identifier: "mm-cache-key".to_string(), + mm_position: PlaceholderRange { + offset: 1, + length: 2, + is_embed: None, + }, + mm_hash: Some("processor-hash".to_string()), + }]), + sampling_params: None, + pooling_params: None, + arrival_time: 43.5, + ..EngineCoreRequest::default() + } +} + +fn ready_message(status: &str) -> ReadyMessage { + ReadyMessage { + status: Some(status.to_string()), + local: Some(true), + headless: Some(true), + parallel_config_hash: None, + } +} + +fn request_output( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason, + stop_reason: None, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +async fn send_outputs(push: &mut PushSocket, outputs: EngineCoreOutputs) { + push.send(ZmqMessage::from(rmp_serde::to_vec_named(&outputs).unwrap())) + .await + .unwrap(); +} + +async fn send_output_frames(push: &mut PushSocket, frames: Vec) { + push.send(ZmqMessage::try_from(frames).unwrap()).await.unwrap(); +} + +async fn recv_engine_message(dealer: &mut DealerSocket) -> Vec { + dealer.recv().await.unwrap().into_vec() +} + +async fn recv_start_dp_wave(sub: &mut SubSocket) -> (u32, u32) { + let frames = sub.recv().await.unwrap().into_vec(); + assert_eq!(frames.len(), 2); + assert_eq!( + frames[0].as_ref(), + EngineCoreRequestType::StartDpWave.to_frame().as_ref() + ); + rmp_serde::from_slice(&frames[1]).expect("decode START_DP_WAVE payload") +} + +async fn connect_client_with_ipc( + config: EngineCoreClientConfig, + ipc: &IpcNamespace, +) -> EngineCoreClient { + EngineCoreClient::connect( + config.with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .unwrap() +} + +fn handshake_test_config( + handshake_address: String, + engine_count: usize, + model_name: &str, + ready_timeout: Duration, + client_index: u32, + coordinator_mode: Option, +) -> EngineCoreClientConfig { + EngineCoreClientConfig { + transport_mode: TransportMode::HandshakeOwner { + handshake_address, + advertised_host: "127.0.0.1".to_string(), + engine_count, + ready_timeout, + local_input_address: None, + local_output_address: None, + }, + coordinator_mode, + model_name: model_name.to_string(), + client_index, + } +} + +fn bootstrapped_test_config( + input_address: String, + output_address: String, + engine_count: usize, + ready_timeout: Duration, + client_index: u32, + coordinator_mode: Option, +) -> EngineCoreClientConfig { + EngineCoreClientConfig { + transport_mode: TransportMode::Bootstrapped { + input_address, + output_address, + engine_count, + ready_timeout, + }, + coordinator_mode, + model_name: "test-model".to_string(), + client_index, + } +} + +async fn recv_xpub_message(xpub: &mut XPubSocket) -> Vec { + xpub.recv().await.unwrap().into_vec() +} + +async fn recv_xpub_subscription(xpub: &mut XPubSocket) { + let frames = recv_xpub_message(xpub).await; + assert_eq!(frames.len(), 1); + assert_eq!(frames[0].as_ref(), b"\x01"); +} + +async fn recv_external_coordinator_wakeup(xpub: &mut XPubSocket) -> (u32, u32) { + let frames = recv_xpub_message(xpub).await; + assert_eq!(frames.len(), 1); + rmp_serde::from_slice(&frames[0]).expect("decode external coordinator wakeup") +} + +async fn send_external_coordinator_publish( + xpub: &mut XPubSocket, + payload: &T, +) { + xpub.send(ZmqMessage::from(rmp_serde::to_vec_named(payload).unwrap())) + .await + .unwrap(); +} + +fn spawn_mock_engine_task_with_init( + engine_handshake: String, + engine_id: impl Into, + run: F, +) -> ( + oneshot::Receiver, + oneshot::Sender<()>, + tokio::task::JoinHandle<()>, +) +where + F: for<'a> FnOnce( + &'a mut DealerSocket, + &'a mut PushSocket, + ) + -> std::pin::Pin + Send + 'a>> + + Send + + 'static, +{ + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let (init_tx, init_rx) = oneshot::channel(); + let engine_id = engine_id.into(); + let engine_task = tokio::spawn(async move { + let (init, mut dealer, mut push) = + setup_mock_engine_with_init(engine_handshake, engine_id).await; + let _ = init_tx.send(init); + run(&mut dealer, &mut push).await; + let _ = shutdown_rx.await; + }); + (init_rx, shutdown_tx, engine_task) +} + +fn init_tracing() { + TRACING.call_once(|| { + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("vllm_engine_core_client=debug")); + let _ = tracing_subscriber::fmt().with_test_writer().with_env_filter(filter).try_init(); + }); +} + +fn is_dispatcher_closed(error: &Error) -> bool { + match error { + Error::DispatcherClosed { .. } => true, + Error::Shared(error) => is_dispatcher_closed(error), + _ => false, + } +} + +fn is_engine_core_dead(error: &Error) -> bool { + match error { + Error::EngineCoreDead => true, + Error::Shared(error) => is_engine_core_dead(error), + _ => false, + } +} + +fn is_decode_error(error: &Error) -> bool { + match error { + Error::Decode { .. } | Error::ExtValueDecode { .. } => true, + Error::Shared(error) => is_decode_error(error), + _ => false, + } +} + +fn is_unexpected_dispatcher_output(error: &Error) -> bool { + match error { + Error::UnexpectedDispatcherOutput { .. } => true, + Error::Shared(error) => is_unexpected_dispatcher_output(error), + _ => false, + } +} + +fn decode_value(bytes: &[u8]) -> Value { + rmpv::decode::read_value(&mut Cursor::new(bytes)).unwrap() +} + +fn encode_value(value: &Value) -> Vec { + let mut out = Vec::new(); + rmpv::encode::write_value(&mut out, value).unwrap(); + out +} + +fn ndarray_value(dtype: &str, shape: &[usize], data: Value) -> Value { + Value::Array(vec![ + Value::from(dtype), + Value::Array(shape.iter().copied().map(Value::from).collect()), + data, + ]) +} + +fn multipart_logprob_output_frames(request_id: &str) -> Vec { + let main = Value::Array(vec![ + Value::from(0), + Value::Array(vec![Value::Array(vec![ + Value::from(request_id), + Value::Array(vec![Value::from(7), Value::from(8)]), + Value::Array(vec![ + ndarray_value("(value: T) -> UtilityResultEnvelope +where + T: serde::Serialize, +{ + UtilityResultEnvelope::without_type_info(rmpv::ext::to_value(value).unwrap()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn coordinator_handshake_includes_engine_control_addresses() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = [0x00, 0x00]; + + let (init_tx, init_rx) = oneshot::channel(); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let engine_task = tokio::spawn(async move { + let connections = setup_mock_engine_connections(handshake_address, &engine_id).await; + let _ = init_tx.send(connections.init.clone()); + let _ = shutdown_rx.await; + }); + + let client = connect_client_with_ipc( + handshake_test_config( + ipc.handshake_endpoint(), + 1, + "test-model", + Duration::from_secs(2), + 0, + Some(CoordinatorMode::InProc), + ), + &ipc, + ) + .await; + + let init = init_rx.await.unwrap(); + assert!(init.addresses.coordinator_input.is_some()); + assert!(init.addresses.coordinator_output.is_some()); + assert!(init.addresses.frontend_stats_publish_address.is_none()); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn coordinator_wave_control_tracks_pause_running_and_rebroadcasts() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + + let (shutdown0_tx, shutdown0_rx) = oneshot::channel(); + let engine0_task = tokio::spawn({ + let handshake_address = handshake_address.clone(); + async move { + let mut engine = setup_mock_engine_connections(handshake_address, &[0x00, 0x00]).await; + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); + + let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; + assert_eq!((wave, exclude_engine), (0, 0)); + + let add = recv_engine_message(&mut engine.dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.request_id, "req-1"); + assert_eq!(request.current_wave, 0); + + assert!( + timeout( + Duration::from_millis(200), + recv_start_dp_wave(&mut coordinator.input_sub) + ) + .await + .is_err() + ); + + send_outputs( + &mut engine.push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output( + "req-1", + vec![], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from(["req-1".to_string()])), + ..Default::default() + }, + ) + .await; + + send_outputs( + &mut coordinator.output_push, + EngineCoreOutputs { + engine_index: 0, + wave_complete: Some(0), + ..Default::default() + }, + ) + .await; + + let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; + assert_eq!((wave, exclude_engine), (1, 0)); + + let add = recv_engine_message(&mut engine.dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.request_id, "req-3"); + assert_eq!(request.current_wave, 1); + + send_outputs( + &mut engine.push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output( + "req-3", + vec![], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from(["req-3".to_string()])), + ..Default::default() + }, + ) + .await; + + let _ = shutdown0_rx.await; + } + }); + + let (shutdown1_tx, shutdown1_rx) = oneshot::channel(); + let engine1_task = tokio::spawn({ + let handshake_address = handshake_address.clone(); + async move { + let mut engine = setup_mock_engine_connections(handshake_address, &[0x01, 0x00]).await; + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); + + let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; + assert_eq!((wave, exclude_engine), (0, 0)); + + let add = recv_engine_message(&mut engine.dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.request_id, "req-2"); + assert_eq!(request.current_wave, 0); + + assert!( + timeout( + Duration::from_millis(200), + recv_start_dp_wave(&mut coordinator.input_sub) + ) + .await + .is_err() + ); + + send_outputs( + &mut engine.push, + EngineCoreOutputs { + engine_index: 1, + outputs: vec![request_output( + "req-2", + vec![], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from(["req-2".to_string()])), + ..Default::default() + }, + ) + .await; + + let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; + assert_eq!((wave, exclude_engine), (1, 0)); + + assert!( + timeout( + Duration::from_millis(200), + recv_engine_message(&mut engine.dealer) + ) + .await + .is_err() + ); + + let _ = shutdown1_rx.await; + } + }); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 2, + "test-model", + Duration::from_secs(2), + 0, + Some(CoordinatorMode::InProc), + ), + &ipc, + ) + .await; + + let mut stream_1 = client.call(sample_request_with_id("req-1")).await.unwrap(); + let mut stream_2 = client.call(sample_request_with_id("req-2")).await.unwrap(); + + let final_1 = timeout(Duration::from_secs(1), stream_1.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(final_1.request_id, "req-1"); + assert_eq!(final_1.finish_reason, Some(EngineCoreFinishReason::Length)); + assert!(timeout(Duration::from_secs(1), stream_1.next()).await.unwrap().is_none()); + + let final_2 = timeout(Duration::from_secs(1), stream_2.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(final_2.request_id, "req-2"); + assert_eq!(final_2.finish_reason, Some(EngineCoreFinishReason::Length)); + assert!(timeout(Duration::from_secs(1), stream_2.next()).await.unwrap().is_none()); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let mut stream_3 = client.call(sample_request_with_id("req-3")).await.unwrap(); + let final_3 = timeout(Duration::from_secs(1), stream_3.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(final_3.request_id, "req-3"); + assert_eq!(final_3.finish_reason, Some(EngineCoreFinishReason::Length)); + assert!(timeout(Duration::from_secs(1), stream_3.next()).await.unwrap().is_none()); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let _ = shutdown0_tx.send(()); + let _ = shutdown1_tx.send(()); + engine0_task.await.unwrap(); + engine1_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn coordinator_rebroadcasts_engine_start_wave_control() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + + let (shutdown0_tx, shutdown0_rx) = oneshot::channel(); + let engine0_task = tokio::spawn({ + let handshake_address = handshake_address.clone(); + async move { + let mut engine = setup_mock_engine_connections(handshake_address, &[0x00, 0x00]).await; + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); + + let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; + assert_eq!((wave, exclude_engine), (4, 1)); + + let _ = shutdown0_rx.await; + } + }); + + let (shutdown1_tx, shutdown1_rx) = oneshot::channel(); + let engine1_task = tokio::spawn({ + let handshake_address = handshake_address.clone(); + async move { + let mut engine = setup_mock_engine_connections(handshake_address, &[0x01, 0x00]).await; + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); + + send_outputs( + &mut coordinator.output_push, + EngineCoreOutputs { + engine_index: 1, + start_wave: Some(4), + ..Default::default() + }, + ) + .await; + + let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; + assert_eq!((wave, exclude_engine), (4, 1)); + + let _ = shutdown1_rx.await; + } + }); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 2, + "test-model", + Duration::from_secs(2), + 0, + Some(CoordinatorMode::InProc), + ), + &ipc, + ) + .await; + + tokio::time::sleep(Duration::from_millis(200)).await; + + let _ = shutdown0_tx.send(()); + let _ = shutdown1_tx.send(()); + engine0_task.await.unwrap(); + engine1_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn coordinator_accepts_stats_only_outputs() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let engine_task = tokio::spawn(async move { + let mut engine = setup_mock_engine_connections(handshake_address, &[0x00, 0x00]).await; + let mut coordinator = + engine.coordinator.take().expect("coordinator sockets should be present"); + + let (wave, exclude_engine) = recv_start_dp_wave(&mut coordinator.input_sub).await; + assert_eq!((wave, exclude_engine), (0, 0)); + + send_outputs( + &mut coordinator.output_push, + EngineCoreOutputs { + engine_index: 0, + scheduler_stats: Some(Box::new(SchedulerStats { + num_running_reqs: 1, + current_wave: 0, + ..Default::default() + })), + ..Default::default() + }, + ) + .await; + + let add = recv_engine_message(&mut engine.dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.request_id, "req-stats"); + + send_outputs( + &mut engine.push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output( + "req-stats", + vec![], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from(["req-stats".to_string()])), + ..Default::default() + }, + ) + .await; + + let _ = shutdown_rx.await; + }); + + let client = connect_client_with_ipc( + handshake_test_config( + ipc.handshake_endpoint(), + 1, + "test-model", + Duration::from_secs(2), + 0, + Some(CoordinatorMode::InProc), + ), + &ipc, + ) + .await; + + let mut stream = client.call(sample_request_with_id("req-stats")).await.unwrap(); + let final_output = + timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap(); + assert_eq!(final_output.request_id, "req-stats"); + assert_eq!( + final_output.finish_reason, + Some(EngineCoreFinishReason::Length) + ); + assert!(client.is_healthy()); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_fail_closes_when_main_output_path_receives_dp_control() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-0".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add_1 = recv_engine_message(dealer).await; + assert_eq!(add_1[0].as_ref(), &[0x00]); + let request_1: EngineCoreRequest = rmp_serde::from_slice(&add_1[1]).unwrap(); + assert_eq!(request_1.client_index, 7); + assert_eq!(request_1.request_id, "req-1"); + + let add_2 = recv_engine_message(dealer).await; + assert_eq!(add_2[0].as_ref(), &[0x00]); + let request_2: EngineCoreRequest = rmp_serde::from_slice(&add_2[1]).unwrap(); + assert_eq!(request_2.client_index, 7); + assert_eq!(request_2.request_id, "req-2"); + + send_outputs( + push, + EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id: 1, + failure_message: None, + result: None, + }), + ..Default::default() + }, + ) + .await; + send_outputs( + push, + EngineCoreOutputs { + start_wave: Some(3), + ..Default::default() + }, + ) + .await; + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output("req-1", vec![999], None)], + ..Default::default() + }, + ) + .await; + + tokio::time::sleep(Duration::from_millis(50)).await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 7, + None, + ), + &ipc, + ) + .await; + assert_eq!(client.engine_identities()[0], b"engine-0"); + assert!(client.ready_responses()[0].max_model_len > 0); + + let mut stream_1 = client.call(sample_request_with_id("req-1")).await.unwrap(); + let mut stream_2 = client.call(sample_request_with_id("req-2")).await.unwrap(); + + let error_2 = timeout(Duration::from_secs(1), stream_2.next()) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(is_unexpected_dispatcher_output(&error_2)); + + let error_1 = timeout(Duration::from_secs(1), stream_1.next()) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(is_unexpected_dispatcher_output(&error_1)); + + assert!(matches!( + client.health_error().as_deref(), + Some(error) if is_unexpected_dispatcher_output(error) + )); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_fail_closes_when_main_output_path_receives_mixed_shape_output() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-0".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add_1 = recv_engine_message(dealer).await; + assert_eq!(add_1[0].as_ref(), &[0x00]); + let request_1: EngineCoreRequest = rmp_serde::from_slice(&add_1[1]).unwrap(); + assert_eq!(request_1.client_index, 7); + assert_eq!(request_1.request_id, "req-1"); + + let add_2 = recv_engine_message(dealer).await; + assert_eq!(add_2[0].as_ref(), &[0x00]); + let request_2: EngineCoreRequest = rmp_serde::from_slice(&add_2[1]).unwrap(); + assert_eq!(request_2.client_index, 7); + assert_eq!(request_2.request_id, "req-2"); + + send_outputs( + push, + EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id: 1, + failure_message: None, + result: None, + }), + outputs: vec![request_output("req-1", vec![999], None)], + ..Default::default() + }, + ) + .await; + + tokio::time::sleep(Duration::from_millis(50)).await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 7, + None, + ), + &ipc, + ) + .await; + assert_eq!(client.engine_identities()[0], b"engine-0"); + assert!(client.ready_responses()[0].max_model_len > 0); + + let mut stream_1 = client.call(sample_request_with_id("req-1")).await.unwrap(); + let mut stream_2 = client.call(sample_request_with_id("req-2")).await.unwrap(); + + let error_2 = timeout(Duration::from_secs(1), stream_2.next()) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(is_unexpected_dispatcher_output(&error_2)); + + let error_1 = timeout(Duration::from_secs(1), stream_1.next()) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(is_unexpected_dispatcher_output(&error_1)); + + assert!(matches!( + client.health_error().as_deref(), + Some(error) if is_unexpected_dispatcher_output(error) + )); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn duplicate_request_ids_are_rejected_without_sending_a_second_add() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-dup".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add_1 = recv_engine_message(dealer).await; + assert_eq!(add_1[0].as_ref(), &[0x00]); + let request_1: EngineCoreRequest = rmp_serde::from_slice(&add_1[1]).unwrap(); + assert_eq!(request_1.request_id, "req-1"); + + assert!(timeout(Duration::from_millis(200), dealer.recv()).await.is_err()); + + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output( + "req-1", + vec![], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from(["req-1".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + let mut stream = client.call(sample_request()).await.unwrap(); + let error = match client.call(sample_request()).await { + Ok(_) => panic!("expected duplicate request error"), + Err(error) => error, + }; + assert!(matches!( + error, + Error::DuplicateRequestId { request_id } if request_id == "req-1" + )); + + let final_output = + timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap(); + assert_eq!( + final_output.finish_reason, + Some(EngineCoreFinishReason::Length) + ); + assert!(timeout(Duration::from_secs(1), stream.next()).await.unwrap().is_none()); + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn finished_requests_without_final_output_is_treated_as_unexpected_close() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-finished-only".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + + send_outputs( + push, + EngineCoreOutputs { + finished_requests: Some(BTreeSet::from(["req-1".to_string()])), + ..Default::default() + }, + ) + .await; + + assert!(timeout(Duration::from_millis(200), dealer.recv()).await.is_err()); + let _ = push; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + let mut stream = client.call(sample_request()).await.unwrap(); + let error = timeout(Duration::from_secs(1), stream.next()) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(matches!( + error, + Error::RequestStreamClosed { request_id } if request_id == "req-1" + )); + assert!(timeout(Duration::from_secs(1), stream.next()).await.unwrap().is_none()); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn dropping_a_live_stream_triggers_abort() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-drop".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output("req-1", vec![99], None)], + ..Default::default() + }, + ) + .await; + + let abort = + timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap(); + assert_eq!(abort[0].as_ref(), &[0x01]); + let aborted_ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap(); + assert_eq!(aborted_ids, vec!["req-1".to_string()]); + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + let mut stream = client.call(sample_request()).await.unwrap(); + let first = timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap(); + assert_eq!(first.new_token_ids, vec![99]); + drop(stream); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn dispatcher_failure_propagates_to_streams_and_future_calls() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-fail".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let _ = recv_engine_message(dealer).await; + let _ = recv_engine_message(dealer).await; + + push.send(ZmqMessage::from(vec![0xc1])).await.unwrap(); + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + let mut stream_1 = client.call(sample_request_with_id("req-1")).await.unwrap(); + let mut stream_2 = client.call(sample_request_with_id("req-2")).await.unwrap(); + + let error_1 = timeout(Duration::from_secs(1), stream_1.next()) + .await + .unwrap() + .unwrap() + .unwrap_err(); + let error_2 = timeout(Duration::from_secs(1), stream_2.next()) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(is_decode_error(&error_1)); + assert!(is_decode_error(&error_2)); + assert!(is_decode_error( + client.health_error().as_deref().expect("health error recorded") + )); + + let abort_error = client.abort(&["req-1".to_string()]).await.unwrap_err(); + assert!(is_decode_error(&abort_error)); + + let add_error = match client.call(sample_request_with_id("req-3")).await { + Ok(_) => panic!("expected dispatcher closed error"), + Err(error) => error, + }; + assert!(is_decode_error(&add_error)); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn is_sleeping_wrapper_sends_typed_request_and_returns_typed_response() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-utility-success".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + + let payload = decode_value(&utility[1]); + let array = match payload { + Value::Array(array) => array, + other => panic!("expected utility payload array, got {other:?}"), + }; + assert_eq!(array.len(), 4); + assert_eq!(array[0], Value::from(5)); + let call_id = array[1].as_i64().expect("call_id"); + assert_eq!(array[2], Value::from("is_sleeping")); + assert_eq!(array[3], Value::Array(Vec::new())); + + send_outputs( + push, + EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id, + failure_message: None, + result: Some(utility_result_value(true)), + }), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 5, + None, + ), + &ipc, + ) + .await; + + let result = client.is_sleeping().await.unwrap(); + assert!(result); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn call_utility_failure_message_surfaces_as_error() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-utility-fail".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + let payload = decode_value(&utility[1]); + let call_id = + payload.as_array().and_then(|array| array[1].as_i64()).expect("call_id"); + + send_outputs( + push, + EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id, + failure_message: Some("boom".to_string()), + result: None, + }), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + let error = client.call_utility::("is_sleeping", ()).await.unwrap_err(); + assert!(matches!( + error, + Error::UtilityCallFailed { + method, + message, + .. + } if method == "is_sleeping" && message == "boom" + )); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn dispatcher_failure_propagates_to_waiting_utility_calls() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-utility-dispatcher-fail".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + + push.send(ZmqMessage::from(vec![0xc1])).await.unwrap(); + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + let error = client.call_utility::("is_sleeping", ()).await.unwrap_err(); + assert!(is_decode_error(&error)); + assert!(is_decode_error( + client.health_error().as_deref().expect("health error recorded") + )); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn connect_times_out_without_ready_message() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_handshake = handshake_address.clone(); + let engine_task = tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(50)).await; + + let mut options = SocketOptions::default(); + options.peer_identity(PeerIdentity::try_from(b"engine-timeout".to_vec()).unwrap()); + let mut handshake = DealerSocket::with_options(options); + handshake.connect(&engine_handshake).await.unwrap(); + handshake + .send(ZmqMessage::from( + rmp_serde::to_vec_named(&ready_message("HELLO")).unwrap(), + )) + .await + .unwrap(); + + let _ = handshake.recv().await.unwrap(); + }); + + let result = EngineCoreClient::connect( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_millis(100), + 0, + None, + ) + .with_local_input_output_addresses(Some(ipc.input_endpoint()), Some(ipc.output_endpoint())), + ) + .await; + + let error = match result { + Ok(_) => panic!("expected ready timeout"), + Err(error) => error, + }; + + let message = error.to_report_string(); + assert!(message.contains("timed out")); + assert!(message.contains("READY")); + engine_task.await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn engine_core_dead_sentinel_marks_client_unhealthy_and_sticks() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-dead".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |_dealer, push| { + Box::pin(async move { + push.send(ZmqMessage::from(ENGINE_CORE_DEAD_SENTINEL.to_vec())).await.unwrap(); + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + timeout(Duration::from_secs(2), async { + while client.is_healthy() { + tokio::task::yield_now().await; + } + }) + .await + .expect("wait for unhealthy client"); + + assert!(!client.is_healthy()); + assert!(matches!( + client.health_error().as_deref(), + Some(Error::EngineCoreDead) + )); + + let error = client.call_utility::("is_sleeping", ()).await.unwrap_err(); + assert!( + is_dispatcher_closed(&error) || is_engine_core_dead(&error), + "unexpected error: {error:?}" + ); + assert!(matches!( + client.health_error().as_deref(), + Some(Error::EngineCoreDead) + )); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn output_loop_failure_marks_client_unhealthy_and_records_first_error() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-output-failure".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |_dealer, push| { + Box::pin(async move { + send_output_frames( + push, + vec![ + bytes::Bytes::from_static(b"frame-1"), + bytes::Bytes::from_static(b"frame-2"), + ], + ) + .await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + timeout(Duration::from_secs(2), async { + while client.is_healthy() { + let _ = client.call_utility::("is_sleeping", ()).await; + tokio::task::yield_now().await; + } + }) + .await + .expect("wait for unhealthy client"); + + assert!(!client.is_healthy()); + assert!(is_decode_error( + client.health_error().as_deref().expect("health error recorded") + )); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_decodes_multipart_logprob_outputs() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-multipart-logprobs".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.request_id, "req-1"); + + send_output_frames(push, multipart_logprob_output_frames("req-1")).await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 1, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + let stream = client.call(sample_request()).await.unwrap(); + let outputs = stream.collect::>().await; + assert_eq!(outputs.len(), 1); + + let output = outputs.into_iter().next().unwrap().unwrap(); + assert_eq!(output.output.new_token_ids, vec![7, 8]); + assert_eq!( + output.output.finish_reason, + Some(EngineCoreFinishReason::Length) + ); + expect_sample_logprobs(output.output.new_logprobs.as_ref().expect("logprobs decoded")); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn multi_engine_client_shares_transport_and_routes_by_inflight_count() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let (engine_0_seen_tx, mut engine_0_seen_rx) = mpsc::unbounded_channel(); + let (engine_1_seen_tx, engine_1_seen_rx) = oneshot::channel(); + let (finish_req_1_tx, finish_req_1_rx) = oneshot::channel(); + let (finish_req_2_tx, finish_req_2_rx) = oneshot::channel(); + let (finish_req_3_tx, finish_req_3_rx) = oneshot::channel(); + + let (init_rx_0, shutdown_tx_0, engine_task_0) = spawn_mock_engine_task_with_init( + handshake_address.clone(), + b"engine-0".to_vec(), + |dealer, push| { + Box::pin(async move { + let add_1 = recv_engine_message(dealer).await; + assert_eq!(add_1[0].as_ref(), &[0x00]); + let request_1: EngineCoreRequest = rmp_serde::from_slice(&add_1[1]).unwrap(); + assert_eq!(request_1.request_id, "req-1"); + engine_0_seen_tx.send(request_1.request_id.clone()).unwrap(); + finish_req_1_rx.await.unwrap(); + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output( + &request_1.request_id, + vec![10], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from([request_1.request_id.clone()])), + ..Default::default() + }, + ) + .await; + + let add_3 = recv_engine_message(dealer).await; + assert_eq!(add_3[0].as_ref(), &[0x00]); + let request_3: EngineCoreRequest = rmp_serde::from_slice(&add_3[1]).unwrap(); + assert_eq!(request_3.request_id, "req-3"); + engine_0_seen_tx.send(request_3.request_id.clone()).unwrap(); + finish_req_3_rx.await.unwrap(); + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output( + &request_3.request_id, + vec![30], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from([request_3.request_id.clone()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + let (init_rx_1, shutdown_tx_1, engine_task_1) = spawn_mock_engine_task_with_init( + handshake_address.clone(), + b"engine-1".to_vec(), + |dealer, push| { + Box::pin(async move { + let add_2 = recv_engine_message(dealer).await; + assert_eq!(add_2[0].as_ref(), &[0x00]); + let request_2: EngineCoreRequest = rmp_serde::from_slice(&add_2[1]).unwrap(); + assert_eq!(request_2.request_id, "req-2"); + let _ = engine_1_seen_tx.send(request_2.request_id.clone()); + finish_req_2_rx.await.unwrap(); + send_outputs( + push, + EngineCoreOutputs { + engine_index: 1, + outputs: vec![request_output( + &request_2.request_id, + vec![20], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from([request_2.request_id.clone()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address.clone(), + 2, + "test-model", + Duration::from_secs(2), + 0, + None, + ), + &ipc, + ) + .await; + + let init_0 = timeout(Duration::from_secs(1), init_rx_0).await.unwrap().unwrap(); + let init_1 = timeout(Duration::from_secs(1), init_rx_1).await.unwrap().unwrap(); + assert_eq!(init_0.addresses.inputs, vec![ipc.input_endpoint()]); + assert_eq!(init_1.addresses.inputs, vec![ipc.input_endpoint()]); + assert_eq!(init_0.addresses.outputs, vec![ipc.output_endpoint()]); + assert_eq!(init_1.addresses.outputs, vec![ipc.output_endpoint()]); + + assert_eq!(client.input_address(), ipc.input_endpoint()); + assert_eq!(client.output_address(), ipc.output_endpoint()); + assert_eq!(client.engine_count(), 2); + assert_eq!( + client.engine_identities(), + vec![b"engine-0".as_slice(), b"engine-1".as_slice()] + ); + assert_eq!(client.ready_responses().len(), 2); + assert_eq!(client.engine_identities()[0], b"engine-0"); + + let mut stream_1 = client.call(sample_request_with_id("req-1")).await.unwrap(); + let mut stream_2 = client.call(sample_request_with_id("req-2")).await.unwrap(); + assert_eq!( + timeout(Duration::from_secs(1), engine_0_seen_rx.recv()).await.unwrap().unwrap(), + "req-1" + ); + assert_eq!( + timeout(Duration::from_secs(1), engine_1_seen_rx).await.unwrap().unwrap(), + "req-2" + ); + + let _ = finish_req_1_tx.send(()); + let final_1 = timeout(Duration::from_secs(1), stream_1.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(final_1.engine_index, 0); + assert_eq!(final_1.new_token_ids, vec![10]); + assert_eq!(final_1.finish_reason, Some(EngineCoreFinishReason::Length)); + + let mut stream_3 = client.call(sample_request_with_id("req-3")).await.unwrap(); + assert_eq!( + timeout(Duration::from_secs(1), engine_0_seen_rx.recv()).await.unwrap().unwrap(), + "req-3" + ); + + let _ = finish_req_3_tx.send(()); + let final_3 = timeout(Duration::from_secs(1), stream_3.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(final_3.engine_index, 0); + assert_eq!(final_3.new_token_ids, vec![30]); + assert_eq!(final_3.finish_reason, Some(EngineCoreFinishReason::Length)); + + let _ = finish_req_2_tx.send(()); + let final_2 = timeout(Duration::from_secs(1), stream_2.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(final_2.engine_index, 1); + assert_eq!(final_2.new_token_ids, vec![20]); + assert_eq!(final_2.finish_reason, Some(EngineCoreFinishReason::Length)); + + assert!(timeout(Duration::from_secs(1), stream_1.next()).await.unwrap().is_none()); + assert!(timeout(Duration::from_secs(1), stream_2.next()).await.unwrap().is_none()); + assert!(timeout(Duration::from_secs(1), stream_3.next()).await.unwrap().is_none()); + + let _ = shutdown_tx_0.send(()); + let _ = shutdown_tx_1.send(()); + engine_task_0.await.unwrap(); + engine_task_1.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn multi_engine_abort_is_grouped_and_utility_fans_out_to_all_engines() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + + let (shutdown_tx_0, engine_task_0) = spawn_mock_engine_task( + handshake_address.clone(), + b"engine-0".to_vec(), + |dealer, push| { + Box::pin(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + let payload = decode_value(&utility[1]); + let array = match payload { + Value::Array(array) => array, + other => panic!("expected utility payload array, got {other:?}"), + }; + let call_id = array[1].as_i64().expect("call_id"); + assert_eq!(array[2], Value::from("is_sleeping")); + send_outputs( + push, + EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id, + failure_message: None, + result: Some(utility_result_value(true)), + }), + ..Default::default() + }, + ) + .await; + + let add_1 = recv_engine_message(dealer).await; + assert_eq!(add_1[0].as_ref(), &[0x00]); + let request_1: EngineCoreRequest = rmp_serde::from_slice(&add_1[1]).unwrap(); + assert_eq!(request_1.request_id, "req-1"); + + let abort = recv_engine_message(dealer).await; + assert_eq!(abort[0].as_ref(), &[0x01]); + let aborted_ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap(); + assert_eq!(aborted_ids, vec!["req-1".to_string()]); + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output( + "req-1", + vec![], + Some(EngineCoreFinishReason::Abort), + )], + finished_requests: Some(BTreeSet::from(["req-1".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + tokio::time::sleep(Duration::from_millis(50)).await; + let (shutdown_tx_1, engine_task_1) = spawn_mock_engine_task( + handshake_address.clone(), + b"engine-1".to_vec(), + |dealer, push| { + Box::pin(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + let payload = decode_value(&utility[1]); + let array = match payload { + Value::Array(array) => array, + other => panic!("expected utility payload array, got {other:?}"), + }; + let call_id = array[1].as_i64().expect("call_id"); + assert_eq!(array[2], Value::from("is_sleeping")); + send_outputs( + push, + EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id, + failure_message: None, + result: Some(utility_result_value(true)), + }), + ..Default::default() + }, + ) + .await; + + let add_2 = recv_engine_message(dealer).await; + assert_eq!(add_2[0].as_ref(), &[0x00]); + let request_2: EngineCoreRequest = rmp_serde::from_slice(&add_2[1]).unwrap(); + assert_eq!(request_2.request_id, "req-2"); + + let abort = recv_engine_message(dealer).await; + assert_eq!(abort[0].as_ref(), &[0x01]); + let aborted_ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap(); + assert_eq!(aborted_ids, vec!["req-2".to_string()]); + send_outputs( + push, + EngineCoreOutputs { + engine_index: 1, + outputs: vec![request_output( + "req-2", + vec![], + Some(EngineCoreFinishReason::Abort), + )], + finished_requests: Some(BTreeSet::from(["req-2".to_string()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 2, + "test-model", + Duration::from_secs(2), + 5, + None, + ), + &ipc, + ) + .await; + + assert!(client.is_sleeping().await.unwrap()); + + let mut stream_1 = client.call(sample_request_with_id("req-1")).await.unwrap(); + let mut stream_2 = client.call(sample_request_with_id("req-2")).await.unwrap(); + + client + .abort(&[ + "req-2".to_string(), + "req-1".to_string(), + "unknown".to_string(), + ]) + .await + .unwrap(); + + let final_1 = timeout(Duration::from_secs(1), stream_1.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(final_1.engine_index, 0); + assert_eq!(final_1.finish_reason, Some(EngineCoreFinishReason::Abort)); + + let final_2 = timeout(Duration::from_secs(1), stream_2.next()) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(final_2.engine_index, 1); + assert_eq!(final_2.finish_reason, Some(EngineCoreFinishReason::Abort)); + + let _ = shutdown_tx_0.send(()); + let _ = shutdown_tx_1.send(()); + engine_task_0.await.unwrap(); + engine_task_1.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn collective_rpc_flattens_results_from_all_engines() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + + let (shutdown_tx_0, engine_task_0) = spawn_mock_engine_task( + handshake_address.clone(), + b"engine-0".to_vec(), + |dealer, push| { + Box::pin(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + let payload = decode_value(&utility[1]); + let array = match payload { + Value::Array(array) => array, + other => panic!("expected utility payload array, got {other:?}"), + }; + let call_id = array[1].as_i64().expect("call_id"); + assert_eq!(array[2], Value::from("collective_rpc")); + + send_outputs( + push, + EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id, + failure_message: None, + result: Some(utility_result_value(vec!["engine-0-worker"])), + }), + ..Default::default() + }, + ) + .await; + }) + }, + ); + tokio::time::sleep(Duration::from_millis(50)).await; + let (shutdown_tx_1, engine_task_1) = spawn_mock_engine_task( + handshake_address.clone(), + b"engine-1".to_vec(), + |dealer, push| { + Box::pin(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + let payload = decode_value(&utility[1]); + let array = match payload { + Value::Array(array) => array, + other => panic!("expected utility payload array, got {other:?}"), + }; + let call_id = array[1].as_i64().expect("call_id"); + assert_eq!(array[2], Value::from("collective_rpc")); + + send_outputs( + push, + EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id, + failure_message: None, + result: Some(utility_result_value(vec!["engine-1-worker"])), + }), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let client = connect_client_with_ipc( + handshake_test_config( + handshake_address, + 2, + "test-model", + Duration::from_secs(2), + 5, + None, + ), + &ipc, + ) + .await; + + let results = client + .collective_rpc( + "get_model_name", + Option::::None, + Vec::::new(), + BTreeMap::::new(), + ) + .await + .unwrap(); + assert_eq!( + results, + vec![ + Value::from("engine-0-worker"), + Value::from("engine-1-worker") + ] + ); + + let _ = shutdown_tx_0.send(()); + let _ = shutdown_tx_1.send(()); + engine_task_0.await.unwrap(); + engine_task_1.await.unwrap(); + client.shutdown().await.unwrap(); +} + +#[test] +fn python_msgpack_fixtures_match_rust_encoding() { + init_tracing(); + let script = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("src/tests/python_compat.py"); + let output = Command::new(&script) + .output() + .unwrap_or_else(|error| panic!("failed to execute {:?}: {error}", script)); + assert!( + output.status.success(), + "python fixture script failed: status={:?}\nstdout:\n{}\nstderr:\n{}", + output.status.code(), + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr), + ); + + let stdout = String::from_utf8(output.stdout).unwrap(); + let mut lines = stdout.lines(); + let request_hex = lines.next().expect("missing request fixture line"); + let multimodal_request_hex = lines.next().expect("missing multimodal request fixture line"); + let outputs_hex = lines.next().expect("missing outputs fixture line"); + let inline_logprobs_frames = lines.next().expect("missing inline logprobs fixture line"); + let multipart_logprobs_frames = lines.next().expect("missing multipart logprobs fixture line"); + let inline_prompt_frames = lines.next().expect("missing inline prompt logprobs fixture line"); + let multipart_prompt_frames = + lines.next().expect("missing multipart prompt logprobs fixture line"); + + let request_bytes = hex::decode(request_hex).unwrap(); + let multimodal_request_bytes = hex::decode(multimodal_request_hex).unwrap(); + let outputs_bytes = hex::decode(outputs_hex).unwrap(); + + let decoded_request: EngineCoreRequest = rmp_serde::from_slice(&request_bytes).unwrap(); + let expected_request = sample_request(); + assert_eq!(decoded_request, expected_request); + + let decoded_multimodal_request: EngineCoreRequest = + rmp_serde::from_slice(&multimodal_request_bytes).unwrap(); + assert_eq!(decoded_multimodal_request, sample_multimodal_request()); + + // The decode assertion above proves Python wire -> Rust struct. Also compare + // Rust struct -> wire for the multimodal subtree, which is the frontend's + // production direction when sending requests to Python EngineCore. + let expected_multimodal_request = sample_multimodal_request(); + let decode_value = |bytes: &[u8]| { + rmpv::decode::read_value(&mut Cursor::new(bytes)).expect("decode msgpack value") + }; + let extract_mm_features = |value: Value| match value { + Value::Array(items) => items.get(2).cloned().expect("request mm_features slot"), + other => panic!("request should encode as tuple array, got {other:?}"), + }; + let python_mm_features = extract_mm_features(decode_value(&multimodal_request_bytes)); + let rust_mm_features = + decode_value(&rmp_serde::to_vec_named(&expected_multimodal_request.mm_features).unwrap()); + assert_eq!(python_mm_features, rust_mm_features); + + let decoded_outputs: EngineCoreOutputs = rmp_serde::from_slice(&outputs_bytes).unwrap(); + expect_test::expect![[r#" + EngineCoreOutputs { + engine_index: 0, + outputs: [ + EngineCoreOutput { + request_id: "req-1", + new_token_ids: [ + 7, + 8, + ], + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason: Some( + Length, + ), + stop_reason: None, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + }, + ], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: Some( + { + "req-1", + }, + ), + wave_complete: None, + start_wave: None, + } + "#]] + .assert_debug_eq(&decoded_outputs); + + let decode_frames = |line: &str| { + line.split_whitespace() + .map(|frame| bytes::Bytes::from(hex::decode(frame).unwrap())) + .collect::>() + }; + + let inline_logprobs = + decode_engine_core_outputs(&decode_frames(inline_logprobs_frames)).unwrap(); + expect_sample_logprobs( + inline_logprobs.outputs[0] + .new_logprobs + .as_ref() + .expect("inline logprobs decoded"), + ); + + let multipart_logprobs = + decode_engine_core_outputs(&decode_frames(multipart_logprobs_frames)).unwrap(); + expect_sample_logprobs( + multipart_logprobs.outputs[0] + .new_logprobs + .as_ref() + .expect("multipart logprobs decoded"), + ); + + let inline_prompt = decode_engine_core_outputs(&decode_frames(inline_prompt_frames)).unwrap(); + expect_prompt_logprobs( + inline_prompt.outputs[0] + .new_prompt_logprobs_tensors + .as_ref() + .expect("inline prompt logprobs decoded"), + ); + + let multipart_prompt = + decode_engine_core_outputs(&decode_frames(multipart_prompt_frames)).unwrap(); + expect_prompt_logprobs( + multipart_prompt.outputs[0] + .new_prompt_logprobs_tensors + .as_ref() + .expect("multipart prompt logprobs decoded"), + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn bootstrapped_connects_after_single_engine_registration() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let input_address = ipc.input_endpoint(); + let output_address = ipc.output_endpoint(); + + let client_task = tokio::spawn({ + let input_address = input_address.clone(); + let output_address = output_address.clone(); + async move { + EngineCoreClient::connect(bootstrapped_test_config( + input_address, + output_address, + 1, + Duration::from_secs(2), + 0, + None, + )) + .await + .unwrap() + } + }); + + let (_dealer, _push) = + setup_bootstrapped_mock_engine(input_address, output_address, &[0x00, 0x00]).await; + let client = client_task.await.unwrap(); + + assert_eq!(client.engine_count(), 1); + let engine_ids = + client.engine_identities().into_iter().map(|id| id.to_vec()).collect::>(); + assert_eq!(engine_ids, vec![vec![0x00, 0x00]]); + + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn bootstrapped_connects_with_contiguous_engine_ids() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let input_address = ipc.input_endpoint(); + let output_address = ipc.output_endpoint(); + + let client_task = tokio::spawn({ + let input_address = input_address.clone(); + let output_address = output_address.clone(); + async move { + EngineCoreClient::connect(bootstrapped_test_config( + input_address, + output_address, + 2, + Duration::from_secs(2), + 0, + None, + )) + .await + .unwrap() + } + }); + + let (_dealer0, _push0) = setup_bootstrapped_mock_engine( + input_address.clone(), + output_address.clone(), + &[0x00, 0x00], + ) + .await; + let (_dealer1, _push1) = + setup_bootstrapped_mock_engine(input_address, output_address, &[0x01, 0x00]).await; + let client = client_task.await.unwrap(); + + assert_eq!(client.engine_count(), 2); + let engine_ids = + client.engine_identities().into_iter().map(|id| id.to_vec()).collect::>(); + assert_eq!(engine_ids, vec![vec![0x00, 0x00], vec![0x01, 0x00]]); + + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn bootstrapped_connect_times_out_without_registration() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let result = EngineCoreClient::connect(bootstrapped_test_config( + ipc.input_endpoint(), + ipc.output_endpoint(), + 1, + Duration::from_millis(100), + 0, + None, + )) + .await; + + let error = match result { + Ok(_) => panic!("bootstrapped connect should time out"), + Err(error) => error, + }; + assert!(matches!(error, Error::InputRegistrationTimeout { .. })); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn bootstrapped_external_coordinator_connects_and_subscribes() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let input_address = ipc.input_endpoint(); + let output_address = ipc.output_endpoint(); + let coordinator_address = ipc.endpoint("stats.sock"); + + let mut stats_socket = XPubSocket::new(); + stats_socket.bind(&coordinator_address).await.unwrap(); + + let client_task = tokio::spawn({ + let input_address = input_address.clone(); + let output_address = output_address.clone(); + let coordinator_address = coordinator_address.clone(); + async move { + EngineCoreClient::connect(bootstrapped_test_config( + input_address, + output_address, + 1, + Duration::from_secs(2), + 0, + Some(CoordinatorMode::External { + address: coordinator_address, + }), + )) + .await + .unwrap() + } + }); + + let (_dealer, _push) = + setup_bootstrapped_mock_engine(input_address, output_address, &[0x00, 0x00]).await; + let client = client_task.await.unwrap(); + + timeout( + Duration::from_secs(1), + recv_xpub_subscription(&mut stats_socket), + ) + .await + .unwrap(); + + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn bootstrapped_external_coordinator_updates_wave_ignores_counts_and_sends_one_wakeup() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let input_address = ipc.input_endpoint(); + let output_address = ipc.output_endpoint(); + let coordinator_address = ipc.endpoint("stats.sock"); + + let mut stats_socket = XPubSocket::new(); + stats_socket.bind(&coordinator_address).await.unwrap(); + + let client_task = tokio::spawn({ + let input_address = input_address.clone(); + let output_address = output_address.clone(); + let coordinator_address = coordinator_address.clone(); + async move { + EngineCoreClient::connect(bootstrapped_test_config( + input_address, + output_address, + 1, + Duration::from_secs(2), + 0, + Some(CoordinatorMode::External { + address: coordinator_address, + }), + )) + .await + .unwrap() + } + }); + + let (mut dealer, mut push) = + setup_bootstrapped_mock_engine(input_address, output_address, &[0x00, 0x00]).await; + let client = client_task.await.unwrap(); + recv_xpub_subscription(&mut stats_socket).await; + + send_external_coordinator_publish(&mut stats_socket, &(vec![(11_u32, 3_u32)], 7_u32, false)) + .await; + tokio::time::sleep(Duration::from_millis(50)).await; + + let mut stream = client.call(sample_request()).await.unwrap(); + + let wakeup = timeout( + Duration::from_secs(1), + recv_external_coordinator_wakeup(&mut stats_socket), + ) + .await + .unwrap(); + assert_eq!(wakeup, (0, 7)); + + assert!( + timeout( + Duration::from_millis(200), + recv_external_coordinator_wakeup(&mut stats_socket) + ) + .await + .is_err() + ); + + let add = recv_engine_message(&mut dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.request_id, "req-1"); + assert_eq!(request.current_wave, 7); + assert!(client.is_healthy()); + + send_outputs( + &mut push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output( + "req-1", + vec![], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from(["req-1".to_string()])), + ..Default::default() + }, + ) + .await; + + let final_output = timeout(Duration::from_secs(1), stream.next()).await.unwrap(); + assert!(final_output.is_some()); + + client.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn bootstrapped_external_coordinator_running_state_suppresses_wakeup() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let input_address = ipc.input_endpoint(); + let output_address = ipc.output_endpoint(); + let coordinator_address = ipc.endpoint("stats.sock"); + + let mut stats_socket = XPubSocket::new(); + stats_socket.bind(&coordinator_address).await.unwrap(); + + let client_task = tokio::spawn({ + let input_address = input_address.clone(); + let output_address = output_address.clone(); + let coordinator_address = coordinator_address.clone(); + async move { + EngineCoreClient::connect(bootstrapped_test_config( + input_address, + output_address, + 1, + Duration::from_secs(2), + 0, + Some(CoordinatorMode::External { + address: coordinator_address, + }), + )) + .await + .unwrap() + } + }); + + let (mut dealer, mut push) = + setup_bootstrapped_mock_engine(input_address, output_address, &[0x00, 0x00]).await; + let client = client_task.await.unwrap(); + recv_xpub_subscription(&mut stats_socket).await; + + send_external_coordinator_publish(&mut stats_socket, &(Value::Nil, 5_u32, true)).await; + tokio::time::sleep(Duration::from_millis(50)).await; + + let mut stream = client.call(sample_request()).await.unwrap(); + + assert!( + timeout( + Duration::from_millis(200), + recv_external_coordinator_wakeup(&mut stats_socket) + ) + .await + .is_err() + ); + + let add = recv_engine_message(&mut dealer).await; + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.current_wave, 5); + + send_outputs( + &mut push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output( + "req-1", + vec![], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from(["req-1".to_string()])), + ..Default::default() + }, + ) + .await; + + let final_output = timeout(Duration::from_secs(1), stream.next()).await.unwrap(); + assert!(final_output.is_some()); + + client.shutdown().await.unwrap(); +} diff --git a/rust/src/engine-core-client/src/tests/mod.rs b/rust/src/engine-core-client/src/tests/mod.rs new file mode 100644 index 00000000000..b79c47fca36 --- /dev/null +++ b/rust/src/engine-core-client/src/tests/mod.rs @@ -0,0 +1 @@ +mod client; diff --git a/rust/src/engine-core-client/src/tests/python_compat.py b/rust/src/engine-core-client/src/tests/python_compat.py new file mode 100755 index 00000000000..bb81a6df1ad --- /dev/null +++ b/rust/src/engine-core-client/src/tests/python_compat.py @@ -0,0 +1,356 @@ +#!/usr/bin/env -S uv run +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "msgspec>=0.19,<1", +# "msgpack>=1,<2", +# "numpy>=2,<3", +# ] +# /// + +from enum import Enum, IntEnum + +import msgpack +import msgspec +import numpy as np + + +class RequestOutputKind(Enum): + DELTA = 1 + FINAL_ONLY = 2 + + +class FinishReason(IntEnum): + STOP = 0 + LENGTH = 1 + ABORT = 2 + ERROR = 3 + REPETITION = 4 + + +class EngineCoreSamplingParams(msgspec.Struct, dict=True): + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = 0 + seed: int | None = None + max_tokens: int = 65536 + min_tokens: int = 0 + min_p: float = 0.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 + repetition_penalty: float = 1.0 + stop_token_ids: list[int] = [] + _eos_token_id: int | None = None + _all_stop_token_ids: set[int] = set() + output_kind: RequestOutputKind = RequestOutputKind.DELTA + + +class EngineCoreRequest( + msgspec.Struct, + array_like=True, + omit_defaults=True, +): + request_id: str + prompt_token_ids: list[int] | None + mm_features: object | None + sampling_params: EngineCoreSamplingParams | None + pooling_params: object | None + arrival_time: float + lora_request: object | None = None + cache_salt: str | None = None + data_parallel_rank: int | None = None + prompt_embeds: object | None = None + prompt_is_token_ids: list[bool] | None = None + client_index: int = 0 + current_wave: int = 0 + priority: int = 0 + trace_headers: dict[str, str] | None = None + resumable: bool = False + external_req_id: str | None = None + reasoning_ended: bool | None = None + reasoning_parser_kwargs: dict[str, object] | None = None + abort_immediately: bool = False + + +class EngineCoreOutput( + msgspec.Struct, + array_like=True, + omit_defaults=True, +): + request_id: str + new_token_ids: list[int] + new_logprobs: object | None = None + new_prompt_logprobs_tensors: object | None = None + pooling_output: object | None = None + finish_reason: FinishReason | None = None + stop_reason: int | str | None = None + events: object | None = None + kv_transfer_params: object | None = None + trace_headers: object | None = None + prefill_stats: object | None = None + routed_experts: object | None = None + num_nans_in_logits: int = 0 + + +class EngineCoreOutputs( + msgspec.Struct, + array_like=True, + omit_defaults=True, +): + engine_index: int = 0 + outputs: list[EngineCoreOutput] = [] + scheduler_stats: object | None = None + timestamp: float = 0.0 + utility_output: object | None = None + finished_requests: set[str] | None = None + wave_complete: int | None = None + start_wave: int | None = None + + +request = EngineCoreRequest( + request_id="req-1", + prompt_token_ids=[11, 22], + mm_features=None, + sampling_params=EngineCoreSamplingParams( + temperature=0.8, + top_p=0.9, + top_k=8, + seed=None, + max_tokens=32, + min_tokens=1, + min_p=0.0, + frequency_penalty=0.0, + presence_penalty=0.0, + repetition_penalty=1.0, + stop_token_ids=[151643], + _eos_token_id=151645, + _all_stop_token_ids={151643, 151645}, + output_kind=RequestOutputKind.FINAL_ONLY, + ), + pooling_params=None, + arrival_time=42.5, + client_index=0, +) + +multimodal_tensor = np.array([[1.0, 2.0], [3.5, 4.25]], dtype=np.float32) +multimodal_features = [ + { + "data": { + "pixel_values": { + "data": [ + "float32", + [2, 2], + msgpack.ExtType(3, multimodal_tensor.tobytes()), + ], + "field": [ + "flat", + { + "slices": [[0, 2, None]], + "dim": 0, + "keep_on_cpu": False, + }, + ], + } + }, + "modality": "image", + "identifier": "mm-cache-key", + "mm_position": { + "offset": 1, + "length": 2, + "is_embed": None, + }, + "mm_hash": "processor-hash", + } +] +multimodal_request_wire = [ + "req-mm", + [101, 102, 103, 104], + multimodal_features, + None, + None, + 43.5, +] + +outputs = EngineCoreOutputs( + outputs=[ + EngineCoreOutput( + request_id="req-1", + new_token_ids=[7, 8], + finish_reason=FinishReason.LENGTH, + ) + ], + finished_requests={"req-1"}, +) + + +def encode_ndarray( + array: np.ndarray, + buffers: list[bytes], + *, + size_threshold: int = 256, +): + arr_data = array.data if array.flags.c_contiguous else array.tobytes() + if not array.shape or array.nbytes < size_threshold: + data = msgpack.ExtType(3, bytes(arr_data)) + else: + data = len(buffers) + buffers.append(bytes(arr_data)) + return [array.dtype.str, list(array.shape), data] + + +def encode_tensor_like( + dtype: str, + shape: list[int], + payload: bytes, + buffers: list[bytes], + *, + size_threshold: int = 256, +): + if len(payload) < size_threshold: + data = msgpack.ExtType(3, payload) + else: + data = len(buffers) + buffers.append(payload) + return [dtype, shape, data] + + +def encode_output_frames(obj, *, size_threshold: int = 256) -> list[bytes]: + buffers = [b""] + + def transform(value): + if isinstance(value, np.ndarray): + return encode_ndarray(value, buffers, size_threshold=size_threshold) + if ( + isinstance(value, tuple) + and len(value) == 3 + and value[0] in ("int32", "int64", "float32") + ): + dtype, shape, payload = value + return encode_tensor_like( + dtype, + shape, + payload, + buffers, + size_threshold=size_threshold, + ) + if type(value) is list: + return [transform(v) for v in value] + if type(value) is tuple: + return [transform(v) for v in value] + if type(value) is dict: + return {k: transform(v) for k, v in value.items()} + return value + + buffers[0] = msgpack.packb(transform(obj), use_bin_type=True) + return buffers + + +def engine_output_wire( + request_id: str, + *, + new_logprobs=None, + new_prompt_logprobs_tensors=None, +): + return [ + request_id, + [7, 8], + new_logprobs, + new_prompt_logprobs_tensors, + None, + int(FinishReason.LENGTH), + ] + + +def engine_outputs_wire(output): + return [0, [output], None, 0.0, None, ["req-1"]] + + +inline_logprobs = engine_outputs_wire( + engine_output_wire( + "req-1", + new_logprobs=( + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64), + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), + np.array([1, 2], dtype=np.int64), + None, + ), + ) +) + +multipart_logprobs = engine_outputs_wire( + engine_output_wire( + "req-1", + new_logprobs=( + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64), + np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), + np.array([1, 2], dtype=np.int64), + None, + ), + ) +) + +inline_prompt_logprobs = engine_outputs_wire( + engine_output_wire( + "req-1", + new_prompt_logprobs_tensors=( + ( + "int64", + [2, 3], + np.array([[10, 11, 12], [13, 14, 15]], dtype=np.int64).tobytes(), + ), + ( + "float32", + [2, 3], + np.array( + [[10, 11, 12], [13, 14, 15]], + dtype=np.float32, + ).tobytes(), + ), + ("int64", [2], np.array([3, 4], dtype=np.int64).tobytes()), + None, + ), + ) +) + +multipart_prompt_logprobs = engine_outputs_wire( + engine_output_wire( + "req-1", + new_prompt_logprobs_tensors=( + ( + "int64", + [2, 3], + np.array([[10, 11, 12], [13, 14, 15]], dtype=np.int64).tobytes(), + ), + ( + "float32", + [2, 3], + np.array( + [[10, 11, 12], [13, 14, 15]], + dtype=np.float32, + ).tobytes(), + ), + ("int64", [2], np.array([3, 4], dtype=np.int64).tobytes()), + None, + ), + ) +) + +print(msgspec.msgpack.encode(request).hex()) +print(msgpack.packb(multimodal_request_wire, use_bin_type=True).hex()) +print(msgspec.msgpack.encode(outputs).hex()) +print(" ".join(frame.hex() for frame in encode_output_frames(inline_logprobs))) +print( + " ".join( + frame.hex() + for frame in encode_output_frames(multipart_logprobs, size_threshold=1) + ) +) +print(" ".join(frame.hex() for frame in encode_output_frames(inline_prompt_logprobs))) +print( + " ".join( + frame.hex() + for frame in encode_output_frames(multipart_prompt_logprobs, size_threshold=1) + ) +) diff --git a/rust/src/engine-core-client/src/transport.rs b/rust/src/engine-core-client/src/transport.rs new file mode 100644 index 00000000000..0d6c49340af --- /dev/null +++ b/rust/src/engine-core-client/src/transport.rs @@ -0,0 +1,606 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::fmt::Debug; +use std::ops::Deref; +use std::time::Duration; + +use bytes::Bytes; +use enum_as_inner::EnumAsInner; +use thiserror_ext::AsReport; +use tokio::sync::mpsc; +use tokio::time::timeout; +use tracing::{debug, error, info, trace, warn}; +use zeromq::prelude::{Socket, SocketRecv, SocketSend}; +use zeromq::util::PeerIdentity; +use zeromq::{PullSocket, RouterSendHalf, RouterSocket, ZmqError, ZmqMessage}; + +use crate::coordinator::CoordinatorBootstrap; +use crate::error::{Error, Result, bail_unexpected_handshake_message}; +use crate::protocol::handshake::{ + EngineCoreReadyResponse, HandshakeAddresses, HandshakeInitMessage, ReadyMessage, +}; +use crate::protocol::{ + EngineCoreOutputs, decode_engine_core_outputs, decode_msgpack, encode_msgpack, +}; + +/// Dedicated single-frame sentinel emitted by Python `EngineCoreProc` when the +/// engine dies. +pub const ENGINE_CORE_DEAD_SENTINEL: &[u8] = b"ENGINE_CORE_DEAD"; + +/// Opaque routing identity of one engine on the frontend transport. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct EngineId(Bytes); + +impl Debug for EngineId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Display the engine id as a hex string for easier debugging. + write!(f, "EngineId({})", hex::encode(&self.0)) + } +} + +impl EngineId { + /// Convert the engine id into a ZMQ frame for sending. + pub fn to_frame(&self) -> Bytes { + self.0.clone() + } + + /// Convert the engine id into a ZMQ frame for sending. + pub fn into_frame(self) -> Bytes { + self.0 + } + + /// Parse the Python-compatible engine index encoded in the routing + /// identity. + /// + /// Python `EngineCoreProc` currently uses a two-byte little-endian engine + /// index as its ROUTER/DEALER identity. Coordinator control messages + /// such as `START_DP_WAVE(exclude_engine_index)` need that engine-side + /// index rather than any frontend-local ordering. + pub fn engine_index(&self) -> Option { + if self.len() != 2 { + return None; + } + Some(u16::from_le_bytes([self[0], self[1]]) as u32) + } + + /// Construct an engine id from the Python-compatible engine index encoding + /// (two-byte little-endian). + pub fn from_engine_index(value: u32) -> Self { + Self(Bytes::copy_from_slice(&(value as u16).to_le_bytes())) + } +} + +impl Deref for EngineId { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl From> for EngineId { + fn from(value: Vec) -> Self { + Self(Bytes::from(value)) + } +} + +impl From<&[u8; N]> for EngineId { + fn from(value: &[u8; N]) -> Self { + Self(Bytes::copy_from_slice(value)) + } +} + +impl TryFrom for PeerIdentity { + type Error = ZmqError; + + fn try_from(value: EngineId) -> std::result::Result { + PeerIdentity::try_from(value.into_frame()) + } +} + +/// Per-engine handshake result collected while bootstrapping one shared +/// transport. +#[derive(Clone, Debug)] +pub struct ConnectedEngine { + /// The identity of the connected engine. + pub engine_id: EngineId, + /// Post-initialization configuration received from the engine on the input + /// socket registration message. `None` until the registration is received. + pub ready_response: Option, +} + +/// Represents the connected shared transport plus all registered engines after +/// a successful multi-engine startup handshake. +pub struct ConnectedTransport { + /// The local address of the shared input socket that all engines connect to + /// for receiving requests. + pub input_address: String, + /// The local address of the shared output socket that all engines connect + /// to for sending responses. + pub output_address: String, + /// All engines connected through the startup handshake. + pub engines: Vec, + /// Optional engine-facing coordinator transport used for in-process wave + /// coordination. + pub coordinator: Option, + + /// The sending half of the shared input socket. + pub input_send: RouterSendHalf, + /// The shared output socket for receiving responses from all engines. + pub output_socket: PullSocket, +} + +#[derive(Clone, Debug, EnumAsInner)] +enum EngineStartupState { + HelloReceived, + ReadyReceived, +} + +/// Connect to one or more engines through the startup handshake protocol, +/// returning the shared data-plane transport plus the registered engines. +pub async fn connect_handshake( + handshake_address: &str, + engine_count: usize, + local_host: &str, + local_input_address: Option<&str>, + local_output_address: Option<&str>, + enable_inproc_coordinator: bool, + ready_timeout: Duration, +) -> Result { + if engine_count == 0 { + bail_unexpected_handshake_message!("expected engine_count >= 1"); + } + + info!( + engine_count, + handshake_address, "waiting for engines to connect" + ); + + // 1. Bind shared local input/output sockets first so every engine receives the same data-plane + // addresses during handshake. + debug!( + local_host, + ?ready_timeout, + engine_count, + "binding shared transport sockets" + ); + let (input_address, mut input_socket, output_address, output_socket) = + bind_local_sockets(local_host, local_input_address, local_output_address).await?; + info!(%input_address, %output_address, "bound local transport sockets"); + + let mut coordinator = if enable_inproc_coordinator { + Some(CoordinatorBootstrap::bind(local_host).await?) + } else { + None + }; + + // 2. Bind the shared handshake socket once. All engines connect to this socket with their own + // identities, and startup order does not matter. + let mut handshake_socket = RouterSocket::new(); + handshake_socket.bind(handshake_address).await?; + + let mut engines = BTreeMap::new(); + + // 3. Receive HELLO from every engine and send a matching INIT. When coordinator mode is + // enabled, the engines will not emit READY until the coordinator barrier below completes. + while engines.len() < engine_count { + debug!( + handshake_address, + connected = engines.len(), + waiting_for = engine_count, + "waiting for engine HELLO" + ); + let message = timeout(ready_timeout, handshake_socket.recv()).await.map_err(|_| { + Error::HandshakeTimeout { + stage: "HELLO", + timeout: ready_timeout, + } + })??; + let (engine_id, handshake_message) = decode_handshake_message(message, None)?; + match handshake_message.status.as_deref() { + Some("HELLO") => { + if engines.contains_key(&engine_id) { + bail_unexpected_handshake_message!( + "duplicate engine id {engine_id:?} observed during startup handshake" + ); + } + debug!(handshake_address, ?engine_id, "received HELLO from engine"); + + send_init_message( + &mut handshake_socket, + &engine_id, + &input_address, + &output_address, + coordinator.as_ref(), + ) + .await?; + debug!(handshake_address, ?engine_id, "sent INIT to engine"); + + engines.insert(engine_id.clone(), EngineStartupState::HelloReceived); + } + Some("READY") => { + if coordinator.is_some() { + bail_unexpected_handshake_message!( + "received READY for engine id {engine_id:?} before coordinator startup gate completed" + ); + } + let state = match engines.get_mut(&engine_id) { + Some(state) if !state.is_ready_received() => state, + _ => { + bail_unexpected_handshake_message!( + "received READY for unexpected or duplicate engine id {engine_id:?}" + ); + } + }; + debug!( + handshake_address, + ?engine_id, + ?handshake_message, + "received overlapping READY from engine during HELLO phase" + ); + *state = EngineStartupState::ReadyReceived; + } + other => { + bail_unexpected_handshake_message!("unexpected handshake status {other:?}"); + } + } + } + + // 4. Optional coordinator startup gate. Without coordinator there is nothing to do. + if let Some(coordinator) = coordinator.as_mut() { + coordinator.wait_for_startup_gate(engine_count, ready_timeout).await?; + } + + // 5. After the optional gate has opened, every engine may now send READY. + while engines.values().any(|state| !state.is_ready_received()) { + debug!( + handshake_address, + connected = engines.len(), + ready = engines.values().filter(|state| state.is_ready_received()).count(), + waiting_for = engine_count, + "waiting for engine READY" + ); + let message = timeout(ready_timeout, handshake_socket.recv()).await.map_err(|_| { + Error::HandshakeTimeout { + stage: "READY", + timeout: ready_timeout, + } + })??; + let (engine_id, handshake_message) = decode_handshake_message(message, None)?; + match handshake_message.status.as_deref() { + Some("READY") => { + let state = match engines.get_mut(&engine_id) { + Some(state) if !state.is_ready_received() => state, + _ => { + bail_unexpected_handshake_message!( + "received READY for unexpected or duplicate engine id {engine_id:?}" + ); + } + }; + debug!( + handshake_address, + ?engine_id, + ?handshake_message, + "received READY from engine" + ); + *state = EngineStartupState::ReadyReceived; + } + Some("HELLO") => { + bail_unexpected_handshake_message!( + "received duplicate HELLO for engine id {engine_id:?} after INIT phase completed" + ); + } + other => { + bail_unexpected_handshake_message!("unexpected handshake status {other:?}"); + } + } + } + + // 4. Wait for every engine to connect to the shared input socket and register itself. The + // `ready_response` is a placeholder; it is populated for each engine by + // `wait_for_input_registrations` below. + let mut engines: Vec<_> = engines + .into_keys() + .map(|engine_id| ConnectedEngine { + engine_id, + ready_response: None, + }) + .collect(); + + wait_for_input_registrations(&mut input_socket, &mut engines, ready_timeout).await?; + debug!( + engine_count = engines.len(), + "all engines registered on shared input socket" + ); + + info!(engine_count = engines.len(), "engines connected"); + + let (input_send, _) = input_socket.split(); + + Ok(ConnectedTransport { + input_address, + output_address, + input_send, + output_socket, + engines, + coordinator, + }) +} + +/// Bind to Python-supplied frontend transport addresses and wait for +/// already-initialized engines to register themselves on the input socket. +/// +/// This path mirrors Python's externally managed `AsyncMPClient` bootstrap +/// model: the addresses are already fixed by the supervisor, and engine +/// identities are synthesized from contiguous rank order instead of being +/// discovered through a Rust-owned handshake. +pub async fn connect_bootstrapped( + input_address: &str, + output_address: &str, + engine_count: usize, + ready_timeout: Duration, +) -> Result { + if engine_count == 0 { + bail_unexpected_handshake_message!("expected engine_count >= 1"); + } + + let mut input_socket = RouterSocket::new(); + let input_address = input_socket.bind(input_address).await?.to_string(); + + let mut output_socket = PullSocket::new(); + let output_address = output_socket.bind(output_address).await?.to_string(); + + // TODO: follow start rank + let mut engines = (0..engine_count) + .map(|index| ConnectedEngine { + engine_id: EngineId::from((index as u16).to_le_bytes().to_vec()), + ready_response: None, + }) + .collect::>(); + + wait_for_input_registrations(&mut input_socket, &mut engines, ready_timeout).await?; + info!( + engine_count = engines.len(), + "bootstrapped engines connected" + ); + + let (input_send, _) = input_socket.split(); + + Ok(ConnectedTransport { + input_address, + output_address, + engines, + coordinator: None, + input_send, + output_socket, + }) +} + +/// Bind new input and output sockets. +async fn bind_local_sockets( + local_host: &str, + local_input_address: Option<&str>, + local_output_address: Option<&str>, +) -> Result<(String, RouterSocket, String, PullSocket)> { + let mut input_socket = RouterSocket::new(); + let input_bind_address = local_input_address + .map(str::to_owned) + .unwrap_or_else(|| format!("tcp://{local_host}:0")); + let input_address = input_socket.bind(&input_bind_address).await?.to_string(); + + let mut output_socket = PullSocket::new(); + let output_bind_address = local_output_address + .map(str::to_owned) + .unwrap_or_else(|| format!("tcp://{local_host}:0")); + let output_address = output_socket.bind(&output_bind_address).await?.to_string(); + + Ok((input_address, input_socket, output_address, output_socket)) +} + +/// Decode a handshake message and validate its structure and identity. +fn decode_handshake_message( + message: ZmqMessage, + expected_id: Option<&EngineId>, +) -> Result<(EngineId, ReadyMessage)> { + if message.len() != 2 { + bail_unexpected_handshake_message!("expected 2 frames, got {}", message.len()); + } + + let frames = message.into_vec(); + let actual_id = EngineId(frames[0].clone()); + if let Some(expected_id) = expected_id + && actual_id != *expected_id + { + return Err(Error::UnexpectedHandshakeIdentity { + expected: expected_id.to_vec(), + actual: actual_id.to_vec(), + }); + } + + let handshake_message: ReadyMessage = decode_msgpack(&frames[1])?; + Ok((actual_id, handshake_message)) +} + +/// Send an INIT message to the engine with the local socket addresses for the +/// engine to connect to, using the handshake socket. +async fn send_init_message( + handshake_socket: &mut RouterSocket, + engine_id: &EngineId, + input_address: &str, + output_address: &str, + coordinator: Option<&CoordinatorBootstrap>, +) -> Result<()> { + let init_message = HandshakeInitMessage { + addresses: HandshakeAddresses { + inputs: vec![input_address.to_string()], + outputs: vec![output_address.to_string()], + coordinator_input: coordinator.map(|c| c.input_address.clone()), + coordinator_output: coordinator.map(|c| c.output_address.clone()), + frontend_stats_publish_address: None, + }, + parallel_config: Default::default(), + }; + let payload = encode_msgpack(&init_message)?; + let message = ZmqMessage::try_from(vec![engine_id.to_frame(), Bytes::from(payload)]) + .expect("handshake router messages must contain identity and payload"); + handshake_socket.send(message).await?; + Ok(()) +} + +/// Receive the input registration message from each engine and validate its +/// identity. +/// +/// Each registration contains 2 frames: `[identity, ready-payload]`. +/// +/// Since vLLM commit `c8d98f81f676552c263f35bbde55e6edbe81b4e8` ("[Core] +/// Simplify API server handshake"), the payload is a msgpack-encoded +/// [`EngineCoreReadyResponse`] carrying post-initialization values such as +/// `max_model_len`. +/// +/// Older engines sent an empty second frame here just to establish the +/// ROUTER/DEALER backchannel, with no structured payload on the input socket. +/// We continue to tolerate that legacy shape so the frontend can still connect +/// to slightly older local engine checkouts. +async fn wait_for_input_registrations( + input_socket: &mut RouterSocket, + engines: &mut [ConnectedEngine], + ready_timeout: Duration, +) -> Result<()> { + let mut pending = engines.iter().map(|e| e.engine_id.clone()).collect::>(); + + while !pending.is_empty() { + let registration = timeout(ready_timeout, input_socket.recv()).await.map_err(|_| { + Error::InputRegistrationTimeout { + timeout: ready_timeout, + } + })??; + + if registration.len() != 2 { + bail_unexpected_handshake_message!( + "expected 2 frames for engine input registration, got {}", + registration.len() + ); + } + + let frames = registration.into_vec(); + let actual_id = EngineId(frames[0].clone()); + if !pending.remove(&actual_id) { + bail_unexpected_handshake_message!( + "received input registration for unexpected engine id {actual_id:?}" + ); + } + + let ready_response = if frames[1].is_empty() { + debug!( + ?actual_id, + "received legacy empty input registration from engine" + ); + None + } else { + let ready_response: EngineCoreReadyResponse = decode_msgpack(&frames[1])?; + debug!( + ?actual_id, + ?ready_response, + "received input registration from engine" + ); + Some(ready_response) + }; + + // Store the ready response in the corresponding engine entry. + if let Some(engine) = engines.iter_mut().find(|e| e.engine_id == actual_id) { + engine.ready_response = ready_response; + } + } + + Ok(()) +} + +/// Send an encoded message to the engine through the input socket. +pub async fn send_message( + input_send: &mut RouterSendHalf, + engine_id: &EngineId, + request_type: Bytes, + payload: Vec, +) -> Result<()> { + let message = ZmqMessage::try_from(vec![ + engine_id.to_frame(), + request_type, + Bytes::from(payload), + ]) + .expect("router messages must contain identity and payload"); + + trace!( + ?engine_id, + frame_count = message.len(), + "sending ZMQ message" + ); + input_send.send(message).await?; + Ok(()) +} + +/// Run the output loop to receive messages from the engine and send them to the +/// provided channel. +pub async fn run_output_loop( + mut output_socket: PullSocket, + tx: mpsc::Sender>, +) { + loop { + let message = match output_socket.recv().await { + Ok(message) => message, + Err(error) => { + // If we fail to receive a message from the engine, it's likely that the engine + // has crashed or become unreachable, so we should notify the + // client and shut down the output loop. + error!(error = %error.as_report(), "failed to receive output message"); + let _ = tx.send(Err(Error::Transport(error))).await; + return; + } + }; + + let frame_count = message.len(); + trace!(frame_count, "received output message"); + let frames = message.into_vec(); + let frame = frames.first().expect("output message must have at least one frame"); + let frame_len = frame.len(); + if frame.as_ref() == ENGINE_CORE_DEAD_SENTINEL { + warn!("received ENGINE_CORE_DEAD sentinel from engine"); + let _ = tx.send(Err(Error::EngineCoreDead)).await; + return; + } + let decoded = match decode_engine_core_outputs(&frames) { + Ok(decoded) => { + trace!(frame_len, outputs = ?decoded, "decoded output message"); + Ok(decoded) + } + Err(error) => { + // If we fail to decode the message from the engine, notify the client but keep + // the output loop running to continue processing future + // messages from the engine. + warn!(frame_len, error = %error.as_report(), "failed to decode output message"); + Err(error) + } + }; + + if tx.send(decoded).await.is_err() { + // If we fail to send the decoded message to the client, it's likely that the + // client has shut down, so we should shut down the output loop as + // well. + warn!("output loop rx dropped, shutting down output loop"); + return; + } + } +} + +#[cfg(test)] +mod tests { + use super::bind_local_sockets; + + #[tokio::test] + async fn bind_local_sockets_resolves_zero_port_bindings() { + let (input_address, _input_socket, output_address, _output_socket) = + bind_local_sockets("127.0.0.1", None, None).await.expect("bind local sockets"); + + assert!(input_address.starts_with("tcp://127.0.0.1:")); + assert!(output_address.starts_with("tcp://127.0.0.1:")); + assert_ne!(input_address, output_address); + } +} diff --git a/rust/src/llm/Cargo.toml b/rust/src/llm/Cargo.toml new file mode 100644 index 00000000000..c7924b85db7 --- /dev/null +++ b/rust/src/llm/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "vllm-llm" +version.workspace = true +edition.workspace = true +license.workspace = true + +[features] +test-util = [] + +[dependencies] +easy-ext.workspace = true +enum-as-inner.workspace = true +futures.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +tokio.workspace = true +tokio-util.workspace = true +tracing.workspace = true +uuid.workspace = true +vllm-engine-core-client.workspace = true +vllm-metrics.workspace = true + +[dev-dependencies] +anyhow.workspace = true +bytes.workspace = true +clap.workspace = true +expect-test.workspace = true +rmp-serde.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +uuid.workspace = true +vllm-engine-core-client = { workspace = true, features = ["test-util"] } +vllm-metrics.workspace = true +zeromq.workspace = true + +[lints] +workspace = true diff --git a/rust/src/llm/examples/README.md b/rust/src/llm/examples/README.md new file mode 100644 index 00000000000..4764dba3a49 --- /dev/null +++ b/rust/src/llm/examples/README.md @@ -0,0 +1,29 @@ +# LLM Smoke Test + +Start headless `vllm`: + +```bash +source ../vllm/.venv/bin/activate +HF_HUB_OFFLINE=1 \ +VLLM_LOGGING_LEVEL=DEBUG \ +VLLM_CPU_KVCACHE_SPACE=2 \ +VLLM_HOST_IP=127.0.0.1 \ +VLLM_LOOPBACK_IP=127.0.0.1 \ +python3 -m vllm.entrypoints.cli.main serve Qwen/Qwen3-0.6B \ + --headless \ + --data-parallel-address 127.0.0.1 \ + --data-parallel-rpc-port 62100 \ + --data-parallel-size-local 1 \ + --max-model-len 512 \ + --dtype float16 +``` + +Run the Rust smoke test through the `vllm-llm` generate interface: + +```bash +cargo run -p vllm-llm --example external_engine_smoke -- \ + --handshake-address tcp://127.0.0.1:62100 \ + --host 127.0.0.1 +``` + +IMPORTANT: You must restart `vllm` each time you run the smoke test, as the vLLM engine cannot manage frontend closures and subsequent reconnects. In other words, do not reuse existing `vllm` instances, if any. diff --git a/rust/src/llm/examples/external_engine_smoke.rs b/rust/src/llm/examples/external_engine_smoke.rs new file mode 100644 index 00000000000..c2d0e6bdfa8 --- /dev/null +++ b/rust/src/llm/examples/external_engine_smoke.rs @@ -0,0 +1,144 @@ +use std::time::Duration; + +use anyhow::{Context, Result, bail}; +use clap::Parser; +use futures::StreamExt as _; +use tokio::time::timeout; +use tracing_subscriber::EnvFilter; +use vllm_engine_core_client::protocol::EngineCoreSamplingParams; +use vllm_engine_core_client::{EngineCoreClient, EngineCoreClientConfig, TransportMode}; +use vllm_llm::{FinishReason, GenerateOutputStream, GenerateRequest, Llm}; + +const PROMPT_TOKEN_IDS: &[u32] = &[20841, 448, 6896, 25, 23811]; + +#[derive(Debug, Parser)] +#[command(about = "Smoke-test the Rust LLM facade against an external vLLM engine.")] +struct Args { + #[arg(long)] + handshake_address: String, + #[arg(long, default_value_t = 1)] + engine_count: usize, + #[arg(long, default_value = "Qwen/Qwen3-0.6B")] + model: String, + #[arg(long, default_value = "127.0.0.1")] + host: String, + #[arg(long, default_value_t = 0)] + client_index: u32, + #[arg(long, default_value_t = 30)] + ready_timeout_secs: u64, + #[arg(long, default_value_t = 120)] + output_timeout_secs: u64, + #[arg(long, default_value_t = 5)] + max_tokens: u32, +} + +fn unique_request_id() -> String { + format!("rust-llm-smoke-{}", uuid::Uuid::new_v4()) +} + +fn init_tracing() { + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("vllm_engine_core_client=debug")); + let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init(); +} + +fn build_request(request_id: String, max_tokens: u32) -> GenerateRequest { + GenerateRequest { + request_id, + prompt_token_ids: PROMPT_TOKEN_IDS.to_vec(), + sampling_params: EngineCoreSamplingParams { + max_tokens, + ..EngineCoreSamplingParams::for_test() + }, + mm_features: None, + arrival_time: None, + cache_salt: None, + trace_headers: None, + priority: 0, + data_parallel_rank: None, + reasoning_ended: None, + lora_request: None, + } +} + +#[derive(Debug)] +struct CompletedRequest { + token_ids: Vec, + finish_reason: FinishReason, +} + +async fn wait_for_request_completion(mut stream: GenerateOutputStream) -> Result { + let output = match stream.next().await { + Some(output) => output.context("failed to receive request output")?, + None => bail!("request stream ended without a final output"), + }; + + let none = stream.next().await; + assert!( + none.is_none(), + "expected final-only stream to end after the final output" + ); + + let finish_reason = output.finish_reason.expect("final-only output must have a finish reason"); + let token_ids = output.token_ids; + + Ok(CompletedRequest { + token_ids, + finish_reason, + }) +} + +async fn wait_for_timeout( + stream: GenerateOutputStream, + output_timeout: Duration, +) -> Result { + timeout(output_timeout, wait_for_request_completion(stream)) + .await + .context("timed out waiting for request output")? +} + +#[tokio::main(flavor = "multi_thread")] +async fn main() -> Result<()> { + init_tracing(); + let args = Args::parse(); + let ready_timeout = Duration::from_secs(args.ready_timeout_secs); + let output_timeout = Duration::from_secs(args.output_timeout_secs); + let request_id = unique_request_id(); + let client = EngineCoreClient::connect(EngineCoreClientConfig { + transport_mode: TransportMode::HandshakeOwner { + handshake_address: args.handshake_address.clone(), + advertised_host: args.host.clone(), + engine_count: args.engine_count, + ready_timeout, + local_input_address: None, + local_output_address: None, + }, + coordinator_mode: None, + model_name: args.model.clone(), + client_index: args.client_index, + }) + .await + .context("failed to connect to external vLLM engine")?; + + println!("model={}", args.model); + println!("handshake_address={}", args.handshake_address); + println!("engine_count={}", args.engine_count); + println!("input_address={}", client.input_address()); + println!("output_address={}", client.output_address()); + println!("engine_identities={:x?}", client.engine_identities()); + + let llm = Llm::new(client); + let request = build_request(request_id.clone(), args.max_tokens); + println!("request_id={request_id}"); + println!("prompt_token_ids={PROMPT_TOKEN_IDS:?}"); + + let stream = llm.generate(request).await.context("failed to submit generate request")?; + let output = wait_for_timeout(stream, output_timeout).await?; + + llm.shutdown().await.context("failed to shut down llm client")?; + + println!("token_ids={:?}", output.token_ids); + println!("finish_reason={:?}", output.finish_reason); + + Ok(()) +} diff --git a/rust/src/llm/src/error.rs b/rust/src/llm/src/error.rs new file mode 100644 index 00000000000..5865e1fcc39 --- /dev/null +++ b/rust/src/llm/src/error.rs @@ -0,0 +1,12 @@ +use thiserror::Error; + +pub type Result = std::result::Result; + +/// Public error type for the Rust `llm` facade. +#[derive(Debug, Error)] +pub enum Error { + #[error("generate request `{request_id}` has an empty prompt_token_ids")] + EmptyPromptTokenIds { request_id: String }, + #[error("engine-core error")] + EngineCoreClient(#[from] vllm_engine_core_client::Error), +} diff --git a/rust/src/llm/src/lib.rs b/rust/src/llm/src/lib.rs new file mode 100644 index 00000000000..d47935259b5 --- /dev/null +++ b/rust/src/llm/src/lib.rs @@ -0,0 +1,100 @@ +use tracing::Span; +use vllm_engine_core_client::EngineCoreClient; + +mod error; +mod log_stats; +mod output; +mod request; +mod request_metrics; + +pub use error::{Error, Result}; +pub use output::{ + CollectedGenerateOutput, FinishReason, GenerateOutput, GenerateOutputStream, + GenerateOutputStreamExt, GeneratePromptInfo, +}; +pub use request::GenerateRequest; +pub use vllm_engine_core_client::protocol::logprobs::{Logprobs, PositionLogprobs, TokenLogprob}; + +use crate::log_stats::StatsLogger; +use crate::request_metrics::RequestMetricsTracker; + +/// Thin generate-only facade over [`EngineCoreClient`]. +/// +/// This mirrors the narrow public shape of Python `AsyncLLM.generate()` and +/// `abort()`, but keeps the boundary close to raw engine-core requests and +/// outputs. +pub struct Llm { + client: EngineCoreClient, + randomize_request_id: bool, + stats_logger: Option, +} + +impl Llm { + /// Create a new minimal LLM facade from an already connected engine-core + /// client. + pub fn new(client: EngineCoreClient) -> Self { + Self { + client, + randomize_request_id: true, + stats_logger: None, + } + } + + /// Enable or disable periodic stats logging. + pub fn with_log_stats(mut self, enabled: bool) -> Self { + if enabled { + let stats_logger = StatsLogger::start( + self.client.model_name().to_string(), + self.client.engine_count(), + ); + self.stats_logger = Some(stats_logger); + } else { + self.stats_logger = None; + } + self + } + + /// Control whether external request ids are randomized before reaching + /// engine-core. + pub fn with_request_id_randomization(mut self, enabled: bool) -> Self { + self.randomize_request_id = enabled; + self + } + + /// Expose the underlying engine-core client for low-level utility/admin + /// calls. + pub fn engine_core_client(&self) -> &EngineCoreClient { + &self.client + } + + /// Submit one tokenized generate request and return a per-request output + /// stream. + pub async fn generate(&self, req: GenerateRequest) -> Result { + let prepared = req.prepare(self.randomize_request_id)?; + let prompt_token_ids = prepared.prompt_token_ids().into(); + + // Record internal engine-core request ID in the current tracing span. + Span::current().record("engine_request_id", &prepared.engine_request.request_id); + + let request_metrics = RequestMetricsTracker::new( + self.client.model_name().to_string(), + prepared.engine_request.arrival_time, + prepared.prompt_token_ids().len() as u32, + (prepared.engine_request.sampling_params.as_ref()).map(|p| p.max_tokens), + 1, + ); + let stream = self.client.call(prepared.engine_request).await?; + + Ok(GenerateOutputStream::new( + prompt_token_ids, + stream, + request_metrics, + )) + } + + /// Shut down the underlying engine-core client and its background tasks. + pub async fn shutdown(self) -> Result<()> { + self.client.shutdown().await?; + Ok(()) + } +} diff --git a/rust/src/llm/src/log_stats.rs b/rust/src/llm/src/log_stats.rs new file mode 100644 index 00000000000..7d3a149731b --- /dev/null +++ b/rust/src/llm/src/log_stats.rs @@ -0,0 +1,199 @@ +use std::fmt::Write; +use std::time::{Duration, Instant}; + +use tokio_util::task::AbortOnDropHandle; +use tracing::{debug, info}; +use vllm_metrics::{ + EngineLabels, F64Gauge, METRICS, PromptTokenSourceLabels, U64Counter, U64Gauge, +}; + +const LOG_STATS_INTERVAL: Duration = Duration::from_secs(10); + +/// Cached, cloned metric handles for one engine. Each clone shares the same +/// underlying `Arc` as the prometheus `Family` entry, so reads go +/// straight to the atomic with no lock. +struct EngineMetrics { + // Counters for throughput deltas. + prompt_tokens_computed: U64Counter, + generation_tokens: U64Counter, + prefix_cache_queries: U64Counter, + prefix_cache_hits: U64Counter, + + // Gauges for instantaneous scheduler state. + scheduler_running: U64Gauge, + scheduler_waiting: U64Gauge, + kv_cache_usage: F64Gauge, +} + +/// Accumulated snapshot values from the last logging interval, used to compute +/// deltas. +struct CounterSnapshot { + prompt_tokens: u64, + generation_tokens: u64, + prefix_cache_queries: u64, + prefix_cache_hits: u64, +} + +/// Periodic stats logger that mirrors Python vLLM's `LoggingStatLogger`. +/// +/// Spawns a background task that logs throughput and scheduler state at a fixed +/// interval. When idle (both current and previous throughputs are zero), logs +/// at DEBUG level. When load drops to zero, emits one final INFO-level line +/// before going quiet. +pub(crate) struct StatsLogger { + _task: AbortOnDropHandle<()>, +} + +impl StatsLogger { + /// Start the background stats logging task. + pub(crate) fn start(model_name: String, engine_count: usize) -> Self { + let task = AbortOnDropHandle::new(tokio::spawn(async move { + run_stats_logger(model_name, engine_count).await; + })); + Self { _task: task } + } +} + +/// Resolve and clone all metric handles once so the hot path is lock-free. +fn resolve_engine_metrics(model_name: &str, engine_count: usize) -> Vec { + let m = &METRICS; + (0..engine_count as u32) + .map(|engine| { + let el = EngineLabels { + model_name: model_name.to_string(), + engine, + }; + let pt = PromptTokenSourceLabels { + model_name: model_name.to_string(), + engine, + source: "local_compute", + }; + EngineMetrics { + // Use "local_compute" source for prompt throughput (excludes + // cached/transferred tokens), matching Python's + // `iteration_stats.prompt_token_stats.computed`. + prompt_tokens_computed: m.request.prompt_tokens_by_source.get_or_create_owned(&pt), + generation_tokens: m.request.generation_tokens.get_or_create_owned(&el), + prefix_cache_queries: m.scheduler.prefix_cache_queries.get_or_create_owned(&el), + prefix_cache_hits: m.scheduler.prefix_cache_hits.get_or_create_owned(&el), + scheduler_running: m.scheduler.scheduler_running.get_or_create_owned(&el), + scheduler_waiting: m.scheduler.scheduler_waiting.get_or_create_owned(&el), + kv_cache_usage: m.scheduler.kv_cache_usage.get_or_create_owned(&el), + } + }) + .collect() +} + +async fn run_stats_logger(model_name: String, engine_count: usize) { + let engines = resolve_engine_metrics(&model_name, engine_count); + + let mut interval = tokio::time::interval(LOG_STATS_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + // The first tick fires immediately; skip it so the first log is after one full + // interval. + interval.tick().await; + + let mut prev = read_counters(&engines); + let mut last_log_time = Instant::now(); + let mut last_prompt_throughput: f64 = 0.0; + let mut last_generation_throughput: f64 = 0.0; + + let mut msg = String::new(); + loop { + interval.tick().await; + + let now = Instant::now(); + let elapsed = now.duration_since(last_log_time).as_secs_f64(); + if elapsed <= 0.0 { + continue; + } + + let curr = read_counters(&engines); + + let prompt_throughput = + curr.prompt_tokens.wrapping_sub(prev.prompt_tokens) as f64 / elapsed; + let generation_throughput = + curr.generation_tokens.wrapping_sub(prev.generation_tokens) as f64 / elapsed; + + // Idle = both current and previous throughputs are zero. + let is_idle = prompt_throughput == 0.0 + && generation_throughput == 0.0 + && last_prompt_throughput == 0.0 + && last_generation_throughput == 0.0; + + // Read scheduler gauges (aggregate across engines). + let (num_running, num_waiting, kv_cache_usage) = read_scheduler_gauges(&engines); + + // Compute prefix cache hit rate over this interval. + let delta_queries = curr.prefix_cache_queries.wrapping_sub(prev.prefix_cache_queries); + let prefix_cache_hit_rate = if delta_queries > 0 { + let delta_hits = curr.prefix_cache_hits.wrapping_sub(prev.prefix_cache_hits); + delta_hits as f64 / delta_queries as f64 * 100.0 + } else { + 0.0 + }; + + // Build the log line. + msg.clear(); + write!( + msg, + "Avg prompt tput: {prompt_throughput:.1} toks/s, \ + Avg generation tput: {generation_throughput:.1} toks/s, \ + Reqs Running: {num_running}, \ + Waiting: {num_waiting}, \ + GPU KV cache used: {:.1}%, \ + Prefix cache hit rate: {prefix_cache_hit_rate:.1}%", + kv_cache_usage * 100.0, + ) + .unwrap(); + + if is_idle { + debug!("{msg}"); + } else { + info!("{msg}"); + } + + last_prompt_throughput = prompt_throughput; + last_generation_throughput = generation_throughput; + last_log_time = now; + prev = curr; + } +} + +/// Read the current cumulative counter values for throughput computation. +fn read_counters(engines: &[EngineMetrics]) -> CounterSnapshot { + let mut snap = CounterSnapshot { + prompt_tokens: 0, + generation_tokens: 0, + prefix_cache_queries: 0, + prefix_cache_hits: 0, + }; + for e in engines { + snap.prompt_tokens += e.prompt_tokens_computed.get(); + snap.generation_tokens += e.generation_tokens.get(); + snap.prefix_cache_queries += e.prefix_cache_queries.get(); + snap.prefix_cache_hits += e.prefix_cache_hits.get(); + } + snap +} + +/// Read the current scheduler gauge values, aggregated across engines. +fn read_scheduler_gauges(engines: &[EngineMetrics]) -> (u64, u64, f64) { + let mut num_running = 0u64; + let mut num_waiting = 0u64; + let mut kv_cache_usage_sum = 0.0f64; + + for e in engines { + num_running += e.scheduler_running.get(); + num_waiting += e.scheduler_waiting.get(); + kv_cache_usage_sum += e.kv_cache_usage.get(); + } + + let kv_cache_usage = if !engines.is_empty() { + kv_cache_usage_sum / engines.len() as f64 + } else { + 0.0 + }; + + (num_running, num_waiting, kv_cache_usage) +} diff --git a/rust/src/llm/src/output.rs b/rust/src/llm/src/output.rs new file mode 100644 index 00000000000..94d9acb3fe8 --- /dev/null +++ b/rust/src/llm/src/output.rs @@ -0,0 +1,346 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll, ready}; + +use enum_as_inner::EnumAsInner; +use futures::stream::FusedStream; +use futures::{Stream, StreamExt as _, pin_mut}; +use serde::{Deserialize, Serialize}; +use vllm_engine_core_client::protocol::logprobs::Logprobs; +use vllm_engine_core_client::protocol::{EngineCoreFinishReason, StopReason}; +use vllm_engine_core_client::{AbortCause, EngineCoreOutputStream}; + +use crate::error::Result; +use crate::request_metrics::{RequestMetricsTracker, current_unix_timestamp_secs}; + +/// Final raw token output plus terminal stream metadata. +#[derive(Debug, Clone, PartialEq)] +pub struct CollectedGenerateOutput { + pub request_id: String, + pub prompt_token_ids: Vec, + pub prompt_logprobs: Option, + pub token_ids: Vec, + pub logprobs: Option, + pub finish_reason: FinishReason, + /// Connector-specific KV transfer parameters for disaggregated serving. + pub kv_transfer_params: Option, +} + +/// Prompt-scoped metadata emitted only once on the first [`GenerateOutput`] for +/// one request. +#[derive(Debug, Clone, PartialEq)] +pub struct GeneratePromptInfo { + /// Original prompt token IDs for this request. + pub prompt_token_ids: Arc<[u32]>, + /// Prompt logprobs returned by engine-core for scored prompt positions, + /// when requested. + pub prompt_logprobs: Option, +} + +/// The reason a request finished. +/// +/// This is a higher-level abstraction over engine-core's finish and stop +/// reasons. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, EnumAsInner)] +pub enum FinishReason { + /// Generation stopped for a stop string, stop token, or EOS. + /// + /// The inner stop reason is present for explicit stop strings or stop + /// tokens, and absent for EOS-driven stops. + Stop(Option), + /// `max_tokens` or `max_model_len` was reached. + Length, + /// The request was aborted by the client. + Abort, + /// A retryable request-level internal error occurred. + Error, + /// A repetitive token pattern was detected. + Repetition, +} + +impl FinishReason { + /// Construct a stop finish reason caused by EOS rather than an explicit + /// stop string/token. + pub fn stop_eos() -> Self { + Self::Stop(None) + } + + /// Returns a human-readable string for this finish reason, used for metrics + /// and reporting. + pub fn as_str(&self) -> &'static str { + match self { + Self::Stop(_) => "stop", + Self::Length => "length", + Self::Abort => "abort", + Self::Error => "error", + Self::Repetition => "repetition", + } + } + + /// If this is a stop finish reason, returns the inner stop reason if it + /// exists. + pub fn as_stop_reason(&self) -> Option<&StopReason> { + match self { + Self::Stop(stop_reason) => stop_reason.as_ref(), + _ => None, + } + } + + /// If this is a stop finish reason, returns the inner stop reason if it + /// exists. + pub fn into_stop_reason(self) -> Option { + match self { + Self::Stop(stop_reason) => stop_reason, + _ => None, + } + } +} + +fn finish_reason_from_engine( + finish_reason: Option, + stop_reason: Option, +) -> Option { + finish_reason.map(|reason| match reason { + EngineCoreFinishReason::Stop => FinishReason::Stop(stop_reason), + EngineCoreFinishReason::Length => FinishReason::Length, + EngineCoreFinishReason::Abort => FinishReason::Abort, + EngineCoreFinishReason::Error => FinishReason::Error, + EngineCoreFinishReason::Repetition => FinishReason::Repetition, + }) +} + +/// Token and logprob output item returned by [`GenerateOutputStream`]. +/// +/// Original Python output reference: +/// +#[derive(Debug, Clone, PartialEq)] +pub struct GenerateOutput { + /// Unique ID of the request that produced this output. + pub request_id: String, + /// One-time prompt metadata emitted only on the first output for this + /// request. + pub prompt_info: Option, + /// Newly produced token IDs for this step. + pub token_ids: Vec, + /// Sample logprobs for the generated positions in this step. + pub logprobs: Option, + /// Terminal finish reason, when this is the final output for the request. + pub finish_reason: Option, + /// Connector-specific KV transfer parameters for disaggregated serving. + pub kv_transfer_params: Option, +} + +impl GenerateOutput { + /// Returns the prompt token IDs when this output carries + /// [`GeneratePromptInfo`]. + /// + /// Only the first output for a request can return `Some`; all later outputs + /// return `None`. + pub fn prompt_token_ids(&self) -> Option<&Arc<[u32]>> { + self.prompt_info.as_ref().map(|info| &info.prompt_token_ids) + } + + /// Returns the prompt logprobs when this output carries + /// [`GeneratePromptInfo`]. + /// + /// Only the first output for a request can return `Some`; all later outputs + /// return `None`. + pub fn prompt_logprobs(&self) -> Option<&Logprobs> { + self.prompt_info.as_ref().and_then(|info| info.prompt_logprobs.as_ref()) + } + + /// Returns whether this output is terminal for the request. + pub fn finished(&self) -> bool { + self.finish_reason.is_some() + } +} + +#[cfg(any(test, feature = "test-util"))] +impl GenerateOutput { + /// Build a [`GenerateOutput`] for tests. + pub fn for_test( + prompt_token_ids: Option>, + token_ids: Vec, + finish_reason: Option, + ) -> Self { + Self { + request_id: String::new(), + prompt_info: prompt_token_ids.map(|ids| GeneratePromptInfo { + prompt_token_ids: ids, + prompt_logprobs: None, + }), + token_ids, + logprobs: None, + finish_reason, + kv_transfer_params: None, + } + } +} + +/// Stream of per-request generate outputs for one request. +/// +/// - A normal termination of the stream represents a clean completion of the request. +/// - For errors, unexpected closes, or explicit aborts, the stream terminates with an error. +pub struct GenerateOutputStream { + pending_prompt_info: Option, + raw_stream: EngineCoreOutputStream, + request_metrics: RequestMetricsTracker, +} + +impl GenerateOutputStream { + /// Create a new generate output stream by adapting one raw engine-core + /// output stream. + pub(crate) fn new( + prompt_token_ids: Arc<[u32]>, + raw_stream: EngineCoreOutputStream, + request_metrics: RequestMetricsTracker, + ) -> Self { + Self { + pending_prompt_info: Some(GeneratePromptInfo { + prompt_token_ids, + prompt_logprobs: None, + }), + raw_stream, + request_metrics, + } + } + + /// Return the internal engine request ID bound to this stream. + pub fn request_id(&self) -> &str { + self.raw_stream.request_id() + } +} + +impl Stream for GenerateOutputStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let raw = match ready!(Pin::new(&mut self.raw_stream).poll_next(cx)) { + Some(Ok(raw)) => raw, + Some(Err(error)) => return Poll::Ready(Some(Err(error.into()))), + None => return Poll::Ready(None), + }; + + let received_at = current_unix_timestamp_secs(); + self.request_metrics.observe_output( + raw.engine_index, + raw.timestamp, + received_at, + &raw.output, + ); + + let raw = raw.output; + + // Populate the one-time prompt info on the first output. + if let Some(info) = &mut self.pending_prompt_info + && info.prompt_logprobs.is_none() + { + info.prompt_logprobs = + raw.new_prompt_logprobs_tensors.map(|value| value.into_direct().unwrap()); + } + + let logprobs = raw.new_logprobs.map(|value| value.into_direct().unwrap()); + + let finish_reason = finish_reason_from_engine(raw.finish_reason, raw.stop_reason); + if let Some(finish_reason) = finish_reason.as_ref() { + self.request_metrics.record_finished(received_at, finish_reason.clone()); + } + + let output = GenerateOutput { + request_id: raw.request_id, + prompt_info: self.pending_prompt_info.take(), + token_ids: raw.new_token_ids, + logprobs, + finish_reason, + kv_transfer_params: raw.kv_transfer_params, + }; + + Poll::Ready(Some(Ok(output))) + } +} + +impl FusedStream for GenerateOutputStream { + fn is_terminated(&self) -> bool { + self.raw_stream.is_terminated() + } +} + +impl Drop for GenerateOutputStream { + fn drop(&mut self) { + if self.raw_stream.is_terminated() { + // Already terminated cleanly, no need to record abort metrics. + return; + } + + // If the user or the upper layer drops a live generate stream, + // `EngineCoreOutputStream::Drop` will trigger an engine-side abort. Record the + // matching terminal request metrics here so frontend-driven aborts are still + // visible as `finished_reason=...` instead of disappearing from observability + // entirely. + let finish_reason = match AbortCause::current() { + AbortCause::DroppedStream => FinishReason::Abort, + AbortCause::StopStringMatched => FinishReason::Stop(None), + }; + + self.request_metrics + .record_finished(current_unix_timestamp_secs(), finish_reason); + } +} + +#[allow(clippy::manual_async_fn, reason = "specify `Send` bound")] +#[easy_ext::ext(GenerateOutputStreamExt)] +impl> + Send> T { + /// Collect the raw generate stream to completion and return the final token + /// output. + pub fn collect_output(self) -> impl Future> + Send { + async move { + let stream = self; + pin_mut!(stream); + let mut prompt_token_ids = None; + let mut prompt_logprobs = None; + let mut collected: Option = None; + + while let Some(output) = stream.next().await.transpose()? { + if let Some(info) = output.prompt_info { + if prompt_token_ids.is_none() { + prompt_token_ids = Some(info.prompt_token_ids.to_vec()); + } + if prompt_logprobs.is_none() { + prompt_logprobs = info.prompt_logprobs; + } + } + + if let Some(existing) = collected.as_mut() { + existing.token_ids.extend(output.token_ids); + if let Some(step_logprobs) = output.logprobs { + if let Some(collected_logprobs) = existing.logprobs.as_mut() { + collected_logprobs.positions.extend(step_logprobs.positions); + } else { + existing.logprobs = Some(step_logprobs); + } + } + } else { + collected = Some(CollectedGenerateOutput { + request_id: output.request_id, + prompt_token_ids: prompt_token_ids.take().unwrap_or_default(), + prompt_logprobs: prompt_logprobs.take(), + token_ids: output.token_ids, + logprobs: output.logprobs, + finish_reason: FinishReason::Error, + kv_transfer_params: None, + }); + } + + if let Some(finish_reason) = output.finish_reason { + let mut collected = collected.expect("terminal output must exist"); + collected.finish_reason = finish_reason; + collected.kv_transfer_params = output.kv_transfer_params; + return Ok(collected); + } + } + + unreachable!("generate stream should yield an error instead of closing early") + } + } +} diff --git a/rust/src/llm/src/request.rs b/rust/src/llm/src/request.rs new file mode 100644 index 00000000000..b17035b0512 --- /dev/null +++ b/rust/src/llm/src/request.rs @@ -0,0 +1,201 @@ +use std::collections::BTreeMap; +use std::time::{SystemTime, UNIX_EPOCH}; + +use uuid::Uuid; +use vllm_engine_core_client::protocol::multimodal::MmFeatures; +use vllm_engine_core_client::protocol::{EngineCoreRequest, EngineCoreSamplingParams, OpaqueValue}; + +use crate::error::{Error, Result}; + +/// Tokenized decoder-only generate request accepted by [`crate::Llm`]. +/// +/// This is the first-stage Rust subset of the inputs that eventually flow into +/// Python `AsyncLLM.generate()`. The boundary is intentionally above +/// [`EngineCoreRequest`], but below higher-level text and multimodal +/// preprocessing. +/// +/// Original Python API reference: +/// +#[derive(Debug, Clone, PartialEq)] +pub struct GenerateRequest { + /// Unique ID of the request. + pub request_id: String, + /// Token IDs of the prompt. + pub prompt_token_ids: Vec, + /// Sampling parameters forwarded to engine-core. + pub sampling_params: EngineCoreSamplingParams, + /// Optional multimodal features already prepared by `vllm-chat`. + pub mm_features: Option, + + // Fields below are currently likely unused by callers. + pub arrival_time: Option, + pub cache_salt: Option, + pub trace_headers: Option>, + pub priority: i32, + pub data_parallel_rank: Option, + pub reasoning_ended: Option, + pub lora_request: Option, +} + +#[derive(Debug)] +pub(crate) struct PreparedGenerateRequest { + pub engine_request: EngineCoreRequest, +} + +impl GenerateRequest { + /// Validate and lower this request into the raw engine-core request format. + pub(crate) fn prepare(self, randomize_request_id: bool) -> Result { + if self.prompt_token_ids.is_empty() { + return Err(Error::EmptyPromptTokenIds { + request_id: self.request_id, + }); + } + let GenerateRequest { + request_id, + prompt_token_ids, + sampling_params, + mm_features, + arrival_time, + cache_salt, + trace_headers, + priority, + data_parallel_rank, + reasoning_ended, + lora_request, + } = self; + + let external_request_id = request_id; + let engine_request_id = if randomize_request_id { + let random_suffix = Uuid::new_v4().simple().to_string(); + format!("{external_request_id}-{}", &random_suffix[..8]) + } else { + external_request_id.clone() + }; + + Ok(PreparedGenerateRequest { + engine_request: EngineCoreRequest { + request_id: engine_request_id, + prompt_token_ids: Some(prompt_token_ids), + mm_features, + sampling_params: Some(sampling_params), + pooling_params: None, + arrival_time: arrival_time.unwrap_or_else(current_unix_timestamp_secs), + lora_request, + cache_salt, + data_parallel_rank, + prompt_embeds: None, + prompt_is_token_ids: None, + client_index: 0, + current_wave: 0, + priority, + trace_headers, + resumable: false, + external_req_id: Some(external_request_id), + reasoning_ended, + reasoning_parser_kwargs: None, + abort_immediately: false, + }, + }) + } +} + +impl PreparedGenerateRequest { + /// Return the original prompt token IDs copied into the raw engine request. + pub fn prompt_token_ids(&self) -> &[u32] { + self.engine_request + .prompt_token_ids + .as_ref() + .expect("prepared request must have prompt token ids") + } +} + +fn current_unix_timestamp_secs() -> f64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock is before unix epoch") + .as_secs_f64() +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use vllm_engine_core_client::protocol::EngineCoreSamplingParams; + + use super::GenerateRequest; + use crate::error::Error; + + fn sample_request() -> GenerateRequest { + GenerateRequest { + request_id: "req-1".to_string(), + prompt_token_ids: vec![11, 22, 33], + sampling_params: EngineCoreSamplingParams::for_test(), + mm_features: None, + arrival_time: Some(42.5), + cache_salt: Some("salt".to_string()), + trace_headers: Some(BTreeMap::from([( + "x-trace-id".to_string(), + "abc".to_string(), + )])), + priority: 3, + data_parallel_rank: Some(2), + reasoning_ended: Some(true), + lora_request: None, + } + } + + #[test] + fn prepare_builds_engine_core_request() { + let prepared = sample_request().prepare(true).unwrap(); + + assert_eq!(prepared.prompt_token_ids(), &[11, 22, 33]); + + let request = prepared.engine_request; + assert_eq!(request.external_req_id.as_deref(), Some("req-1")); + assert!(request.request_id.starts_with("req-1-")); + assert_ne!(request.request_id, "req-1"); + assert_eq!(request.prompt_token_ids.as_deref(), Some(&[11, 22, 33][..])); + assert_eq!(request.arrival_time, 42.5); + assert_eq!(request.cache_salt.as_deref(), Some("salt")); + assert_eq!(request.data_parallel_rank, Some(2)); + assert_eq!( + request.trace_headers, + Some(BTreeMap::from([( + "x-trace-id".to_string(), + "abc".to_string(), + )])) + ); + assert_eq!(request.reasoning_ended, Some(true)); + } + + #[test] + fn prepare_rejects_empty_prompt_tokens() { + let mut request = sample_request(); + request.prompt_token_ids.clear(); + + let error = request.prepare(true).unwrap_err(); + assert!(matches!( + error, + Error::EmptyPromptTokenIds { request_id } if request_id == "req-1" + )); + } + + #[test] + fn prepare_can_preserve_external_request_id() { + let prepared = sample_request().prepare(false).unwrap(); + + let request = prepared.engine_request; + assert_eq!(request.external_req_id.as_deref(), Some("req-1")); + assert_eq!(request.request_id, "req-1"); + } + + #[test] + fn prepare_forwards_multimodal_features() { + let mut request = sample_request(); + request.mm_features = Some(Vec::new()); + + let prepared = request.prepare(false).unwrap(); + + assert_eq!(prepared.engine_request.mm_features, Some(Vec::new())); + } +} diff --git a/rust/src/llm/src/request_metrics.rs b/rust/src/llm/src/request_metrics.rs new file mode 100644 index 00000000000..d28b83be816 --- /dev/null +++ b/rust/src/llm/src/request_metrics.rs @@ -0,0 +1,391 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +use vllm_engine_core_client::protocol::stats::PrefillStats; +use vllm_engine_core_client::protocol::{EngineCoreEvent, EngineCoreEventType, EngineCoreOutput}; +use vllm_metrics::{ + EngineLabels, FinishedReasonLabels, METRICS, PromptTokenSourceLabels, RequestMetrics, +}; + +use crate::FinishReason; + +fn metrics() -> &'static RequestMetrics { + &METRICS.request +} + +const PROMPT_TOKEN_SOURCE_LOCAL_COMPUTE: &str = "local_compute"; +const PROMPT_TOKEN_SOURCE_LOCAL_CACHE_HIT: &str = "local_cache_hit"; +const PROMPT_TOKEN_SOURCE_EXTERNAL_KV_TRANSFER: &str = "external_kv_transfer"; + +/// Request-scoped metrics state tracked across streamed engine-core updates. +/// +/// This is the Rust-side counterpart of the Python frontend's request-lifecycle +/// bookkeeping, centered on `RequestStateStats` and the per-output/per-finished +/// update flow. +/// +/// Original Python definitions: +/// +/// +/// Original Python update flow: +/// +#[derive(Debug, Clone)] +pub(crate) struct RequestMetricsTracker { + model_name: String, + arrival_time: f64, + prompt_len: u32, + max_tokens_param: Option, + n_param: u32, + is_prefilling: bool, + queued_ts: f64, + scheduled_ts: f64, + first_token_ts: f64, + last_token_ts: f64, + first_token_latency: f64, + num_generation_tokens: u32, + latest_num_cached_tokens: u32, + last_seen_engine_index: u32, +} + +impl RequestMetricsTracker { + /// Create the per-request tracker from the normalized `llm`-layer request + /// context. + pub(crate) fn new( + model_name: String, + arrival_time: f64, + prompt_len: u32, + max_tokens_param: Option, + n_param: u32, + ) -> Self { + Self { + model_name, + arrival_time, + prompt_len, + max_tokens_param, + n_param, + is_prefilling: true, + queued_ts: 0.0, + scheduled_ts: 0.0, + first_token_ts: 0.0, + last_token_ts: 0.0, + first_token_latency: 0.0, + num_generation_tokens: 0, + latest_num_cached_tokens: 0, + last_seen_engine_index: 0, + } + } + + /// Update request-lifecycle state from one engine-core output item. + /// + /// Original Python stats logic: + /// + pub(crate) fn observe_output( + &mut self, + engine_index: u32, + batch_timestamp: f64, + received_at: f64, + output: &EngineCoreOutput, + ) { + self.last_seen_engine_index = engine_index; + if let Some(prefill_stats) = &output.prefill_stats { + self.latest_num_cached_tokens = prefill_stats.num_cached_tokens; + } + self.num_generation_tokens += output.new_token_ids.len() as u32; + metrics() + .generation_tokens + .get_or_create(&engine_labels(&self.model_name, engine_index)) + .inc_by(output.new_token_ids.len() as u64); + + if let Some(events) = &output.events { + self.observe_events(engine_index, events); + } + + if self.is_prefilling { + if let Some(prefill_stats) = &output.prefill_stats { + record_prompt_tokens(&self.model_name, engine_index, prefill_stats); + } + self.first_token_latency = received_at - self.arrival_time; + observe_time_to_first_token_seconds( + &self.model_name, + engine_index, + self.first_token_latency, + ); + self.first_token_ts = batch_timestamp; + self.is_prefilling = false; + } else if self.last_token_ts > 0.0 { + observe_inter_token_latency_seconds( + &self.model_name, + engine_index, + batch_timestamp - self.last_token_ts, + ); + } + + self.last_token_ts = batch_timestamp; + } + + /// Emit the terminal request metrics once a finished output has been + /// observed. + /// + /// Original Python finished-request stats: + /// + pub(crate) fn record_finished(&self, received_at: f64, finish_reason: FinishReason) { + let labels = engine_labels(&self.model_name, self.last_seen_engine_index); + let prefill_kv_computed_tokens = + self.prompt_len.saturating_sub(self.latest_num_cached_tokens); + let e2e_latency_seconds = received_at - self.arrival_time; + let queue_time_seconds = diff_or_zero(self.scheduled_ts, self.queued_ts); + let prefill_time_seconds = diff_or_zero(self.first_token_ts, self.scheduled_ts); + let decode_time_seconds = diff_or_zero(self.last_token_ts, self.first_token_ts); + let inference_time_seconds = diff_or_zero(self.last_token_ts, self.scheduled_ts); + let time_per_output_token_seconds = if self.num_generation_tokens > 1 { + diff_or_zero(self.last_token_ts, self.first_token_ts) + / (self.num_generation_tokens - 1) as f64 + } else { + 0.0 + }; + + record_request_success(&self.model_name, self.last_seen_engine_index, finish_reason); + metrics() + .request_prompt_tokens + .get_or_create(&labels) + .observe(self.prompt_len as f64); + metrics() + .request_generation_tokens + .get_or_create(&labels) + .observe(self.num_generation_tokens as f64); + metrics() + .request_max_num_generation_tokens + .get_or_create(&labels) + .observe(self.num_generation_tokens as f64); + if let Some(max_tokens_param) = self.max_tokens_param { + metrics() + .request_params_max_tokens + .get_or_create(&labels) + .observe(max_tokens_param as f64); + } + metrics().request_params_n.get_or_create(&labels).observe(self.n_param as f64); + metrics() + .request_prefill_kv_computed_tokens + .get_or_create(&labels) + .observe(prefill_kv_computed_tokens as f64); + metrics() + .e2e_request_latency_seconds + .get_or_create(&labels) + .observe(e2e_latency_seconds); + metrics() + .request_queue_time_seconds + .get_or_create(&labels) + .observe(queue_time_seconds); + metrics() + .request_prefill_time_seconds + .get_or_create(&labels) + .observe(prefill_time_seconds); + metrics() + .request_decode_time_seconds + .get_or_create(&labels) + .observe(decode_time_seconds); + metrics() + .request_inference_time_seconds + .get_or_create(&labels) + .observe(inference_time_seconds); + metrics() + .request_time_per_output_token_seconds + .get_or_create(&labels) + .observe(time_per_output_token_seconds); + } + + fn observe_events(&mut self, engine_index: u32, events: &[EngineCoreEvent]) { + for event in events { + match event.r#type { + EngineCoreEventType::Queued => { + self.queued_ts = event.timestamp; + } + EngineCoreEventType::Scheduled => { + if self.scheduled_ts == 0.0 { + self.scheduled_ts = event.timestamp; + } + } + EngineCoreEventType::Preempted => { + metrics() + .num_preemptions + .get_or_create(&engine_labels(&self.model_name, engine_index)) + .inc(); + } + } + } + } +} + +fn engine_labels(model_name: &str, engine: u32) -> EngineLabels { + EngineLabels { + model_name: model_name.to_string(), + engine, + } +} + +fn observe_time_to_first_token_seconds(model_name: &str, engine: u32, seconds: f64) { + metrics() + .time_to_first_token_seconds + .get_or_create(&engine_labels(model_name, engine)) + .observe(seconds); +} + +fn observe_inter_token_latency_seconds(model_name: &str, engine: u32, seconds: f64) { + metrics() + .inter_token_latency_seconds + .get_or_create(&engine_labels(model_name, engine)) + .observe(seconds); +} + +fn record_request_success(model_name: &str, engine: u32, finish_reason: FinishReason) { + metrics() + .request_success + .get_or_create(&FinishedReasonLabels { + model_name: model_name.to_string(), + engine, + finished_reason: finish_reason.as_str(), + }) + .inc(); +} + +fn prompt_token_source_labels( + model_name: &str, + engine: u32, + source: &'static str, +) -> PromptTokenSourceLabels { + PromptTokenSourceLabels { + model_name: model_name.to_string(), + engine, + source, + } +} + +fn record_prompt_tokens(model_name: &str, engine: u32, prefill_stats: &PrefillStats) { + let computed = prefill_stats.num_computed_tokens as u64; + let local_cache_hit = prefill_stats.num_local_cached_tokens as u64; + let external_kv_transfer = prefill_stats.num_external_cached_tokens as u64; + + metrics() + .prompt_tokens + .get_or_create(&engine_labels(model_name, engine)) + .inc_by(prefill_stats.num_prompt_tokens as u64); + metrics() + .prompt_tokens_by_source + .get_or_create(&prompt_token_source_labels( + model_name, + engine, + PROMPT_TOKEN_SOURCE_LOCAL_COMPUTE, + )) + .inc_by(computed); + metrics() + .prompt_tokens_by_source + .get_or_create(&prompt_token_source_labels( + model_name, + engine, + PROMPT_TOKEN_SOURCE_LOCAL_CACHE_HIT, + )) + .inc_by(local_cache_hit); + metrics() + .prompt_tokens_by_source + .get_or_create(&prompt_token_source_labels( + model_name, + engine, + PROMPT_TOKEN_SOURCE_EXTERNAL_KV_TRANSFER, + )) + .inc_by(external_kv_transfer); + metrics() + .prompt_tokens_cached + .get_or_create(&engine_labels(model_name, engine)) + .inc_by(prefill_stats.num_cached_tokens as u64); +} + +fn diff_or_zero(end: f64, start: f64) -> f64 { + if end > 0.0 && start > 0.0 && end >= start { + end - start + } else { + 0.0 + } +} + +/// Return the current wall-clock time in seconds since the Unix epoch. +/// +/// This is used for frontend-side latency measurements such as TTFT and E2E, +/// matching the Python frontend's use of wall-clock request arrival/iteration +/// timestamps rather than engine-core's monotonic scheduler timestamps. +/// +/// Original Python request timestamp source: +/// +pub(crate) fn current_unix_timestamp_secs() -> f64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock is before unix epoch") + .as_secs_f64() +} + +#[cfg(test)] +mod tests { + use vllm_engine_core_client::protocol::stats::PrefillStats; + use vllm_engine_core_client::protocol::{EngineCoreEvent, EngineCoreEventType}; + + use super::{RequestMetricsTracker, diff_or_zero}; + + #[test] + fn tracker_updates_timing_state_across_prefill_decode_and_finish() { + let mut tracker = RequestMetricsTracker::new("model".to_string(), 100.0, 64, Some(128), 1); + + tracker.observe_output( + 2, + 10.0, + 100.2, + &vllm_engine_core_client::protocol::EngineCoreOutput { + request_id: "req-1".to_string(), + new_token_ids: vec![1], + finish_reason: None, + events: Some(vec![ + EngineCoreEvent { + r#type: EngineCoreEventType::Queued, + timestamp: 8.0, + }, + EngineCoreEvent { + r#type: EngineCoreEventType::Scheduled, + timestamp: 9.0, + }, + ]), + prefill_stats: Some(PrefillStats { + num_prompt_tokens: 64, + num_computed_tokens: 60, + num_cached_tokens: 4, + num_local_cached_tokens: 4, + num_external_cached_tokens: 0, + }), + ..Default::default() + }, + ); + tracker.observe_output( + 2, + 11.5, + 100.4, + &vllm_engine_core_client::protocol::EngineCoreOutput { + request_id: "req-1".to_string(), + new_token_ids: vec![2, 3], + finish_reason: None, + events: Some(vec![EngineCoreEvent { + r#type: EngineCoreEventType::Preempted, + timestamp: 10.5, + }]), + ..Default::default() + }, + ); + + assert!(!tracker.is_prefilling); + assert_eq!(tracker.last_seen_engine_index, 2); + assert_eq!(tracker.num_generation_tokens, 3); + assert_eq!(tracker.queued_ts, 8.0); + assert_eq!(tracker.scheduled_ts, 9.0); + assert_eq!(tracker.first_token_ts, 10.0); + assert_eq!(tracker.last_token_ts, 11.5); + assert!((tracker.first_token_latency - 0.2).abs() < 1e-9); + assert_eq!( + diff_or_zero(tracker.last_token_ts, tracker.first_token_ts), + 1.5 + ); + } +} diff --git a/rust/src/llm/tests/generate.rs b/rust/src/llm/tests/generate.rs new file mode 100644 index 00000000000..8b1b98bdc48 --- /dev/null +++ b/rust/src/llm/tests/generate.rs @@ -0,0 +1,749 @@ +use std::collections::BTreeSet; +use std::sync::Once; +use std::time::Duration; + +use futures::StreamExt as _; +use tokio::time::timeout; +use tracing_subscriber::EnvFilter; +use uuid::Uuid; +use vllm_engine_core_client::protocol::logprobs::{ + Logprobs, MaybeWireLogprobs, PositionLogprobs, TokenLogprob, +}; +use vllm_engine_core_client::protocol::stats::PrefillStats; +use vllm_engine_core_client::protocol::{ + EngineCoreEvent, EngineCoreEventType, EngineCoreFinishReason, EngineCoreOutput, + EngineCoreOutputs, EngineCoreRequest, EngineCoreSamplingParams, +}; +use vllm_engine_core_client::test_utils::{IpcNamespace, spawn_mock_engine_task}; +use vllm_engine_core_client::{EngineCoreClient, EngineCoreClientConfig}; +use vllm_llm::{ + Error, FinishReason, GenerateOutputStreamExt as _, GeneratePromptInfo, GenerateRequest, Llm, +}; +use vllm_metrics::METRICS; +use zeromq::prelude::{SocketRecv, SocketSend}; +use zeromq::{DealerSocket, PushSocket, ZmqMessage}; + +static TRACING: Once = Once::new(); + +fn request_output( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, +) -> EngineCoreOutput { + request_output_with_events(request_id, new_token_ids, finish_reason, None) +} + +fn request_output_with_events( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, + events: Option>, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason, + stop_reason: None, + events, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn request_output_with_logprobs( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, + new_logprobs: Option, + prompt_logprobs: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: new_logprobs.map(MaybeWireLogprobs::Direct), + new_prompt_logprobs_tensors: prompt_logprobs.map(MaybeWireLogprobs::Direct), + pooling_output: None, + finish_reason, + stop_reason: None, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn request_output_with_logprobs_and_kv( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, + new_logprobs: Option, + prompt_logprobs: Option, + kv_transfer_params: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: new_logprobs.map(MaybeWireLogprobs::Direct), + new_prompt_logprobs_tensors: prompt_logprobs.map(MaybeWireLogprobs::Direct), + pooling_output: None, + finish_reason, + stop_reason: None, + events: None, + kv_transfer_params, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn logprobs_for_position( + sampled_token_id: u32, + sampled_logprob: f32, + sampled_rank: u32, + top_token_id: u32, + top_logprob: f32, +) -> Logprobs { + Logprobs { + positions: vec![PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: sampled_token_id, + logprob: sampled_logprob, + rank: sampled_rank, + }, + TokenLogprob { + token_id: top_token_id, + logprob: top_logprob, + rank: 1, + }, + ], + }], + } +} + +fn prompt_logprobs() -> Logprobs { + Logprobs { + positions: vec![ + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: 11, + logprob: -0.1, + rank: 2, + }, + TokenLogprob { + token_id: 7, + logprob: -0.05, + rank: 1, + }, + ], + }, + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: 22, + logprob: -0.2, + rank: 3, + }, + TokenLogprob { + token_id: 8, + logprob: -0.1, + rank: 1, + }, + ], + }, + ], + } +} + +fn sample_generate_request(request_id: &str, max_tokens: u32) -> GenerateRequest { + GenerateRequest { + request_id: request_id.to_string(), + prompt_token_ids: vec![11, 22], + sampling_params: EngineCoreSamplingParams { + max_tokens, + ..EngineCoreSamplingParams::for_test() + }, + mm_features: None, + arrival_time: Some(42.5), + cache_salt: None, + trace_headers: None, + priority: 0, + data_parallel_rank: None, + reasoning_ended: None, + lora_request: None, + } +} + +async fn send_outputs(push: &mut PushSocket, outputs: EngineCoreOutputs) { + push.send(ZmqMessage::from(rmp_serde::to_vec_named(&outputs).unwrap())) + .await + .unwrap(); +} + +async fn recv_engine_message(dealer: &mut DealerSocket) -> Vec { + dealer.recv().await.unwrap().into_vec() +} + +async fn connect_async_llm_with_ipc( + handshake_address: String, + client_index: u32, + model_name: &str, + ipc: &IpcNamespace, +) -> Llm { + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name(model_name) + .with_client_index(client_index) + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .unwrap(); + Llm::new(client) +} + +fn request_metrics_model_name(prefix: &str) -> String { + format!("{prefix}-{}", Uuid::new_v4().simple()) +} + +fn init_tracing() { + TRACING.call_once(|| { + let filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("vllm_engine_core_client=debug")); + let _ = tracing_subscriber::fmt().with_test_writer().with_env_filter(filter).try_init(); + }); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn generate_streams_outputs() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-delta".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.external_req_id.as_deref(), Some("req-delta")); + assert!(request.request_id.starts_with("req-delta-")); + assert_ne!(request.request_id, "req-delta"); + assert_eq!(request.client_index, 7); + assert_eq!(request.prompt_token_ids, Some(vec![11, 22])); + + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![ + request_output_with_logprobs( + &request.request_id, + vec![1, 2], + None, + Some(logprobs_for_position(1, -0.3, 4, 9, -0.1)), + Some(prompt_logprobs()), + ), + request_output_with_logprobs( + &request.request_id, + vec![3], + Some(EngineCoreFinishReason::Length), + Some(logprobs_for_position(3, -0.4, 5, 10, -0.2)), + None, + ), + ], + finished_requests: Some(BTreeSet::from([request.request_id.clone()])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let llm = connect_async_llm_with_ipc(handshake_address, 7, "test-model", &ipc).await; + let mut stream = llm.generate(sample_generate_request("req-delta", 3)).await.unwrap(); + let internal_id = stream.request_id().to_string(); + + let first = stream.next().await.unwrap().unwrap(); + assert_eq!(first.request_id, internal_id); + assert_eq!( + first.prompt_info, + Some(GeneratePromptInfo { + prompt_token_ids: vec![11, 22].into(), + prompt_logprobs: Some(prompt_logprobs()), + }) + ); + assert_eq!(first.token_ids, vec![1, 2]); + assert_eq!( + first.logprobs, + Some(logprobs_for_position(1, -0.3, 4, 9, -0.1)) + ); + assert_eq!(first.finish_reason, None); + + let second = stream.next().await.unwrap().unwrap(); + assert_eq!(second.prompt_info, None); + assert_eq!(second.token_ids, vec![3]); + assert_eq!( + second.logprobs, + Some(logprobs_for_position(3, -0.4, 5, 10, -0.2)) + ); + assert_eq!(second.finish_reason, Some(FinishReason::Length)); + assert!(stream.next().await.is_none()); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + llm.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn collect_output_aggregates_raw_tokens_logprobs_and_terminal_metadata() { + init_tracing(); + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-collect-output".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.external_req_id.as_deref(), Some("req-collect")); + assert!(request.request_id.starts_with("req-collect-")); + + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![ + request_output_with_logprobs( + &request.request_id, + vec![33], + None, + Some(logprobs_for_position(33, -0.1, 1, 99, -0.2)), + Some(prompt_logprobs()), + ), + request_output_with_logprobs_and_kv( + &request.request_id, + vec![44], + Some(EngineCoreFinishReason::Stop), + Some(logprobs_for_position(44, -0.3, 1, 88, -0.4)), + None, + Some(serde_json::json!({"connector": "x"})), + ), + ], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + }, + ) + .await; + }) + }, + ); + + let llm = connect_async_llm_with_ipc(handshake_address, 7, "test-model", &ipc).await; + let stream = llm.generate(sample_generate_request("req-collect", 4)).await.unwrap(); + let internal_id = stream.request_id().to_string(); + let collected = stream.collect_output().await.unwrap(); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + + assert_eq!(collected.request_id, internal_id); + assert_eq!(collected.prompt_token_ids, vec![11, 22]); + assert_eq!(collected.token_ids, vec![33, 44]); + assert_eq!(collected.finish_reason, FinishReason::stop_eos()); + assert_eq!(collected.prompt_logprobs, Some(prompt_logprobs())); + assert_eq!( + collected.logprobs.as_ref().map(|lp| lp.positions.len()), + Some(2) + ); + assert_eq!( + collected.kv_transfer_params, + Some(serde_json::json!({"connector": "x"})) + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn generate_propagates_unexpected_close_errors() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-close".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + + send_outputs( + push, + EngineCoreOutputs { + finished_requests: Some(BTreeSet::from([request.request_id])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let llm = connect_async_llm_with_ipc(handshake_address, 0, "test-model", &ipc).await; + let mut stream = llm.generate(sample_generate_request("req-close", 1)).await.unwrap(); + let internal_id = stream.request_id().to_string(); + + let error = stream.next().await.unwrap().unwrap_err(); + assert!(matches!( + error, + Error::EngineCoreClient(vllm_engine_core_client::Error::RequestStreamClosed { + request_id + }) if request_id == internal_id + )); + assert!(stream.next().await.is_none()); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + llm.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn dropping_a_live_generate_stream_triggers_abort() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-drop".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.external_req_id.as_deref(), Some("req-drop")); + assert!(request.request_id.starts_with("req-drop-")); + + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output(&request.request_id, vec![99], None)], + ..Default::default() + }, + ) + .await; + + let abort = + timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap(); + assert_eq!(abort[0].as_ref(), &[0x01]); + let aborted_ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap(); + assert_eq!(aborted_ids, vec![request.request_id]); + }) + }, + ); + + let llm = connect_async_llm_with_ipc(handshake_address, 0, "test-model", &ipc).await; + let mut stream = llm.generate(sample_generate_request("req-drop", 4)).await.unwrap(); + + let output = stream.next().await.unwrap().unwrap(); + assert_eq!(output.token_ids, vec![99]); + drop(stream); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + llm.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn duplicate_external_request_ids_are_randomized_before_reaching_engine_core_client() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-dup".to_vec(); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add_1 = recv_engine_message(dealer).await; + assert_eq!(add_1[0].as_ref(), &[0x00]); + let request_1: EngineCoreRequest = rmp_serde::from_slice(&add_1[1]).unwrap(); + assert_eq!(request_1.external_req_id.as_deref(), Some("req-dup")); + assert!(request_1.request_id.starts_with("req-dup-")); + + let add_2 = recv_engine_message(dealer).await; + assert_eq!(add_2[0].as_ref(), &[0x00]); + let request_2: EngineCoreRequest = rmp_serde::from_slice(&add_2[1]).unwrap(); + assert_eq!(request_2.external_req_id.as_deref(), Some("req-dup")); + assert!(request_2.request_id.starts_with("req-dup-")); + assert_ne!(request_1.request_id, request_2.request_id); + + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output( + &request_1.request_id, + vec![], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from([request_1.request_id.clone()])), + ..Default::default() + }, + ) + .await; + + send_outputs( + push, + EngineCoreOutputs { + outputs: vec![request_output( + &request_2.request_id, + vec![], + Some(EngineCoreFinishReason::Length), + )], + finished_requests: Some(BTreeSet::from([request_2.request_id])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let llm = connect_async_llm_with_ipc(handshake_address, 0, "test-model", &ipc).await; + let stream_1 = llm.generate(sample_generate_request("req-dup", 1)).await.unwrap(); + let stream_2 = llm.generate(sample_generate_request("req-dup", 1)).await.unwrap(); + let internal_id_1 = stream_1.request_id().to_string(); + let internal_id_2 = stream_2.request_id().to_string(); + let collected_1 = stream_1.collect_output().await.unwrap(); + let collected_2 = stream_2.collect_output().await.unwrap(); + assert_eq!(collected_1.request_id, internal_id_1); + assert_eq!(collected_2.request_id, internal_id_2); + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + llm.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn generate_records_request_metrics_in_prometheus_output() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-metrics".to_vec(); + let model_name = request_metrics_model_name("metrics-model"); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + + send_outputs( + push, + EngineCoreOutputs { + engine_index: 4, + timestamp: 10.0, + outputs: vec![EngineCoreOutput { + prefill_stats: Some(PrefillStats { + num_prompt_tokens: 2, + num_computed_tokens: 2, + ..Default::default() + }), + ..request_output_with_events( + &request.request_id, + vec![1], + None, + Some(vec![ + EngineCoreEvent { + r#type: EngineCoreEventType::Queued, + timestamp: 8.0, + }, + EngineCoreEvent { + r#type: EngineCoreEventType::Scheduled, + timestamp: 9.0, + }, + ]), + ) + }], + ..Default::default() + }, + ) + .await; + + send_outputs( + push, + EngineCoreOutputs { + engine_index: 4, + timestamp: 11.5, + outputs: vec![request_output_with_events( + &request.request_id, + vec![2, 3], + Some(EngineCoreFinishReason::Length), + Some(vec![EngineCoreEvent { + r#type: EngineCoreEventType::Preempted, + timestamp: 10.5, + }]), + )], + finished_requests: Some(BTreeSet::from([request.request_id])), + ..Default::default() + }, + ) + .await; + }) + }, + ); + + let llm = connect_async_llm_with_ipc(handshake_address, 0, &model_name, &ipc).await; + let mut request = sample_generate_request("req-metrics", 8); + request.arrival_time = None; + let mut stream = llm.generate(request).await.unwrap(); + + assert_eq!(stream.next().await.unwrap().unwrap().token_ids, vec![1]); + let final_output = stream.next().await.unwrap().unwrap(); + assert_eq!(final_output.token_ids, vec![2, 3]); + assert_eq!(final_output.finish_reason, Some(FinishReason::Length)); + assert!(stream.next().await.is_none()); + + let rendered = METRICS.render().unwrap(); + assert!(rendered.contains(&format!( + "vllm:request_success_total{{model_name=\"{model_name}\",engine=\"4\",finished_reason=\"length\"}} 1" + ))); + assert!(rendered.contains(&format!( + "vllm:prompt_tokens_total{{model_name=\"{model_name}\",engine=\"4\"}} 2" + ))); + assert!(rendered.contains(&format!( + "vllm:prompt_tokens_by_source_total{{model_name=\"{model_name}\",engine=\"4\",source=\"local_compute\"}} 2" + ))); + assert!(rendered.contains(&format!( + "vllm:prompt_tokens_by_source_total{{model_name=\"{model_name}\",engine=\"4\",source=\"local_cache_hit\"}} 0" + ))); + assert!(rendered.contains(&format!( + "vllm:prompt_tokens_by_source_total{{model_name=\"{model_name}\",engine=\"4\",source=\"external_kv_transfer\"}} 0" + ))); + assert!(rendered.contains(&format!( + "vllm:prompt_tokens_cached_total{{model_name=\"{model_name}\",engine=\"4\"}} 0" + ))); + assert!(rendered.contains(&format!( + "vllm:generation_tokens_total{{model_name=\"{model_name}\",engine=\"4\"}} 3" + ))); + assert!(rendered.contains(&format!( + "vllm:num_preemptions_total{{model_name=\"{model_name}\",engine=\"4\"}} 1" + ))); + assert!(rendered.contains(&format!( + "vllm:time_to_first_token_seconds_count{{model_name=\"{model_name}\",engine=\"4\"}} 1" + ))); + assert!(rendered.contains(&format!( + "vllm:inter_token_latency_seconds_count{{model_name=\"{model_name}\",engine=\"4\"}} 1" + ))); + assert!(rendered.contains(&format!( + "vllm:e2e_request_latency_seconds_count{{model_name=\"{model_name}\",engine=\"4\"}} 1" + ))); + assert!(rendered.contains(&format!( + "vllm:request_prompt_tokens_count{{model_name=\"{model_name}\",engine=\"4\"}} 1" + ))); + assert!(rendered.contains(&format!( + "vllm:request_generation_tokens_count{{model_name=\"{model_name}\",engine=\"4\"}} 1" + ))); + assert!(rendered.contains(&format!( + "vllm:request_prefill_kv_computed_tokens_count{{model_name=\"{model_name}\",engine=\"4\"}} 1" + ))); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + llm.shutdown().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn dropping_stream_records_abort_terminal_request_metrics() { + let ipc = IpcNamespace::new().unwrap(); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-metrics-drop".to_vec(); + let model_name = request_metrics_model_name("metrics-drop-model"); + + let (shutdown_tx, engine_task) = spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + assert_eq!(add[0].as_ref(), &[0x00]); + let request: EngineCoreRequest = rmp_serde::from_slice(&add[1]).unwrap(); + assert_eq!(request.external_req_id.as_deref(), Some("req-metrics-drop")); + assert!(request.request_id.starts_with("req-metrics-drop-")); + + send_outputs( + push, + EngineCoreOutputs { + engine_index: 5, + timestamp: 10.0, + outputs: vec![request_output_with_events( + &request.request_id, + vec![99], + None, + Some(vec![ + EngineCoreEvent { + r#type: EngineCoreEventType::Queued, + timestamp: 8.0, + }, + EngineCoreEvent { + r#type: EngineCoreEventType::Scheduled, + timestamp: 9.0, + }, + ]), + )], + ..Default::default() + }, + ) + .await; + + let abort = + timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap(); + assert_eq!(abort[0].as_ref(), &[0x01]); + let aborted_ids: Vec = rmp_serde::from_slice(&abort[1]).unwrap(); + assert_eq!(aborted_ids, vec![request.request_id]); + }) + }, + ); + + let llm = connect_async_llm_with_ipc(handshake_address, 0, &model_name, &ipc).await; + let mut request = sample_generate_request("req-metrics-drop", 8); + request.arrival_time = None; + let mut stream = llm.generate(request).await.unwrap(); + assert_eq!(stream.next().await.unwrap().unwrap().token_ids, vec![99]); + drop(stream); + + let _ = shutdown_tx.send(()); + engine_task.await.unwrap(); + let rendered = METRICS.render().unwrap(); + assert!(rendered.contains(&format!( + "vllm:request_success_total{{model_name=\"{model_name}\",engine=\"5\",finished_reason=\"abort\"}} 1" + ))); + assert!(rendered.contains(&format!( + "vllm:e2e_request_latency_seconds_count{{model_name=\"{model_name}\",engine=\"5\"}} 1" + ))); + + llm.shutdown().await.unwrap(); +} diff --git a/rust/src/managed-engine/Cargo.toml b/rust/src/managed-engine/Cargo.toml new file mode 100644 index 00000000000..2bc108b2c53 --- /dev/null +++ b/rust/src/managed-engine/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "vllm-managed-engine" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +anyhow.workspace = true +clap.workspace = true +libc.workspace = true +tokio = { workspace = true, features = ["process"] } +tracing.workspace = true + +[dev-dependencies] +expect-test.workspace = true + +[lints] +workspace = true diff --git a/rust/src/managed-engine/src/cli.rs b/rust/src/managed-engine/src/cli.rs new file mode 100644 index 00000000000..302737dbd88 --- /dev/null +++ b/rust/src/managed-engine/src/cli.rs @@ -0,0 +1,349 @@ +use std::collections::HashSet; +use std::ffi::OsString; + +use clap::error::ErrorKind; +use clap::{Args, CommandFactory}; + +use crate::{ManagedEngineConfig, allocate_handshake_port}; + +/// Managed Python headless-engine CLI arguments. +#[derive(Debug, Clone, Args, PartialEq, Eq)] +pub struct ManagedEngineArgs { + /// Python executable used to launch the managed headless vLLM engine. + #[arg(long, env = "VLLM_RS_PYTHON", default_value = "python3")] + pub python: String, + /// Host/IP used both for the managed-engine handshake endpoint and the + /// frontend-advertised input/output ZMQ socket addresses. + #[arg( + long = "data-parallel-address", + visible_alias = "handshake-host", + default_value = "127.0.0.1" + )] + pub handshake_host: String, + /// Optional TCP port for the managed-engine handshake / data-parallel RPC + /// endpoint. + /// + /// When omitted, the CLI allocates an ephemeral port automatically. + #[arg( + long = "data-parallel-rpc-port", + visible_alias = "handshake-port", + value_parser = clap::value_parser!(u16).range(1..) + )] + pub handshake_port: Option, + /// Number of data parallel replicas across the whole deployment. + #[arg(long, default_value_t = 1)] + pub data_parallel_size: usize, + /// Number of data parallel replicas to run on this node. + #[arg(long)] + pub data_parallel_size_local: Option, + + /// Additional arguments forwarded to `python -m vllm.entrypoints.cli.main + /// serve ...`. + /// + /// Arguments after an explicit `--` are forwarded verbatim. Before `--`, + /// `vllm-rs serve` automatically keeps recognized frontend options on + /// the Rust side and forwards everything else to Python. + #[arg( + last = true, + allow_hyphen_values = true, + help_heading = "Passthrough arguments" + )] + pub python_args: Vec, +} + +impl ManagedEngineArgs { + /// Build the handshake address shared by the Rust frontend and managed + /// Python engine. + pub fn handshake_address(&self, handshake_port: u16) -> String { + format!("tcp://{}:{}", self.handshake_host, handshake_port) + } + + /// Resolve the handshake port, either from the CLI argument (if specified) + /// or by allocating a fresh port. + pub fn resolve_handshake_port(&self) -> anyhow::Result { + self.handshake_port + .map(Ok) + .unwrap_or_else(|| allocate_handshake_port(&self.handshake_host)) + } + + /// Build the managed Python-engine spawn configuration. + pub fn into_config( + self, + model: String, + max_model_len: Option, + handshake_port: u16, + ) -> ManagedEngineConfig { + let mut python_args = self.python_args; + // Manually forward some args to the Python engine. + if let Some(max_model_len) = max_model_len { + python_args.push("--max-model-len".to_string()); + python_args.push(max_model_len.to_string()); + } + if let Some(data_parallel_size_local) = self.data_parallel_size_local { + python_args.push("--data-parallel-size-local".to_string()); + python_args.push(data_parallel_size_local.to_string()); + } + + ManagedEngineConfig { + python: self.python, + model, + handshake_host: self.handshake_host, + handshake_port, + data_parallel_size: self.data_parallel_size, + python_args, + } + } + + /// Return the number of engines that the Rust frontend should expect to + /// coordinate with. + fn local_engine_count(&self) -> usize { + self.data_parallel_size_local.unwrap_or(self.data_parallel_size) + } + + /// Return whether the managed Rust frontend only needs to communicate with + /// colocated engines. + pub fn frontend_local_only(&self) -> bool { + self.data_parallel_size_local != Some(0) + && self.local_engine_count() == self.data_parallel_size + } +} + +/// Python `argparse` accepts these multi-character single-dash aliases, but +/// `clap` cannot model them directly. +const PYTHON_MULTI_CHAR_ALIASES: &[(&str, &str)] = &[ + ("-asc", "--api-server-count"), + ("-pp", "--pipeline-parallel-size"), + ("-tp", "--tensor-parallel-size"), + ("-dcp", "--decode-context-parallel-size"), + ("-pcp", "--prefill-context-parallel-size"), + ("-dp", "--data-parallel-size"), + ("-dpn", "--data-parallel-rank"), + ("-dpr", "--data-parallel-start-rank"), + ("-dpl", "--data-parallel-size-local"), + ("-dpa", "--data-parallel-address"), + ("-dpp", "--data-parallel-rpc-port"), + ("-dpb", "--data-parallel-backend"), + ("-dph", "--data-parallel-hybrid-lb"), + ("-dpe", "--data-parallel-external-lb"), + ("-ep", "--enable-expert-parallel"), + ("-cc", "--compilation-config"), + ("-ac", "--attention-config"), +]; + +/// Repartition managed-engine argv so Rust-owned flags stay before `--`, while +/// everything else is forwarded to Python. +pub fn repartition_managed_engine_args( + args: &[OsString], + subcommand: Option<&str>, +) -> Result, clap::Error> +where + C: CommandFactory, +{ + let command = C::command(); + let (prefix, real_args, command) = match subcommand { + Some(subcommand) => { + if !matches_subcommand(args, subcommand) { + return Ok(args.to_vec()); + }; + + let subcommand = command + .find_subcommand(subcommand) + .expect("managed-engine subcommand should exist"); + + (args[..2].to_vec(), &args[2..], subcommand) + } + None => { + let Some(program) = args.first() else { + return Ok(args.to_vec()); + }; + + (vec![program.clone()], &args[1..], &command) + } + }; + + let mut repartitioned = prefix; + repartitioned.extend(repartition_real_managed_engine_args(real_args, command)?); + Ok(repartitioned) +} + +fn repartition_real_managed_engine_args( + args: &[OsString], + command: &clap::Command, +) -> Result, clap::Error> { + let Some(model) = args.first() else { + return Ok(args.to_vec()); + }; + + let model = model.to_string_lossy(); + if is_help_flag(&model) { + return Ok(args.to_vec()); + } + if model == "--" || is_option_like(&model) { + return Err(build_missing_model_error(command)); + } + + let (long_flags, short_flags) = collect_option_names(command); + let (front_args, explicit_passthrough, had_separator) = split_managed_engine_args(&args[1..]); + let normalized_front_args = normalize_python_arg_aliases(front_args); + + let mut frontend_chunks = Vec::new(); + let mut python_chunks = Vec::new(); + let mut current_chunk = Vec::new(); + + for arg in normalized_front_args { + let text = arg.to_string_lossy(); + if is_option_like(&text) && !current_chunk.is_empty() { + push_chunk( + &mut frontend_chunks, + &mut python_chunks, + std::mem::take(&mut current_chunk), + &long_flags, + &short_flags, + ); + } + current_chunk.push(arg); + } + if !current_chunk.is_empty() { + push_chunk( + &mut frontend_chunks, + &mut python_chunks, + current_chunk, + &long_flags, + &short_flags, + ); + } + + let mut repartitioned = vec![args[0].clone()]; + repartitioned.extend(frontend_chunks); + if had_separator || !python_chunks.is_empty() || !explicit_passthrough.is_empty() { + repartitioned.push("--".into()); + repartitioned.extend(python_chunks); + repartitioned.extend(explicit_passthrough.iter().cloned()); + } + + Ok(repartitioned) +} + +fn matches_subcommand(args: &[OsString], subcommand: &str) -> bool { + args.get(1) + .and_then(|arg| arg.to_str()) + .is_some_and(|candidate| candidate == subcommand) +} + +fn split_managed_engine_args(args: &[OsString]) -> (&[OsString], &[OsString], bool) { + if let Some(index) = args.iter().position(|arg| arg == "--") { + (&args[..index], &args[index + 1..], true) + } else { + (args, &[], false) + } +} + +fn normalize_python_arg_aliases(args: &[OsString]) -> Vec { + args.iter() + .map(|arg| { + let text = arg.to_string_lossy(); + normalize_python_multi_char_alias(&text) + .map(Into::into) + .unwrap_or_else(|| arg.clone()) + }) + .collect() +} + +fn normalize_python_multi_char_alias(arg: &str) -> Option { + find_python_multi_char_alias(arg).map(|canonical| match arg.split_once('=') { + Some((_, value)) => format!("{canonical}={value}"), + None => canonical.to_string(), + }) +} + +fn find_python_multi_char_alias(arg: &str) -> Option<&'static str> { + PYTHON_MULTI_CHAR_ALIASES.iter().find_map(|&(alias, canonical)| { + (arg == alias || arg.starts_with(&format!("{alias}="))).then_some(canonical) + }) +} + +fn push_chunk( + frontend_chunks: &mut Vec, + python_chunks: &mut Vec, + chunk: Vec, + long_flags: &HashSet, + short_flags: &HashSet, +) { + if chunk_head_is_frontend_owned(&chunk, long_flags, short_flags) { + frontend_chunks.extend(chunk); + } else { + python_chunks.extend(chunk); + } +} + +fn chunk_head_is_frontend_owned( + chunk: &[OsString], + long_flags: &HashSet, + short_flags: &HashSet, +) -> bool { + let Some(head) = chunk.first() else { + return false; + }; + let head = head.to_string_lossy(); + + if let Some(rest) = head.strip_prefix("--") { + let name = rest.split_once('=').map_or(rest, |(name, _)| name); + return long_flags.contains(name); + } + + let Some(rest) = head.strip_prefix('-') else { + return false; + }; + let Some(short) = rest.chars().next() else { + return false; + }; + short_flags.contains(&short) +} + +fn collect_option_names(command: &clap::Command) -> (HashSet, HashSet) { + let mut long_flags = HashSet::new(); + let mut short_flags = HashSet::new(); + for arg in command.get_arguments() { + if let Some(names) = arg.get_long_and_visible_aliases() { + long_flags.extend(names.into_iter().map(str::to_owned)); + } + if let Some(short) = arg.get_short() { + short_flags.insert(short); + } + if let Some(short_aliases) = arg.get_visible_short_aliases() { + short_flags.extend(short_aliases); + } + } + + long_flags.insert("help".to_string()); + short_flags.insert('h'); + + (long_flags, short_flags) +} + +fn is_option_like(arg: &str) -> bool { + if arg == "--" { + return false; + } + + if let Some(rest) = arg.strip_prefix("--") { + return rest.chars().next().is_some_and(char::is_alphabetic); + } + + if let Some(rest) = arg.strip_prefix('-') { + return rest.chars().next().is_some_and(char::is_alphabetic); + } + + false +} + +fn is_help_flag(arg: &str) -> bool { + arg == "-h" || arg == "--help" +} + +fn build_missing_model_error(command: &clap::Command) -> clap::Error { + command.clone().error( + ErrorKind::MissingRequiredArgument, + "the model must appear immediately after the command", + ) +} diff --git a/rust/src/managed-engine/src/lib.rs b/rust/src/managed-engine/src/lib.rs new file mode 100644 index 00000000000..e9812104cb4 --- /dev/null +++ b/rust/src/managed-engine/src/lib.rs @@ -0,0 +1,4 @@ +pub mod cli; +mod process; + +pub use process::{ManagedEngineConfig, ManagedEngineHandle, allocate_handshake_port}; diff --git a/rust/src/managed-engine/src/process.rs b/rust/src/managed-engine/src/process.rs new file mode 100644 index 00000000000..0a506244dc3 --- /dev/null +++ b/rust/src/managed-engine/src/process.rs @@ -0,0 +1,263 @@ +use std::io; +use std::net::TcpListener; +use std::process::{Command as StdCommand, ExitStatus, Stdio}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; + +use anyhow::{Context, Result}; +use tokio::process::{Child, Command}; +use tokio::sync::Mutex; +use tokio::time::interval; +use tracing::info; + +const CHILD_POLL_INTERVAL: Duration = Duration::from_millis(200); +const MIN_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + +/// Allocate one ephemeral TCP port for the managed headless-engine handshake on +/// the given host. +pub fn allocate_handshake_port(host: &str) -> Result { + let listener = TcpListener::bind((host, 0)).context("failed to allocate handshake port")?; + let port = listener + .local_addr() + .context("failed to inspect allocated handshake listener address")? + .port(); + Ok(port) +} + +/// Spawn configuration for one managed headless Python vLLM engine. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ManagedEngineConfig { + /// Python executable used to launch `vllm.entrypoints.cli.main`. + pub python: String, + /// Model identifier passed to `vllm ... serve `. + pub model: String, + /// Host portion of the headless-engine handshake endpoint. + pub handshake_host: String, + /// Port portion of the headless-engine handshake endpoint. + pub handshake_port: u16, + /// Number of data parallel replicas across the whole deployment. + /// + /// The per-node replica count is forwarded separately in `python_args` as + /// `--data-parallel-size-local`. + pub data_parallel_size: usize, + /// Extra CLI arguments forwarded verbatim to Python vLLM. + pub python_args: Vec, +} + +impl ManagedEngineConfig { + /// Render the handshake address that the Rust frontend should dial. + pub fn handshake_address(&self) -> String { + format!("tcp://{}:{}", self.handshake_host, self.handshake_port) + } + + /// Build the concrete Python command line for the managed headless engine. + pub fn to_command(&self) -> StdCommand { + let mut command = StdCommand::new(&self.python); + command + .arg("-m") + .arg("vllm.entrypoints.cli.main") + .arg("serve") + .arg(&self.model) + .arg("--headless") + .arg("--data-parallel-address") + .arg(&self.handshake_host) + .arg("--data-parallel-rpc-port") + .arg(self.handshake_port.to_string()) + .arg("--data-parallel-size") + .arg(self.data_parallel_size.to_string()) + .args(&self.python_args); + command + } +} + +/// RAII-style handle for one managed Python headless engine subprocess. +#[derive(Clone)] +pub struct ManagedEngineHandle { + child: Arc>, + shutdown_started: Arc, +} + +impl ManagedEngineHandle { + /// Spawn one managed Python headless engine and return a handle for + /// monitoring it. + pub async fn spawn(config: ManagedEngineConfig) -> Result { + let command = config.to_command(); + info!( + handshake_address = %config.handshake_address(), + ?command, + "starting managed Python headless engine" + ); + + let mut command = Command::from(command); + command.stdin(Stdio::null()).stdout(Stdio::inherit()).stderr(Stdio::inherit()); + + process_group::configure(&mut command); + + let child = command.spawn().context("failed to spawn managed engine")?; + + Ok(Self { + child: Arc::new(Mutex::new(child)), + shutdown_started: Arc::new(AtomicBool::new(false)), + }) + } + + /// Poll whether the managed engine has exited yet. + pub async fn try_wait(&self) -> Option { + let mut child = self.child.lock().await; + child.try_wait().expect("failed to poll the status of managed engine") + } + + /// Wait until the managed engine exits. + pub async fn wait_for_exit(&self) -> ExitStatus { + let mut interval = interval(CHILD_POLL_INTERVAL); + loop { + interval.tick().await; + if let Some(status) = self.try_wait().await { + return status; + } + } + } + + /// Terminate the managed engine process group and wait for it to stop. + pub async fn shutdown(&self, timeout: Duration) -> Result<()> { + if self.shutdown_started.swap(true, Ordering::SeqCst) { + return Ok(()); + } + + let Some(pid) = self.child.lock().await.id() else { + return Ok(()); + }; + + // Enforce a minimum shutdown timeout to give the engine process enough time to + // clean up. + let shutdown_timeout = std::cmp::max(timeout, MIN_SHUTDOWN_TIMEOUT); + + // First, try to gracefully terminate. + info!( + pid, + ?shutdown_timeout, + "shutting down managed engine with SIGTERM" + ); + process_group::terminate(pid)?; + + // Wait for the process to exit on its own. + if tokio::time::timeout(shutdown_timeout, self.wait_for_exit()).await.is_ok() { + return Ok(()); + } + + // If it doesn't exit within the timeout, force kill it. + info!( + pid, + "managed engine did not exit within timeout, sending SIGKILL" + ); + process_group::kill(pid)?; + + let _ = self.wait_for_exit().await; + Ok(()) + } +} + +/// Process group helper functions for managing the Python subprocess and its +/// potential children in a platform-aware way. +mod process_group { + use super::*; + + /// Place the Python child into its own process group so `serve` can tear + /// down the whole subtree rather than just the immediate shell process. + pub fn configure(command: &mut Command) { + unsafe { + command.pre_exec(|| { + if libc::setpgid(0, 0) != 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + }); + } + } + + /// Send SIGTERM to the managed Python process group. + pub fn terminate(pid: u32) -> Result<()> { + signal(pid, libc::SIGTERM) + } + + /// Send SIGKILL to the managed Python process group. + pub fn kill(pid: u32) -> Result<()> { + signal(pid, libc::SIGKILL) + } + + /// Deliver one signal to the managed Python process group. + fn signal(pid: u32, signal: i32) -> Result<()> { + let rc = unsafe { libc::kill(-(pid as i32), signal) }; + if rc == 0 { + return Ok(()); + } + + let error = io::Error::last_os_error(); + if matches!(error.raw_os_error(), Some(code) if code == libc::ESRCH) { + return Ok(()); + } + Err(error).context("failed to signal managed engine process group") + } +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + + use super::{ManagedEngineConfig, allocate_handshake_port}; + + #[test] + fn command_snapshot() { + let config = ManagedEngineConfig { + python: "python3".to_string(), + model: "Qwen/Qwen3-0.6B".to_string(), + handshake_host: "127.0.0.1".to_string(), + handshake_port: 62100, + data_parallel_size: 4, + python_args: vec![ + "--data-parallel-size-local".to_string(), + "2".to_string(), + "--data-parallel-start-rank".to_string(), + "2".to_string(), + "--dtype".to_string(), + "float16".to_string(), + "--max-model-len".to_string(), + "512".to_string(), + ], + }; + let command = config.to_command(); + let args = command.get_args().collect::>(); + + expect![[r#" + [ + "-m", + "vllm.entrypoints.cli.main", + "serve", + "Qwen/Qwen3-0.6B", + "--headless", + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "62100", + "--data-parallel-size", + "4", + "--data-parallel-size-local", + "2", + "--data-parallel-start-rank", + "2", + "--dtype", + "float16", + "--max-model-len", + "512", + ] + "#]] + .assert_debug_eq(&args); + } + + #[test] + fn allocate_handshake_port_returns_non_zero_port() { + let port = allocate_handshake_port("127.0.0.1").unwrap(); + assert_ne!(port, 0); + } +} diff --git a/rust/src/metrics/Cargo.toml b/rust/src/metrics/Cargo.toml new file mode 100644 index 00000000000..e6b579b97a4 --- /dev/null +++ b/rust/src/metrics/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "vllm-metrics" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +prometheus-client.workspace = true + +[lints] +workspace = true diff --git a/rust/src/metrics/src/api_server.rs b/rust/src/metrics/src/api_server.rs new file mode 100644 index 00000000000..5ef6600a385 --- /dev/null +++ b/rust/src/metrics/src/api_server.rs @@ -0,0 +1,78 @@ +use prometheus_client::encoding::EncodeLabelSet; +use prometheus_client::metrics::family::Family; +use prometheus_client::metrics::histogram::Histogram; +use prometheus_client::registry::Registry; + +use crate::U64Counter; + +const HTTP_REQUEST_DURATION_BUCKETS: [f64; 3] = [0.1, 0.5, 1.0]; +const HTTP_REQUEST_DURATION_HIGHR_BUCKETS: [f64; 21] = [ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, + 7.5, 10.0, 30.0, 60.0, +]; + +fn http_request_duration_histogram() -> Histogram { + Histogram::new(HTTP_REQUEST_DURATION_BUCKETS.iter().copied()) +} + +fn http_request_duration_highr_histogram() -> Histogram { + Histogram::new(HTTP_REQUEST_DURATION_HIGHR_BUCKETS.iter().copied()) +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct HttpRequestLabels { + pub method: String, + pub status: &'static str, + pub handler: String, +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct HttpHandlerLabels { + pub method: String, + pub handler: String, +} + +pub(crate) type HttpRequestCounterFamily = Family; +pub(crate) type HttpHandlerHistogramFamily = + Family Histogram>; + +/// API-server Prometheus families exported from the HTTP middleware layer. +pub struct ApiServerMetrics { + pub http_requests: HttpRequestCounterFamily, + pub http_request_duration_seconds: HttpHandlerHistogramFamily, + pub http_request_duration_highr_seconds: Histogram, +} + +impl ApiServerMetrics { + /// Register the API-server metric families into the shared registry. + pub(crate) fn register(registry: &mut Registry) -> Self { + let http_requests = HttpRequestCounterFamily::default(); + registry.register( + "http_requests", + "Total number of HTTP requests by method, status, and handler.", + http_requests.clone(), + ); + + let http_request_duration_seconds = HttpHandlerHistogramFamily::new_with_constructor( + http_request_duration_histogram as fn() -> Histogram, + ); + registry.register( + "http_request_duration_seconds", + "Duration of HTTP requests in seconds grouped by method and handler.", + http_request_duration_seconds.clone(), + ); + + let http_request_duration_highr_seconds = http_request_duration_highr_histogram(); + registry.register( + "http_request_duration_highr_seconds", + "High-resolution duration of HTTP requests in seconds.", + http_request_duration_highr_seconds.clone(), + ); + + Self { + http_requests, + http_request_duration_seconds, + http_request_duration_highr_seconds, + } + } +} diff --git a/rust/src/metrics/src/lib.rs b/rust/src/metrics/src/lib.rs new file mode 100644 index 00000000000..8f0db53d3ff --- /dev/null +++ b/rust/src/metrics/src/lib.rs @@ -0,0 +1,76 @@ +use std::fmt; +use std::sync::LazyLock; +use std::sync::atomic::AtomicU64; + +use prometheus_client::encoding::text::encode; +use prometheus_client::metrics::counter::Counter; +use prometheus_client::metrics::family::Family; +use prometheus_client::metrics::gauge::Gauge; +use prometheus_client::metrics::histogram::Histogram; +use prometheus_client::registry::Registry; + +mod api_server; +mod request; +mod scheduler; + +pub use api_server::*; +pub use request::*; +pub use scheduler::*; + +// Note: `prometheus-client` appends the `_total` suffix automatically when +// encoding counters, so all counter family registration names in this crate +// must use the base metric name without a trailing `_total`. +pub type U64Counter = Counter; +pub type U64Gauge = Gauge; +pub type F64Gauge = Gauge; +pub(crate) type HistogramFamily = Family Histogram>; + +/// Shared Prometheus registry for frontend metrics. +/// +/// Original Python definition: +/// +pub struct Metrics { + registry: Registry, + pub scheduler: SchedulerMetrics, + pub request: RequestMetrics, + pub api_server: ApiServerMetrics, +} + +impl Metrics { + /// Construct a new metrics registry. + pub fn new() -> Self { + let mut registry = Registry::default(); + let scheduler = SchedulerMetrics::register(&mut registry); + let request = RequestMetrics::register(&mut registry); + let api_server = ApiServerMetrics::register(&mut registry); + + Self { + registry, + scheduler, + request, + api_server, + } + } + + /// Render the current metrics registry into Prometheus/OpenMetrics text + /// format. + pub fn render(&self) -> Result { + let mut output = String::new(); + encode(&mut output, &self.registry)?; + Ok(output) + } + + /// Return the registry owned by this metrics object. + pub fn registry(&self) -> &Registry { + &self.registry + } +} + +impl Default for Metrics { + fn default() -> Self { + Self::new() + } +} + +/// Process-global metrics registry shared by the frontend crates. +pub static METRICS: LazyLock = LazyLock::new(Metrics::new); diff --git a/rust/src/metrics/src/request.rs b/rust/src/metrics/src/request.rs new file mode 100644 index 00000000000..421ff303491 --- /dev/null +++ b/rust/src/metrics/src/request.rs @@ -0,0 +1,298 @@ +use prometheus_client::encoding::EncodeLabelSet; +use prometheus_client::metrics::family::Family; +use prometheus_client::metrics::histogram::Histogram; +use prometheus_client::registry::Registry; + +use crate::{EngineLabels, HistogramFamily, U64Counter}; + +const TTFT_BUCKETS: [f64; 22] = [ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, + 20.0, 40.0, 80.0, 160.0, 640.0, 2560.0, +]; +const ITL_BUCKETS: [f64; 19] = [ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, + 40.0, 80.0, +]; +const REQUEST_LATENCY_BUCKETS: [f64; 21] = [ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0, 120.0, 240.0, + 480.0, 960.0, 1920.0, 7680.0, +]; +const REQUEST_PARAMS_N_BUCKETS: [f64; 5] = [1.0, 2.0, 5.0, 10.0, 20.0]; + +fn build_1_2_5_buckets(max_value: u32) -> Vec { + let mut buckets = Vec::new(); + let mut exponent = 0; + loop { + for mantissa in [1_u32, 2, 5] { + let value = mantissa * 10_u32.pow(exponent); + if value <= max_value { + buckets.push(value as f64); + } else { + if buckets.last().copied() != Some(max_value as f64) { + buckets.push(max_value as f64); + } + return buckets; + } + } + exponent += 1; + } +} + +fn time_to_first_token_histogram() -> Histogram { + Histogram::new(TTFT_BUCKETS.iter().copied()) +} + +fn inter_token_latency_histogram() -> Histogram { + Histogram::new(ITL_BUCKETS.iter().copied()) +} + +fn request_time_per_output_token_histogram() -> Histogram { + Histogram::new(ITL_BUCKETS.iter().copied()) +} + +fn request_latency_histogram() -> Histogram { + Histogram::new(REQUEST_LATENCY_BUCKETS.iter().copied()) +} + +fn request_token_count_histogram() -> Histogram { + // TODO: determine max value based on `max_model_len`. + Histogram::new(build_1_2_5_buckets(131_072)) +} + +fn request_params_n_histogram() -> Histogram { + Histogram::new(REQUEST_PARAMS_N_BUCKETS.iter().copied()) +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct FinishedReasonLabels { + pub model_name: String, + pub engine: u32, + pub finished_reason: &'static str, +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct PromptTokenSourceLabels { + pub model_name: String, + pub engine: u32, + pub source: &'static str, +} + +pub(crate) type FinishedReasonCounterFamily = Family; +pub(crate) type PromptTokenSourceCounterFamily = Family; + +/// Request-lifecycle Prometheus families exported from the `llm` layer. +pub struct RequestMetrics { + // Request-derived counters. + pub num_preemptions: Family, + pub prompt_tokens: Family, + pub prompt_tokens_by_source: PromptTokenSourceCounterFamily, + pub prompt_tokens_cached: Family, + pub generation_tokens: Family, + + // We intentionally don't support iteration-level histograms for now, since it seems to make + // more sense if the engine maintains these metrics and frontend simply forwards. + // + // pub iteration_tokens_total: HistogramFamily, + + // Request lifecycle counters and histograms. + pub request_success: FinishedReasonCounterFamily, + pub request_prompt_tokens: HistogramFamily, + pub request_generation_tokens: HistogramFamily, + pub request_max_num_generation_tokens: HistogramFamily, + pub request_params_max_tokens: HistogramFamily, + pub request_params_n: HistogramFamily, + pub request_prefill_kv_computed_tokens: HistogramFamily, + pub time_to_first_token_seconds: HistogramFamily, + pub inter_token_latency_seconds: HistogramFamily, + pub e2e_request_latency_seconds: HistogramFamily, + pub request_queue_time_seconds: HistogramFamily, + pub request_prefill_time_seconds: HistogramFamily, + pub request_decode_time_seconds: HistogramFamily, + pub request_inference_time_seconds: HistogramFamily, + pub request_time_per_output_token_seconds: HistogramFamily, +} + +impl RequestMetrics { + /// Register the request-oriented metric families into the shared registry. + pub(crate) fn register(registry: &mut Registry) -> Self { + // Request-derived counters. + let num_preemptions = Family::default(); + registry.register( + "vllm:num_preemptions", + "Cumulative number of preemption events.", + num_preemptions.clone(), + ); + + let prompt_tokens = Family::default(); + registry.register( + "vllm:prompt_tokens", + "Number of prefill tokens processed.", + prompt_tokens.clone(), + ); + + let prompt_tokens_by_source = Family::default(); + registry.register( + "vllm:prompt_tokens_by_source", + "Number of prompt tokens by source.", + prompt_tokens_by_source.clone(), + ); + + let prompt_tokens_cached = Family::default(); + registry.register( + "vllm:prompt_tokens_cached", + "Number of prompt tokens with prefix cache hits.", + prompt_tokens_cached.clone(), + ); + + let generation_tokens = Family::default(); + registry.register( + "vllm:generation_tokens", + "Number of generation tokens processed.", + generation_tokens.clone(), + ); + + // Request lifecycle counters and histograms. + let request_success = Family::default(); + registry.register( + "vllm:request_success", + "Count of successfully processed requests.", + request_success.clone(), + ); + + let request_prompt_tokens = + Family::new_with_constructor(request_token_count_histogram as fn() -> Histogram); + registry.register( + "vllm:request_prompt_tokens", + "Number of prefill tokens processed.", + request_prompt_tokens.clone(), + ); + + let request_generation_tokens = + Family::new_with_constructor(request_token_count_histogram as fn() -> Histogram); + registry.register( + "vllm:request_generation_tokens", + "Number of generation tokens processed.", + request_generation_tokens.clone(), + ); + + let request_max_num_generation_tokens = + Family::new_with_constructor(request_token_count_histogram as fn() -> Histogram); + registry.register( + "vllm:request_max_num_generation_tokens", + "Histogram of maximum number of requested generation tokens.", + request_max_num_generation_tokens.clone(), + ); + + let request_params_max_tokens = + Family::new_with_constructor(request_token_count_histogram as fn() -> Histogram); + registry.register( + "vllm:request_params_max_tokens", + "Histogram of the max_tokens request parameter.", + request_params_max_tokens.clone(), + ); + + let request_params_n = + Family::new_with_constructor(request_params_n_histogram as fn() -> Histogram); + registry.register( + "vllm:request_params_n", + "Histogram of the n request parameter.", + request_params_n.clone(), + ); + + let request_prefill_kv_computed_tokens = + Family::new_with_constructor(request_token_count_histogram as fn() -> Histogram); + registry.register( + "vllm:request_prefill_kv_computed_tokens", + "Histogram of new KV tokens computed during prefill (excluding cached tokens).", + request_prefill_kv_computed_tokens.clone(), + ); + + let time_to_first_token_seconds = + Family::new_with_constructor(time_to_first_token_histogram as fn() -> Histogram); + registry.register( + "vllm:time_to_first_token_seconds", + "Histogram of time to first token in seconds.", + time_to_first_token_seconds.clone(), + ); + + let inter_token_latency_seconds = + Family::new_with_constructor(inter_token_latency_histogram as fn() -> Histogram); + registry.register( + "vllm:inter_token_latency_seconds", + "Histogram of inter-token latency in seconds.", + inter_token_latency_seconds.clone(), + ); + + let e2e_request_latency_seconds = + Family::new_with_constructor(request_latency_histogram as fn() -> Histogram); + registry.register( + "vllm:e2e_request_latency_seconds", + "Histogram of e2e request latency in seconds.", + e2e_request_latency_seconds.clone(), + ); + + let request_queue_time_seconds = + Family::new_with_constructor(request_latency_histogram as fn() -> Histogram); + registry.register( + "vllm:request_queue_time_seconds", + "Histogram of time spent in WAITING phase for request.", + request_queue_time_seconds.clone(), + ); + + let request_prefill_time_seconds = + Family::new_with_constructor(request_latency_histogram as fn() -> Histogram); + registry.register( + "vllm:request_prefill_time_seconds", + "Histogram of time spent in PREFILL phase for request.", + request_prefill_time_seconds.clone(), + ); + + let request_decode_time_seconds = + Family::new_with_constructor(request_latency_histogram as fn() -> Histogram); + registry.register( + "vllm:request_decode_time_seconds", + "Histogram of time spent in DECODE phase for request.", + request_decode_time_seconds.clone(), + ); + + let request_inference_time_seconds = + Family::new_with_constructor(request_latency_histogram as fn() -> Histogram); + registry.register( + "vllm:request_inference_time_seconds", + "Histogram of time spent in RUNNING phase for request.", + request_inference_time_seconds.clone(), + ); + + let request_time_per_output_token_seconds = Family::new_with_constructor( + request_time_per_output_token_histogram as fn() -> Histogram, + ); + registry.register( + "vllm:request_time_per_output_token_seconds", + "Histogram of time_per_output_token_seconds per request.", + request_time_per_output_token_seconds.clone(), + ); + + Self { + num_preemptions, + prompt_tokens, + prompt_tokens_by_source, + prompt_tokens_cached, + generation_tokens, + request_success, + request_prompt_tokens, + request_generation_tokens, + request_max_num_generation_tokens, + request_params_max_tokens, + request_params_n, + request_prefill_kv_computed_tokens, + time_to_first_token_seconds, + inter_token_latency_seconds, + e2e_request_latency_seconds, + request_queue_time_seconds, + request_prefill_time_seconds, + request_decode_time_seconds, + request_inference_time_seconds, + request_time_per_output_token_seconds, + } + } +} diff --git a/rust/src/metrics/src/scheduler.rs b/rust/src/metrics/src/scheduler.rs new file mode 100644 index 00000000000..0acbdf0fa75 --- /dev/null +++ b/rust/src/metrics/src/scheduler.rs @@ -0,0 +1,272 @@ +use prometheus_client::encoding::EncodeLabelSet; +use prometheus_client::metrics::family::Family; +use prometheus_client::metrics::histogram::Histogram; +use prometheus_client::registry::Registry; + +use crate::{F64Gauge, HistogramFamily, U64Counter, U64Gauge}; + +const KV_CACHE_RESIDENCY_BUCKETS: [f64; 21] = [ + 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 30.0, 60.0, + 120.0, 300.0, 600.0, 1200.0, 1800.0, +]; + +fn kv_block_lifetime_histogram() -> Histogram { + Histogram::new(KV_CACHE_RESIDENCY_BUCKETS.iter().copied()) +} + +fn kv_block_idle_before_evict_histogram() -> Histogram { + Histogram::new(KV_CACHE_RESIDENCY_BUCKETS.iter().copied()) +} + +fn kv_block_reuse_gap_histogram() -> Histogram { + Histogram::new(KV_CACHE_RESIDENCY_BUCKETS.iter().copied()) +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct EngineLabels { + pub model_name: String, + pub engine: u32, +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct EnginePositionLabels { + pub model_name: String, + pub engine: u32, + pub position: u32, +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] +pub struct WaitingReasonLabels { + pub model_name: String, + pub engine: u32, + pub reason: &'static str, +} + +/// Scheduler/batch-scoped Prometheus families exported from `SchedulerStats`. +pub struct SchedulerMetrics { + // Scheduler state gauges. + pub scheduler_running: Family, + pub scheduler_waiting: Family, + pub scheduler_waiting_by_reason: Family, + pub kv_cache_usage: Family, + + // Prefix-cache counters, including the connector-backed external cache path. + pub prefix_cache_queries: Family, + pub prefix_cache_hits: Family, + pub external_prefix_cache_queries: Family, + pub external_prefix_cache_hits: Family, + + // Speculative decoding counters. + pub spec_decode_num_drafts: Family, + pub spec_decode_num_draft_tokens: Family, + pub spec_decode_num_accepted_tokens: Family, + pub spec_decode_num_accepted_tokens_per_pos: Family, + + // Per-engine performance / MFU counters. + pub estimated_flops_per_gpu: Family, + pub estimated_read_bytes_per_gpu: Family, + pub estimated_write_bytes_per_gpu: Family, + + // Sampled KV-cache residency histograms. + pub kv_block_lifetime_seconds: HistogramFamily, + pub kv_block_idle_before_evict_seconds: HistogramFamily, + pub kv_block_reuse_gap_seconds: HistogramFamily, +} + +impl SchedulerMetrics { + /// Register the scheduler-oriented metric families into the shared + /// registry. + pub(crate) fn register(registry: &mut Registry) -> Self { + // Scheduler state gauges. + let scheduler_running = Family::default(); + registry.register( + "vllm:num_requests_running", + "Number of requests in model execution batches", + scheduler_running.clone(), + ); + + let scheduler_waiting = Family::default(); + registry.register( + "vllm:num_requests_waiting", + "Number of requests waiting to be processed", + scheduler_waiting.clone(), + ); + + let scheduler_waiting_by_reason = Family::default(); + registry.register( + "vllm:num_requests_waiting_by_reason", + "Number of waiting requests by reason. \ + Reason labels: 'capacity' = waiting for scheduling capacity; \ + 'deferred' = deferred by transient constraints (LoRA budget, KV transfer, \ + blocked status). Sum of all reasons equals vllm:num_requests_waiting.", + scheduler_waiting_by_reason.clone(), + ); + + let kv_cache_usage = Family::default(); + registry.register( + "vllm:kv_cache_usage_perc", + "KV-cache usage. 1 means 100 percent usage", + kv_cache_usage.clone(), + ); + + // Prefix-cache counters, including the connector-backed external cache path. + let prefix_cache_queries = Family::default(); + registry.register( + "vllm:prefix_cache_queries", + "Prefix cache queries, in terms of number of queried tokens", + prefix_cache_queries.clone(), + ); + + let prefix_cache_hits = Family::default(); + registry.register( + "vllm:prefix_cache_hits", + "Prefix cache hits, in terms of number of cached tokens.", + prefix_cache_hits.clone(), + ); + + let external_prefix_cache_queries = Family::default(); + registry.register( + "vllm:external_prefix_cache_queries", + "External prefix cache queries from KV connector cross-instance cache sharing, in terms of number of queried tokens.", + external_prefix_cache_queries.clone(), + ); + + let external_prefix_cache_hits = Family::default(); + registry.register( + "vllm:external_prefix_cache_hits", + "External prefix cache hits from KV connector cross-instance cache sharing, in terms of number of cached tokens.", + external_prefix_cache_hits.clone(), + ); + + // Speculative decoding counters. + let spec_decode_num_drafts = Family::default(); + registry.register( + "vllm:spec_decode_num_drafts", + "Number of spec decoding drafts.", + spec_decode_num_drafts.clone(), + ); + + let spec_decode_num_draft_tokens = Family::default(); + registry.register( + "vllm:spec_decode_num_draft_tokens", + "Number of draft tokens.", + spec_decode_num_draft_tokens.clone(), + ); + + let spec_decode_num_accepted_tokens = Family::default(); + registry.register( + "vllm:spec_decode_num_accepted_tokens", + "Number of accepted tokens.", + spec_decode_num_accepted_tokens.clone(), + ); + + let spec_decode_num_accepted_tokens_per_pos = Family::default(); + registry.register( + "vllm:spec_decode_num_accepted_tokens_per_pos", + "Accepted tokens per draft position.", + spec_decode_num_accepted_tokens_per_pos.clone(), + ); + + // Per-engine performance / MFU counters. + let estimated_flops_per_gpu = Family::default(); + registry.register( + "vllm:estimated_flops_per_gpu", + "Estimated number of floating point operations per GPU (for Model Flops Utilization calculations).", + estimated_flops_per_gpu.clone(), + ); + + let estimated_read_bytes_per_gpu = Family::default(); + registry.register( + "vllm:estimated_read_bytes_per_gpu", + "Estimated number of bytes read from memory per GPU (for Model Flops Utilization calculations).", + estimated_read_bytes_per_gpu.clone(), + ); + + let estimated_write_bytes_per_gpu = Family::default(); + registry.register( + "vllm:estimated_write_bytes_per_gpu", + "Estimated number of bytes written to memory per GPU (for Model Flops Utilization calculations).", + estimated_write_bytes_per_gpu.clone(), + ); + + // Sampled KV-cache residency histograms. + let kv_block_lifetime_seconds = + Family::new_with_constructor(kv_block_lifetime_histogram as fn() -> Histogram); + registry.register( + "vllm:kv_block_lifetime_seconds", + "Histogram of KV cache block lifetime from allocation to eviction. Sampled metrics (controlled by --kv-cache-metrics-sample).", + kv_block_lifetime_seconds.clone(), + ); + + let kv_block_idle_before_evict_seconds = + Family::new_with_constructor(kv_block_idle_before_evict_histogram as fn() -> Histogram); + registry.register( + "vllm:kv_block_idle_before_evict_seconds", + "Histogram of idle time before KV cache block eviction. Sampled metrics (controlled by --kv-cache-metrics-sample).", + kv_block_idle_before_evict_seconds.clone(), + ); + + let kv_block_reuse_gap_seconds = + Family::new_with_constructor(kv_block_reuse_gap_histogram as fn() -> Histogram); + registry.register( + "vllm:kv_block_reuse_gap_seconds", + "Histogram of time gaps between consecutive KV cache block accesses. Only the most recent accesses are recorded (ring buffer). Sampled metrics (controlled by --kv-cache-metrics-sample).", + kv_block_reuse_gap_seconds.clone(), + ); + + Self { + scheduler_running, + scheduler_waiting, + scheduler_waiting_by_reason, + kv_cache_usage, + prefix_cache_queries, + prefix_cache_hits, + external_prefix_cache_queries, + external_prefix_cache_hits, + spec_decode_num_drafts, + spec_decode_num_draft_tokens, + spec_decode_num_accepted_tokens, + spec_decode_num_accepted_tokens_per_pos, + estimated_flops_per_gpu, + estimated_read_bytes_per_gpu, + estimated_write_bytes_per_gpu, + kv_block_lifetime_seconds, + kv_block_idle_before_evict_seconds, + kv_block_reuse_gap_seconds, + } + } +} + +#[cfg(test)] +mod tests { + use crate::{EngineLabels, Metrics}; + + #[test] + fn perf_counters_render_with_a_single_total_suffix() { + let metrics = Metrics::new(); + let labels = EngineLabels { + model_name: "model".to_string(), + engine: 0, + }; + + metrics.scheduler.estimated_flops_per_gpu.get_or_create(&labels).inc(); + metrics.scheduler.estimated_read_bytes_per_gpu.get_or_create(&labels).inc(); + metrics.scheduler.estimated_write_bytes_per_gpu.get_or_create(&labels).inc(); + + let rendered = metrics.render().unwrap(); + assert!( + rendered.contains( + "vllm:estimated_flops_per_gpu_total{model_name=\"model\",engine=\"0\"} 1" + ) + ); + assert!(rendered.contains( + "vllm:estimated_read_bytes_per_gpu_total{model_name=\"model\",engine=\"0\"} 1" + )); + assert!(rendered.contains( + "vllm:estimated_write_bytes_per_gpu_total{model_name=\"model\",engine=\"0\"} 1" + )); + assert!(!rendered.contains("vllm:estimated_flops_per_gpu_total_total")); + assert!(!rendered.contains("vllm:estimated_read_bytes_per_gpu_total_total")); + assert!(!rendered.contains("vllm:estimated_write_bytes_per_gpu_total_total")); + } +} diff --git a/rust/src/reasoning-parser/Cargo.toml b/rust/src/reasoning-parser/Cargo.toml new file mode 100644 index 00000000000..d6500a7b0c1 --- /dev/null +++ b/rust/src/reasoning-parser/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "vllm-reasoning-parser" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +thiserror.workspace = true +vllm-tokenizer.workspace = true + +[lints] +workspace = true diff --git a/rust/src/reasoning-parser/src/cohere_cmd.rs b/rust/src/reasoning-parser/src/cohere_cmd.rs new file mode 100644 index 00000000000..9acaf3c5f1e --- /dev/null +++ b/rust/src/reasoning-parser/src/cohere_cmd.rs @@ -0,0 +1,45 @@ +use vllm_tokenizer::DynTokenizer; + +use super::{DelimitedReasoningParser, ReasoningDelta, ReasoningParser, Result}; + +/// Reasoning parser for Cohere Command models that use explicit START/END tags. +pub struct CohereCmdReasoningParser { + inner: DelimitedReasoningParser, +} + +impl CohereCmdReasoningParser { + /// Create a Cohere Command parser backed by the shared delimited state + /// machine. + pub fn new(tokenizer: DynTokenizer) -> Result { + Ok(Self { + inner: DelimitedReasoningParser::new( + tokenizer, + "<|START_THINKING|>", + "<|END_THINKING|>", + false, + )?, + }) + } +} + +impl ReasoningParser for CohereCmdReasoningParser { + fn create(tokenizer: DynTokenizer) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tokenizer)?)) + } + + fn initialize(&mut self, prompt_token_ids: &[u32]) -> Result<()> { + self.inner.initialize(prompt_token_ids); + Ok(()) + } + + fn push(&mut self, delta: &str) -> Result { + Ok(self.inner.push(delta)) + } + + fn finish(&mut self) -> Result { + Ok(self.inner.finish()) + } +} diff --git a/rust/src/reasoning-parser/src/deepseek_r1.rs b/rust/src/reasoning-parser/src/deepseek_r1.rs new file mode 100644 index 00000000000..069de478b7a --- /dev/null +++ b/rust/src/reasoning-parser/src/deepseek_r1.rs @@ -0,0 +1,44 @@ +use vllm_tokenizer::DynTokenizer; + +use super::{DelimitedReasoningParser, ReasoningDelta, ReasoningParser, Result}; + +/// Reasoning parser for DeepSeek R1 style outputs. +/// +/// DeepSeek R1 may begin generating directly inside a reasoning span and only +/// emit the closing `` delimiter, so the no-boundary fallback defaults +/// to `in_reasoning = true`. +pub struct DeepSeekR1ReasoningParser { + inner: DelimitedReasoningParser, +} + +impl DeepSeekR1ReasoningParser { + /// Create a DeepSeek R1 parser backed by the shared delimited state + /// machine. + pub fn new(tokenizer: DynTokenizer) -> Result { + Ok(Self { + inner: DelimitedReasoningParser::new(tokenizer, "", "", true)?, + }) + } +} + +impl ReasoningParser for DeepSeekR1ReasoningParser { + fn create(tokenizer: DynTokenizer) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tokenizer)?)) + } + + fn initialize(&mut self, prompt_token_ids: &[u32]) -> Result<()> { + self.inner.initialize(prompt_token_ids); + Ok(()) + } + + fn push(&mut self, delta: &str) -> Result { + Ok(self.inner.push(delta)) + } + + fn finish(&mut self) -> Result { + Ok(self.inner.finish()) + } +} diff --git a/rust/src/reasoning-parser/src/delimited.rs b/rust/src/reasoning-parser/src/delimited.rs new file mode 100644 index 00000000000..485202e3e2e --- /dev/null +++ b/rust/src/reasoning-parser/src/delimited.rs @@ -0,0 +1,161 @@ +use vllm_tokenizer::{DynTokenizer, Tokenizer}; + +use super::{ReasoningDelta, ReasoningError, Result}; + +/// Shared incremental state machine for tag-delimited reasoning protocols. +/// +/// This helper is intentionally not a public parser type. Model-family parser +/// wrappers own one `DelimitedReasoningParser` internally and expose the +/// request-facing [`super::ReasoningParser`] trait. +/// +/// The shared state machine stays generic by deriving its initial +/// `current_in_reasoning` state from the prompt token boundary instead of +/// hardcoding model-family conventions. That means families with the same +/// delimiters can often reuse this implementation even if their chat templates +/// prefill different prompts. +pub(crate) struct DelimitedReasoningParser { + tokenizer: DynTokenizer, + current_in_reasoning: bool, + buffer: String, + start_token: String, + end_token: String, + start_token_id: u32, + end_token_id: u32, + default_in_reasoning: bool, +} + +impl DelimitedReasoningParser { + /// Create one delimited parser state machine. + /// + /// `default_in_reasoning` is only used when prompt initialization sees no + /// reasoning boundary token at all. If the prompt contains either the + /// start or end delimiter, that prompt boundary always wins. + pub(crate) fn new( + tokenizer: DynTokenizer, + start_token: &'static str, + end_token: &'static str, + default_in_reasoning: bool, + ) -> Result { + let start_token_id = + tokenizer.token_to_id(start_token).ok_or_else(|| ReasoningError::MissingToken { + token: start_token.to_string(), + })?; + let end_token_id = + tokenizer.token_to_id(end_token).ok_or_else(|| ReasoningError::MissingToken { + token: end_token.to_string(), + })?; + + Ok(Self { + tokenizer, + current_in_reasoning: default_in_reasoning, + buffer: String::new(), + start_token: start_token.to_string(), + end_token: end_token.to_string(), + start_token_id, + end_token_id, + default_in_reasoning, + }) + } + + /// Initialize the starting state from prompt token IDs. + pub(crate) fn initialize(&mut self, prompt_token_ids: &[u32]) { + self.current_in_reasoning = last_reasoning_boundary( + prompt_token_ids, + self.start_token_id, + self.end_token_id, + self.tokenizer.as_ref(), + ) + .unwrap_or(self.default_in_reasoning); + } + + /// Parse one decoded text delta and return its reasoning/content split. + pub(crate) fn push(&mut self, delta: &str) -> ReasoningDelta { + self.buffer.push_str(delta); + + let partial_suffix_len = self.partial_suffix_len(&self.buffer); + let stable_len = self.buffer.len() - partial_suffix_len; + let pending_suffix = self.buffer.split_off(stable_len); + let stable_text = std::mem::replace(&mut self.buffer, pending_suffix); + + self.parse_stable_text(&stable_text) + } + + /// Flush any buffered partial delimiter suffix at end of stream. + pub(crate) fn finish(&mut self) -> ReasoningDelta { + let stable_text = std::mem::take(&mut self.buffer); + self.parse_stable_text(&stable_text) + } + + /// Parse text that is known not to end with a partial delimiter suffix. + fn parse_stable_text(&mut self, mut stable: &str) -> ReasoningDelta { + let mut delta = ReasoningDelta::default(); + + while !stable.is_empty() { + if self.current_in_reasoning { + if let Some(end_idx) = stable.find(&self.end_token) { + delta.push_reasoning(&stable[..end_idx]); + stable = &stable[end_idx + self.end_token.len()..]; + self.current_in_reasoning = false; + } else { + delta.push_reasoning(stable); + break; + } + } else if let Some(start_idx) = stable.find(&self.start_token) { + delta.push_content(&stable[..start_idx]); + stable = &stable[start_idx + self.start_token.len()..]; + self.current_in_reasoning = true; + } else { + delta.push_content(stable); + break; + } + } + + delta + } + + /// Return the longest trailing suffix that could still complete a + /// delimiter. + fn partial_suffix_len(&self, text: &str) -> usize { + let mut best = 0; + for idx in text.char_indices().map(|(idx, _)| idx).skip(1) { + let suffix = &text[idx..]; + if self.start_token.starts_with(suffix) && self.start_token != suffix { + best = best.max(text.len() - idx); + } + if self.end_token.starts_with(suffix) && self.end_token != suffix { + best = best.max(text.len() - idx); + } + } + + if self.start_token.starts_with(text) && self.start_token != text { + best = best.max(text.len()); + } + if self.end_token.starts_with(text) && self.end_token != text { + best = best.max(text.len()); + } + + best + } +} + +/// Determine the reasoning state implied by the last prompt boundary, if any. +fn last_reasoning_boundary( + prompt_token_ids: &[u32], + start_token_id: u32, + end_token_id: u32, + tokenizer: &dyn Tokenizer, +) -> Option { + for token_id in prompt_token_ids.iter().rev() { + if *token_id == start_token_id { + return Some(true); + } + if *token_id == end_token_id { + return Some(false); + } + if tokenizer.is_special_id(*token_id) { + return None; + } + } + + None +} diff --git a/rust/src/reasoning-parser/src/gemma4.rs b/rust/src/reasoning-parser/src/gemma4.rs new file mode 100644 index 00000000000..86824f2ad40 --- /dev/null +++ b/rust/src/reasoning-parser/src/gemma4.rs @@ -0,0 +1,273 @@ +use vllm_tokenizer::DynTokenizer; + +use super::{DelimitedReasoningParser, ReasoningDelta, ReasoningParser, Result}; + +const THOUGHT_PREFIX: &str = "thought\n"; + +/// Reasoning parser for Google Gemma4 thinking models. +/// +/// Gemma4 emits reasoning inside `<|channel> ... ` spans and adds a +/// structural `thought\n` label at the beginning of the reasoning channel. +/// This parser keeps the delimiter handling in the shared delimited parser and +/// only layers on Gemma4-specific request adjustment plus prefix stripping. +/// +/// Original Python implementation: +/// +pub struct Gemma4ReasoningParser { + inner: DelimitedReasoningParser, + reasoning_text: String, + prefix_stripped: bool, +} + +impl Gemma4ReasoningParser { + /// Create a Gemma4 parser. + pub fn new(tokenizer: DynTokenizer) -> Result { + Ok(Self { + inner: DelimitedReasoningParser::new(tokenizer, "<|channel>", "", false)?, + reasoning_text: String::new(), + prefix_stripped: false, + }) + } + + /// Apply Gemma4's `thought\n` stripping rule to one reasoning delta. + /// + /// Early reasoning text is buffered until we can decide whether it begins + /// with the structural channel label. + fn strip_thought_prefix(&mut self, reasoning: &str) -> Option { + if self.prefix_stripped { + return Some(reasoning.to_string()); + } + + self.reasoning_text.push_str(reasoning); + + if self.reasoning_text.starts_with(THOUGHT_PREFIX) { + let prefix_len = THOUGHT_PREFIX.len(); + let previous_len = self.reasoning_text.len() - reasoning.len(); + if previous_len >= prefix_len { + self.reasoning_text.clear(); + self.prefix_stripped = true; + return Some(reasoning.to_string()); + } + + let prefix_chars_in_delta = prefix_len - previous_len; + let stripped = &reasoning[prefix_chars_in_delta.min(reasoning.len())..]; + if stripped.is_empty() { + if self.reasoning_text.len() >= prefix_len { + self.reasoning_text.clear(); + self.prefix_stripped = true; + } + return None; + } + + self.reasoning_text.clear(); + self.prefix_stripped = true; + return Some(stripped.to_string()); + } + + if THOUGHT_PREFIX.starts_with(&self.reasoning_text) { + return None; + } + + self.prefix_stripped = true; + Some(std::mem::take(&mut self.reasoning_text)) + } + + /// Apply Gemma4-specific reasoning post-processing to one parsed delta. + fn post_process(&mut self, mut result: ReasoningDelta) -> ReasoningDelta { + if let Some(reasoning) = result.reasoning.take() { + result.reasoning = + self.strip_thought_prefix(&reasoning).filter(|text| !text.is_empty()); + } + result + } +} + +impl ReasoningParser for Gemma4ReasoningParser { + fn create(tokenizer: DynTokenizer) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tokenizer)?)) + } + + fn preserve_special_tokens(&self) -> bool { + true + } + + fn initialize(&mut self, prompt_token_ids: &[u32]) -> Result<()> { + self.inner.initialize(prompt_token_ids); + self.reasoning_text.clear(); + self.prefix_stripped = false; + Ok(()) + } + + fn push(&mut self, delta: &str) -> Result { + let result = self.inner.push(delta); + Ok(self.post_process(result)) + } + + fn finish(&mut self) -> Result { + let result = self.inner.finish(); + Ok(self.post_process(result)) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use vllm_tokenizer::Tokenizer; + + use super::Gemma4ReasoningParser; + use crate::ReasoningParser; + + struct FakeTokenizer; + + impl Tokenizer for FakeTokenizer { + fn encode( + &self, + text: &str, + _add_special_tokens: bool, + ) -> vllm_tokenizer::Result> { + Ok(text.chars().map(u32::from).collect()) + } + + fn decode( + &self, + token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_tokenizer::Result { + Ok(token_ids + .iter() + .map(|token_id| char::from_u32(*token_id).unwrap_or('\u{FFFD}')) + .collect()) + } + + fn token_to_id(&self, token: &str) -> Option { + match token { + "<|channel>" => Some(1000), + "" => Some(1001), + _ => None, + } + } + } + + fn run_streaming(output: &[&str]) -> (Option, Option) { + let tokenizer = Arc::new(FakeTokenizer); + let mut parser = Gemma4ReasoningParser::new(tokenizer).unwrap(); + let mut reasoning = String::new(); + let mut content = String::new(); + + for delta in output { + let result = parser.push(delta).unwrap(); + if let Some(next) = result.reasoning { + reasoning.push_str(&next); + } + if let Some(next) = result.content { + content.push_str(&next); + } + } + + let final_delta = parser.finish().unwrap(); + if let Some(next) = final_delta.reasoning { + reasoning.push_str(&next); + } + if let Some(next) = final_delta.content { + content.push_str(&next); + } + + ( + (!reasoning.is_empty()).then_some(reasoning), + (!content.is_empty()).then_some(content), + ) + } + + #[test] + fn gemma4_reasoning_streaming_handles_channel_delimited_outputs() { + let cases = [ + ( + "no_reasoning", + vec!["This is content"], + None, + Some("This is content"), + ), + ( + "reasoning_and_content", + vec!["<|channel>This is a reasoning sectionThis is the rest"], + Some("This is a reasoning section"), + Some("This is the rest"), + ), + ( + "complete_reasoning", + vec!["<|channel>This is a reasoning section"], + Some("This is a reasoning section"), + None, + ), + ( + "multiple_lines", + vec!["<|channel>This\nThatThis is the rest\nThat"], + Some("This\nThat"), + Some("This is the rest\nThat"), + ), + ( + "no_end", + vec!["<|channel>This is a reasoning section"], + Some("This is a reasoning section"), + None, + ), + ("empty", vec![""], None, None), + ( + "newline_around_reasoning", + vec!["Before\n<|channel>This is a reasoning section\nThis is the rest"], + Some("This is a reasoning section"), + Some("Before\n\nThis is the rest"), + ), + ( + "thought_prefix", + vec!["<|channel>thought\nActual reasoning hereFinal answer"], + Some("Actual reasoning here"), + Some("Final answer"), + ), + ( + "thought_prefix_only", + vec!["<|channel>thought\n"], + None, + None, + ), + ( + "thought_prefix_multiline", + vec!["<|channel>thought\nLine1\nLine2Answer"], + Some("Line1\nLine2"), + Some("Answer"), + ), + ( + "thought_prefix_diverge", + vec!["<|channel>thousand reasonsDone"], + Some("thousand reasons"), + Some("Done"), + ), + ]; + + for (name, output, expected_reasoning, expected_content) in cases { + let (reasoning, content) = run_streaming(&output); + assert_eq!(reasoning.as_deref(), expected_reasoning, "{name}"); + assert_eq!(content.as_deref(), expected_content, "{name}"); + } + } + + #[test] + fn gemma4_strips_thought_prefix_even_when_split_across_deltas() { + let (reasoning, content) = + run_streaming(&["<|channel>thou", "ght", "\nabc", "done"]); + assert_eq!(reasoning.as_deref(), Some("abc")); + assert_eq!(content.as_deref(), Some("done")); + } + + #[test] + fn gemma4_preserves_special_tokens() { + let tokenizer = Arc::new(FakeTokenizer); + let parser = Gemma4ReasoningParser::new(tokenizer).unwrap(); + + assert!(parser.preserve_special_tokens()); + } +} diff --git a/rust/src/reasoning-parser/src/kimi.rs b/rust/src/reasoning-parser/src/kimi.rs new file mode 100644 index 00000000000..6c042d0d8f1 --- /dev/null +++ b/rust/src/reasoning-parser/src/kimi.rs @@ -0,0 +1,39 @@ +use vllm_tokenizer::DynTokenizer; + +use super::{DelimitedReasoningParser, ReasoningDelta, ReasoningParser, Result}; + +/// Reasoning parser for legacy Kimi models that use Unicode thinking tags. +pub struct KimiReasoningParser { + inner: DelimitedReasoningParser, +} + +impl KimiReasoningParser { + /// Create a Kimi parser backed by the shared delimited state machine. + pub fn new(tokenizer: DynTokenizer) -> Result { + Ok(Self { + inner: DelimitedReasoningParser::new(tokenizer, "◁think▷", "◁/think▷", false)?, + }) + } +} + +impl ReasoningParser for KimiReasoningParser { + fn create(tokenizer: DynTokenizer) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tokenizer)?)) + } + + fn initialize(&mut self, prompt_token_ids: &[u32]) -> Result<()> { + self.inner.initialize(prompt_token_ids); + Ok(()) + } + + fn push(&mut self, delta: &str) -> Result { + Ok(self.inner.push(delta)) + } + + fn finish(&mut self) -> Result { + Ok(self.inner.finish()) + } +} diff --git a/rust/src/reasoning-parser/src/lib.rs b/rust/src/reasoning-parser/src/lib.rs new file mode 100644 index 00000000000..084168ab2f1 --- /dev/null +++ b/rust/src/reasoning-parser/src/lib.rs @@ -0,0 +1,129 @@ +//! Streaming reasoning parsers for chat completions. +//! +//! The key design choice here is that parser initialization prefers the +//! *actual rendered prompt state* over model-family conventions. When a stream +//! starts, each parser receives the prompt token IDs and inspects the last +//! reasoning boundary that is already present in the prompt. In practice this +//! is a more faithful signal than hardcoding assumptions such as "this model +//! always starts in reasoning" or "this model always emits `` itself". +//! +//! That prompt-first initialization lets multiple model families share the +//! same incremental parser implementation even when older Python parsers split +//! them apart. If two families use the same textual delimiters and differ +//! mostly in how their chat templates prefill `` / ``, they can +//! usually reuse one parser here because the prompt token IDs already tell us +//! which state the stream is entering with. + +mod cohere_cmd; +mod deepseek_r1; +mod delimited; +mod gemma4; +mod kimi; +mod qwen3; + +use thiserror::Error; +use vllm_tokenizer::DynTokenizer; + +pub use self::cohere_cmd::CohereCmdReasoningParser; +pub use self::deepseek_r1::DeepSeekR1ReasoningParser; +pub(crate) use self::delimited::DelimitedReasoningParser; +pub use self::gemma4::Gemma4ReasoningParser; +pub use self::kimi::KimiReasoningParser; +pub use self::qwen3::Qwen3ReasoningParser; + +/// DeepSeek V3 currently shares the standard `...` parser. +pub type DeepSeekV3ReasoningParser = Qwen3ReasoningParser; +/// DeepSeek V4 currently shares the standard `...` parser. +pub type DeepSeekV4ReasoningParser = Qwen3ReasoningParser; +/// GLM45 currently shares the standard `...` parser. +pub type Glm45ReasoningParser = Qwen3ReasoningParser; +/// Kimi K2 currently shares the standard `...` parser. +// TODO: kimi k2 may implicitly end reasoning by starting a tool call section +// using <|tool_calls_section_begin|>, we should support that. +pub type KimiK2ReasoningParser = Qwen3ReasoningParser; +/// MiniMax M2 currently shares the standard `...` parser. +pub type MiniMaxM2ReasoningParser = Qwen3ReasoningParser; +/// Nemotron V3 currently shares the standard `...` parser. +pub type NemotronV3ReasoningParser = Qwen3ReasoningParser; +/// Step3 currently shares the standard `...` parser. +pub type Step3ReasoningParser = Qwen3ReasoningParser; + +/// Result alias for reasoning parser operations. +pub type Result = std::result::Result; + +/// One parsed streaming delta split into reasoning and visible content. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct ReasoningDelta { + pub reasoning: Option, + pub content: Option, +} + +impl ReasoningDelta { + /// Return true when this delta carries neither reasoning nor content text. + pub fn is_empty(&self) -> bool { + self.reasoning.is_none() && self.content.is_none() + } + + /// Append text to the reasoning portion, creating it on first use. + pub(crate) fn push_reasoning(&mut self, text: &str) { + if text.is_empty() { + return; + } + match &mut self.reasoning { + Some(existing) => existing.push_str(text), + None => self.reasoning = Some(text.to_string()), + } + } + + /// Append text to the visible content portion, creating it on first use. + pub(crate) fn push_content(&mut self, text: &str) { + if text.is_empty() { + return; + } + match &mut self.content { + Some(existing) => existing.push_str(text), + None => self.content = Some(text.to_string()), + } + } +} + +/// Incremental parser that splits decoded text deltas into reasoning and +/// content. +pub trait ReasoningParser: Send { + /// Construct a boxed parser instance for one request stream. + fn create(tokenizer: DynTokenizer) -> Result> + where + Self: Sized + 'static; + + /// Initialize parser state from prompt token IDs before output deltas + /// arrive. + fn initialize(&mut self, _prompt_token_ids: &[u32]) -> Result<()> { + Ok(()) + } + + /// Return whether decoded output must preserve tokenizer special tokens. + /// + /// Some model families emit reasoning sentinels as special tokens. Those + /// parsers need `skip_special_tokens = false` while parsing is enabled. + fn preserve_special_tokens(&self) -> bool { + false + } + + /// Feed one decoded text delta into the parser. + fn push(&mut self, delta: &str) -> Result; + + /// Flush any buffered partial delimiter state at end of stream. + fn finish(&mut self) -> Result { + Ok(ReasoningDelta::default()) + } +} + +/// Errors produced while creating or running reasoning parsers. +#[derive(Debug, Error)] +pub enum ReasoningError { + #[error("tokenizer is missing reasoning delimiter token `{token}`")] + MissingToken { token: String }, +} + +#[cfg(test)] +mod tests; diff --git a/rust/src/reasoning-parser/src/qwen3.rs b/rust/src/reasoning-parser/src/qwen3.rs new file mode 100644 index 00000000000..c50ab2d0fbd --- /dev/null +++ b/rust/src/reasoning-parser/src/qwen3.rs @@ -0,0 +1,43 @@ +use vllm_tokenizer::DynTokenizer; + +use super::{DelimitedReasoningParser, ReasoningDelta, ReasoningParser, Result}; + +/// Reasoning parser for the Qwen3/Qwen3.5 family. +/// +/// This parser uses standard `...` delimiters and defaults to +/// waiting for an explicit start token when prompt initialization finds no +/// reasoning boundary. +pub struct Qwen3ReasoningParser { + inner: DelimitedReasoningParser, +} + +impl Qwen3ReasoningParser { + /// Create a Qwen3 parser backed by the shared delimited state machine. + pub fn new(tokenizer: DynTokenizer) -> Result { + Ok(Self { + inner: DelimitedReasoningParser::new(tokenizer, "", "", false)?, + }) + } +} + +impl ReasoningParser for Qwen3ReasoningParser { + fn create(tokenizer: DynTokenizer) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tokenizer)?)) + } + + fn initialize(&mut self, prompt_token_ids: &[u32]) -> Result<()> { + self.inner.initialize(prompt_token_ids); + Ok(()) + } + + fn push(&mut self, delta: &str) -> Result { + Ok(self.inner.push(delta)) + } + + fn finish(&mut self) -> Result { + Ok(self.inner.finish()) + } +} diff --git a/rust/src/reasoning-parser/src/tests.rs b/rust/src/reasoning-parser/src/tests.rs new file mode 100644 index 00000000000..da602d9fddd --- /dev/null +++ b/rust/src/reasoning-parser/src/tests.rs @@ -0,0 +1,161 @@ +use std::sync::Arc; + +use vllm_tokenizer::Tokenizer; + +use super::{ + DeepSeekR1ReasoningParser, DelimitedReasoningParser, Qwen3ReasoningParser, ReasoningParser, +}; + +struct FakeTokenizer; + +impl Tokenizer for FakeTokenizer { + fn encode(&self, text: &str, _add_special_tokens: bool) -> vllm_tokenizer::Result> { + Ok(text.chars().map(u32::from).collect()) + } + + fn decode( + &self, + token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_tokenizer::Result { + Ok(token_ids + .iter() + .map(|token_id| char::from_u32(*token_id).unwrap_or('\u{FFFD}')) + .collect()) + } + + fn token_to_id(&self, token: &str) -> Option { + match token { + "" => Some(1), + "" => Some(2), + "<|START_THINKING|>" => Some(3), + "<|END_THINKING|>" => Some(4), + "◁think▷" => Some(5), + "◁/think▷" => Some(6), + _ => None, + } + } + + fn is_special_id(&self, token_id: u32) -> bool { + token_id == 7 + } +} + +#[test] +fn delimited_content_only_stream() { + let tokenizer = Arc::new(FakeTokenizer); + let mut parser = + DelimitedReasoningParser::new(tokenizer, "", "", false).unwrap(); + + assert_eq!( + parser.push("plain content").content.as_deref(), + Some("plain content") + ); +} + +#[test] +fn delimited_single_chunk_with_reasoning_and_content() { + let tokenizer = Arc::new(FakeTokenizer); + let mut parser = + DelimitedReasoningParser::new(tokenizer, "", "", false).unwrap(); + + let delta = parser.push("reasonanswer"); + assert_eq!(delta.reasoning.as_deref(), Some("reason")); + assert_eq!(delta.content.as_deref(), Some("answer")); +} + +#[test] +fn delimited_partial_tokens_across_chunks() { + let tokenizer = Arc::new(FakeTokenizer); + let mut parser = + DelimitedReasoningParser::new(tokenizer, "", "", false).unwrap(); + + assert!(parser.push("reasonanswer"); + assert_eq!(delta.reasoning.as_deref(), Some("reason")); + assert_eq!(delta.content.as_deref(), Some("answer")); +} + +#[test] +fn delimited_finish_flushes_buffer() { + let tokenizer = Arc::new(FakeTokenizer); + let mut parser = + DelimitedReasoningParser::new(tokenizer, "", "", false).unwrap(); + parser.initialize(&[1]); + + let delta = parser.push("unfinishedanswer").unwrap(); + assert_eq!(delta.reasoning, None); + assert_eq!(delta.content.as_deref(), Some("reasonanswer")); +} + +#[test] +fn qwen3_prompt_end_marker_starts_in_content() { + let tokenizer = Arc::new(FakeTokenizer); + let mut parser = Qwen3ReasoningParser::new(tokenizer).unwrap(); + parser.initialize(&[2]).unwrap(); + + let delta = parser.push("answer").unwrap(); + assert_eq!(delta.reasoning, None); + assert_eq!(delta.content.as_deref(), Some("answer")); +} + +#[test] +fn qwen3_tolerates_old_and_new_formats() { + let tokenizer = Arc::new(FakeTokenizer); + + let mut old_parser = Qwen3ReasoningParser::new(tokenizer.clone()).unwrap(); + let old = old_parser.push("reasonanswer").unwrap(); + assert_eq!(old.reasoning.as_deref(), Some("reason")); + assert_eq!(old.content.as_deref(), Some("answer")); + + let mut new_parser = Qwen3ReasoningParser::new(tokenizer).unwrap(); + new_parser.initialize(&[1]).unwrap(); + let new = new_parser.push("reasonanswer").unwrap(); + assert_eq!(new.reasoning.as_deref(), Some("reason")); + assert_eq!(new.content.as_deref(), Some("answer")); +} + +#[test] +fn qwen3_stops_scanning_at_last_special_token() { + let tokenizer = Arc::new(FakeTokenizer); + let mut parser = Qwen3ReasoningParser::new(tokenizer).unwrap(); + + parser.initialize(&[1, 7]).unwrap(); + + let delta = parser.push("answer").unwrap(); + assert_eq!(delta.reasoning, None); + assert_eq!(delta.content.as_deref(), Some("answer")); +} + +#[test] +fn deepseek_r1_defaults_to_reasoning_without_prompt_boundary() { + let tokenizer = Arc::new(FakeTokenizer); + let mut parser = DeepSeekR1ReasoningParser::new(tokenizer).unwrap(); + + let delta = parser.push("reasonanswer").unwrap(); + assert_eq!(delta.reasoning.as_deref(), Some("reason")); + assert_eq!(delta.content.as_deref(), Some("answer")); +} + +#[test] +fn deepseek_r1_stops_scanning_at_last_special_token() { + let tokenizer = Arc::new(FakeTokenizer); + let mut parser = DeepSeekR1ReasoningParser::new(tokenizer).unwrap(); + + parser.initialize(&[2, 7]).unwrap(); + + let delta = parser.push("reasonanswer").unwrap(); + assert_eq!(delta.reasoning.as_deref(), Some("reason")); + assert_eq!(delta.content.as_deref(), Some("answer")); +} diff --git a/rust/src/server/Cargo.toml b/rust/src/server/Cargo.toml new file mode 100644 index 00000000000..6030f972a9f --- /dev/null +++ b/rust/src/server/Cargo.toml @@ -0,0 +1,57 @@ +[package] +name = "vllm-server" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +anyhow.workspace = true +asynk-strim-attr.workspace = true +axum.workspace = true +futures.workspace = true +http-body.workspace = true +itertools.workspace = true +libc.workspace = true +llm-multimodal.workspace = true +prost.workspace = true +prost-types.workspace = true +rmpv.workspace = true +serde.workspace = true +serde_json.workspace = true +serde_with.workspace = true +socket2.workspace = true +thiserror-ext.workspace = true +tokio.workspace = true +tokio-stream.workspace = true +tokio-util.workspace = true +tonic.workspace = true +tonic-prost.workspace = true +tower-http.workspace = true +tracing.workspace = true +tracing-futures.workspace = true +tracing-subscriber.workspace = true +uuid.workspace = true +validator.workspace = true +vllm-chat.workspace = true +vllm-engine-core-client.workspace = true +vllm-llm.workspace = true +vllm-metrics.workspace = true +vllm-text.workspace = true + +[build-dependencies] +tonic-prost-build.workspace = true + +[dev-dependencies] +anyhow.workspace = true +async-openai = { workspace = true, features = ["full"] } +bytes.workspace = true +clap.workspace = true +expect-test.workspace = true +rmp-serde.workspace = true +serial_test.workspace = true +tower.workspace = true +vllm-engine-core-client = { workspace = true, features = ["test-util"] } +zeromq.workspace = true + +[lints] +workspace = true diff --git a/rust/src/server/build.rs b/rust/src/server/build.rs new file mode 100644 index 00000000000..86a44aac7bb --- /dev/null +++ b/rust/src/server/build.rs @@ -0,0 +1,12 @@ +fn main() -> Result<(), Box> { + let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap(); + let proto_dir = format!("{manifest_dir}/../../proto"); + + tonic_prost_build::configure() + .build_server(true) + .build_client(true) + .protoc_arg("--experimental_allow_proto3_optional") // be compatible with old compilers + .compile_protos(&[format!("{proto_dir}/vllm_grpc.proto")], &[proto_dir])?; + + Ok(()) +} diff --git a/rust/src/server/examples/README.md b/rust/src/server/examples/README.md new file mode 100644 index 00000000000..f4f152082d9 --- /dev/null +++ b/rust/src/server/examples/README.md @@ -0,0 +1,39 @@ +# Server Smoke Test + +Start a fresh headless `vllm` engine: + +```bash +source ../vllm/.venv/bin/activate +HF_HUB_OFFLINE=1 \ +VLLM_LOGGING_LEVEL=DEBUG \ +VLLM_CPU_KVCACHE_SPACE=2 \ +VLLM_HOST_IP=127.0.0.1 \ +VLLM_LOOPBACK_IP=127.0.0.1 \ +python3 -m vllm.entrypoints.cli.main serve Qwen/Qwen3-0.6B \ + --headless \ + --data-parallel-address 127.0.0.1 \ + --data-parallel-rpc-port 62100 \ + --data-parallel-size-local 1 \ + --max-model-len 512 \ + --dtype float16 +``` + +Run the Rust server smoke test: + +```bash +cargo run -p vllm-server --example external_engine_openai_qwen -- \ + --handshake-address tcp://127.0.0.1:62100 +``` + +The example starts the Rust OpenAI-compatible server on an ephemeral local port, +connects to it via the `async-openai` Rust client, lists models, and then checks +that a streamed chat completion yields the assistant role chunk, final-answer +content chunks, and a terminal finish chunk. This example intentionally uses +`async-openai`'s standard typed `create_stream` API instead of BYOT, so it does +not inspect the nonstandard `reasoning_content` field even though the Rust +server may emit it for reasoning-capable models such as Qwen3. For reasoning +behavior itself, use the `vllm-chat` smoke test or the `vllm-server` +route tests. + +IMPORTANT: Restart `vllm` each time you run the smoke test. The current headless +engine cannot safely handle frontend reconnects after the client shuts down. diff --git a/rust/src/server/examples/external_engine_openai_qwen.rs b/rust/src/server/examples/external_engine_openai_qwen.rs new file mode 100644 index 00000000000..6ef2e1a883e --- /dev/null +++ b/rust/src/server/examples/external_engine_openai_qwen.rs @@ -0,0 +1,215 @@ +use std::time::Duration; + +use anyhow::{Context, Result, bail}; +use async_openai::Client; +use async_openai::config::OpenAIConfig; +use async_openai::types::chat::{ + ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequest, + CreateChatCompletionRequestArgs, +}; +use async_openai::types::models::ListModelResponse; +use clap::Parser; +use futures::StreamExt as _; +use tokio_util::sync::CancellationToken; +use tracing_subscriber::EnvFilter; +use vllm_engine_core_client::TransportMode; +use vllm_server::{ + ChatTemplateContentFormatOption, Config, CoordinatorMode, HttpListenerMode, ParserSelection, + RendererSelection, serve, +}; + +#[derive(Debug, Parser)] +#[command( + about = "Smoke-test the Rust OpenAI server with async-openai against an external Qwen vLLM engine." +)] +struct Args { + #[arg(long)] + handshake_address: String, + #[arg(long, default_value_t = 1)] + engine_count: usize, + #[arg(long, default_value = "Qwen/Qwen3-0.6B")] + model: String, + #[arg(long, default_value = "127.0.0.1")] + host: String, + #[arg(long, default_value_t = 30)] + ready_timeout_secs: u64, + #[arg( + long, + default_value = "What is the capital of France? Answer with one word." + )] + prompt: String, +} + +#[tokio::main(flavor = "multi_thread")] +async fn main() -> Result<()> { + init_tracing(); + let args = Args::parse(); + let port = unique_local_port()?; + let config = Config { + transport_mode: TransportMode::HandshakeOwner { + handshake_address: args.handshake_address, + advertised_host: args.host, + engine_count: args.engine_count, + ready_timeout: Duration::from_secs(args.ready_timeout_secs), + local_input_address: None, + local_output_address: None, + }, + coordinator_mode: CoordinatorMode::MaybeInProc, + model: args.model, + served_model_name: vec![], + listener_mode: HttpListenerMode::BindTcp { + host: "127.0.0.1".to_string(), + port, + }, + tool_call_parser: ParserSelection::Auto, + reasoning_parser: ParserSelection::Auto, + renderer: RendererSelection::Auto, + chat_template: None, + default_chat_template_kwargs: None, + chat_template_content_format: ChatTemplateContentFormatOption::Auto, + enable_log_requests: false, + disable_log_stats: false, + grpc_port: None, + shutdown_timeout: Duration::ZERO, + }; + + let bind_address = format!("127.0.0.1:{port}"); + let shutdown = CancellationToken::new(); + let server_config = config.clone(); + let server_shutdown = shutdown.clone(); + let server_task = tokio::spawn(async move { serve(server_config, server_shutdown).await }); + + let client = Client::with_config( + OpenAIConfig::new() + .with_api_key("unused") + .with_api_base(format!("http://{bind_address}/v1")), + ); + + print_models(&client).await?; + let final_text = stream_completion(&client, &config.model, &args.prompt).await?; + + println!(); + println!("final_text={final_text:?}"); + + shutdown_server(server_task, shutdown).await +} + +fn init_tracing() { + let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); + let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init(); +} + +fn unique_local_port() -> Result { + let listener = std::net::TcpListener::bind("127.0.0.1:0") + .context("failed to allocate local smoke-test port")?; + let port = listener.local_addr().context("failed to read local smoke-test port")?.port(); + drop(listener); + Ok(port) +} + +async fn print_models(client: &Client) -> Result<()> { + let models = wait_for_models(client).await?; + let model_ids = models.data.into_iter().map(|model| model.id).collect::>(); + println!("models={model_ids:?}"); + Ok(()) +} + +async fn wait_for_models(client: &Client) -> Result { + let mut last_error = None; + for _ in 0..240 { + match client.models().list().await { + Ok(models) => return Ok(models), + Err(error) => { + last_error = Some(error); + tokio::time::sleep(Duration::from_millis(500)).await; + } + } + } + + match last_error { + Some(error) => Err(error).context("OpenAI server did not become ready in time"), + None => bail!("OpenAI server readiness loop finished without a result"), + } +} + +async fn stream_completion( + client: &Client, + model: &str, + prompt: &str, +) -> Result { + // Keep this smoke test on async-openai's standard `create_stream` path so it + // exercises the ordinary typed chat-completions client without BYOT + // request/response types. + // + // The current async-openai chat-completions stream delta type does not expose + // our OpenAI-compatible `reasoning_content` extension field, so this + // example only validates the assistant role chunk, visible `content` + // deltas, and terminal finish chunk. Reasoning coverage lives in our own + // route tests and in the `vllm-chat` smoke example. + let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default() + .model(model) + .stream(true) + .temperature(0.0) + .max_completion_tokens(128u32) + .messages([ChatCompletionRequestUserMessageArgs::default() + .content(prompt) + .build() + .context("failed to build user chat message")? + .into()]) + .build() + .context("failed to build chat completion request")?; + + let mut stream = client + .chat() + .create_stream(request) + .await + .context("failed to create streaming chat completion")?; + + let mut final_text = String::new(); + let mut saw_role = false; + let mut saw_finish_reason = false; + let mut saw_text = false; + + while let Some(chunk) = stream.next().await { + let chunk = chunk.context("streaming chat completion failed")?; + for choice in chunk.choices { + if choice.delta.role.is_some() { + saw_role = true; + } + if let Some(delta) = choice.delta.content { + if !saw_text { + print!("[answer] "); + } + print!("{delta}"); + final_text.push_str(&delta); + saw_text = true; + } + if choice.finish_reason.is_some() { + saw_finish_reason = true; + } + } + } + + if !saw_role { + bail!("stream ended without an assistant role chunk"); + } + if !saw_finish_reason { + bail!("stream ended without a terminal finish reason"); + } + if final_text.is_empty() { + bail!("stream ended without any content deltas"); + } + + Ok(final_text) +} + +async fn shutdown_server( + server_task: tokio::task::JoinHandle>, + shutdown: CancellationToken, +) -> Result<()> { + shutdown.cancel(); + server_task + .await + .context("server task join failed")? + .context("server task failed") +} diff --git a/rust/src/server/src/config.rs b/rust/src/server/src/config.rs new file mode 100644 index 00000000000..522133427f4 --- /dev/null +++ b/rust/src/server/src/config.rs @@ -0,0 +1,110 @@ +use std::collections::HashMap; +use std::time::Duration; + +use anyhow::Result; +use serde_json::Value; +use vllm_chat::{ChatTemplateContentFormatOption, ParserSelection, RendererSelection}; +use vllm_engine_core_client::{CoordinatorMode as EngineCoreCoordinatorMode, TransportMode}; + +/// How the HTTP server obtains its listening socket. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HttpListenerMode { + /// Bind a fresh TCP listener on the given host/port. + BindTcp { host: String, port: u16 }, + /// Bind a fresh Unix domain listener on the given filesystem path. + BindUnix { path: String }, + /// Adopt an already-open listening socket inherited from a supervisor + /// process. + InheritedFd { fd: i32 }, +} + +/// Which coordinator implementation should be active when one is present for a +/// frontend client. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CoordinatorMode { + /// Do not run a coordinator at all. + None, + /// Run the Rust in-process coordinator for managed `serve` deployments, if + /// there are multiple engines and the model is MoE. + MaybeInProc, + /// Connect to an external coordinator owned by another process. + External { address: String }, +} + +/// Normalized runtime configuration for the minimal OpenAI-compatible server. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Config { + /// Frontend-to-engine transport setup. + pub transport_mode: TransportMode, + /// Requested frontend-side coordinator behavior. + pub coordinator_mode: CoordinatorMode, + /// Backend model identifier used for engine-core loading. + pub model: String, + /// Model name(s) exposed to clients via the OpenAI API. When non-empty, + /// the first entry is used as the primary ID in responses and all entries + /// are accepted in requests. When empty, falls back to `model`. + pub served_model_name: Vec, + /// HTTP listener setup. + pub listener_mode: HttpListenerMode, + /// Tool-call parser selection. + pub tool_call_parser: ParserSelection, + /// Reasoning parser selection. + pub reasoning_parser: ParserSelection, + /// Chat renderer selection. + pub renderer: RendererSelection, + /// Server-default chat template override, as a file path or inline + /// template. + pub chat_template: Option, + /// Server-default keyword arguments merged into every chat-template render. + pub default_chat_template_kwargs: Option>, + /// How to serialize `message.content` for chat-template rendering. + pub chat_template_content_format: ChatTemplateContentFormatOption, + /// Log a summary line for each completed request. + pub enable_log_requests: bool, + /// When `true`, suppress periodic stats logging (throughput, queue depth, + /// cache usage). + pub disable_log_stats: bool, + /// TCP port for the gRPC Generate service. When `None`, no gRPC server is + /// started. + pub grpc_port: Option, + /// Maximum time to wait for active HTTP/gRPC requests to drain on shutdown. + pub shutdown_timeout: Duration, +} + +impl Config { + /// Validate frontend configuration that can be checked before engine + /// startup. + pub fn validate(&self) -> Result<()> { + vllm_chat::validate_parser_overrides(&self.tool_call_parser, &self.reasoning_parser)?; + + Ok(()) + } + + /// Return the number of engines implied by the configured transport mode. + pub fn engine_count(&self) -> usize { + match &self.transport_mode { + TransportMode::HandshakeOwner { engine_count, .. } + | TransportMode::Bootstrapped { engine_count, .. } => *engine_count, + } + } + + /// Resolve the effective coordinator mode. + pub fn effective_coordinator_mode( + &self, + model_is_moe: bool, + ) -> Option { + match &self.coordinator_mode { + CoordinatorMode::None => None, + CoordinatorMode::MaybeInProc => { + if model_is_moe && self.engine_count() > 1 { + Some(EngineCoreCoordinatorMode::InProc) + } else { + None + } + } + CoordinatorMode::External { address } => Some(EngineCoreCoordinatorMode::External { + address: address.clone(), + }), + } + } +} diff --git a/rust/src/server/src/error.rs b/rust/src/server/src/error.rs new file mode 100644 index 00000000000..cc425ca076f --- /dev/null +++ b/rust/src/server/src/error.rs @@ -0,0 +1,74 @@ +use axum::Json; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use thiserror_ext::{Construct, Macro}; + +use crate::routes::openai::utils::types::{ErrorDetail, ErrorResponse}; + +/// Small OpenAI-style error family used by the minimal HTTP layer. +#[derive(Debug, Construct, Macro)] +pub enum ApiError { + /// The request is syntactically valid OpenAI JSON but asks for unsupported + /// behavior. + InvalidRequest { + message: String, + param: Option<&'static str>, + }, + /// The requested model name does not match the single configured model. + ModelNotFound { model: String }, + /// The request body could not be parsed as valid JSON. + JsonParseError { message: String }, + /// An unexpected internal failure happened before streaming started. + ServerError { message: String }, +} + +impl ApiError { + /// Return the HTTP status code associated with this API error. + pub fn status_code(&self) -> StatusCode { + match self { + Self::InvalidRequest { .. } => StatusCode::BAD_REQUEST, + Self::ModelNotFound { .. } => StatusCode::NOT_FOUND, + Self::ServerError { .. } => StatusCode::INTERNAL_SERVER_ERROR, + Self::JsonParseError { .. } => StatusCode::BAD_REQUEST, + } + } + + /// Convert this error into the standard OpenAI-compatible JSON error + /// payload. + pub fn to_error_response(&self) -> ErrorResponse { + let error = match self { + Self::InvalidRequest { message, param } => ErrorDetail { + message: message.clone(), + error_type: "invalid_request_error".to_string(), + param: param.map(|p| p.to_string()), + code: Some("invalid_request_error".to_string()), + }, + Self::ModelNotFound { model } => ErrorDetail { + message: format!("The model `{model}` does not exist."), + error_type: "invalid_request_error".to_string(), + param: Some("model".to_string()), + code: Some("model_not_found".to_string()), + }, + Self::ServerError { message } => ErrorDetail { + message: message.clone(), + error_type: "server_error".to_string(), + param: None, + code: Some("server_error".to_string()), + }, + Self::JsonParseError { message } => ErrorDetail { + message: message.clone(), + error_type: "invalid_request_error".to_string(), + param: None, + code: Some("json_parse_error".to_string()), + }, + }; + + ErrorResponse { error } + } +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + (self.status_code(), Json(self.to_error_response())).into_response() + } +} diff --git a/rust/src/server/src/grpc/convert.rs b/rust/src/server/src/grpc/convert.rs new file mode 100644 index 00000000000..ed21dd3d339 --- /dev/null +++ b/rust/src/server/src/grpc/convert.rs @@ -0,0 +1,679 @@ +//! Conversion between gRPC protobuf types and internal `vllm-text` +//! request/response types. + +use tonic::Status; +use uuid::Uuid; +use vllm_engine_core_client::protocol::{StopReason, StructuredOutputsParams}; +use vllm_text::{ + DecodedLogprobs, DecodedPromptLogprobs, FinishReason, Finished, Prompt, SamplingParams, + TextDecodeOptions, TextRequest, +}; + +use super::pb; + +// ======================================================================================== +// Request conversion +// ======================================================================================== + +/// Convert a gRPC `GenerateRequest` into the internal `TextRequest`. +/// +/// If `req.model` is non-empty, it must match one of `served_model_names`; +/// otherwise the request is rejected with `NotFound`. An empty string is +/// treated as "unset" (proto3 default) and accepted. +pub fn to_text_request( + req: pb::GenerateRequest, + stream: bool, + served_model_names: &[String], +) -> Result { + if !req.model.is_empty() && !served_model_names.iter().any(|n| n == &req.model) { + return Err(Status::not_found(format!( + "model `{}` not found", + req.model + ))); + } + + if req.truncate_prompt_tokens != 0 { + return Err(Status::invalid_argument( + "truncate_prompt_tokens is not supported", + )); + } + + let prompt = match req.prompt { + Some(pb::generate_request::Prompt::Text(text)) => Prompt::Text(text), + Some(pb::generate_request::Prompt::TokenIds(ids)) => Prompt::TokenIds(ids.ids), + None => return Err(Status::invalid_argument("prompt is required")), + }; + + let request_id = if req.request_id.is_empty() { + Uuid::new_v4().to_string() + } else { + req.request_id + }; + + let sampling = req.sampling.as_ref(); + let decoding = req.decoding.as_ref(); + let stopping = req.stopping.as_ref(); + let response = req.response.as_ref(); + let kv = req.kv.as_ref(); + + let mut sampling_params = + build_sampling_params(req.temperature, sampling, decoding, stopping, response)?; + + // Thread KVCacheParameters → SamplingParams fields. + if let Some(kv) = kv { + // Thread kv_transfer_params through vllm_xargs, matching the HTTP route + // convention. + if let Some(kv_struct) = kv.kv_transfer_params.as_ref() { + let kv_json = proto_struct_to_json(kv_struct); + let map = sampling_params.vllm_xargs.get_or_insert_with(Default::default); + map.insert("kv_transfer_params".to_string(), kv_json); + } + if kv.bypass_prefix_cache { + sampling_params.skip_reading_prefix_cache = Some(true); + } + } + + let decode_options = TextDecodeOptions { + skip_special_tokens: true, + include_stop_str_in_output: stopping.is_some_and(|s| s.include_stop_strings), + stop_strings: stopping.map(|s| &s.stop_strings).filter(|ss| !ss.is_empty()).cloned(), + min_tokens: stopping.map_or(0, |s| s.min_new_tokens), + }; + + Ok(TextRequest { + request_id, + prompt, + mm_features: None, + sampling_params, + decode_options, + intermediate: stream, + priority: req.priority, + cache_salt: kv.map(|k| &k.cache_salt).filter(|s| !s.is_empty()).cloned(), + add_special_tokens: true, + data_parallel_rank: None, + }) +} + +fn build_sampling_params( + temperature: Option, + sampling: Option<&pb::RandomSampling>, + decoding: Option<&pb::DecodingParameters>, + stopping: Option<&pb::StoppingCriteria>, + response: Option<&pb::ResponseOptions>, +) -> Result { + // Temperature is a top-level GenerateRequest field. Default to greedy (0.0) for + // the gRPC API when the caller does not specify a value. This differs from + // the HTTP/OpenAI API (which defaults to 1.0) and matches the convention of + // programmatic generation APIs. + let temperature = temperature.or(Some(0.0)); + let mut params = SamplingParams { + temperature, + ..SamplingParams::default() + }; + + // RandomSampling: for every remaining sampling field the protobuf default (`0`) + // is treated as "unset" and leaves the resolved value to the lowering + // stage, which falls back to the model-provided default or a + // neutral/disabled value otherwise. + if let Some(s) = sampling { + // num_sequences (n > 1) is not supported yet by the TextLlm layer; the response + // path also hardcodes SequenceOutput.index = 0, so accepting >1 would silently + // truncate output cardinality. Reject explicitly. + if s.num_sequences > 1 { + return Err(Status::invalid_argument( + "num_sequences > 1 is not supported", + )); + } + if s.top_k != 0 { + params.top_k = Some(s.top_k); + } + if s.top_p != 0.0 { + params.top_p = Some(s.top_p); + } + if s.min_p != 0.0 { + params.min_p = Some(s.min_p); + } + params.seed = s.seed; + } + + // DecodingParameters + if let Some(d) = decoding { + if d.presence_penalty != 0.0 { + params.presence_penalty = Some(d.presence_penalty); + } + if d.frequency_penalty != 0.0 { + params.frequency_penalty = Some(d.frequency_penalty); + } + if d.repetition_penalty != 0.0 { + params.repetition_penalty = Some(d.repetition_penalty); + } + if !d.logit_bias.is_empty() { + params.logit_bias = Some(d.logit_bias.clone()); + } + if !d.allowed_token_ids.is_empty() { + params.allowed_token_ids = Some(d.allowed_token_ids.clone()); + } + params.structured_outputs = convert_structured_output(d)?; + } + + // StoppingCriteria + if let Some(s) = stopping { + if s.max_new_tokens != 0 { + params.max_tokens = Some(s.max_new_tokens); + } + if s.min_new_tokens != 0 { + params.min_tokens = Some(s.min_new_tokens); + } + if !s.stop_token_ids.is_empty() { + params.stop_token_ids = Some(s.stop_token_ids.clone()); + } + params.ignore_eos = s.ignore_eos; + } + + // ResponseOptions → logprobs + if let Some(r) = response { + if r.output_logprobs { + let (count, token_ids) = candidate_logprob_spec(r.output_candidates.as_ref()); + params.logprobs = Some(count); + params.logprob_token_ids = token_ids; + } + if r.prompt_logprobs { + // The engine-core protocol has only one shared `logprob_token_ids` field + // for output and prompt logprobs, so a per-token-id selector for prompt + // candidates can't be honored independently. Reject it instead of silently + // dropping the list. + if matches!( + r.prompt_candidates.as_ref().and_then(|c| c.select.as_ref()), + Some(pb::candidate_tokens::Select::TokenIds(_)) + ) { + return Err(Status::invalid_argument( + "prompt_candidates token_ids selector is not supported", + )); + } + let (count, _) = candidate_logprob_spec(r.prompt_candidates.as_ref()); + params.prompt_logprobs = Some(count); + } + } + + Ok(params) +} + +/// Map the proto `CandidateTokens` selector to a `(logprobs_count, +/// logprob_token_ids)` pair. +/// +/// - `top_n(k)` → `(k, None)` — return top-k candidates by probability +/// - `all` → `(-1, None)` — return the full vocabulary +/// - `token_ids(n)` → `(1, Some(vec of n token ids))` — return logprobs for specific tokens (the +/// count `n` is stored in the proto as the number of token IDs that follow, but the actual IDs +/// are carried via `logprob_token_ids` on `SamplingParams`) +/// - absent → `(1, None)` — just the sampled/scored token +fn candidate_logprob_spec(candidates: Option<&pb::CandidateTokens>) -> (i32, Option>) { + match candidates.and_then(|c| c.select.as_ref()) { + Some(pb::candidate_tokens::Select::TopN(n)) => (*n as i32, None), + Some(pb::candidate_tokens::Select::All(true)) => (-1, None), + Some(pb::candidate_tokens::Select::TokenIds(ids)) => (1, Some(ids.ids.clone())), + _ => (1, None), + } +} + +fn convert_structured_output( + d: &pb::DecodingParameters, +) -> Result, Status> { + let so = match d.structured_output.as_ref() { + None => return Ok(None), + Some(so) => so, + }; + use pb::decoding_parameters::StructuredOutput; + let params = match so { + StructuredOutput::Json(schema) => { + let json: serde_json::Value = serde_json::from_str(schema) + .map_err(|e| Status::invalid_argument(format!("invalid json schema: {e}")))?; + StructuredOutputsParams { + json: Some(json), + ..Default::default() + } + } + StructuredOutput::Regex(regex) => StructuredOutputsParams { + regex: Some(regex.clone()), + ..Default::default() + }, + StructuredOutput::Choice(choices) => StructuredOutputsParams { + choice: Some(choices.choices.clone()), + ..Default::default() + }, + StructuredOutput::Grammar(grammar) => StructuredOutputsParams { + grammar: Some(grammar.clone()), + ..Default::default() + }, + StructuredOutput::JsonObject(true) => StructuredOutputsParams { + json_object: Some(true), + ..Default::default() + }, + StructuredOutput::JsonObject(false) => return Ok(None), + StructuredOutput::StructuralTag(tag) => StructuredOutputsParams { + structural_tag: Some(tag.clone()), + ..Default::default() + }, + }; + Ok(Some(params)) +} + +// ======================================================================================== +// Response conversion +// ======================================================================================== + +/// Convert a `DecodedTextEvent::Start` into the prompt info portion of a gRPC +/// response. +pub fn to_prompt_info( + prompt_token_ids: &[u32], + prompt_logprobs: Option<&DecodedPromptLogprobs>, + opts: &ResponseOpts, +) -> pb::PromptInfo { + let token_ids = if opts.prompt_token_ids { + prompt_token_ids.to_vec() + } else { + vec![] + }; + + let (logprobs, ranks, candidate_tokens) = match prompt_logprobs { + Some(plp) if opts.prompt_logprobs => prompt_logprobs_to_proto(plp), + _ => (vec![], vec![], vec![]), + }; + + pb::PromptInfo { + num_prompt_tokens: prompt_token_ids.len() as u32, + token_ids, + logprobs, + ranks, + candidate_tokens, + } +} + +/// Convert a `DecodedTextEvent::TextDelta` into a gRPC `SequenceOutput`. +pub fn to_sequence_output( + delta: &str, + token_ids: &[u32], + logprobs: Option<&DecodedLogprobs>, + finished: Option<&Finished>, + opts: &ResponseOpts, +) -> pb::SequenceOutput { + let (lp_values, rank_values, candidates) = match logprobs { + Some(lp) if opts.output_logprobs => output_logprobs_to_proto(lp), + _ => (vec![], vec![], vec![]), + }; + + pb::SequenceOutput { + index: 0, // TODO: multi-sequence (n > 1) not supported + text: if opts.output_text { + delta.to_string() + } else { + String::new() + }, + num_tokens: token_ids.len() as u32, + token_ids: if opts.output_token_ids { + token_ids.to_vec() + } else { + vec![] + }, + logprobs: lp_values, + ranks: rank_values, + candidate_tokens: candidates, + finish_info: finished.map(|f| to_finish_info(f, token_ids)), + } +} + +fn to_finish_info(finished: &Finished, token_ids: &[u32]) -> pb::FinishInfo { + use pb::finish_info::FinishReason as PbFinishReason; + + let (finish_reason, stop_reason) = match &finished.finish_reason { + FinishReason::Stop(reason) => { + let sr = match reason { + Some(StopReason::TokenId(id)) => { + Some(pb::finish_info::StopReason::StopTokenId(*id)) + } + Some(StopReason::Text(s)) => { + Some(pb::finish_info::StopReason::StopString(s.clone())) + } + // EOS-driven stop: engine-core matched the primary EOS token id but did not + // echo it back as a `stop_reason`. The matched token is, by construction, the + // last token of the terminal output batch (see vllm's `check_stop` in + // vllm/v1/core/sched/utils.py), so we recover it from there. + None => token_ids.last().copied().map(pb::finish_info::StopReason::EosTokenId), + }; + (PbFinishReason::Stop as i32, sr) + } + FinishReason::Length => (PbFinishReason::Length as i32, None), + FinishReason::Abort | FinishReason::Error | FinishReason::Repetition => { + (PbFinishReason::Aborted as i32, None) + } + }; + + pb::FinishInfo { + num_output_tokens: finished.output_token_count as u32, + finish_reason, + stop_reason, + kv_transfer_params: finished.kv_transfer_params.as_ref().and_then(json_to_proto_struct), + } +} + +// ======================================================================================== +// Logprobs helpers +// ======================================================================================== + +/// Convert output logprobs to the flat proto representation. +/// +/// Returns (logprob_values, ranks, candidate_tokens) — all parallel arrays +/// indexed by position. +fn output_logprobs_to_proto( + lp: &DecodedLogprobs, +) -> (Vec, Vec, Vec) { + positions_to_proto(&lp.positions) +} + +/// Convert prompt logprobs to the flat proto representation. +fn prompt_logprobs_to_proto( + plp: &DecodedPromptLogprobs, +) -> (Vec, Vec, Vec) { + // The proto PromptInfo has flat parallel arrays covering all prompt positions. + // DecodedPromptLogprobs has first_token separately + scored_positions for the + // rest. The first prompt position has no scores, so we emit zeros for it. + let (mut logprobs, mut ranks, mut candidates) = positions_to_proto(&plp.scored_positions); + logprobs.insert(0, 0.0); + ranks.insert(0, 0); + candidates.insert(0, pb::CandidateTokenInfo { tokens: vec![] }); + (logprobs, ranks, candidates) +} + +/// Shared helper: convert a slice of decoded position logprobs to flat proto +/// arrays. +fn positions_to_proto( + positions: &[vllm_text::DecodedPositionLogprobs], +) -> (Vec, Vec, Vec) { + let mut logprobs = Vec::with_capacity(positions.len()); + let mut ranks = Vec::with_capacity(positions.len()); + let mut candidates = Vec::with_capacity(positions.len()); + + for pos in positions { + // First entry is the sampled/scored token. + if let Some(first) = pos.entries.first() { + logprobs.push(first.logprob); + ranks.push(first.rank); + } + + // Extra candidates beyond the first. + let entries = pos.entries.iter().skip(1); + candidates.push(pb::CandidateTokenInfo { + tokens: entries + .map(|e| pb::candidate_token_info::TokenInfo { + id: e.token_id, + logprob: e.logprob, + rank: e.rank, + }) + .collect(), + }); + } + + (logprobs, ranks, candidates) +} + +// ======================================================================================== +// KV transfer params conversion (serde_json::Value ↔ prost_types::Struct) +// ======================================================================================== + +fn proto_struct_to_json(s: &prost_types::Struct) -> serde_json::Value { + serde_json::Value::Object( + s.fields.iter().map(|(k, v)| (k.clone(), proto_value_to_json(v))).collect(), + ) +} + +fn proto_value_to_json(v: &prost_types::Value) -> serde_json::Value { + use prost_types::value::Kind; + match v.kind.as_ref() { + None | Some(Kind::NullValue(_)) => serde_json::Value::Null, + Some(Kind::BoolValue(b)) => serde_json::Value::Bool(*b), + Some(Kind::NumberValue(n)) => serde_json::json!(*n), + Some(Kind::StringValue(s)) => serde_json::Value::String(s.clone()), + Some(Kind::ListValue(list)) => { + serde_json::Value::Array(list.values.iter().map(proto_value_to_json).collect()) + } + Some(Kind::StructValue(s)) => proto_struct_to_json(s), + } +} + +fn json_to_proto_struct(value: &serde_json::Value) -> Option { + match value { + serde_json::Value::Object(map) => Some(prost_types::Struct { + fields: map.iter().map(|(k, v)| (k.clone(), json_to_proto_value(v))).collect(), + }), + _ => None, + } +} + +fn json_to_proto_value(v: &serde_json::Value) -> prost_types::Value { + use prost_types::value::Kind; + let kind = match v { + serde_json::Value::Null => Kind::NullValue(0), + serde_json::Value::Bool(b) => Kind::BoolValue(*b), + serde_json::Value::Number(n) => Kind::NumberValue(n.as_f64().unwrap_or(0.0)), + serde_json::Value::String(s) => Kind::StringValue(s.clone()), + serde_json::Value::Array(arr) => Kind::ListValue(prost_types::ListValue { + values: arr.iter().map(json_to_proto_value).collect(), + }), + serde_json::Value::Object(map) => Kind::StructValue(prost_types::Struct { + fields: map.iter().map(|(k, v)| (k.clone(), json_to_proto_value(v))).collect(), + }), + }; + prost_types::Value { kind: Some(kind) } +} + +// ======================================================================================== +// Options extracted from the request for response building +// ======================================================================================== + +/// Response-shaping options extracted from the proto `ResponseOptions`. +#[derive(Default)] +pub struct ResponseOpts { + pub prompt_token_ids: bool, + pub prompt_logprobs: bool, + pub output_text: bool, + pub output_token_ids: bool, + pub output_logprobs: bool, +} + +impl ResponseOpts { + pub fn from_proto(r: Option<&pb::ResponseOptions>) -> Self { + match r { + Some(r) => Self { + prompt_token_ids: r.prompt_token_ids, + prompt_logprobs: r.prompt_logprobs, + output_text: r.output_text.unwrap_or(true), + output_token_ids: r.output_token_ids, + output_logprobs: r.output_logprobs, + }, + None => Self { + output_text: true, + ..Default::default() + }, + } + } +} + +#[cfg(test)] +mod tests { + use vllm_engine_core_client::protocol::StopReason; + use vllm_text::{FinishReason, Finished, Prompt}; + + use super::pb::finish_info::{FinishReason as PbFinishReason, StopReason as PbStopReason}; + use super::{ResponseOpts, pb, to_finish_info, to_sequence_output, to_text_request}; + + fn base_request() -> pb::GenerateRequest { + pb::GenerateRequest { + request_id: "req".to_string(), + model: "test-model".to_string(), + prompt: Some(pb::generate_request::Prompt::Text("hi".to_string())), + ..Default::default() + } + } + + #[test] + fn temperature_propagates_from_top_level_request_field() { + let req = pb::GenerateRequest { + temperature: Some(0.7), + ..base_request() + }; + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); + assert_eq!(text.sampling_params.temperature, Some(0.7)); + } + + #[test] + fn unset_temperature_defaults_to_greedy() { + let text = to_text_request(base_request(), false, &["test-model".to_string()]) + .expect("convert ok"); + // The gRPC API defaults to greedy (0.0) when temperature is not specified. + assert_eq!(text.sampling_params.temperature, Some(0.0)); + } + + #[test] + fn absent_seed_is_none() { + let req = pb::GenerateRequest { + sampling: Some(pb::RandomSampling { + seed: None, + ..Default::default() + }), + ..base_request() + }; + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); + assert_eq!(text.sampling_params.seed, None); + } + + #[test] + fn zero_seed_is_valid() { + let req = pb::GenerateRequest { + sampling: Some(pb::RandomSampling { + seed: Some(0), + ..Default::default() + }), + ..base_request() + }; + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); + assert_eq!(text.sampling_params.seed, Some(0)); + } + + #[test] + fn bypass_prefix_cache_maps_to_skip_reading_prefix_cache() { + let req = pb::GenerateRequest { + kv: Some(pb::KvCacheParameters { + bypass_prefix_cache: true, + ..Default::default() + }), + ..base_request() + }; + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); + assert_eq!(text.sampling_params.skip_reading_prefix_cache, Some(true)); + } + + #[test] + fn bypass_prefix_cache_false_leaves_field_unset() { + let req = pb::GenerateRequest { + kv: Some(pb::KvCacheParameters { + bypass_prefix_cache: false, + ..Default::default() + }), + ..base_request() + }; + let text = to_text_request(req, false, &["test-model".to_string()]).expect("convert ok"); + assert_eq!(text.sampling_params.skip_reading_prefix_cache, None); + // Prompt conversion still succeeds and reaches the expected variant. + assert!(matches!(text.prompt, Prompt::Text(s) if s == "hi")); + } + + fn finished(reason: FinishReason) -> Finished { + Finished { + prompt_token_count: 0, + output_token_count: 0, + finish_reason: reason, + kv_transfer_params: None, + } + } + + #[test] + fn eos_stop_reports_last_output_token_as_eos_id() { + let fin = finished(FinishReason::Stop(None)); + let token_ids = [1_u32, 2, 3, 151643]; + + let info = to_finish_info(&fin, &token_ids); + + assert_eq!(info.finish_reason, PbFinishReason::Stop as i32); + assert_eq!(info.stop_reason, Some(PbStopReason::EosTokenId(151643))); + } + + #[test] + fn eos_stop_with_empty_token_ids_leaves_stop_reason_unset() { + let fin = finished(FinishReason::Stop(None)); + + let info = to_finish_info(&fin, &[]); + + assert_eq!(info.finish_reason, PbFinishReason::Stop as i32); + assert_eq!(info.stop_reason, None); + } + + #[test] + fn explicit_stop_token_id_is_preserved() { + let fin = finished(FinishReason::Stop(Some(StopReason::TokenId(42)))); + // Terminal token list should be ignored when an explicit stop reason is + // present. + let info = to_finish_info(&fin, &[7, 42]); + + assert_eq!(info.finish_reason, PbFinishReason::Stop as i32); + assert_eq!(info.stop_reason, Some(PbStopReason::StopTokenId(42))); + } + + #[test] + fn explicit_stop_string_is_preserved() { + let fin = finished(FinishReason::Stop(Some(StopReason::Text("".into())))); + + let info = to_finish_info(&fin, &[1, 2, 3]); + + assert_eq!(info.finish_reason, PbFinishReason::Stop as i32); + assert_eq!( + info.stop_reason, + Some(PbStopReason::StopString("".into())) + ); + } + + #[test] + fn length_finish_has_no_stop_reason() { + let fin = finished(FinishReason::Length); + + let info = to_finish_info(&fin, &[1, 2, 3]); + + assert_eq!(info.finish_reason, PbFinishReason::Length as i32); + assert_eq!(info.stop_reason, None); + } + + #[test] + fn abort_finish_is_mapped_to_aborted() { + let fin = finished(FinishReason::Abort); + + let info = to_finish_info(&fin, &[]); + + assert_eq!(info.finish_reason, PbFinishReason::Aborted as i32); + assert_eq!(info.stop_reason, None); + } + + #[test] + fn to_sequence_output_threads_token_ids_into_eos_id() { + let fin = finished(FinishReason::Stop(None)); + let opts = ResponseOpts { + output_text: true, + output_token_ids: true, + ..Default::default() + }; + + let out = to_sequence_output("hello", &[10, 20, 30], None, Some(&fin), &opts); + + let finish = out.finish_info.expect("finish_info should be present"); + assert_eq!(finish.finish_reason, PbFinishReason::Stop as i32); + assert_eq!(finish.stop_reason, Some(PbStopReason::EosTokenId(30))); + } +} diff --git a/rust/src/server/src/grpc/mod.rs b/rust/src/server/src/grpc/mod.rs new file mode 100644 index 00000000000..2f648aa6ce0 --- /dev/null +++ b/rust/src/server/src/grpc/mod.rs @@ -0,0 +1,158 @@ +//! gRPC Generate service backed by the shared [`vllm_text::TextLlm`] facade. + +mod convert; + +use std::pin::Pin; +use std::sync::Arc; + +use futures::{Stream, StreamExt as _}; +use thiserror_ext::AsReport as _; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; +use tracing::info; +use vllm_text::{DecodedTextEvent, TextOutputStreamExt as _}; + +use self::convert::ResponseOpts; +use crate::state::AppState; + +/// Generated protobuf/gRPC types for the `vllm` package. +pub mod pb { + tonic::include_proto!("vllm"); +} + +pub use pb::generate_server::GenerateServer; + +#[cfg(test)] +mod tests; + +/// gRPC Generate service implementation backed by the shared application state. +pub struct GenerateServiceImpl { + state: Arc, +} + +impl GenerateServiceImpl { + pub fn new(state: Arc) -> Self { + Self { state } + } +} + +#[tonic::async_trait] +impl pb::generate_server::Generate for GenerateServiceImpl { + type GenerateStreamStream = + Pin> + Send>>; + + /// Unary generate: collect all output and return a single response. + async fn generate( + &self, + request: Request, + ) -> Result, Status> { + let proto_req = request.into_inner(); + let response_opts = ResponseOpts::from_proto(proto_req.response.as_ref()); + let text_request = + convert::to_text_request(proto_req, false, self.state.served_model_names())?; + + let request_id = text_request.request_id.clone(); + info!(%request_id, "grpc generate (unary)"); + + let stream = self.state.chat.text().generate(text_request).await; + let stream = stream.map_err(|e| Status::internal(e.to_report_string()))?; + + let collected = stream + .collect_output() + .await + .map_err(|e| Status::internal(e.to_report_string()))?; + + // Build the single aggregated response. + let prompt_info = convert::to_prompt_info( + &collected.prompt_token_ids, + collected.prompt_logprobs.as_ref(), + &response_opts, + ); + + let finish_info = vllm_text::Finished { + prompt_token_count: collected.prompt_token_ids.len(), + output_token_count: collected.token_ids.len(), + finish_reason: collected.finish_reason, + kv_transfer_params: collected.kv_transfer_params, + }; + + let outputs = convert::to_sequence_output( + &collected.text, + &collected.token_ids, + collected.logprobs.as_ref(), + Some(&finish_info), + &response_opts, + ); + + Ok(Response::new(pb::GenerateResponse { + prompt_info: Some(prompt_info), + outputs: Some(outputs), + })) + } + + /// Streaming generate: yield incremental responses as tokens are produced. + async fn generate_stream( + &self, + request: Request, + ) -> Result, Status> { + let proto_req = request.into_inner(); + let response_opts = ResponseOpts::from_proto(proto_req.response.as_ref()); + let text_request = + convert::to_text_request(proto_req, true, self.state.served_model_names())?; + + let request_id = text_request.request_id.clone(); + info!(%request_id, "grpc generate (stream)"); + + let stream = self.state.chat.text().generate(text_request).await; + let stream = stream.map_err(|e| Status::internal(e.to_report_string()))?; + + let (tx, rx) = mpsc::channel(32); + + tokio::spawn(async move { + futures::pin_mut!(stream); + while let Some(event) = stream.next().await { + let response = match event { + Err(e) => Err(Status::internal(e.to_report_string())), + Ok(DecodedTextEvent::Start { + prompt_token_ids, + prompt_logprobs, + }) => { + let prompt_info = convert::to_prompt_info( + &prompt_token_ids, + prompt_logprobs.as_ref(), + &response_opts, + ); + Ok(pb::GenerateResponse { + prompt_info: Some(prompt_info), + outputs: None, + }) + } + Ok(DecodedTextEvent::TextDelta { + delta, + token_ids, + logprobs, + finished, + }) => Ok(pb::GenerateResponse { + prompt_info: None, + outputs: Some(convert::to_sequence_output( + &delta, + &token_ids, + logprobs.as_ref(), + finished.as_ref(), + &response_opts, + )), + }), + }; + + if tx.send(response).await.is_err() { + // Client disconnected. + break; + } + } + }); + + let response_stream = ReceiverStream::new(rx); + Ok(Response::new(Box::pin(response_stream))) + } +} diff --git a/rust/src/server/src/grpc/tests.rs b/rust/src/server/src/grpc/tests.rs new file mode 100644 index 00000000000..17361ae0e86 --- /dev/null +++ b/rust/src/server/src/grpc/tests.rs @@ -0,0 +1,662 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::StreamExt as _; +use serial_test::serial; +use tonic::transport::Server as TonicServer; +use vllm_chat::{ + ChatBackend, ChatLlm, ChatRenderer, ChatRequest, ChatTextBackend, DefaultChatOutputProcessor, + DynChatOutputProcessor, DynChatRenderer, NewChatOutputProcessorOptions, RenderedPrompt, +}; +use vllm_engine_core_client::protocol::{ + EngineCoreFinishReason, EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, +}; +use vllm_engine_core_client::test_utils::{IpcNamespace, spawn_mock_engine_task}; +use vllm_engine_core_client::{EngineCoreClient, EngineCoreClientConfig, EngineId}; +use vllm_llm::Llm; +use vllm_text::tokenizer::{DynTokenizer, Tokenizer}; +use vllm_text::{Prompt, TextBackend}; +use zeromq::prelude::{SocketRecv, SocketSend}; +use zeromq::{DealerSocket, PushSocket, ZmqMessage}; + +use super::pb::generate_client::GenerateClient; +use super::{GenerateServer, GenerateServiceImpl, pb}; +use crate::state::AppState; + +// ======================================================================================== +// Helpers (mirrors the patterns in routes/tests.rs) +// ======================================================================================== + +type TestFuture<'a> = Pin + Send + 'a>>; + +fn boxed_test_future<'a>(future: impl Future + Send + 'a) -> TestFuture<'a> { + Box::pin(future) +} + +struct MockEngineTask { + shutdown_tx: Option>, + join_handle: Option>, +} + +impl MockEngineTask { + fn new( + (shutdown_tx, join_handle): ( + tokio::sync::oneshot::Sender<()>, + tokio::task::JoinHandle<()>, + ), + ) -> Self { + Self { + shutdown_tx: Some(shutdown_tx), + join_handle: Some(join_handle), + } + } +} + +impl Future for MockEngineTask { + type Output = Result<(), tokio::task::JoinError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + match self.join_handle.as_mut() { + Some(join_handle) => Pin::new(join_handle).poll(cx), + None => Poll::Ready(Ok(())), + } + } +} + +impl Drop for MockEngineTask { + fn drop(&mut self) { + if let Some(join_handle) = &self.join_handle { + join_handle.abort(); + } + } +} + +fn request_output( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason, + stop_reason: None, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn engine_outputs_for_request( + request_id: &str, + output_specs: Vec<(Vec, Option)>, +) -> EngineCoreOutputs { + EngineCoreOutputs { + engine_index: 0, + outputs: output_specs + .into_iter() + .map(|(token_ids, finish_reason)| request_output(request_id, token_ids, finish_reason)) + .collect(), + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + } +} + +fn default_stream_output_specs() -> Vec<(Vec, Option)> { + vec![ + (vec![b'h' as u32], None), + (vec![b'i' as u32], None), + (vec![b'!' as u32], Some(EngineCoreFinishReason::Stop)), + ] +} + +async fn send_outputs(push: &mut PushSocket, outputs: EngineCoreOutputs) { + push.send(ZmqMessage::from( + rmp_serde::to_vec_named(&outputs).expect("encode outputs"), + )) + .await + .expect("send outputs"); +} + +async fn recv_engine_message(dealer: &mut DealerSocket) -> Vec { + dealer.recv().await.expect("recv engine message").into_vec() +} + +fn test_llm(client: EngineCoreClient) -> Llm { + Llm::new(client).with_request_id_randomization(false) +} + +#[derive(Clone, Debug)] +struct FakeTextBackend; + +#[derive(Debug)] +struct FakeTokenizer; + +impl Tokenizer for FakeTokenizer { + fn encode( + &self, + text: &str, + _add_special_tokens: bool, + ) -> vllm_text::tokenizer::Result> { + Ok(text.bytes().map(u32::from).collect()) + } + + fn decode( + &self, + token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_text::tokenizer::Result { + Ok( + String::from_utf8_lossy(&token_ids.iter().map(|id| *id as u8).collect::>()) + .into_owned(), + ) + } + + fn token_to_id(&self, token: &str) -> Option { + token.bytes().next().map(u32::from) + } +} + +impl TextBackend for FakeTextBackend { + fn tokenizer(&self) -> DynTokenizer { + Arc::new(FakeTokenizer) + } + + fn model_id(&self) -> &str { + "test-model" + } +} + +impl ChatBackend for FakeTextBackend { + fn chat_renderer(&self) -> DynChatRenderer { + Arc::new(self.clone()) + } + + fn new_chat_output_processor( + &self, + request: &mut ChatRequest, + options: NewChatOutputProcessorOptions<'_>, + ) -> vllm_chat::Result { + Ok(Box::new(DefaultChatOutputProcessor::new( + request, + self.model_id(), + self.tokenizer(), + options.tool_call_parser, + options.reasoning_parser, + )?)) + } +} + +impl ChatRenderer for FakeTextBackend { + fn render(&self, _request: &ChatRequest) -> vllm_chat::Result { + Ok(RenderedPrompt { + prompt: Prompt::Text(String::new()), + }) + } +} + +/// Spin up a gRPC server backed by a mock engine that serves a single request +/// with the given output specs. Returns the client, the gRPC server task, and +/// the mock engine task. +async fn grpc_test_server( + engine_id: impl Into, + output_specs: Vec<(Vec, Option)>, +) -> ( + GenerateClient, + tokio::task::JoinHandle<()>, + MockEngineTask, +) { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = engine_id.into(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + move |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + send_outputs( + push, + engine_outputs_for_request(&request.request_id, output_specs), + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + + let chat = ChatLlm::from_shared_backend( + test_llm(client), + Arc::new(FakeTextBackend) as Arc, + ); + let state = Arc::new(AppState::new(vec!["test-model".to_string()], chat)); + let svc = GenerateServer::new(GenerateServiceImpl::new(state)); + + // Bind to an OS-assigned port. + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind grpc listener"); + let addr = listener.local_addr().expect("local addr"); + + let server_task = tokio::spawn(async move { + let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); + TonicServer::builder() + .add_service(svc) + .serve_with_incoming(incoming) + .await + .expect("grpc server"); + }); + + // Connect the client. + let grpc_client = GenerateClient::connect(format!("http://{addr}")) + .await + .expect("connect grpc client"); + + (grpc_client, server_task, engine_task) +} + +// ======================================================================================== +// Tests +// ======================================================================================== + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn unary_generate_returns_collected_text() { + let (mut client, server_task, engine_task) = + grpc_test_server(b"engine-grpc-unary", default_stream_output_specs()).await; + + let response = client + .generate(pb::GenerateRequest { + request_id: "test-unary-1".to_string(), + model: "test-model".to_string(), + prompt: Some(pb::generate_request::Prompt::Text("hello".to_string())), + stopping: Some(pb::StoppingCriteria { + max_new_tokens: 10, + ..Default::default() + }), + response: Some(pb::ResponseOptions { + output_text: Some(true), + ..Default::default() + }), + ..Default::default() + }) + .await + .expect("unary generate") + .into_inner(); + + // Unary collects all tokens into one response. + let outputs = response.outputs.expect("outputs present"); + assert_eq!(outputs.text, "hi"); + + let finish = outputs.finish_info.expect("finish_info present"); + assert_eq!( + finish.finish_reason, + pb::finish_info::FinishReason::Stop as i32 + ); + assert_eq!(finish.num_output_tokens, 3); + + let prompt = response.prompt_info.expect("prompt_info present"); + assert_eq!(prompt.num_prompt_tokens, 5); // "hello" = 5 bytes + + engine_task.await.expect("mock engine task"); + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn unary_generate_with_token_ids_prompt() { + let (mut client, server_task, engine_task) = + grpc_test_server(b"engine-grpc-token-ids", default_stream_output_specs()).await; + + let response = client + .generate(pb::GenerateRequest { + request_id: "test-token-ids".to_string(), + model: "test-model".to_string(), + prompt: Some(pb::generate_request::Prompt::TokenIds(pb::TokenIds { + ids: vec![1, 2, 3], + })), + stopping: Some(pb::StoppingCriteria { + max_new_tokens: 10, + ..Default::default() + }), + ..Default::default() + }) + .await + .expect("unary generate with token ids") + .into_inner(); + + let outputs = response.outputs.expect("outputs present"); + assert_eq!(outputs.text, "hi"); + assert_eq!( + response.prompt_info.expect("prompt_info").num_prompt_tokens, + 3 + ); + + engine_task.await.expect("mock engine task"); + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn unary_generate_returns_token_ids_when_requested() { + let (mut client, server_task, engine_task) = + grpc_test_server(b"engine-grpc-tok-resp", default_stream_output_specs()).await; + + let response = client + .generate(pb::GenerateRequest { + request_id: "test-tok-resp".to_string(), + model: "test-model".to_string(), + prompt: Some(pb::generate_request::Prompt::Text("hi".to_string())), + stopping: Some(pb::StoppingCriteria { + max_new_tokens: 10, + ..Default::default() + }), + response: Some(pb::ResponseOptions { + output_text: Some(true), + output_token_ids: true, + prompt_token_ids: true, + ..Default::default() + }), + ..Default::default() + }) + .await + .expect("unary generate") + .into_inner(); + + let outputs = response.outputs.expect("outputs present"); + assert_eq!( + outputs.token_ids, + vec![b'h' as u32, b'i' as u32, b'!' as u32] + ); + + let prompt = response.prompt_info.expect("prompt_info present"); + assert_eq!(prompt.token_ids, vec![b'h' as u32, b'i' as u32]); + + engine_task.await.expect("mock engine task"); + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn unary_generate_missing_prompt_returns_invalid_argument() { + let (mut client, server_task, _engine_task) = + grpc_test_server(b"engine-grpc-no-prompt", default_stream_output_specs()).await; + + let status = client + .generate(pb::GenerateRequest { + request_id: "test-no-prompt".to_string(), + model: "test-model".to_string(), + prompt: None, + ..Default::default() + }) + .await + .expect_err("should fail without prompt"); + + assert_eq!(status.code(), tonic::Code::InvalidArgument); + assert!(status.message().contains("prompt")); + + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn streaming_generate_yields_incremental_responses() { + let (mut client, server_task, engine_task) = + grpc_test_server(b"engine-grpc-stream", default_stream_output_specs()).await; + + let stream = client + .generate_stream(pb::GenerateRequest { + request_id: "test-stream-1".to_string(), + model: "test-model".to_string(), + prompt: Some(pb::generate_request::Prompt::Text("hello".to_string())), + stopping: Some(pb::StoppingCriteria { + max_new_tokens: 10, + ..Default::default() + }), + response: Some(pb::ResponseOptions { + output_text: Some(true), + ..Default::default() + }), + ..Default::default() + }) + .await + .expect("streaming generate") + .into_inner(); + + let responses: Vec = + stream.map(|r| r.expect("stream item")).collect().await; + + // First response carries prompt info, subsequent ones carry output deltas. + assert!( + responses.len() >= 2, + "expected at least 2 streamed responses, got {}", + responses.len() + ); + + // First message should have prompt info. + let first = &responses[0]; + let prompt_info = first.prompt_info.as_ref().expect("first response has prompt_info"); + assert_eq!(prompt_info.num_prompt_tokens, 5); // "hello" + + // Collect all text deltas. + let full_text: String = responses + .iter() + .filter_map(|r| r.outputs.as_ref()) + .map(|o| o.text.as_str()) + .collect(); + assert_eq!(full_text, "hi"); + + // Last output response should have finish info. + let last_output = responses + .iter() + .rev() + .find_map(|r| r.outputs.as_ref()) + .expect("at least one output"); + let finish = last_output.finish_info.as_ref().expect("finish_info on last output"); + assert_eq!( + finish.finish_reason, + pb::finish_info::FinishReason::Stop as i32 + ); + + engine_task.await.expect("mock engine task"); + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn streaming_generate_missing_prompt_returns_invalid_argument() { + let (mut client, server_task, _engine_task) = grpc_test_server( + b"engine-grpc-stream-no-prompt", + default_stream_output_specs(), + ) + .await; + + let status = client + .generate_stream(pb::GenerateRequest { + request_id: "test-stream-no-prompt".to_string(), + model: "test-model".to_string(), + prompt: None, + ..Default::default() + }) + .await + .expect_err("should fail without prompt"); + + assert_eq!(status.code(), tonic::Code::InvalidArgument); + + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn unary_generate_with_sampling_params() { + let (mut client, server_task, engine_task) = + grpc_test_server(b"engine-grpc-sampling", default_stream_output_specs()).await; + + let response = client + .generate(pb::GenerateRequest { + request_id: "test-sampling".to_string(), + model: "test-model".to_string(), + prompt: Some(pb::generate_request::Prompt::Text("test".to_string())), + temperature: Some(0.7), + sampling: Some(pb::RandomSampling { + top_k: 50, + top_p: 0.9, + seed: Some(42), + ..Default::default() + }), + stopping: Some(pb::StoppingCriteria { + max_new_tokens: 5, + ..Default::default() + }), + ..Default::default() + }) + .await + .expect("generate with sampling params") + .into_inner(); + + // Verify the request was accepted and produced output. + let outputs = response.outputs.expect("outputs present"); + assert_eq!(outputs.text, "hi"); + + engine_task.await.expect("mock engine task"); + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn unary_generate_rejects_wrong_model() { + let (mut client, server_task, _engine_task) = + grpc_test_server(b"engine-grpc-wrong-model", default_stream_output_specs()).await; + + let status = client + .generate(pb::GenerateRequest { + request_id: "test-wrong-model".to_string(), + model: "other-model".to_string(), + prompt: Some(pb::generate_request::Prompt::Text("hi".to_string())), + stopping: Some(pb::StoppingCriteria { + max_new_tokens: 10, + ..Default::default() + }), + ..Default::default() + }) + .await + .expect_err("should fail with wrong model"); + + assert_eq!(status.code(), tonic::Code::NotFound); + assert!(status.message().contains("other-model")); + + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn streaming_generate_rejects_wrong_model() { + let (mut client, server_task, _engine_task) = grpc_test_server( + b"engine-grpc-stream-wrong-model", + default_stream_output_specs(), + ) + .await; + + let status = client + .generate_stream(pb::GenerateRequest { + request_id: "test-stream-wrong-model".to_string(), + model: "other-model".to_string(), + prompt: Some(pb::generate_request::Prompt::Text("hi".to_string())), + stopping: Some(pb::StoppingCriteria { + max_new_tokens: 10, + ..Default::default() + }), + ..Default::default() + }) + .await + .expect_err("should fail with wrong model"); + + assert_eq!(status.code(), tonic::Code::NotFound); + assert!(status.message().contains("other-model")); + + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn unary_generate_accepts_empty_model() { + let (mut client, server_task, engine_task) = + grpc_test_server(b"engine-grpc-empty-model", default_stream_output_specs()).await; + + // Empty `model` (proto3 default) is treated as "unset" and should be accepted. + let response = client + .generate(pb::GenerateRequest { + request_id: "test-empty-model".to_string(), + model: String::new(), + prompt: Some(pb::generate_request::Prompt::Text("hi".to_string())), + stopping: Some(pb::StoppingCriteria { + max_new_tokens: 10, + ..Default::default() + }), + ..Default::default() + }) + .await + .expect("unary generate with empty model") + .into_inner(); + + let outputs = response.outputs.expect("outputs present"); + assert_eq!(outputs.text, "hi"); + + engine_task.await.expect("mock engine task"); + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn unary_generate_output_text_defaults_to_true() { + let (mut client, server_task, engine_task) = + grpc_test_server(b"engine-grpc-default-text", default_stream_output_specs()).await; + + // No response options at all — output_text should default to true. + let response = client + .generate(pb::GenerateRequest { + request_id: "test-default-text".to_string(), + model: "test-model".to_string(), + prompt: Some(pb::generate_request::Prompt::Text("x".to_string())), + stopping: Some(pb::StoppingCriteria { + max_new_tokens: 10, + ..Default::default() + }), + ..Default::default() + }) + .await + .expect("unary generate") + .into_inner(); + + let outputs = response.outputs.expect("outputs present"); + assert_eq!(outputs.text, "hi"); + + engine_task.await.expect("mock engine task"); + server_task.abort(); +} diff --git a/rust/src/server/src/lib.rs b/rust/src/server/src/lib.rs new file mode 100644 index 00000000000..2b684287ba2 --- /dev/null +++ b/rust/src/server/src/lib.rs @@ -0,0 +1,237 @@ +//! Minimal OpenAI-compatible HTTP server above [`vllm_chat`]. + +mod config; +mod error; +mod grpc; +mod listener; +mod middleware; +mod routes; +mod state; +mod utils; + +use std::sync::{Arc, OnceLock}; + +use anyhow::{Context as _, Result}; +use axum::serve::ListenerExt as _; +pub use config::{Config, CoordinatorMode, HttpListenerMode}; +use tokio::net::TcpListener; +use tokio::time::{Instant, sleep_until}; +use tokio_stream::wrappers::TcpListenerStream; +use tokio_util::either::Either; +use tokio_util::sync::CancellationToken; +use tonic::transport::Server as TonicServer; +use tracing::{info, trace, warn}; +use vllm_chat::{ChatLlm, LoadModelBackendsOptions, load_model_backends}; +pub use vllm_chat::{ChatTemplateContentFormatOption, ParserSelection, RendererSelection}; +use vllm_engine_core_client::{EngineCoreClient, EngineCoreClientConfig}; +use vllm_llm::Llm; +use vllm_text::TextLlm; + +use crate::listener::Listener; +use crate::routes::build_router; +use crate::state::AppState; + +/// Build the shared application state for one configured model and one engine +/// client. +async fn build_state(config: &Config) -> Result> { + // Load both backends from the same model metadata so they stay in sync. + let loaded = load_model_backends( + &config.model, + LoadModelBackendsOptions { + renderer: config.renderer, + chat_template: config.chat_template.clone(), + chat_template_content_format: config.chat_template_content_format, + default_chat_template_kwargs: config + .default_chat_template_kwargs + .clone() + .unwrap_or_default(), + }, + ) + .await + .context("failed to create chat/text backends")?; + let text_backend = loaded.text_backend; + let chat_backend = loaded.chat_backend; + + let coordinator_mode = config.effective_coordinator_mode(text_backend.is_moe()); + info!( + engine_count = config.engine_count(), + model_is_moe = text_backend.is_moe(), + ?coordinator_mode, + "resolved coordinator mode" + ); + + let client = EngineCoreClient::connect(EngineCoreClientConfig { + transport_mode: config.transport_mode.clone(), + coordinator_mode, + model_name: config.model.clone(), + client_index: 0, + }) + .await + .context("failed to connect to engine core")?; + + let llm = Llm::new(client).with_log_stats(!config.disable_log_stats); + let text = TextLlm::new(llm, text_backend); + + let chat = ChatLlm::new(text, chat_backend) + .with_tool_call_parser(config.tool_call_parser.clone()) + .with_reasoning_parser(config.reasoning_parser.clone()); + + // If no served names are specified, fall back to the backend model path so + // that the API always has at least one valid model ID. + let served_model_names = if config.served_model_name.is_empty() { + vec![config.model.clone()] + } else { + config.served_model_name.clone() + }; + + Ok(Arc::new( + AppState::new(served_model_names, chat).with_log_requests(config.enable_log_requests), + )) +} + +/// Run the OpenAI-compatible HTTP server until the supplied shutdown token is +/// cancelled. +/// +/// The server owns one `vllm-chat` facade, which in turn owns the lower +/// `vllm-text` and `vllm-llm` layers, and shuts them down before returning. +pub async fn serve(config: Config, shutdown: CancellationToken) -> Result<()> { + config.validate().context("invalid OpenAI frontend configuration")?; + + // Also check shutdown during the (potentially long) startup handshake. + let state = tokio::select! { + result = build_state(&config) => result?, + _ = shutdown.cancelled() => return Ok(()), + }; + let listener = Listener::bind(&config.listener_mode) + .await + .context("failed to bind listener for OpenAI server")?; + let bind_address = listener.local_addr()?; + let model = state.primary_model_name().to_owned(); + let app = build_router(state.clone()); + + // Optionally bind the gRPC Generate server on a separate port. Bind + // synchronously here so bind errors (port in use, permission denied, ...) + // surface before we start serving, rather than being deferred until + // shutdown. The gRPC listener follows the same host as the HTTP listener so + // that enabling --grpc-port does not accidentally expose the service on all + // interfaces when HTTP is intentionally local-only. + let grpc_setup = if let Some(grpc_port) = config.grpc_port { + let grpc_host = match &config.listener_mode { + HttpListenerMode::BindTcp { host, .. } => host.as_str(), + HttpListenerMode::BindUnix { .. } | HttpListenerMode::InheritedFd { .. } => "0.0.0.0", + }; + let grpc_listener = TcpListener::bind((grpc_host, grpc_port)) + .await + .with_context(|| format!("failed to bind gRPC listener on {grpc_host}:{grpc_port}"))?; + let addr = grpc_listener.local_addr()?; + let svc = grpc::GenerateServer::new(grpc::GenerateServiceImpl::new(state.clone())); + info!(%addr, "starting gRPC server"); + Some((grpc_listener, svc)) + } else { + None + }; + + info!(%bind_address, %model, "starting OpenAI server"); + + // Set TCP_NODELAY on accepted connections to reduce latency. + // By `tap_io` we will do this on every accepted connection. + let listener = listener.tap_io(|io| { + if let Either::Left(tcp_stream) = io + && let Err(err) = tcp_stream.set_nodelay(true) + { + trace!(error = %err, "failed to enable TCP_NODELAY on accepted HTTP connection"); + } + }); + + // Run HTTP and gRPC concurrently under a child token of the caller's shutdown + // token. Caller cancellation propagates into both protocols; if either + // protocol exits first, we cancel this child token so its sibling also + // begins a graceful drain. + let server_shutdown = shutdown.child_token(); + let force_shutdown = CancellationToken::new(); + let shutdown_deadline = Arc::new(OnceLock::new()); + + // Spawn a task to trigger `force_shutdown` after shutdown deadline elapses. + tokio::spawn({ + let shutdown = server_shutdown.clone(); + let force_shutdown = force_shutdown.clone(); + let shutdown_deadline = shutdown_deadline.clone(); + let shutdown_timeout = config.shutdown_timeout; + + async move { + shutdown.cancelled().await; + let deadline = Instant::now() + shutdown_timeout; + let _ = shutdown_deadline.set(deadline); + + if shutdown_timeout.is_zero() { + force_shutdown.cancel(); + } else { + sleep_until(deadline).await; + force_shutdown.cancel(); + } + } + }); + + let http_fut = { + let shutdown = server_shutdown.child_token(); + let server_shutdown = server_shutdown.clone(); + let force_shutdown = force_shutdown.clone(); + async move { + let server = + axum::serve(listener, app).with_graceful_shutdown(shutdown.cancelled_owned()); + + let result = tokio::select! { + result = server => { + result.context("HTTP server failed") + } + _ = force_shutdown.cancelled() => { + warn!("HTTP graceful shutdown deadline elapsed; aborting server"); + Ok(()) + } + }; + + server_shutdown.cancel(); + result + } + }; + + let grpc_fut = { + let shutdown = server_shutdown.child_token(); + let server_shutdown = server_shutdown.clone(); + let force_shutdown = force_shutdown.clone(); + async move { + let Some((grpc_listener, svc)) = grpc_setup else { + // No gRPC configured: just wait for shutdown so we do not race the + // join! by resolving early and tripping the cancellation token. + shutdown.cancelled().await; + return Ok(()); + }; + let server = TonicServer::builder().add_service(svc).serve_with_incoming_shutdown( + TcpListenerStream::new(grpc_listener), + shutdown.cancelled_owned(), + ); + + let result = tokio::select! { + result = server => { + result.context("gRPC server failed") + } + _ = force_shutdown.cancelled() => { + warn!("gRPC graceful shutdown deadline elapsed; aborting server"); + Ok(()) + } + }; + + server_shutdown.cancel(); + result + } + }; + + let (http_res, grpc_res) = tokio::join!(http_fut, grpc_fut); + http_res.and(grpc_res)?; + + let shutdown_deadline = shutdown_deadline + .get() + .copied() + .unwrap_or_else(|| Instant::now() + config.shutdown_timeout); + state.shutdown(shutdown_deadline).await +} diff --git a/rust/src/server/src/listener.rs b/rust/src/server/src/listener.rs new file mode 100644 index 00000000000..b7b715b0ebd --- /dev/null +++ b/rust/src/server/src/listener.rs @@ -0,0 +1,135 @@ +//! Unified HTTP listener wrapper for the Rust frontend. +//! +//! This module hides the difference between TCP and Unix-domain listeners so +//! the rest of the server can bind or inherit one socket and pass it to +//! `axum::serve(...)` through a single type. + +use std::io::Result; +use std::net::TcpListener as StdTcpListener; +use std::os::fd::{FromRawFd, IntoRawFd, OwnedFd}; +use std::os::unix::net::UnixListener as StdUnixListener; + +use socket2::Socket; +use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; +use tokio_util::either::Either; + +use crate::HttpListenerMode; + +/// Runtime listener type used by the OpenAI-compatible HTTP server, which is +/// either a TCP listener or a Unix-domain listener. +#[derive(Debug)] +pub enum Listener { + Tcp(TcpListener), + Unix(UnixListener), +} + +impl Listener { + /// Bind or adopt the listener described by the frontend configuration. + /// + /// For inherited sockets, the concrete listener kind is detected from the + /// socket family of the supplied file descriptor. + pub async fn bind(mode: &HttpListenerMode) -> Result { + match mode { + HttpListenerMode::BindTcp { host, port } => { + Ok(Self::Tcp(TcpListener::bind((host.as_str(), *port)).await?)) + } + HttpListenerMode::BindUnix { path } => Ok(Self::Unix(UnixListener::bind(path)?)), + HttpListenerMode::InheritedFd { fd } => Self::from_inherited_fd(*fd), + } + } + + /// Return a log-friendly local address string for either TCP or Unix + /// sockets. + pub fn local_addr(&self) -> Result { + match self { + Self::Tcp(listener) => Ok(listener.local_addr()?.to_string()), + Self::Unix(listener) => Ok(match listener.local_addr()?.as_pathname() { + Some(path) => format!("unix:{}", path.display()), + None => "unix:".to_string(), + }), + } + } + + fn from_inherited_fd(fd: i32) -> Result { + // SAFETY: We trust the caller to only pass valid listener fds, and we only use + // this fd once to create a single listener. + let owned_fd = unsafe { OwnedFd::from_raw_fd(fd) }; + let socket = Socket::from(owned_fd); + + // The Python supervisor pre-binds the socket to reserve the endpoint early, but + // Rust is responsible for transitioning inherited stream sockets into + // the listening state before accepting connections. + socket.listen(libc::SOMAXCONN)?; + socket.set_nonblocking(true)?; + + if socket.local_addr()?.is_unix() { + let std_listener = unsafe { StdUnixListener::from_raw_fd(socket.into_raw_fd()) }; + Ok(Self::Unix(UnixListener::from_std(std_listener)?)) + } else { + let std_listener = unsafe { StdTcpListener::from_raw_fd(socket.into_raw_fd()) }; + Ok(Self::Tcp(TcpListener::from_std(std_listener)?)) + } + } +} + +/// Allow the unified listener to plug directly into `axum::serve(...)`. +impl axum::serve::Listener for Listener { + type Addr = Either; + type Io = Either; + + async fn accept(&mut self) -> (Self::Io, Self::Addr) { + match self { + Self::Tcp(listener) => { + let (io, addr) = listener.accept().await; + (Either::Left(io), Either::Left(addr)) + } + Self::Unix(listener) => { + let (io, addr) = listener.accept().await; + (Either::Right(io), Either::Right(addr)) + } + } + } + + fn local_addr(&self) -> Result { + match self { + Self::Tcp(listener) => listener.local_addr().map(Either::Left), + Self::Unix(listener) => listener.local_addr().map(Either::Right), + } + } +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, SocketAddrV4}; + use std::os::fd::IntoRawFd; + + use socket2::{Domain, SockAddr, Socket, Type}; + use uuid::Uuid; + + use super::Listener; + use crate::HttpListenerMode; + + #[tokio::test(flavor = "current_thread")] + async fn inherited_fd_detects_tcp_listener_without_uds_hint() { + let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); + socket.bind(&SockAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))).unwrap(); + let fd = socket.into_raw_fd(); + + let listener = Listener::bind(&HttpListenerMode::InheritedFd { fd }).await.unwrap(); + + assert!(matches!(listener, Listener::Tcp(_))); + } + + #[tokio::test(flavor = "current_thread")] + async fn inherited_fd_detects_unix_listener_from_fd() { + let path = std::env::temp_dir().join(format!("vllm-rs-{}.sock", Uuid::new_v4())); + let socket = Socket::new(Domain::UNIX, Type::STREAM, None).unwrap(); + socket.bind(&SockAddr::unix(&path).unwrap()).unwrap(); + let fd = socket.into_raw_fd(); + + let listener = Listener::bind(&HttpListenerMode::InheritedFd { fd }).await.unwrap(); + + assert!(matches!(listener, Listener::Unix(_))); + let _ = std::fs::remove_file(path); + } +} diff --git a/rust/src/server/src/middleware/load.rs b/rust/src/server/src/middleware/load.rs new file mode 100644 index 00000000000..d03b36fb5fe --- /dev/null +++ b/rust/src/server/src/middleware/load.rs @@ -0,0 +1,110 @@ +use std::pin::Pin; +use std::sync::{Arc, Weak}; +use std::task::{Context, Poll}; + +use axum::body::{Body, Bytes, HttpBody}; +use axum::extract::{MatchedPath, Request, State}; +use axum::middleware::Next; +use axum::response::Response; +use http_body::{Frame, SizeHint}; + +use crate::state::AppState; + +/// Endpoints that will be tracked for server load. +/// +/// Derived from the Python frontend's actual `@load_aware_call` coverage. This +/// includes alias paths that delegate into decorated handlers, such as +/// `/v1/rerank` and `/v2/rerank`. +const TRACKED_HANDLERS: &[&str] = &[ + "/v1/responses", + "/v1/responses/{response_id}", + "/v1/responses/{response_id}/cancel", + "/v1/messages", + "/v1/messages/count_tokens", + "/v1/chat/completions", + "/v1/completions", + "/v1/audio/transcriptions", + "/v1/audio/translations", + "/v1/embeddings", + "/pooling", + "/classify", + "/score", + "/v1/score", + "/rerank", + "/v1/rerank", + "/v2/rerank", + "/inference/v1/generate", +]; + +/// Track frontend-local in-flight inference requests for the `/load` endpoint. +pub async fn track_server_load( + State(state): State>, + req: Request, + next: Next, +) -> Response { + let handler = req + .extensions() + .get::() + .map_or_else(|| "none", |path| path.as_str()); + + if !TRACKED_HANDLERS.contains(&handler) { + return next.run(req).await; + } + + state.increment_server_load(); + let guard = ServerLoadGuard { + state: Arc::downgrade(&state), + }; + let response = next.run(req).await; + + let (parts, body) = response.into_parts(); + Response::from_parts( + parts, + Body::new(LoadTrackedBody { + inner: body, + _guard: guard, + }), + ) +} + +/// A guard that decrements the server load when dropped. +struct ServerLoadGuard { + state: Weak, +} + +impl Drop for ServerLoadGuard { + fn drop(&mut self) { + if let Some(state) = self.state.upgrade() { + state.decrement_server_load(); + } + } +} + +/// A wrapper around response bodies that tracks server load by holding a +/// `ServerLoadGuard`, which will decrement the load when the body is fully +/// consumed and dropped. +struct LoadTrackedBody { + inner: Body, + _guard: ServerLoadGuard, +} + +// Simply delegate all `HttpBody` methods to the inner body. +impl HttpBody for LoadTrackedBody { + type Data = Bytes; + type Error = axum::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Pin::new(&mut self.inner).poll_frame(cx) + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } +} diff --git a/rust/src/server/src/middleware/metrics.rs b/rust/src/server/src/middleware/metrics.rs new file mode 100644 index 00000000000..366d73dd3dd --- /dev/null +++ b/rust/src/server/src/middleware/metrics.rs @@ -0,0 +1,79 @@ +use std::time::Instant; + +use axum::extract::{MatchedPath, Request}; +use axum::middleware::Next; +use axum::response::Response; +use vllm_metrics::{HttpHandlerLabels, HttpRequestLabels, METRICS}; + +/// Endpoints that will be excluded from HTTP metrics tracking. +/// +/// Original Python definition: +/// +const EXCLUDED_HANDLERS: &[&str] = &[ + "/metrics", + "/health", + "/load", + "/ping", + "/version", + "/server_info", + // Rust frontend extra: + "/reset_prefix_cache", + "/reset_mm_cache", + "/reset_encoder_cache", + "/collective_rpc", + "/sleep", + "/wake_up", + "/is_sleeping", +]; + +/// Record API-server HTTP metrics with Python-compatible +/// (`PrometheusFastApiInstrumentator` style) family names and labels. +pub async fn track_http_metrics(req: Request, next: Next) -> Response { + let method = req.method().as_str().to_string(); + let handler = req + .extensions() + .get::() + .map_or_else(|| "none".to_string(), |path| path.as_str().to_string()); + let excluded = EXCLUDED_HANDLERS.contains(&handler.as_str()); + let started_at = Instant::now(); + + let response = next.run(req).await; + + if excluded { + return response; + } + + let elapsed = started_at.elapsed().as_secs_f64(); + let status = status_group(response.status().as_u16()); + + let metrics = &METRICS.api_server; + + metrics + .http_requests + .get_or_create(&HttpRequestLabels { + method: method.clone(), + status, + handler: handler.clone(), + }) + .inc(); + + metrics + .http_request_duration_seconds + .get_or_create(&HttpHandlerLabels { method, handler }) + .observe(elapsed); + + metrics.http_request_duration_highr_seconds.observe(elapsed); + + response +} + +fn status_group(status: u16) -> &'static str { + match status / 100 { + 1 => "1xx", + 2 => "2xx", + 3 => "3xx", + 4 => "4xx", + 5 => "5xx", + _ => "unknown", + } +} diff --git a/rust/src/server/src/middleware/mod.rs b/rust/src/server/src/middleware/mod.rs new file mode 100644 index 00000000000..acb3dd1fdb7 --- /dev/null +++ b/rust/src/server/src/middleware/mod.rs @@ -0,0 +1,5 @@ +mod load; +mod metrics; + +pub use load::track_server_load; +pub use metrics::track_http_metrics; diff --git a/rust/src/server/src/routes.rs b/rust/src/server/src/routes.rs new file mode 100644 index 00000000000..ccf90db9aa8 --- /dev/null +++ b/rust/src/server/src/routes.rs @@ -0,0 +1,68 @@ +mod cache; +mod collective_rpc; +mod health; +mod inference; +mod load; +mod metrics; +pub(crate) mod openai; +mod sleep; + +use std::sync::Arc; + +use axum::Router; +use axum::middleware::{from_fn, from_fn_with_state}; +use axum::routing::{get, post}; +use tower_http::trace::TraceLayer; + +use crate::middleware; +use crate::state::AppState; + +fn server_dev_mode_enabled() -> bool { + std::env::var("VLLM_SERVER_DEV_MODE") + .ok() + .and_then(|value| value.parse::().ok()) + .is_some_and(|value| value != 0) +} + +/// Build the minimal OpenAI-compatible router for one configured model. +pub fn build_router(state: Arc) -> Router { + build_router_with_dev_mode(state, server_dev_mode_enabled()) +} + +fn build_router_with_dev_mode(state: Arc, dev_mode_enabled: bool) -> Router { + let mut router = Router::new() + // Health & monitoring + .route("/health", get(health::health)) + .route("/metrics", get(metrics::scrape)) + .route("/load", get(load::load)) + // OpenAI-compatible endpoints + .route("/v1/models", get(openai::list_models)) + .route("/v1/completions", post(openai::completions)) + .route("/v1/chat/completions", post(openai::chat_completions)) + // vLLM specific inference endpoints + .route("/inference/v1/generate", post(inference::generate)); + + if dev_mode_enabled { + // Development-only + router = router + .route("/reset_prefix_cache", post(cache::reset_prefix_cache)) + .route("/reset_mm_cache", post(cache::reset_mm_cache)) + .route("/reset_encoder_cache", post(cache::reset_encoder_cache)) + .route("/collective_rpc", post(collective_rpc::collective_rpc)) + .route("/sleep", post(sleep::sleep)) + .route("/wake_up", post(sleep::wake_up)) + .route("/is_sleeping", get(sleep::is_sleeping)) + } + + router + .with_state(state.clone()) + .layer(from_fn_with_state(state, middleware::track_server_load)) + .layer(from_fn(middleware::track_http_metrics)) + .layer(TraceLayer::new_for_http()) +} + +#[cfg(test)] +mod tests; + +#[cfg(test)] +mod http_client_tests; diff --git a/rust/src/server/src/routes/cache.rs b/rust/src/server/src/routes/cache.rs new file mode 100644 index 00000000000..580b91d4131 --- /dev/null +++ b/rust/src/server/src/routes/cache.rs @@ -0,0 +1,56 @@ +use std::sync::Arc; + +use axum::extract::{Query, State}; +use axum::http::StatusCode; +use serde::Deserialize; + +use crate::error::ApiError; +use crate::state::AppState; +use crate::utils::utility_call_error; + +#[derive(Debug, Default, Deserialize)] +pub(crate) struct ResetPrefixCacheParams { + #[serde(default)] + reset_running_requests: bool, + #[serde(default)] + reset_external: bool, +} + +/// Reset the local prefix cache and optionally the connector-managed external +/// cache. +pub async fn reset_prefix_cache( + State(state): State>, + Query(params): Query, +) -> Result { + state + .engine_core_client() + .reset_prefix_cache(params.reset_running_requests, params.reset_external) + .await + .map_err(|error| utility_call_error("reset_prefix_cache", error))?; + + Ok(StatusCode::OK) +} + +/// Reset the multi-modal cache. +pub async fn reset_mm_cache(State(state): State>) -> Result { + state + .engine_core_client() + .reset_mm_cache() + .await + .map_err(|error| utility_call_error("reset_mm_cache", error))?; + + Ok(StatusCode::OK) +} + +/// Reset the encoder cache. +pub async fn reset_encoder_cache( + State(state): State>, +) -> Result { + state + .engine_core_client() + .reset_encoder_cache() + .await + .map_err(|error| utility_call_error("reset_encoder_cache", error))?; + + Ok(StatusCode::OK) +} diff --git a/rust/src/server/src/routes/collective_rpc.rs b/rust/src/server/src/routes/collective_rpc.rs new file mode 100644 index 00000000000..965797f3019 --- /dev/null +++ b/rust/src/server/src/routes/collective_rpc.rs @@ -0,0 +1,51 @@ +use std::collections::BTreeMap; +use std::sync::Arc; + +use axum::Json; +use axum::extract::State; +use axum::extract::rejection::JsonRejection; +use rmpv::Value as MsgpackValue; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; + +use crate::error::ApiError; +use crate::state::AppState; +use crate::utils::utility_call_error; + +#[derive(Debug, Deserialize)] +pub(crate) struct CollectiveRpcRequest { + method: Option, + #[serde(default)] + timeout: Option, + #[serde(default)] + args: Vec, + #[serde(default)] + kwargs: BTreeMap, +} + +#[derive(Debug, Serialize)] +pub(crate) struct CollectiveRpcResponse { + results: Vec, +} + +/// Execute a development-only collective RPC on the connected engine(s). +pub async fn collective_rpc( + State(state): State>, + body: Result, JsonRejection>, +) -> Result, ApiError> { + let Json(body) = body.map_err(|error| ApiError::json_parse_error(error.body_text()))?; + let method = body.method.ok_or_else(|| { + ApiError::invalid_request( + "Missing 'method' in request body".to_string(), + Some("method"), + ) + })?; + + let results = state + .engine_core_client() + .collective_rpc(&method, body.timeout, body.args, body.kwargs) + .await + .map_err(|error| utility_call_error("collective_rpc", error))?; + + Ok(Json(CollectiveRpcResponse { results })) +} diff --git a/rust/src/server/src/routes/health.rs b/rust/src/server/src/routes/health.rs new file mode 100644 index 00000000000..ff91f857007 --- /dev/null +++ b/rust/src/server/src/routes/health.rs @@ -0,0 +1,14 @@ +use std::sync::Arc; + +use axum::extract::State; +use axum::http::StatusCode; + +use crate::state::AppState; + +pub async fn health(State(state): State>) -> StatusCode { + if state.chat.engine_core_client().is_healthy() { + StatusCode::OK + } else { + StatusCode::SERVICE_UNAVAILABLE + } +} diff --git a/rust/src/server/src/routes/http_client_tests.rs b/rust/src/server/src/routes/http_client_tests.rs new file mode 100644 index 00000000000..8055ada9794 --- /dev/null +++ b/rust/src/server/src/routes/http_client_tests.rs @@ -0,0 +1,392 @@ +//! Integration tests that exercise the OpenAI-compatible HTTP API through a +//! real TCP connection using the `async-openai` client library, backed by a +//! mock engine. + +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_openai::Client; +use async_openai::config::OpenAIConfig; +use async_openai::types::chat::{ + ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, +}; +use futures::StreamExt as _; +use serial_test::serial; +use vllm_chat::{ + ChatBackend, ChatLlm, ChatRenderer, ChatRequest, ChatTextBackend, DefaultChatOutputProcessor, + DynChatOutputProcessor, DynChatRenderer, NewChatOutputProcessorOptions, RenderedPrompt, +}; +use vllm_engine_core_client::protocol::{ + EngineCoreFinishReason, EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, +}; +use vllm_engine_core_client::test_utils::{IpcNamespace, spawn_mock_engine_task}; +use vllm_engine_core_client::{EngineCoreClient, EngineCoreClientConfig, EngineId}; +use vllm_llm::Llm; +use vllm_text::tokenizer::{DynTokenizer, Tokenizer}; +use vllm_text::{Prompt, TextBackend}; +use zeromq::prelude::{SocketRecv, SocketSend}; +use zeromq::{DealerSocket, PushSocket, ZmqMessage}; + +use crate::routes::build_router; +use crate::state::AppState; + +// ======================================================================================== +// Test infrastructure (mirrors routes/tests.rs helpers) +// ======================================================================================== + +type TestFuture<'a> = Pin + Send + 'a>>; + +fn boxed_test_future<'a>(future: impl Future + Send + 'a) -> TestFuture<'a> { + Box::pin(future) +} + +struct MockEngineTask { + shutdown_tx: Option>, + join_handle: Option>, +} + +impl MockEngineTask { + fn new( + (shutdown_tx, join_handle): ( + tokio::sync::oneshot::Sender<()>, + tokio::task::JoinHandle<()>, + ), + ) -> Self { + Self { + shutdown_tx: Some(shutdown_tx), + join_handle: Some(join_handle), + } + } +} + +impl Future for MockEngineTask { + type Output = Result<(), tokio::task::JoinError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + match self.join_handle.as_mut() { + Some(join_handle) => Pin::new(join_handle).poll(cx), + None => Poll::Ready(Ok(())), + } + } +} + +impl Drop for MockEngineTask { + fn drop(&mut self) { + if let Some(join_handle) = &self.join_handle { + join_handle.abort(); + } + } +} + +fn request_output( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason, + stop_reason: None, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn engine_outputs_for_request( + request_id: &str, + output_specs: Vec<(Vec, Option)>, +) -> EngineCoreOutputs { + EngineCoreOutputs { + engine_index: 0, + outputs: output_specs + .into_iter() + .map(|(token_ids, finish_reason)| request_output(request_id, token_ids, finish_reason)) + .collect(), + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + } +} + +fn default_stream_output_specs() -> Vec<(Vec, Option)> { + vec![ + (vec![b'h' as u32], None), + (vec![b'i' as u32], None), + (vec![b'!' as u32], Some(EngineCoreFinishReason::Stop)), + ] +} + +async fn send_outputs(push: &mut PushSocket, outputs: EngineCoreOutputs) { + push.send(ZmqMessage::from( + rmp_serde::to_vec_named(&outputs).expect("encode outputs"), + )) + .await + .expect("send outputs"); +} + +async fn recv_engine_message(dealer: &mut DealerSocket) -> Vec { + dealer.recv().await.expect("recv engine message").into_vec() +} + +fn test_llm(client: EngineCoreClient) -> Llm { + Llm::new(client).with_request_id_randomization(false) +} + +#[derive(Clone, Debug)] +struct FakeChatBackend; + +#[derive(Debug)] +struct FakeChatTokenizer; + +impl Tokenizer for FakeChatTokenizer { + fn encode( + &self, + text: &str, + _add_special_tokens: bool, + ) -> vllm_text::tokenizer::Result> { + Ok(text.bytes().map(u32::from).collect()) + } + + fn decode( + &self, + token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_text::tokenizer::Result { + Ok( + String::from_utf8_lossy(&token_ids.iter().map(|id| *id as u8).collect::>()) + .into_owned(), + ) + } + + fn token_to_id(&self, token: &str) -> Option { + token.bytes().next().map(u32::from) + } +} + +impl TextBackend for FakeChatBackend { + fn tokenizer(&self) -> DynTokenizer { + Arc::new(FakeChatTokenizer) + } + + fn model_id(&self) -> &str { + "test-model" + } +} + +impl ChatBackend for FakeChatBackend { + fn chat_renderer(&self) -> DynChatRenderer { + Arc::new(self.clone()) + } + + fn new_chat_output_processor( + &self, + request: &mut ChatRequest, + options: NewChatOutputProcessorOptions<'_>, + ) -> vllm_chat::Result { + Ok(Box::new(DefaultChatOutputProcessor::new( + request, + self.model_id(), + self.tokenizer(), + options.tool_call_parser, + options.reasoning_parser, + )?)) + } +} + +impl ChatRenderer for FakeChatBackend { + fn render(&self, request: &ChatRequest) -> vllm_chat::Result { + let mut prompt = String::new(); + for message in &request.messages { + prompt.push_str(message.role().as_str()); + prompt.push_str(": "); + prompt.push_str(&message.text_content()?); + prompt.push('\n'); + } + if request.chat_options.add_generation_prompt() { + prompt.push_str("assistant:"); + } + Ok(RenderedPrompt { + prompt: Prompt::Text(prompt), + }) + } +} + +/// Spin up an HTTP server on a random port backed by a mock engine. +/// Returns the `async-openai` client, the HTTP server task, and the mock engine +/// task. +async fn http_test_server( + engine_id: impl Into, + output_specs: Vec<(Vec, Option)>, +) -> ( + Client, + tokio::task::JoinHandle<()>, + MockEngineTask, +) { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = engine_id.into(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + move |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + send_outputs( + push, + engine_outputs_for_request(&request.request_id, output_specs), + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + + let chat = ChatLlm::from_shared_backend( + test_llm(client), + Arc::new(FakeChatBackend) as Arc, + ); + let state = Arc::new(AppState::new(vec!["test-model".to_string()], chat)); + let app = build_router(state); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.expect("bind http listener"); + let addr = listener.local_addr().expect("local addr"); + + let server_task = tokio::spawn(async move { + axum::serve(listener, app).await.expect("http server"); + }); + + let openai_client = Client::with_config( + OpenAIConfig::new() + .with_api_key("unused") + .with_api_base(format!("http://{addr}/v1")), + ); + + (openai_client, server_task, engine_task) +} + +// ======================================================================================== +// Tests +// ======================================================================================== + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn list_models_via_http_client() { + let (client, server_task, _engine_task) = + http_test_server(b"engine-http-models", default_stream_output_specs()).await; + + let models = client.models().list().await.expect("list models"); + let model_ids: Vec<&str> = models.data.iter().map(|m| m.id.as_str()).collect(); + assert_eq!(model_ids, vec!["test-model"]); + + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_streaming_chat_via_http_client() { + let (client, server_task, engine_task) = + http_test_server(b"engine-http-chat", default_stream_output_specs()).await; + + let request = CreateChatCompletionRequestArgs::default() + .model("test-model") + .stream(false) + .max_completion_tokens(10u32) + .messages([ChatCompletionRequestUserMessageArgs::default() + .content("hello") + .build() + .expect("build user message") + .into()]) + .build() + .expect("build request"); + + let response = client.chat().create(request).await.expect("chat completion"); + + assert_eq!(response.model, "test-model"); + assert_eq!(response.choices.len(), 1); + let choice = &response.choices[0]; + // The stop token `!` is suppressed from text. + assert_eq!(choice.message.content.as_deref(), Some("hi")); + assert_eq!( + choice.finish_reason, + Some(async_openai::types::chat::FinishReason::Stop) + ); + + engine_task.await.expect("mock engine task"); + server_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn streaming_chat_via_http_client() { + let (client, server_task, engine_task) = + http_test_server(b"engine-http-stream", default_stream_output_specs()).await; + + let request = CreateChatCompletionRequestArgs::default() + .model("test-model") + .stream(true) + .max_completion_tokens(10u32) + .messages([ChatCompletionRequestUserMessageArgs::default() + .content("hello") + .build() + .expect("build user message") + .into()]) + .build() + .expect("build request"); + + let mut stream = client.chat().create_stream(request).await.expect("streaming chat completion"); + + let mut full_text = String::new(); + let mut saw_role = false; + let mut saw_finish_reason = false; + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk"); + for choice in &chunk.choices { + if choice.delta.role.is_some() { + saw_role = true; + } + if let Some(ref delta) = choice.delta.content { + full_text.push_str(delta); + } + if choice.finish_reason.is_some() { + saw_finish_reason = true; + } + } + } + + assert!(saw_role, "expected an assistant role chunk"); + assert!(saw_finish_reason, "expected a terminal finish reason"); + assert_eq!(full_text, "hi"); + + engine_task.await.expect("mock engine task"); + server_task.abort(); +} diff --git a/rust/src/server/src/routes/inference/generate.rs b/rust/src/server/src/routes/inference/generate.rs new file mode 100644 index 00000000000..ff7ea3c6302 --- /dev/null +++ b/rust/src/server/src/routes/inference/generate.rs @@ -0,0 +1,215 @@ +mod convert; +mod types; +mod validate; + +use std::collections::HashMap; +use std::sync::Arc; + +use axum::Json; +use axum::extract::State; +use axum::http::HeaderMap; +use axum::response::{IntoResponse, Response}; +use thiserror_ext::AsReport as _; +use tracing::info; +use tracing_futures::Instrument as _; +use vllm_engine_core_client::protocol::logprobs::{Logprobs, PositionLogprobs}; +use vllm_llm::{CollectedGenerateOutput, GenerateOutputStreamExt as _}; + +use self::convert::prepare_generate_request; +use self::types::{GenerateLogprob, GenerateRequest, GenerateResponse, GenerateResponseChoice}; +use crate::error::{ApiError, server_error}; +use crate::routes::openai::utils::logprobs::clamp_logprob; +use crate::routes::openai::utils::types::{ChatLogProbs, ChatLogProbsContent, TopLogProb}; +use crate::routes::openai::utils::validated_json::ValidatedJson; +use crate::state::AppState; +use crate::utils::resolve_request_context; + +/// Validate one token-in/token-out request and proxy it into the shared +/// `vllm-text` stack. +pub async fn generate( + State(state): State>, + headers: HeaderMap, + ValidatedJson(body): ValidatedJson, +) -> Response { + let request_context = resolve_request_context(&headers, body.request_id.as_deref()); + let prepared = match prepare_generate_request(body, state.served_model_names(), request_context) + { + Ok(prepared) => prepared, + Err(error) => return error.into_response(), + }; + let request_span = tracing::info_span!( + "generate", + request_id = %prepared.request_id, + engine_request_id = tracing::field::Empty, + ); + + let log_request = state.enable_log_requests; + let include_logprobs = prepared.include_logprobs; + let include_prompt_logprobs = prepared.include_prompt_logprobs; + + let raw_stream = match state + .chat + .text() + .generate_raw(prepared.text_request) + .instrument(request_span.clone()) + .await + { + Ok(stream) => stream, + Err(error) => { + return server_error!( + "failed to submit raw generate request: {}", + error.to_report_string() + ) + .into_response(); + } + }; + + let collected = match raw_stream.collect_output().instrument(request_span.clone()).await { + Ok(collected) => collected, + Err(error) => { + return server_error!( + "failed to collect raw generate response: {}", + error.to_report_string() + ) + .into_response(); + } + }; + + if log_request { + info!( + parent: &request_span, + prompt_tokens = collected.prompt_token_ids.len(), + output_tokens = collected.token_ids.len(), + finish_reason = collected.finish_reason.as_str(), + "generate finished" + ); + } + + let response = match collect_generate( + collected, + prepared.request_id, + include_logprobs, + include_prompt_logprobs, + ) { + Ok(response) => response, + Err(error) => return error.into_response(), + }; + + Json(response).into_response() +} + +fn collect_generate( + collected: CollectedGenerateOutput, + request_id: String, + include_logprobs: bool, + include_prompt_logprobs: bool, +) -> Result { + let logprobs = if include_logprobs { + let logprobs = collected.logprobs.as_ref().ok_or_else(|| { + ApiError::server_error( + "raw generate response requested logprobs but generation returned none".to_string(), + ) + })?; + Some(raw_logprobs_to_openai_chat(logprobs)?) + } else { + None + }; + let prompt_logprobs = if include_prompt_logprobs { + let prompt_logprobs = collected.prompt_logprobs.as_ref().ok_or_else(|| { + ApiError::server_error( + "raw generate response requested prompt_logprobs but generation returned none" + .to_string(), + ) + })?; + Some(raw_prompt_logprobs_to_maps(prompt_logprobs)) + } else { + None + }; + + Ok(GenerateResponse { + request_id, + choices: vec![GenerateResponseChoice { + index: 0, + logprobs, + finish_reason: Some(collected.finish_reason.as_str().to_string()), + token_ids: collected.token_ids, + }], + prompt_logprobs, + kv_transfer_params: collected.kv_transfer_params, + }) +} + +fn raw_logprobs_to_openai_chat(logprobs: &Logprobs) -> Result { + let content = logprobs + .positions + .iter() + .map(position_to_chat_logprobs_content) + .collect::, _>>()?; + + Ok(ChatLogProbs { + content: Some(content), + }) +} + +fn raw_prompt_logprobs_to_maps( + prompt_logprobs: &Logprobs, +) -> Vec>> { + std::iter::once(None) + .chain( + prompt_logprobs + .positions + .iter() + .map(|position| Some(position_to_logprob_map(position))), + ) + .collect() +} + +fn position_to_chat_logprobs_content( + position: &PositionLogprobs, +) -> Result { + let chosen = position.entries.first().ok_or_else(|| { + ApiError::server_error( + "raw generate logprobs position unexpectedly had no token candidates".to_string(), + ) + })?; + let token = format_token_id(chosen.token_id); + + Ok(ChatLogProbsContent { + token: token.clone(), + logprob: clamp_logprob(chosen.logprob), + bytes: Some(token.as_bytes().to_vec()), + top_logprobs: position + .entries + .iter() + .map(|entry| { + let token = format_token_id(entry.token_id); + TopLogProb { + token: token.clone(), + logprob: clamp_logprob(entry.logprob), + bytes: Some(token.into_bytes()), + } + }) + .collect(), + }) +} + +fn position_to_logprob_map(position: &PositionLogprobs) -> HashMap { + position + .entries + .iter() + .map(|entry| { + ( + entry.token_id, + GenerateLogprob { + logprob: clamp_logprob(entry.logprob), + rank: Some(entry.rank), + decoded_token: Some(format_token_id(entry.token_id)), + }, + ) + }) + .collect() +} + +fn format_token_id(token_id: u32) -> String { + format!("token_id:{token_id}") +} diff --git a/rust/src/server/src/routes/inference/generate/convert.rs b/rust/src/server/src/routes/inference/generate/convert.rs new file mode 100644 index 00000000000..70374c2f1eb --- /dev/null +++ b/rust/src/server/src/routes/inference/generate/convert.rs @@ -0,0 +1,112 @@ +use vllm_text::{Prompt, TextDecodeOptions, TextRequest}; + +use super::types::GenerateRequest; +use super::validate; +use crate::error::ApiError; +use crate::utils::{ResolvedRequestContext, merge_kv_transfer_params}; + +/// Lowered generate request plus the response request ID. +#[derive(Debug, Clone, PartialEq)] +pub struct PreparedRequest { + pub request_id: String, + pub text_request: TextRequest, + pub include_logprobs: bool, + pub include_prompt_logprobs: bool, +} + +/// Validate and lower one raw generate request into the internal +/// text-generation format. +pub fn prepare_generate_request( + request: GenerateRequest, + served_model_names: &[String], + ctx: ResolvedRequestContext, +) -> Result { + validate::validate_request_compat(&request, served_model_names)?; + + let include_logprobs = request.sampling_params.logprobs.is_some(); + let include_prompt_logprobs = request.sampling_params.prompt_logprobs.is_some(); + let mut sampling_params = request.sampling_params; + sampling_params.vllm_xargs = merge_kv_transfer_params( + sampling_params.vllm_xargs, + request.kv_transfer_params.as_ref(), + ); + + let text_request = TextRequest { + request_id: ctx.request_id.clone(), + prompt: Prompt::TokenIds(request.token_ids), + mm_features: None, + sampling_params, + decode_options: TextDecodeOptions::default(), + intermediate: false, + priority: request.priority, + cache_salt: request.cache_salt, + add_special_tokens: false, + data_parallel_rank: ctx.data_parallel_rank, + }; + + Ok(PreparedRequest { + request_id: ctx.request_id, + text_request, + include_logprobs, + include_prompt_logprobs, + }) +} + +#[cfg(test)] +mod tests { + use serde_json::json; + use vllm_text::Prompt; + + use super::prepare_generate_request; + use crate::routes::inference::generate::types::GenerateRequest; + use crate::utils::ResolvedRequestContext; + + #[test] + fn prepare_generate_request_maps_token_prompt_and_sampling_params() { + let request: GenerateRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "token_ids": [11, 22, 33], + "priority": -3, + "cache_salt": "salt", + "sampling_params": { + "max_tokens": 7, + "logprobs": 2, + "prompt_logprobs": 1, + "ignore_eos": true + }, + "kv_transfer_params": { + "connector": "x" + } + })) + .expect("parse request"); + + let prepared = prepare_generate_request( + request, + &["Qwen/Qwen1.5-0.5B-Chat".to_string()], + ResolvedRequestContext::default(), + ) + .expect("prepare"); + + assert_eq!( + prepared.text_request.prompt, + Prompt::TokenIds(vec![11, 22, 33]) + ); + assert_eq!(prepared.text_request.sampling_params.max_tokens, Some(7)); + assert_eq!(prepared.text_request.sampling_params.logprobs, Some(2)); + assert_eq!( + prepared.text_request.sampling_params.prompt_logprobs, + Some(1) + ); + assert!(prepared.text_request.sampling_params.ignore_eos); + assert_eq!(prepared.text_request.priority, -3); + assert_eq!(prepared.text_request.cache_salt.as_deref(), Some("salt")); + assert_eq!( + prepared + .text_request + .sampling_params + .vllm_xargs + .and_then(|mut xargs| xargs.remove("kv_transfer_params")), + Some(json!({"connector": "x"})) + ); + } +} diff --git a/rust/src/server/src/routes/inference/generate/types.rs b/rust/src/server/src/routes/inference/generate/types.rs new file mode 100644 index 00000000000..de7a196c3c6 --- /dev/null +++ b/rust/src/server/src/routes/inference/generate/types.rs @@ -0,0 +1,57 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; +use validator::Validate; +use vllm_text::SamplingParams; + +use crate::routes::openai::utils::types::{ChatLogProbs, Normalizable}; + +/// vLLM-compatible request type for the token-in/token-out generate API. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize, Validate)] +pub struct GenerateRequest { + pub request_id: Option, + pub model: Option, + pub token_ids: Vec, + pub sampling_params: SamplingParams, + #[serde(default)] + pub stream: bool, + pub cache_salt: Option, + #[serde(default)] + pub priority: i32, + pub kv_transfer_params: Option>, + #[serde(flatten)] + pub other: Map, +} + +impl Normalizable for GenerateRequest {} + +/// Mirrors the Python vLLM `GenerateResponseChoice` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct GenerateResponseChoice { + pub index: u32, + pub logprobs: Option, + pub finish_reason: Option, + pub token_ids: Vec, +} + +/// Mirrors the Python vLLM `GenerateResponse` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct GenerateResponse { + pub request_id: String, + pub choices: Vec, + pub prompt_logprobs: Option>>>, + pub kv_transfer_params: Option, +} + +/// Mirrors the Python vLLM `Logprob` class used in prompt-logprobs payloads. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct GenerateLogprob { + pub logprob: f32, + pub rank: Option, + pub decoded_token: Option, +} diff --git a/rust/src/server/src/routes/inference/generate/validate.rs b/rust/src/server/src/routes/inference/generate/validate.rs new file mode 100644 index 00000000000..74a5bbb690a --- /dev/null +++ b/rust/src/server/src/routes/inference/generate/validate.rs @@ -0,0 +1,84 @@ +use super::types::GenerateRequest; +use crate::error::{ApiError, bail_invalid_request}; + +/// Enforce the minimal compatibility contract for the Rust token generate +/// route. +pub(super) fn validate_request_compat( + request: &GenerateRequest, + served_model_names: &[String], +) -> Result<(), ApiError> { + if let Some(model) = request.model.as_ref() + && !served_model_names.iter().any(|n| n == model) + { + return Err(ApiError::model_not_found(model.clone())); + } + + if request.stream { + bail_invalid_request!(param = "stream", "stream=true is not supported."); + } + + if request.token_ids.is_empty() { + bail_invalid_request!( + param = "token_ids", + "token_ids must contain at least one token ID." + ); + } + + if request.sampling_params.max_tokens == Some(0) { + bail_invalid_request!( + param = "sampling_params", + "max_tokens must be greater than 0." + ); + } + + if let Some(prompt_logprobs) = request.sampling_params.prompt_logprobs + && prompt_logprobs < 0 + && prompt_logprobs != -1 + { + bail_invalid_request!( + param = "sampling_params", + "`prompt_logprobs` must be a non-negative value or -1." + ); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::validate_request_compat; + use crate::routes::inference::generate::types::GenerateRequest; + + fn base_request() -> GenerateRequest { + serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "token_ids": [11, 22], + "sampling_params": {} + })) + .expect("parse request") + } + + fn served(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + + #[test] + fn validate_request_compat_rejects_streaming() { + let request = GenerateRequest { + stream: true, + ..base_request() + }; + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); + } + + #[test] + fn validate_request_compat_rejects_empty_token_ids() { + let request = GenerateRequest { + token_ids: Vec::new(), + ..base_request() + }; + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); + } +} diff --git a/rust/src/server/src/routes/inference/mod.rs b/rust/src/server/src/routes/inference/mod.rs new file mode 100644 index 00000000000..d601d038745 --- /dev/null +++ b/rust/src/server/src/routes/inference/mod.rs @@ -0,0 +1,3 @@ +pub mod generate; + +pub use generate::generate; diff --git a/rust/src/server/src/routes/load.rs b/rust/src/server/src/routes/load.rs new file mode 100644 index 00000000000..0f666b24e40 --- /dev/null +++ b/rust/src/server/src/routes/load.rs @@ -0,0 +1,18 @@ +use std::sync::Arc; + +use axum::Json; +use axum::extract::State; +use serde::Serialize; + +use crate::state::AppState; + +#[derive(Serialize)] +pub(crate) struct ServerLoadResponse { + server_load: u64, +} + +pub async fn load(State(state): State>) -> Json { + Json(ServerLoadResponse { + server_load: state.server_load(), + }) +} diff --git a/rust/src/server/src/routes/metrics.rs b/rust/src/server/src/routes/metrics.rs new file mode 100644 index 00000000000..f7017fadccb --- /dev/null +++ b/rust/src/server/src/routes/metrics.rs @@ -0,0 +1,26 @@ +use axum::http::header::CONTENT_TYPE; +use axum::http::{HeaderValue, StatusCode}; +use axum::response::{IntoResponse, Response}; +use thiserror_ext::AsReport; +use vllm_metrics::METRICS; + +const OPENMETRICS_CONTENT_TYPE: &str = "application/openmetrics-text; version=1.0.0; charset=utf-8"; + +pub async fn scrape() -> Response { + match METRICS.render() { + Ok(body) => ( + [( + CONTENT_TYPE, + HeaderValue::from_static(OPENMETRICS_CONTENT_TYPE), + )], + body, + ) + .into_response(), + + Err(error) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("failed to render metrics: {}", error.as_report()), + ) + .into_response(), + } +} diff --git a/rust/src/server/src/routes/openai/chat_completions.rs b/rust/src/server/src/routes/openai/chat_completions.rs new file mode 100644 index 00000000000..c0894bb70c9 --- /dev/null +++ b/rust/src/server/src/routes/openai/chat_completions.rs @@ -0,0 +1,1046 @@ +pub mod convert; +mod types; +mod validate; + +use std::convert::Infallible; +use std::result::Result; +use std::sync::Arc; + +use asynk_strim_attr::{TryYielder, try_stream}; +use axum::Json; +use axum::extract::State; +use axum::http::HeaderMap; +use axum::response::sse::{Event, Sse}; +use axum::response::{IntoResponse, Response}; +use futures::{Stream, StreamExt as _, pin_mut}; +use serde_json::Value; +use thiserror_ext::AsReport as _; +use tracing::{debug, error, info, trace}; +use tracing_futures::Instrument as _; +use vllm_chat::{ + AssistantBlockKind, AssistantMessageExt as _, ChatEvent, ChatEventStream, ChatEventStreamTrait, + CollectedAssistantMessage, FinishReason, +}; +use vllm_engine_core_client::protocol::StopReason; + +use crate::error::{ApiError, bail_server_error, server_error}; +use crate::routes::openai::chat_completions::convert::prepare_chat_request; +use crate::routes::openai::chat_completions::types::{ + AssistantRole, ChatCompletionChoice, ChatCompletionMessage, ChatCompletionRequest, + ChatCompletionResponse, ChatCompletionStreamChoice, ChatCompletionStreamResponse, + ChatMessageDelta, +}; +use crate::routes::openai::utils::logprobs::{ + decoded_logprobs_to_openai_chat, decoded_prompt_logprobs_to_maps, +}; +use crate::routes::openai::utils::types::{ + ChatLogProbs, FunctionCallDelta, FunctionCallResponse, ToolCall, ToolCallDelta, Usage, +}; +use crate::routes::openai::utils::validated_json::ValidatedJson; +use crate::state::AppState; +use crate::utils::{resolve_request_context, unix_timestamp}; + +/// Validate one chat completion request and proxy it into the shared +/// `vllm-chat` stack. +pub async fn chat_completions( + State(state): State>, + headers: HeaderMap, + ValidatedJson(body): ValidatedJson, +) -> Response { + let stream = body.stream; + let request_context = resolve_request_context(&headers, body.request_id.as_deref()); + + let prepared = match prepare_chat_request(body, state.served_model_names(), request_context) { + Ok(prepared) => prepared, + Err(error) => return error.into_response(), + }; + let request_span = tracing::info_span!( + "chat_completions", + request_id = %prepared.request_id, + engine_request_id = tracing::field::Empty, + ); + + let created = unix_timestamp(); + let log_request = state.enable_log_requests; + + let chat_stream = + match state.chat.chat(prepared.chat_request).instrument(request_span.clone()).await { + Ok(stream) => stream, + Err(error) => { + return server_error!( + "failed to submit chat request: {}", + error.to_report_string() + ) + .into_response(); + } + }; + + if stream { + let chunk_stream = chat_completion_chunk_stream( + chat_stream, + prepared.request_id, + prepared.response_model, + created, + log_request, + prepared.include_usage, + prepared.requested_logprobs, + prepared.echo, + prepared.return_token_ids, + prepared.return_tokens_as_token_ids, + ); + let sse_stream = chat_completion_sse_stream(chunk_stream).instrument(request_span); + + Sse::new(sse_stream).into_response() + } else { + let response = match collect_chat_completion( + chat_stream, + prepared.request_id, + prepared.response_model, + created, + prepared.requested_logprobs, + prepared.include_prompt_logprobs, + prepared.echo, + prepared.return_token_ids, + prepared.return_tokens_as_token_ids, + ) + .instrument(request_span.clone()) + .await + { + Ok(response) => response, + Err(error) => return error.into_response(), + }; + + if log_request { + let usage = response.usage.as_ref(); + info!( + parent: &request_span, + model = %response.model, + prompt_tokens = usage.map_or(0, |u| u.prompt_tokens), + output_tokens = usage.and_then(|u| u.completion_tokens).unwrap_or(0), + finish_reason = response.choices.first().and_then(|c| c.finish_reason.as_deref()).unwrap_or("unknown"), + "chat completion finished" + ); + } + + Json(response).into_response() + } +} + +async fn collect_chat_completion( + stream: ChatEventStream, + request_id: String, + response_model: String, + created: u64, + requested_logprobs: bool, + include_prompt_logprobs: bool, + echo: Option, + return_token_ids: bool, + return_tokens_as_token_ids: bool, +) -> Result { + let collected = stream.collect_message().await.map_err(|error| { + server_error!( + "failed to collect chat completion response: {}", + error.to_report_string() + ) + })?; + let CollectedAssistantMessage { + message, + prompt_token_count, + prompt_token_ids, + prompt_logprobs, + logprobs, + token_ids, + output_token_count, + finish_reason, + kv_transfer_params, + } = collected; + let stop_reason = finish_reason.as_stop_reason().map(stop_reason_to_json); + let saw_tool_calls = message.tool_calls().next().is_some(); + let finish_reason = chat_finish_reason_to_openai(&finish_reason, saw_tool_calls)?.to_string(); + let tool_calls = message + .tool_calls() + .map(|call| ToolCall { + id: call.id.clone(), + tool_type: "function".to_string(), + function: FunctionCallResponse { + name: call.name.clone(), + arguments: Some(call.arguments.clone()), + }, + }) + .collect::>(); + let logprobs = if requested_logprobs { + Some(decoded_logprobs_to_openai_chat( + logprobs.as_ref().ok_or_else(|| { + server_error!("chat response requested logprobs but generation returned none") + })?, + return_tokens_as_token_ids, + )?) + } else { + None + }; + let prompt_logprobs = if include_prompt_logprobs { + Some(decoded_prompt_logprobs_to_maps( + prompt_logprobs.as_ref().ok_or_else(|| { + server_error!( + "chat response requested prompt_logprobs but generation returned none" + ) + })?, + return_tokens_as_token_ids, + )) + } else { + None + }; + let usage = Usage::from_counts(prompt_token_count as u32, output_token_count as u32); + + Ok(ChatCompletionResponse { + id: request_id, + object: "chat.completion".to_string(), + created, + model: response_model, + choices: vec![ChatCompletionChoice { + index: 0, + message: ChatCompletionMessage { + role: AssistantRole, + content: match &echo { + Some(prefix) => Some(format!("{prefix}{}", message.text())), + None => Some(message.text()).filter(|t| !t.is_empty()), + }, + tool_calls: Some(tool_calls).filter(|calls| !calls.is_empty()), + reasoning: message.reasoning(), + }, + logprobs, + finish_reason: Some(finish_reason), + stop_reason, + token_ids: return_token_ids.then_some(token_ids), + }], + usage: Some(usage), + system_fingerprint: None, + prompt_logprobs, + prompt_token_ids: return_token_ids.then(|| prompt_token_ids.to_vec()), + kv_transfer_params, + }) +} + +/// Convert one internal chat event stream into OpenAI chat-completion chunks. +#[try_stream] +async fn chat_completion_chunk_stream( + mut stream: impl ChatEventStreamTrait + Unpin, + request_id: String, + response_model: String, + created: u64, + log_request: bool, + include_usage: bool, + requested_logprobs: bool, + echo: Option, + return_token_ids: bool, + return_tokens_as_token_ids: bool, + mut y: TryYielder, +) -> Result<(), ApiError> { + let mut saw_tool_calls = false; + + // If the client requested logprobs or token_ids, we need to buffer chunks until + // we receive the separate `LogprobsDelta` event, so that we can emit one + // combined chunk with both the semantic delta and its per-update metadata. + let mut pending_chunk = + (requested_logprobs || return_token_ids).then(PendingChatChunk::default); + + while let Some(next) = stream.next().await { + match next { + Ok(ChatEvent::Start { + prompt_token_ids, .. + }) => { + let mut chunk = start_chunk(&request_id, &response_model, created); + if return_token_ids { + chunk.prompt_token_ids = Some(prompt_token_ids.to_vec()); + } + y.yield_ok(chunk).await; + // When echo=true, emit the last assistant message content as a delta chunk. + if let Some(echo_text) = &echo { + y.yield_ok(block_delta_chunk( + &request_id, + &response_model, + created, + AssistantBlockKind::Text, + echo_text.clone(), + )) + .await; + } + } + Ok(ChatEvent::BlockDelta { kind, delta, .. }) => { + if let Some(pending_chunk) = pending_chunk.as_mut() { + pending_chunk.push_block_delta(kind, delta); + } else { + y.yield_ok(block_delta_chunk( + &request_id, + &response_model, + created, + kind, + delta, + )) + .await; + } + } + Ok(ChatEvent::LogprobsDelta { + logprobs, + token_ids, + }) => { + let openai_logprobs = logprobs + .as_ref() + .map(|lp| decoded_logprobs_to_openai_chat(lp, return_tokens_as_token_ids)) + .transpose()?; + let openai_token_ids = + return_token_ids.then_some(token_ids).filter(|t| !t.is_empty()); + if let Some(pending_chunk) = pending_chunk.as_mut() { + pending_chunk.logprobs = openai_logprobs; + pending_chunk.token_ids = openai_token_ids; + if let Some(chunk) = + pending_chunk.take_chunk(&request_id, &response_model, created) + { + y.yield_ok(chunk).await; + } + } else if let Some(logprobs) = openai_logprobs { + y.yield_ok(logprobs_only_chunk( + &request_id, + &response_model, + created, + logprobs, + )) + .await; + } + } + Ok(ChatEvent::BlockStart { kind, .. }) => { + debug!(?kind, "starting new block"); + } + Ok(ChatEvent::BlockEnd { .. }) => { + debug!("ending current block"); + } + Ok(ChatEvent::ToolCallStart { index, id, name }) => { + let tool_index = index as u32; + saw_tool_calls = true; + debug!( + tool_call_id = %id, + tool_call_name = %name, + "starting new tool call" + ); + if let Some(pending_chunk) = pending_chunk.as_mut() { + pending_chunk.push_tool_call_start(tool_index, id, name); + } else { + y.yield_ok(tool_call_start_chunk( + &request_id, + &response_model, + created, + tool_index, + id, + name, + )) + .await; + } + } + Ok(ChatEvent::ToolCallArgumentsDelta { index, delta }) => { + let tool_index = index as u32; + if let Some(pending_chunk) = pending_chunk.as_mut() { + pending_chunk.push_tool_call_arguments(tool_index, delta); + } else { + y.yield_ok(tool_call_arguments_chunk( + &request_id, + &response_model, + created, + tool_index, + delta, + )) + .await; + } + } + Ok(ChatEvent::ToolCallEnd { .. }) => { + debug!("ending current tool call"); + } + Ok(ChatEvent::Done { + prompt_token_count, + finish_reason, + output_token_count, + .. + }) => { + if log_request { + info!( + stream = true, + model = %response_model, + prompt_tokens = prompt_token_count, + output_tokens = output_token_count, + finish_reason = finish_reason.as_str(), + "chat completion finished" + ); + } + + if let Some(pending_chunk) = pending_chunk.as_mut() + && let Some(chunk) = + pending_chunk.take_chunk(&request_id, &response_model, created) + { + y.yield_ok(chunk).await; + } + + match final_chunk( + &request_id, + &response_model, + created, + finish_reason, + saw_tool_calls, + ) { + Ok(chunk) => y.yield_ok(chunk).await, + Err(error) => { + error!( + error = %error.to_error_response().error.message, + "invalid terminal finish reason" + ); + return Err(error); + } + } + + if include_usage { + y.yield_ok(usage_chunk( + &request_id, + &response_model, + created, + Usage::from_counts(prompt_token_count as u32, output_token_count as u32), + )) + .await; + } + + return Ok(()); + } + Err(error) => { + error!( + error = %error.as_report(), + "chat stream failed" + ); + bail_server_error!("{}", error.to_report_string()); + } + } + } + Ok(()) +} + +fn usage_chunk( + request_id: &str, + response_model: &str, + created: u64, + usage: Usage, +) -> ChatCompletionStreamResponse { + let mut chunk = ChatCompletionStreamResponse::new(request_id, response_model, created); + chunk.usage = Some(usage); + chunk +} + +/// One in-flight chat-completions SSE chunk being assembled at the route layer. +/// +/// `vllm-chat` emits semantic chat events first and `LogprobsDelta` separately, +/// because one decoded update may be rewritten into multiple chat events. +/// The OpenAI chat API, though, wants one streamed chunk to optionally carry +/// both the delta and its logprobs. +/// +/// This small buffer accumulates the semantic delta first, then attaches the +/// following `LogprobsDelta` and flushes one combined chunk. It relies on the +/// current `vllm-chat` invariant that all semantic events from one decoded +/// update are emitted before that update's `LogprobsDelta`. +#[derive(Debug, Default)] +struct PendingChatChunk { + /// The currently buffered OpenAI delta payload assembled from one or more + /// chat semantic events belonging to the same decoded update. + delta: ChatMessageDelta, + /// The token-aligned logprobs for that same decoded update. + logprobs: Option, + /// Per-update output token IDs for the same decoded update. + token_ids: Option>, +} + +impl PendingChatChunk { + /// Append one assistant text/reasoning block delta to the buffered OpenAI + /// delta payload. + fn push_block_delta(&mut self, kind: AssistantBlockKind, delta: String) { + match kind { + AssistantBlockKind::Text => append_delta_text(&mut self.delta.content, delta), + AssistantBlockKind::Reasoning => append_delta_text(&mut self.delta.reasoning, delta), + AssistantBlockKind::ToolCall => { + unreachable!("tool calls must flow through dedicated tool-call chunks") + } + } + } + + /// Append the OpenAI tool-call-start representation to the buffered delta. + fn push_tool_call_start(&mut self, index: u32, id: String, name: String) { + self.delta.tool_calls.get_or_insert_with(Vec::new).push(ToolCallDelta { + index, + id: Some(id), + tool_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(name), + arguments: None, + }), + }); + } + + /// Append one incremental tool-call arguments update to the buffered delta. + fn push_tool_call_arguments(&mut self, index: u32, delta: String) { + self.delta.tool_calls.get_or_insert_with(Vec::new).push(ToolCallDelta { + index, + id: None, + tool_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(delta), + }), + }); + } + + /// Finalize the currently buffered SSE chunk, if it contains either a + /// semantic delta or a logprobs payload. + /// + /// This may produce: + /// - a combined delta + logprobs chunk + /// - a delta-only chunk + /// - a logprobs-only chunk + /// + /// The logprobs-only case is intentional: token-level metadata in one + /// decoded update is correlated with the same update boundary, not + /// necessarily with a visible/chat-semantic delta. + fn take_chunk( + &mut self, + request_id: &str, + response_model: &str, + created: u64, + ) -> Option { + let has_delta = self.delta.content.is_some() + || self.delta.reasoning.is_some() + || self.delta.tool_calls.is_some(); + let logprobs = self.logprobs.take(); + let token_ids = self.token_ids.take(); + if !has_delta && logprobs.is_none() && token_ids.is_none() { + return None; + } + + let mut chunk = ChatCompletionStreamResponse::new(request_id, response_model, created); + chunk.choices.push(ChatCompletionStreamChoice { + delta: self.take_delta(), + logprobs, + token_ids, + ..Default::default() + }); + Some(chunk) + } + + /// Take the currently buffered OpenAI delta payload and leave this pending + /// chunk empty for the next decoded update. + fn take_delta(&mut self) -> ChatMessageDelta { + ChatMessageDelta { + role: self.delta.role.take(), + content: self.delta.content.take(), + tool_calls: self.delta.tool_calls.take(), + reasoning: self.delta.reasoning.take(), + } + } +} + +/// Append one text fragment to an optional OpenAI delta string field. +fn append_delta_text(slot: &mut Option, delta: String) { + match slot { + Some(existing) => existing.push_str(&delta), + None => *slot = Some(delta), + } +} + +/// Convert one chunk stream into OpenAI-style SSE events. +/// +/// OpenAI-style streaming errors are encoded as ordinary `data: {"error": ...}` +/// events followed by `data: [DONE]`, so the transport stream itself stays +/// infallible even when generation fails after the HTTP response has started. +#[try_stream] +async fn chat_completion_sse_stream( + stream: impl Stream>, + mut y: TryYielder, +) -> Result<(), Infallible> { + pin_mut!(stream); + + while let Some(next) = stream.next().await { + match next { + Ok(chunk) => y.yield_ok(to_sse_event(&chunk)).await, + Err(error) => { + y.yield_ok(to_error_sse_event(&error)).await; + break; + } + } + } + + y.yield_ok(done_sse_event()).await; + Ok(()) +} + +/// Serialize one OpenAI chunk payload into one SSE `data:` event. +fn to_sse_event(chunk: &ChatCompletionStreamResponse) -> Event { + let payload = + serde_json::to_string(chunk).expect("ChatCompletionStreamResponse must serialize to JSON"); + trace!(payload, "chat completion emitting chunk"); + Event::default().data(payload) +} + +/// Serialize one OpenAI error payload into one SSE `data:` event. +fn to_error_sse_event(error: &ApiError) -> Event { + let payload = serde_json::to_string(&error.to_error_response()) + .expect("ErrorResponse must serialize to JSON"); + trace!(payload, "chat completion emitting error"); + Event::default().data(payload) +} + +/// Build the terminal OpenAI SSE sentinel event. +fn done_sse_event() -> Event { + trace!("chat completion emitting done"); + Event::default().data("[DONE]") +} + +/// Build the initial assistant-role SSE chunk required by the OpenAI streaming +/// protocol. +fn start_chunk( + request_id: &str, + response_model: &str, + created: u64, +) -> ChatCompletionStreamResponse { + let mut chunk = ChatCompletionStreamResponse::new(request_id, response_model, created); + chunk.choices.push(ChatCompletionStreamChoice { + delta: ChatMessageDelta { + role: Some(AssistantRole), + ..Default::default() + }, + ..Default::default() + }); + chunk +} + +/// Build one content-delta SSE chunk from one internal assistant block delta. +fn block_delta_chunk( + request_id: &str, + response_model: &str, + created: u64, + kind: AssistantBlockKind, + delta: String, +) -> ChatCompletionStreamResponse { + let delta = match kind { + AssistantBlockKind::Text => ChatMessageDelta { + content: Some(delta), + ..Default::default() + }, + AssistantBlockKind::Reasoning => ChatMessageDelta { + reasoning: Some(delta), + ..Default::default() + }, + AssistantBlockKind::ToolCall => { + unreachable!("tool calls must flow through dedicated tool-call chunks") + } + }; + + let mut chunk = ChatCompletionStreamResponse::new(request_id, response_model, created); + chunk.choices.push(ChatCompletionStreamChoice { + delta, + ..Default::default() + }); + chunk +} + +fn tool_call_start_chunk( + request_id: &str, + response_model: &str, + created: u64, + tool_index: u32, + id: String, + name: String, +) -> ChatCompletionStreamResponse { + let mut chunk = ChatCompletionStreamResponse::new(request_id, response_model, created); + chunk.choices.push(ChatCompletionStreamChoice { + delta: ChatMessageDelta { + tool_calls: Some(vec![ToolCallDelta { + index: tool_index, + id: Some(id), + tool_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(name), + arguments: None, + }), + }]), + ..Default::default() + }, + ..Default::default() + }); + chunk +} + +fn tool_call_arguments_chunk( + request_id: &str, + response_model: &str, + created: u64, + tool_index: u32, + delta: String, +) -> ChatCompletionStreamResponse { + let mut chunk = ChatCompletionStreamResponse::new(request_id, response_model, created); + chunk.choices.push(ChatCompletionStreamChoice { + delta: ChatMessageDelta { + tool_calls: Some(vec![ToolCallDelta { + index: tool_index, + id: None, + tool_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: Some(delta), + }), + }]), + ..Default::default() + }, + ..Default::default() + }); + chunk +} + +fn logprobs_only_chunk( + request_id: &str, + response_model: &str, + created: u64, + logprobs: ChatLogProbs, +) -> ChatCompletionStreamResponse { + let mut chunk = ChatCompletionStreamResponse::new(request_id, response_model, created); + chunk.choices.push(ChatCompletionStreamChoice { + logprobs: Some(logprobs), + ..Default::default() + }); + chunk +} + +/// Build the terminal SSE chunk carrying the OpenAI finish reason. +fn final_chunk( + request_id: &str, + response_model: &str, + created: u64, + finish_reason: FinishReason, + saw_tool_calls: bool, +) -> Result { + let stop_reason = finish_reason.as_stop_reason().map(stop_reason_to_json); + let finish_reason = chat_finish_reason_to_openai(&finish_reason, saw_tool_calls)?; + + debug!( + finish_reason = %finish_reason, + stop_reason = ?stop_reason, + "chat stream finished" + ); + + let mut chunk = ChatCompletionStreamResponse::new(request_id, response_model, created); + chunk.choices.push(ChatCompletionStreamChoice { + finish_reason: Some(finish_reason.to_string()), + stop_reason, + ..Default::default() + }); + Ok(chunk) +} + +fn chat_finish_reason_to_openai( + finish_reason: &FinishReason, + saw_tool_calls: bool, +) -> Result<&'static str, ApiError> { + match finish_reason { + FinishReason::Stop(_) if saw_tool_calls => Ok("tool_calls"), + FinishReason::Stop(_) => Ok("stop"), + FinishReason::Length => Ok("length"), + FinishReason::Abort => Ok("abort"), + FinishReason::Repetition => Ok("stop"), + FinishReason::Error => { + bail_server_error!("Internal server error"); + } + } +} + +/// Convert one internal stop reason into the OpenAI-compatible `stop_reason` +/// JSON shape. +fn stop_reason_to_json(stop_reason: &StopReason) -> Value { + serde_json::to_value(stop_reason).expect("StopReason must serialize to JSON") +} + +#[cfg(test)] +mod tests { + use futures::{StreamExt as _, stream}; + use serde_json::json; + use vllm_chat::{AssistantBlockKind, AssistantToolCall, ChatEvent, FinishReason}; + use vllm_engine_core_client::protocol::StopReason; + use vllm_text::{DecodedLogprobs, DecodedPositionLogprobs, DecodedTokenLogprob}; + + use super::{block_delta_chunk, chat_completion_chunk_stream, final_chunk}; + + #[test] + fn text_chunk_uses_content_only_delta() { + let chunk = block_delta_chunk( + "chatcmpl-1", + "model", + 1, + AssistantBlockKind::Text, + "hello".to_string(), + ); + assert_eq!(chunk.choices[0].delta.role, None); + assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("hello")); + assert_eq!(chunk.choices[0].delta.reasoning, None); + } + + #[test] + fn reasoning_chunk_uses_reasoning_only_delta() { + let chunk = block_delta_chunk( + "chatcmpl-1", + "model", + 1, + AssistantBlockKind::Reasoning, + "thinking".to_string(), + ); + assert_eq!(chunk.choices[0].delta.role, None); + assert_eq!(chunk.choices[0].delta.content, None); + assert_eq!( + chunk.choices[0].delta.reasoning.as_deref(), + Some("thinking") + ); + } + + #[test] + fn final_chunk_maps_stop_finish_reason_and_stop_reason() { + let chunk = final_chunk( + "chatcmpl-1", + "model", + 1, + FinishReason::Stop(Some(StopReason::Text("stop".to_string()))), + false, + ) + .expect("finish reason is valid"); + + assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("stop")); + assert_eq!(chunk.choices[0].stop_reason, Some(json!("stop"))); + } + + #[test] + fn final_chunk_maps_length_finish_reason() { + let chunk = final_chunk("chatcmpl-1", "model", 1, FinishReason::Length, false) + .expect("finish reason is valid"); + + assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("length")); + assert_eq!(chunk.choices[0].stop_reason, None); + } + + #[test] + fn final_chunk_maps_abort_finish_reason() { + let chunk = final_chunk("chatcmpl-1", "model", 1, FinishReason::Abort, false) + .expect("abort is a valid finish reason"); + + assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("abort")); + assert_eq!(chunk.choices[0].stop_reason, None); + } + + #[test] + fn final_chunk_rejects_error_finish_reason() { + assert!(final_chunk("chatcmpl-1", "model", 1, FinishReason::Error, false).is_err()); + } + + #[test] + fn final_chunk_maps_stop_to_tool_calls_when_tool_calls_were_streamed() { + let chunk = final_chunk("chatcmpl-1", "model", 1, FinishReason::stop_eos(), true) + .expect("finish reason is valid"); + + assert_eq!( + chunk.choices[0].finish_reason.as_deref(), + Some("tool_calls") + ); + } + + #[tokio::test] + async fn chunk_stream_coalesces_text_delta_with_logprobs() { + let stream = stream::iter(vec![ + Ok(ChatEvent::Start { + prompt_token_ids: vec![].into(), + prompt_logprobs: None, + }), + Ok(ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Text, + }), + Ok(ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Text, + delta: "hi".to_string(), + }), + Ok(ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "hi".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }), + token_ids: vec![], + }), + Ok(ChatEvent::Done { + message: Default::default(), + prompt_token_count: 1, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let chunks = chat_completion_chunk_stream( + stream, + "chatcmpl-1".to_string(), + "model".to_string(), + 1, + false, + false, + true, + None, + false, + false, + ) + .collect::>() + .await + .into_iter() + .collect::, _>>() + .expect("stream chunks"); + + assert_eq!(chunks.len(), 3); + assert_eq!(chunks[1].choices[0].delta.content.as_deref(), Some("hi")); + let logprobs = chunks[1].choices[0].logprobs.as_ref().expect("logprobs"); + let content = logprobs.content.as_ref().expect("logprobs content"); + assert_eq!(content[0].token, "hi"); + } + + #[tokio::test] + async fn chunk_stream_coalesces_reasoning_delta_with_logprobs() { + let stream = stream::iter(vec![ + Ok(ChatEvent::Start { + prompt_token_ids: vec![].into(), + prompt_logprobs: None, + }), + Ok(ChatEvent::BlockStart { + index: 0, + kind: AssistantBlockKind::Reasoning, + }), + Ok(ChatEvent::BlockDelta { + index: 0, + kind: AssistantBlockKind::Reasoning, + delta: "think".to_string(), + }), + Ok(ChatEvent::LogprobsDelta { + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "think".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }), + token_ids: vec![], + }), + Ok(ChatEvent::Done { + message: Default::default(), + prompt_token_count: 1, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let chunks = chat_completion_chunk_stream( + stream, + "chatcmpl-1".to_string(), + "model".to_string(), + 1, + false, + false, + true, + None, + false, + false, + ) + .collect::>() + .await + .into_iter() + .collect::, _>>() + .expect("stream chunks"); + + assert_eq!(chunks.len(), 3); + assert_eq!( + chunks[1].choices[0].delta.reasoning.as_deref(), + Some("think") + ); + assert!(chunks[1].choices[0].logprobs.is_some()); + } + + #[tokio::test] + async fn chunk_stream_preserves_tool_call_index_and_omits_id_from_arguments_delta() { + let stream = stream::iter(vec![ + Ok(ChatEvent::Start { + prompt_token_ids: vec![].into(), + prompt_logprobs: None, + }), + Ok(ChatEvent::ToolCallStart { + index: 3, + id: "call_1".to_string(), + name: "get_weather".to_string(), + }), + Ok(ChatEvent::ToolCallArgumentsDelta { + index: 3, + delta: r#"{"city":"Paris"}"#.to_string(), + }), + Ok(ChatEvent::ToolCallEnd { + index: 3, + call: AssistantToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: r#"{"city":"Paris"}"#.to_string(), + }, + }), + Ok(ChatEvent::Done { + message: Default::default(), + prompt_token_count: 1, + output_token_count: 1, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + ]); + + let chunks = chat_completion_chunk_stream( + stream, + "chatcmpl-1".to_string(), + "model".to_string(), + 1, + false, + false, + false, + None, + false, + false, + ) + .collect::>() + .await + .into_iter() + .collect::, _>>() + .expect("stream chunks"); + + assert_eq!( + chunks[1].choices[0].delta.tool_calls.as_ref().unwrap()[0].index, + 3 + ); + assert_eq!( + chunks[1].choices[0].delta.tool_calls.as_ref().unwrap()[0].id, + Some("call_1".to_string()) + ); + assert_eq!( + chunks[2].choices[0].delta.tool_calls.as_ref().unwrap()[0].index, + 3 + ); + assert_eq!( + chunks[2].choices[0].delta.tool_calls.as_ref().unwrap()[0].id, + None + ); + } +} diff --git a/rust/src/server/src/routes/openai/chat_completions/convert.rs b/rust/src/server/src/routes/openai/chat_completions/convert.rs new file mode 100644 index 00000000000..a7884b11e52 --- /dev/null +++ b/rust/src/server/src/routes/openai/chat_completions/convert.rs @@ -0,0 +1,992 @@ +use itertools::Itertools as _; +use vllm_chat::{ + AssistantContentBlock, AssistantToolCall, ChatContent, ChatContentPart, + ChatMessage as VllmChatMessage, ChatOptions, ChatRequest, ChatTool, ChatToolChoice, + GenerationPromptMode, SamplingParams, +}; + +use super::types::ChatCompletionRequest; +use super::validate; +use crate::error::{ApiError, bail_invalid_request}; +use crate::routes::openai::utils::structured_outputs::convert_from_response_format; +use crate::routes::openai::utils::types::{ + ChatMessage, ContentPart, MessageContent, Tool, ToolChoice, ToolChoiceValue, +}; +use crate::utils::{ResolvedRequestContext, convert_logit_bias, merge_kv_transfer_params}; + +/// Lowered chat request plus the public response metadata carried by every SSE +/// chunk. +#[derive(Debug, Clone, PartialEq)] +pub struct PreparedRequest { + /// Stable OpenAI-style request ID, reused as the external chat request ID. + pub request_id: String, + /// Public model ID echoed back to the client. + pub response_model: String, + /// Whether the caller asked for the final streamed usage chunk. + pub include_usage: bool, + /// Whether the caller requested output logprobs on chat choices. + pub requested_logprobs: bool, + /// Whether the caller requested top-level prompt logprobs. + pub include_prompt_logprobs: bool, + /// Lowered chat request for `vllm-chat`. + pub chat_request: ChatRequest, + /// Last assistant-role message content to echo back when `echo=true`. + pub echo: Option, + /// Whether to include token IDs alongside generated text. + pub return_token_ids: bool, + /// Whether to format logprob tokens as `token_id:{id}`. + pub return_tokens_as_token_ids: bool, +} + +/// Validate and lower one OpenAI chat completion request into the internal chat +/// format. +/// +/// `served_model_names` must be non-empty; the first entry is used as the +/// `model` field in responses. +pub(crate) fn prepare_chat_request( + request: ChatCompletionRequest, + served_model_names: &[String], + ctx: ResolvedRequestContext, +) -> Result { + validate::validate_request_compat(&request, served_model_names)?; + + let request_id = format!("chatcmpl-{}", ctx.request_id); + let echo = request + .echo + .then(|| extract_last_assistant_content(&request.messages)) + .flatten(); + let messages: Vec<_> = request.messages.into_iter().map(convert_message).try_collect()?; + let generation_prompt_mode = normalize_generation_prompt_mode( + request.add_generation_prompt, + request.continue_final_message, + &messages, + )?; + + let template_kwargs = request.chat_template_kwargs.unwrap_or_default(); + + let include_usage = (request.stream_options.as_ref()) + .and_then(|options| options.include_usage) + .unwrap_or(false); + let requested_logprobs = request.logprobs; + + // Auto-enable prompt logprobs for non-streaming echo, matching Python vLLM's + // behavior. + let top_logprobs = request.top_logprobs.unwrap_or(0); + let prompt_logprobs = request + .prompt_logprobs + .or((request.echo && !request.stream).then_some(top_logprobs)); + let include_prompt_logprobs = prompt_logprobs.is_some(); + + let structured_outputs = convert_from_response_format( + request.response_format.as_ref(), + &request.structured_outputs, + )?; + + let chat_request = ChatRequest { + request_id: request_id.clone(), + messages, + sampling_params: SamplingParams { + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + seed: request.seed, + max_tokens: request.max_completion_tokens, + min_tokens: request.min_tokens, + logprobs: request.logprobs.then_some(top_logprobs), + prompt_logprobs, + min_p: request.min_p, + frequency_penalty: request.frequency_penalty, + presence_penalty: request.presence_penalty, + repetition_penalty: request.repetition_penalty, + stop_token_ids: request.stop_token_ids, + ignore_eos: request.ignore_eos, + logit_bias: convert_logit_bias(request.logit_bias)?, + allowed_token_ids: request.allowed_token_ids, + bad_words: request.bad_words, + logprob_token_ids: None, + structured_outputs, + skip_reading_prefix_cache: None, + vllm_xargs: merge_kv_transfer_params( + request.vllm_xargs, + request.kv_transfer_params.as_ref(), + ), + }, + chat_options: ChatOptions { + generation_prompt_mode, + chat_template: request.chat_template, + reasoning_effort: request.reasoning_effort, + template_kwargs, + }, + tools: convert_tools(request.tools)?, + tool_choice: convert_tool_choice(request.tool_choice.as_ref())?, + decode_options: vllm_text::output::TextDecodeOptions { + skip_special_tokens: request.skip_special_tokens, + include_stop_str_in_output: request.include_stop_str_in_output, + stop_strings: request.stop.map(|stop| stop.into_vec()), + min_tokens: request.min_tokens.unwrap_or(0), + }, + intermediate: request.stream, + priority: request.priority.unwrap_or(0), + documents: request.documents, + cache_salt: request.cache_salt, + add_special_tokens: request.add_special_tokens, + data_parallel_rank: ctx.data_parallel_rank, + }; + + Ok(PreparedRequest { + request_id, + response_model: served_model_names.first().cloned().unwrap_or_default(), + include_usage, + requested_logprobs, + include_prompt_logprobs, + chat_request, + echo, + return_token_ids: request.return_token_ids.unwrap_or(false), + return_tokens_as_token_ids: request.return_tokens_as_token_ids.unwrap_or(false), + }) +} + +fn normalize_generation_prompt_mode( + add_generation_prompt: Option, + continue_final_message: bool, + messages: &[VllmChatMessage], +) -> Result { + if add_generation_prompt == Some(true) && continue_final_message { + bail_invalid_request!( + "Cannot set both `continue_final_message` and `add_generation_prompt` to True." + ); + } + + let last_role = messages.last().map(VllmChatMessage::role); + match (add_generation_prompt, continue_final_message, last_role) { + (Some(true), true, _) => unreachable!("rejected above"), + (_, true, Some(vllm_chat::ChatRole::Assistant)) => { + Ok(GenerationPromptMode::ContinueFinalAssistant) + } + (_, true, _) => { + bail_invalid_request!( + "Cannot set `continue_final_message` to True when the last message is not from the assistant." + ); + } + (Some(false), false, _) => Ok(GenerationPromptMode::NoGenerationPrompt), + (None | Some(true), false, _) => Ok(GenerationPromptMode::StartNewAssistant), + } +} + +/// Extract the text content of the last message if it has the assistant role. +fn extract_last_assistant_content(messages: &[ChatMessage]) -> Option { + let ChatMessage::Assistant { content, .. } = messages.last()? else { + return None; + }; + let text = match content.as_ref()? { + MessageContent::Text(text) => text.clone(), + MessageContent::Parts(parts) => parts + .iter() + .filter_map(|p| match p { + ContentPart::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("\n"), + }; + (!text.is_empty()).then_some(text) +} + +/// Lower one OpenAI chat message into the `vllm-chat` message shape. +fn convert_message(message: ChatMessage) -> Result { + match message { + ChatMessage::System { content, .. } => { + Ok(VllmChatMessage::system(convert_content(content)?)) + } + ChatMessage::User { content, .. } => Ok(VllmChatMessage::user(convert_content(content)?)), + ChatMessage::Assistant { + content, + tool_calls, + reasoning, + name: _, + } => { + let mut blocks = Vec::new(); + if let Some(reasoning) = reasoning + && !reasoning.is_empty() + { + blocks.push(AssistantContentBlock::Reasoning { text: reasoning }); + } + if let Some(content) = content { + blocks.extend(convert_assistant_text_blocks(content)?); + } + if let Some(tool_calls) = tool_calls { + blocks.extend(convert_assistant_tool_calls(tool_calls)?); + } + if blocks.is_empty() { + bail_invalid_request!( + "Assistant messages must contain text, reasoning content, or tool_calls." + ); + } + + Ok(VllmChatMessage::assistant_blocks(blocks)) + } + ChatMessage::Tool { + content, + tool_call_id, + } => Ok(VllmChatMessage::tool_response( + convert_content(content)?, + tool_call_id, + )), + ChatMessage::Function { .. } => { + bail_invalid_request!("Function messages are not supported.") + } + ChatMessage::Developer { + content, + tools, + name: _, + } => Ok(VllmChatMessage::developer( + convert_content(content)?, + convert_message_tools(tools)?, + )), + } +} + +/// Convert the given OpenAI message content value into the internal format in +/// `vllm-chat`. +fn convert_content(content: MessageContent) -> Result { + match content { + MessageContent::Text(text) => Ok(ChatContent::Text(text)), + MessageContent::Parts(parts) => parts + .into_iter() + .map(|part| match part { + ContentPart::Text { text } => Ok(ChatContentPart::text(text)), + ContentPart::ImageUrl { image_url, uuid } => Ok(ChatContentPart::ImageUrl { + image_url: image_url.url, + detail: image_url.detail, + uuid, + }), + _ => bail_invalid_request!("Only text and image_url content parts are supported."), + }) + .try_collect() + .map(ChatContent::Parts), + } +} + +/// Convert the given OpenAI assistant message content into the internal format +/// in `vllm-chat`. +fn convert_assistant_text_blocks( + content: MessageContent, +) -> Result, ApiError> { + match content { + MessageContent::Text(text) => Ok(vec![AssistantContentBlock::Text { text }]), + MessageContent::Parts(parts) => parts + .into_iter() + .map(|part| match part { + ContentPart::Text { text } => Ok(AssistantContentBlock::Text { text }), + _ => bail_invalid_request!( + "Only text content parts are supported for assistant messages." + ), + }) + .try_collect(), + } +} + +fn convert_assistant_tool_calls( + tool_calls: Vec, +) -> Result, ApiError> { + tool_calls + .into_iter() + .map(|tool_call| { + if tool_call.tool_type != "function" { + bail_invalid_request!("Only function tool calls are supported."); + } + + Ok(AssistantContentBlock::ToolCall(AssistantToolCall { + id: tool_call.id, + name: tool_call.function.name, + arguments: tool_call.function.arguments.unwrap_or_else(|| "{}".to_string()), + })) + }) + .collect() +} + +fn convert_tools(tools: Option>) -> Result, ApiError> { + tools + .unwrap_or_default() + .into_iter() + .map(|tool| { + if tool.tool_type != "function" { + bail_invalid_request!("Only function tools are supported."); + } + Ok(ChatTool { + name: tool.function.name, + description: tool.function.description, + parameters: tool.function.parameters, + strict: tool.function.strict, + }) + }) + .collect() +} + +fn convert_message_tools(tools: Option>) -> Result>, ApiError> { + let tools = convert_tools(tools)?; + Ok((!tools.is_empty()).then_some(tools)) +} + +fn convert_tool_choice(tool_choice: Option<&ToolChoice>) -> Result { + match tool_choice { + None | Some(ToolChoice::Value(ToolChoiceValue::Auto)) => Ok(ChatToolChoice::Auto), + Some(ToolChoice::Value(ToolChoiceValue::None)) => Ok(ChatToolChoice::None), + _ => bail_invalid_request!("tool_choice={:?} is not supported yet.", tool_choice), + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use axum::http::HeaderMap; + use expect_test::expect; + use llm_multimodal::ImageDetail; + use serde_json::json; + use vllm_chat::{ + AssistantContentBlock, AssistantToolCall, ChatContentPart, ChatMessage as VllmChatMessage, + ChatTool as VllmChatTool, ChatToolChoice, GenerationPromptMode, + SamplingParams as VllmSamplingParams, + }; + use vllm_text::output::TextDecodeOptions; + + use super::prepare_chat_request; + use crate::routes::openai::chat_completions::types::{ + AssistantRole, ChatCompletionMessage, ChatCompletionRequest, + }; + use crate::routes::openai::utils::types::{ + ChatMessage, ContentPart, Function, FunctionCallResponse, ImageUrl, MessageContent, Tool, + ToolCall, ToolChoice, ToolChoiceValue, VideoUrl, + }; + use crate::utils::{ResolvedRequestContext, resolve_request_context}; + + fn request_context(headers: &HeaderMap, request_id: Option<&str>) -> ResolvedRequestContext { + resolve_request_context(headers, request_id) + } + + fn served(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + + fn base_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: "Qwen/Qwen1.5-0.5B-Chat".to_string(), + messages: vec![ChatMessage::User { + content: MessageContent::Text("hello".to_string()), + name: None, + }], + stream: true, + ..Default::default() + } + } + + #[test] + fn prepare_chat_request_maps_text_parts() { + let mut request = base_request(); + request.messages = vec![ChatMessage::Assistant { + content: Some(MessageContent::Parts(vec![ContentPart::Text { + text: "hello".to_string(), + }])), + name: None, + tool_calls: None, + reasoning: None, + }]; + request.add_generation_prompt = Some(false); + request.continue_final_message = true; + request.skip_special_tokens = false; + request.chat_template_kwargs = Some(HashMap::from([("foo".to_string(), json!("bar"))])); + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert!(prepared.request_id.starts_with("chatcmpl-")); + assert_eq!( + prepared.chat_request.messages, + vec![VllmChatMessage::assistant_text("hello")] + ); + assert_eq!( + prepared.chat_request.sampling_params, + VllmSamplingParams::default() + ); + assert_eq!( + prepared.chat_request.chat_options.generation_prompt_mode, + GenerationPromptMode::ContinueFinalAssistant + ); + assert_eq!( + prepared.chat_request.chat_options.template_kwargs, + HashMap::from([("foo".to_string(), json!("bar"))]) + ); + assert_eq!( + prepared.chat_request.decode_options, + TextDecodeOptions { + skip_special_tokens: false, + include_stop_str_in_output: false, + stop_strings: None, + min_tokens: 0, + } + ); + assert!(prepared.chat_request.tools.is_empty()); + assert_eq!(prepared.chat_request.tool_choice, ChatToolChoice::Auto); + } + + #[test] + fn prepare_chat_request_keeps_optional_sampling_fields_unset() { + let prepared = prepare_chat_request( + base_request(), + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert!(prepared.request_id.starts_with("chatcmpl-")); + assert_eq!( + prepared.chat_request.messages, + vec![VllmChatMessage::user("hello")] + ); + assert_eq!( + prepared.chat_request.sampling_params, + VllmSamplingParams::default() + ); + assert_eq!( + prepared.chat_request.chat_options.generation_prompt_mode, + GenerationPromptMode::StartNewAssistant + ); + assert_eq!( + prepared.chat_request.decode_options, + TextDecodeOptions { + skip_special_tokens: true, + include_stop_str_in_output: false, + stop_strings: None, + min_tokens: 0, + } + ); + assert!(prepared.chat_request.tools.is_empty()); + assert_eq!(prepared.chat_request.tool_choice, ChatToolChoice::Auto); + } + + #[test] + fn prepare_chat_request_preserves_sampling_passthrough_fields() { + let request = ChatCompletionRequest { + seed: Some(42), + min_p: Some(0.2), + frequency_penalty: Some(0.3), + presence_penalty: Some(0.4), + repetition_penalty: Some(1.1), + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + let expected = VllmSamplingParams { + seed: Some(42), + min_p: Some(0.2), + frequency_penalty: Some(0.3), + presence_penalty: Some(0.4), + repetition_penalty: Some(1.1), + ..VllmSamplingParams::default() + }; + assert_eq!(prepared.chat_request.sampling_params, expected); + } + + #[test] + fn prepare_chat_request_accepts_developer_messages() { + let request = ChatCompletionRequest { + messages: vec![ChatMessage::Developer { + content: MessageContent::Text("hello".to_string()), + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + }), + strict: Some(true), + }, + }]), + name: None, + }], + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert_eq!( + prepared.chat_request.messages, + vec![VllmChatMessage::developer( + "hello", + Some(vec![VllmChatTool { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + }), + strict: Some(true), + }]), + )] + ); + } + + #[test] + fn prepare_chat_request_maps_image_url_content_parts() { + let request = ChatCompletionRequest { + messages: vec![ChatMessage::User { + content: MessageContent::Parts(vec![ + ContentPart::Text { + text: "describe ".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.png".to_string(), + detail: Some(ImageDetail::Low), + }, + uuid: Some("image-1".to_string()), + }, + ContentPart::Text { + text: " briefly".to_string(), + }, + ]), + name: None, + }], + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert_eq!( + prepared.chat_request.messages, + vec![VllmChatMessage::user(vec![ + ChatContentPart::text("describe "), + ChatContentPart::ImageUrl { + image_url: "https://example.com/image.png".to_string(), + detail: Some(ImageDetail::Low), + uuid: Some("image-1".to_string()), + }, + ChatContentPart::text(" briefly"), + ])] + ); + } + + #[test] + fn prepare_chat_request_maps_developer_image_url_content_parts() { + let request = ChatCompletionRequest { + messages: vec![ChatMessage::Developer { + content: MessageContent::Parts(vec![ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.png".to_string(), + detail: None, + }, + uuid: None, + }]), + tools: None, + name: None, + }], + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert_eq!( + prepared.chat_request.messages, + vec![VllmChatMessage::developer( + vec![ChatContentPart::image_url("https://example.com/image.png")], + None, + )] + ); + } + + #[test] + fn prepare_chat_request_rejects_video_content_parts() { + let request = ChatCompletionRequest { + messages: vec![ChatMessage::User { + content: MessageContent::Parts(vec![ContentPart::VideoUrl { + video_url: VideoUrl { + url: "https://example.com/video.mp4".to_string(), + }, + }]), + name: None, + }], + ..base_request() + }; + + let error = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .unwrap_err(); + + expect!["Only text and image_url content parts are supported."] + .assert_eq(&error.to_error_response().error.message); + } + + #[test] + fn prepare_chat_request_rejects_assistant_image_url_content_parts() { + let request = ChatCompletionRequest { + messages: vec![ChatMessage::Assistant { + content: Some(MessageContent::Parts(vec![ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.png".to_string(), + detail: None, + }, + uuid: None, + }])), + name: None, + tool_calls: None, + reasoning: None, + }], + ..base_request() + }; + + let error = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .unwrap_err(); + + expect!["Only text content parts are supported for assistant messages."] + .assert_eq(&error.to_error_response().error.message); + } + + #[test] + fn prepare_chat_request_accepts_assistant_reasoning_history() { + let message = ChatCompletionMessage { + role: AssistantRole, + content: Some("answer".to_string()), + tool_calls: None, + reasoning: Some("inner".to_string()), + }; + let message_json = serde_json::to_value(message).expect("message serializes"); + + let request = ChatCompletionRequest { + messages: vec![ + serde_json::from_value(message_json).expect("response message is valid history"), + ], + add_generation_prompt: Some(false), + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + assert_eq!( + prepared.chat_request.messages, + vec![VllmChatMessage::assistant_blocks(vec![ + AssistantContentBlock::Reasoning { + text: "inner".to_string(), + }, + AssistantContentBlock::Text { + text: "answer".to_string(), + }, + ])] + ); + assert!(prepared.chat_request.tools.is_empty()); + assert_eq!(prepared.chat_request.tool_choice, ChatToolChoice::Auto); + } + + #[test] + fn prepare_chat_request_accepts_legacy_reasoning_content_alias() { + let request = ChatCompletionRequest { + messages: vec![ + serde_json::from_value(json!({ + "role": "assistant", + "content": "answer", + "reasoning_content": "inner", + })) + .expect("legacy reasoning_content alias is accepted"), + ], + add_generation_prompt: Some(false), + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + assert_eq!( + prepared.chat_request.messages, + vec![VllmChatMessage::assistant_blocks(vec![ + AssistantContentBlock::Reasoning { + text: "inner".to_string(), + }, + AssistantContentBlock::Text { + text: "answer".to_string(), + }, + ])] + ); + } + + #[test] + fn prepare_chat_request_accepts_tools_and_tool_history() { + let request = ChatCompletionRequest { + messages: vec![ + ChatMessage::Assistant { + content: None, + name: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + tool_type: "function".to_string(), + function: FunctionCallResponse { + name: "get_weather".to_string(), + arguments: Some(r#"{"city":"Paris"}"#.to_string()), + }, + }]), + reasoning: None, + }, + ChatMessage::Tool { + content: MessageContent::Text("Sunny".to_string()), + tool_call_id: "call_1".to_string(), + }, + ], + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + }), + strict: None, + }, + }]), + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::None)), + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + assert_eq!( + prepared.chat_request.messages, + vec![ + VllmChatMessage::assistant_blocks(vec![AssistantContentBlock::ToolCall( + AssistantToolCall { + id: "call_1".to_string(), + name: "get_weather".to_string(), + arguments: r#"{"city":"Paris"}"#.to_string(), + }, + )]), + VllmChatMessage::tool_response("Sunny", "call_1"), + ] + ); + assert_eq!( + prepared.chat_request.tools, + vec![VllmChatTool { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + }), + strict: None, + }] + ); + assert_eq!(prepared.chat_request.tool_choice, ChatToolChoice::None); + } + + #[test] + fn prepare_chat_request_lowers_logprobs_fields() { + let request = ChatCompletionRequest { + stream: false, + logprobs: true, + prompt_logprobs: Some(2), + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert!(prepared.requested_logprobs); + assert!(prepared.include_prompt_logprobs); + assert_eq!(prepared.chat_request.sampling_params.logprobs, Some(0)); + assert_eq!( + prepared.chat_request.sampling_params.prompt_logprobs, + Some(2) + ); + } + + #[test] + fn prepare_chat_request_keeps_prompt_logprobs_independent_from_echo() { + let request = ChatCompletionRequest { + logprobs: true, + top_logprobs: Some(3), + echo: true, + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert_eq!(prepared.chat_request.sampling_params.logprobs, Some(3)); + assert_eq!(prepared.chat_request.sampling_params.prompt_logprobs, None); + assert!(!prepared.include_prompt_logprobs); + } + + #[test] + fn prepare_chat_request_threads_data_parallel_rank() { + let mut headers = HeaderMap::new(); + headers.insert("X-data-parallel-rank", "7".parse().unwrap()); + let prepared = prepare_chat_request( + base_request(), + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + request_context(&headers, None), + ) + .expect("request is valid"); + assert_eq!(prepared.chat_request.data_parallel_rank, Some(7)); + } + + #[test] + fn prepare_chat_request_leaves_data_parallel_rank_none_when_absent() { + let prepared = prepare_chat_request( + base_request(), + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + assert_eq!(prepared.chat_request.data_parallel_rank, None); + } + + #[test] + fn prepare_chat_request_maps_no_generation_prompt_mode() { + let mut request = base_request(); + request.add_generation_prompt = Some(false); + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert_eq!( + prepared.chat_request.chat_options.generation_prompt_mode, + GenerationPromptMode::NoGenerationPrompt + ); + } + + #[test] + fn prepare_chat_request_rejects_conflicting_explicit_generation_prompt_flags() { + let mut request = base_request(); + request.add_generation_prompt = Some(true); + request.continue_final_message = true; + + let error = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .unwrap_err(); + + expect!["Cannot set both `continue_final_message` and `add_generation_prompt` to True."] + .assert_eq(&error.to_error_response().error.message); + } + + #[test] + fn prepare_chat_request_accepts_continue_final_message_with_implicit_add_generation_prompt() { + let mut request = base_request(); + request.messages = vec![ChatMessage::Assistant { + content: Some(MessageContent::Text("hello".to_string())), + name: None, + tool_calls: None, + reasoning: None, + }]; + request.continue_final_message = true; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert_eq!( + prepared.chat_request.chat_options.generation_prompt_mode, + GenerationPromptMode::ContinueFinalAssistant + ); + } + + #[test] + fn prepare_chat_request_rejects_continue_final_message_without_final_assistant() { + let mut request = base_request(); + request.continue_final_message = true; + + let error = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .unwrap_err(); + + expect!["Cannot set `continue_final_message` to True when the last message is not from the assistant."] + .assert_eq(&error.to_error_response().error.message); + } + + #[test] + fn prepare_chat_request_allows_new_assistant_mode_after_final_assistant() { + let request = ChatCompletionRequest { + messages: vec![ChatMessage::Assistant { + content: Some(MessageContent::Text("hello".to_string())), + name: None, + tool_calls: None, + reasoning: None, + }], + ..base_request() + }; + + let prepared = prepare_chat_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("request is valid"); + + assert_eq!( + prepared.chat_request.chat_options.generation_prompt_mode, + GenerationPromptMode::StartNewAssistant + ); + } +} diff --git a/rust/src/server/src/routes/openai/chat_completions/types.rs b/rust/src/server/src/routes/openai/chat_completions/types.rs new file mode 100644 index 00000000000..00557ad53d2 --- /dev/null +++ b/rust/src/server/src/routes/openai/chat_completions/types.rs @@ -0,0 +1,600 @@ +use std::collections::HashMap; +use std::fmt; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde_with::SerializeDisplay; +use validator::Validate; +use vllm_chat::ReasoningEffort; + +use crate::routes::openai::utils::structured_outputs::ResponseFormat; +use crate::routes::openai::utils::types::{ + ChatLogProbs, ChatMessage, MessageContent, Normalizable, StreamOptions, StringOrArray, Tool, + ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, ToolReference, UNKNOWN_MODEL_ID, Usage, + default_true, validate_stop, validate_top_p_value, +}; + +/// vLLM-compatible request type for the Chat Completions API. +/// +/// Mirrors the Python vLLM `ChatCompletionRequest` class. The local copy keeps +/// the request type route-owned so we can add vLLM-only fields directly instead +/// of layering wrapper deserializers on top. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize, Validate)] +#[validate(schema(function = "validate_chat_cross_parameters"))] +pub struct ChatCompletionRequest { + // -------- Standard OpenAI API Parameters -------- + /// A list of messages comprising the conversation so far + #[validate(custom(function = "validate_messages"))] + pub messages: Vec, + + /// ID of the model to use + #[serde(default = "default_model")] + pub model: String, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based + /// on their existing frequency in the text so far + #[validate(range(min = -2.0, max = 2.0))] + pub frequency_penalty: Option, + + /// Modify the likelihood of specified tokens appearing in the completion + pub logit_bias: Option>, + + /// Whether to return log probabilities of the output tokens + #[serde(default)] + pub logprobs: bool, + + /// An integer specifying the number of most likely tokens to return + /// -1 means return all + #[validate(range(min = -1))] + pub top_logprobs: Option, + + /// Deprecated: Replaced by max_completion_tokens + #[deprecated(note = "Use max_completion_tokens instead")] + #[validate(range(min = 1))] + pub max_tokens: Option, + + /// An upper bound for the number of tokens that can be generated for a + /// completion + #[validate(range(min = 1))] + pub max_completion_tokens: Option, + + /// How many chat completion choices to generate for each input message + #[validate(range(min = 1, max = 10))] + pub n: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based + /// on whether they appear in the text so far + #[validate(range(min = -2.0, max = 2.0))] + pub presence_penalty: Option, + + /// An object specifying the format that the model must output + pub response_format: Option, + + /// If specified, our system will make a best effort to sample + /// deterministically + pub seed: Option, + + /// Up to 4 sequences where the API will stop generating further tokens + #[validate(custom(function = "validate_stop"))] + pub stop: Option, + + /// If set, partial message deltas will be sent + #[serde(default)] + pub stream: bool, + + /// Options for streaming response + pub stream_options: Option, + + /// What sampling temperature to use, between 0 and 2 + #[validate(range(min = 0.0, max = 2.0))] + pub temperature: Option, + + /// An alternative to sampling with temperature + #[validate(custom(function = "validate_top_p_value"))] + pub top_p: Option, + + /// A list of tools the model may call + pub tools: Option>, + + /// Controls which (if any) tool is called by the model + pub tool_choice: Option, + + /// Effort level for reasoning models (none, minimal, low, medium, high, + /// xhigh, max) + pub reasoning_effort: Option, + + /// Whether to enable parallel function calling during tool use + pub parallel_tool_calls: Option, + + /// A unique identifier representing your end-user + pub user: Option, + + // -------- vLLM Sampling Parameters -------- + /// Use beam search instead of sampling + #[serde(default)] + pub use_beam_search: bool, + + /// Top-k sampling parameter + pub top_k: Option, + + /// Min-p nucleus sampling parameter + #[validate(range(min = 0.0, max = 1.0))] + pub min_p: Option, + + /// Repetition penalty for reducing repetitive text + #[validate(range(min = 0.0, max = 2.0))] + pub repetition_penalty: Option, + + /// Length penalty for beam search + pub length_penalty: Option, + + /// Specific token IDs to use as stop conditions + pub stop_token_ids: Option>, + + /// Include stop string in output + #[serde(default)] + pub include_stop_str_in_output: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Minimum number of tokens to generate + #[validate(range(min = 1))] + pub min_tokens: Option, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + /// Add spaces between special tokens during detokenization + #[serde(default = "default_true")] + pub spaces_between_special_tokens: bool, + + /// Truncate prompt tokens to this length + pub truncate_prompt_tokens: Option, + + /// Number of prompt logprobs to return + pub prompt_logprobs: Option, + + /// Restrict output to these token IDs only + pub allowed_token_ids: Option>, + + /// List of bad words to avoid during generation + pub bad_words: Option>, + + // -------- Extra vLLM Parameters -------- + /// Token budget for reasoning/thinking + pub thinking_token_budget: Option, + + /// Whether to include reasoning content in the response + #[serde(default = "default_true")] + pub include_reasoning: bool, + + /// If true, the new message will be prepended with the last message if they + /// belong to the same role. + #[serde(default)] + pub echo: bool, + + /// Whether to add the generation prompt to the chat template. + /// + /// When omitted, the request follows the API default behavior, which is + /// equivalent to `true` unless `continue_final_message=true` selects + /// final assistant continuation instead. + pub add_generation_prompt: Option, + + /// Continue generating from final assistant message + #[serde(default)] + pub continue_final_message: bool, + + /// Whether to add special tokens (e.g. BOS) to the prompt + #[serde(default)] + pub add_special_tokens: bool, + + /// Documents for RAG (retrieval-augmented generation) + pub documents: Option>, + + /// Jinja chat template override + pub chat_template: Option, + + /// Additional keyword args passed to the chat template renderer + pub chat_template_kwargs: Option>, + + /// Additional kwargs for media IO connectors, keyed by modality + pub media_io_kwargs: Option>, + + /// Additional kwargs for the HF processor + pub mm_processor_kwargs: Option>, + + /// Additional kwargs for structured outputs + pub structured_outputs: Option, + + /// Request scheduling priority (lower means earlier; default 0) + pub priority: Option, + + /// External request ID used for response correlation. + pub request_id: Option, + + /// Tokens represented as strings of the form 'token_id:{token_id}' in + /// logprobs + pub return_tokens_as_token_ids: Option, + + /// Include token IDs alongside generated text + pub return_token_ids: Option, + + /// Salt for prefix cache isolation in multi-user environments + pub cache_salt: Option, + + /// KV transfer parameters for disaggregated serving + pub kv_transfer_params: Option>, + + /// Additional request parameters with string or numeric values for custom + /// extensions + pub vllm_xargs: Option>, + + /// Parameters for detecting repetitive N-gram patterns in output tokens + pub repetition_detection: Option, +} + +impl Default for ChatCompletionRequest { + #[expect(deprecated)] + fn default() -> Self { + Self { + messages: Vec::new(), + model: default_model(), + frequency_penalty: None, + logit_bias: None, + logprobs: false, + top_logprobs: None, + max_tokens: None, + max_completion_tokens: None, + n: None, + presence_penalty: None, + response_format: None, + seed: None, + stop: None, + stream: false, + stream_options: None, + temperature: None, + top_p: None, + tools: None, + tool_choice: None, + reasoning_effort: None, + thinking_token_budget: None, + include_reasoning: true, + parallel_tool_calls: None, + user: None, + use_beam_search: false, + top_k: None, + min_p: None, + repetition_penalty: None, + length_penalty: None, + stop_token_ids: None, + include_stop_str_in_output: false, + ignore_eos: false, + min_tokens: None, + skip_special_tokens: true, + spaces_between_special_tokens: true, + truncate_prompt_tokens: None, + prompt_logprobs: None, + allowed_token_ids: None, + bad_words: None, + echo: false, + add_generation_prompt: None, + continue_final_message: false, + add_special_tokens: false, + documents: None, + chat_template: None, + chat_template_kwargs: None, + media_io_kwargs: None, + mm_processor_kwargs: None, + structured_outputs: None, + priority: None, + request_id: None, + return_tokens_as_token_ids: None, + return_token_ids: None, + cache_salt: None, + kv_transfer_params: None, + vllm_xargs: None, + repetition_detection: None, + } + } +} + +impl Normalizable for ChatCompletionRequest { + /// Normalize the request by applying migrations and defaults. + fn normalize(&mut self) { + // Migrate deprecated max_tokens → max_completion_tokens + #[expect(deprecated)] + if self.max_completion_tokens.is_none() && self.max_tokens.is_some() { + self.max_completion_tokens = self.max_tokens; + self.max_tokens = None; + } + + // Apply tool_choice defaults + // If tools is None, leave tool_choice as None (don't set it) + if self.tool_choice.is_none() + && let Some(tools) = &self.tools + { + let choice_value = if tools.is_empty() { + ToolChoiceValue::None + } else { + ToolChoiceValue::Auto + }; + self.tool_choice = Some(ToolChoice::Value(choice_value)); + } + } +} + +/// Mirrors the Python vLLM `ChatCompletionResponse` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct ChatCompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Option, + pub system_fingerprint: Option, + pub prompt_logprobs: Option>>>, + pub prompt_token_ids: Option>, + pub kv_transfer_params: Option, +} + +/// Mirrors the Python vLLM `ChatCompletionResponseChoice` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct ChatCompletionChoice { + pub index: u32, + pub message: ChatCompletionMessage, + pub logprobs: Option, + pub finish_reason: Option, + pub stop_reason: Option, + pub token_ids: Option>, +} + +/// A literal type for the "assistant" role, since the API only allows that +/// specific value in responses. +#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeDisplay)] +pub(super) struct AssistantRole; + +impl fmt::Display for AssistantRole { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("assistant") + } +} + +/// Mirrors the Python vLLM response `ChatMessage` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct ChatCompletionMessage { + pub role: AssistantRole, + pub content: Option, + pub tool_calls: Option>, + pub reasoning: Option, +} + +/// Mirrors the Python vLLM `ChatCompletionStreamResponse` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct ChatCompletionStreamResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Option, + pub prompt_token_ids: Option>, +} + +impl ChatCompletionStreamResponse { + /// Create a stream response with the standard envelope fields pre-filled. + pub fn new(id: &str, model: &str, created: u64) -> Self { + Self { + id: id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + choices: Vec::new(), + usage: None, + prompt_token_ids: None, + } + } +} + +/// Mirrors the Python vLLM `ChatCompletionResponseStreamChoice` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Default, Serialize)] +pub(super) struct ChatCompletionStreamChoice { + pub index: u32, + pub delta: ChatMessageDelta, + pub logprobs: Option, + pub finish_reason: Option, + pub stop_reason: Option, + pub token_ids: Option>, +} + +/// Mirrors the Python vLLM `DeltaMessage` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Default, Serialize)] +pub(super) struct ChatMessageDelta { + pub role: Option, + pub content: Option, + pub tool_calls: Option>, + pub reasoning: Option, +} + +fn default_model() -> String { + UNKNOWN_MODEL_ID.to_string() +} + +/// Validates messages array is not empty and has valid content +fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> { + if messages.is_empty() { + return Err(validator::ValidationError::new("messages cannot be empty")); + } + + for msg in messages { + if let ChatMessage::User { content, .. } = msg { + match content { + MessageContent::Text(text) if text.is_empty() => { + return Err(validator::ValidationError::new( + "message content cannot be empty", + )); + } + MessageContent::Parts(parts) if parts.is_empty() => { + return Err(validator::ValidationError::new( + "message content parts cannot be empty", + )); + } + _ => {} + } + } + } + Ok(()) +} + +/// Schema-level validation for cross-field dependencies +fn validate_chat_cross_parameters( + req: &ChatCompletionRequest, +) -> Result<(), validator::ValidationError> { + // 1. Validate logprobs dependency + if req.top_logprobs.is_some() && !req.logprobs { + let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs"); + e.message = Some("top_logprobs is only allowed when logprobs is enabled".into()); + return Err(e); + } + + // 2. Validate stream_options dependency + if req.stream_options.is_some() && !req.stream { + let mut e = validator::ValidationError::new("stream_options_requires_stream"); + e.message = Some("stream_options can only be used when stream is true".into()); + return Err(e); + } + + // 3. Validate token limits - min <= max + if let (Some(min_tokens), Some(max_completion_tokens)) = + (req.min_tokens, req.max_completion_tokens) + && min_tokens > max_completion_tokens + { + let mut e = validator::ValidationError::new("min_tokens_exceeds_max_completion_tokens"); + e.message = Some("min_tokens cannot be greater than max_completion_tokens".into()); + return Err(e); + } + + #[expect(deprecated, reason = "Local type still mirrors legacy upstream field")] + if let (Some(min_tokens), Some(max_tokens)) = (req.min_tokens, req.max_tokens) + && min_tokens > max_tokens + { + let mut e = validator::ValidationError::new("min_tokens_exceeds_max_tokens"); + e.message = Some("min_tokens cannot be greater than max_tokens".into()); + return Err(e); + } + + // 4. Validate response format JSON schema name + if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format + && json_schema.name.is_empty() + { + let mut e = validator::ValidationError::new("json_schema_name_empty"); + e.message = Some("JSON schema name cannot be empty".into()); + return Err(e); + } + + // 5. Validate tool_choice requires tools (except for "none") + if let Some(ref tool_choice) = req.tool_choice { + let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty()); + + // Check if tool_choice is anything other than "none" + let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None)); + + if is_some_choice && !has_tools { + let mut e = validator::ValidationError::new("tool_choice_requires_tools"); + e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into()); + return Err(e); + } + + // Additional validation when tools are present + if let Some(tools) = req.tools.as_ref().filter(|t| !t.is_empty()) { + match tool_choice { + ToolChoice::Function { function, .. } => { + // Validate that the specified function name exists in tools + let function_exists = tools.iter().any(|tool| { + tool.tool_type == "function" && tool.function.name == function.name + }); + + if !function_exists { + let mut e = + validator::ValidationError::new("tool_choice_function_not_found"); + e.message = Some( + format!( + "Invalid value for 'tool_choice': function '{}' not found in 'tools'.", + function.name + ) + .into(), + ); + return Err(e); + } + } + ToolChoice::AllowedTools { + mode, + tools: allowed_tools, + .. + } => { + // Validate mode is "auto" or "required" + if mode != "auto" && mode != "required" { + let mut e = validator::ValidationError::new("tool_choice_invalid_mode"); + e.message = Some(format!( + "Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{mode}'." + ).into()); + return Err(e); + } + + // Validate that all ToolReferences are Function type (Chat API only supports + // function tools) + for tool_ref in allowed_tools { + match tool_ref { + ToolReference::Function { name } => { + // Validate that the function exists in tools array + let tool_exists = tools.iter().any(|tool| { + tool.tool_type == "function" && tool.function.name == *name + }); + + if !tool_exists { + let mut e = validator::ValidationError::new( + "tool_choice_tool_not_found", + ); + e.message = Some( + format!( + "Invalid value for 'tool_choice.tools': tool '{name}' not found in 'tools'." + ) + .into(), + ); + return Err(e); + } + } + _ => { + // Chat Completion API only supports function tools in tool_choice + let mut e = validator::ValidationError::new( + "tool_choice_invalid_tool_type", + ); + e.message = Some( + format!( + "Invalid value for 'tool_choice.tools': Chat Completion API only supports function tools, got '{}'.", + tool_ref.identifier() + ) + .into(), + ); + return Err(e); + } + } + } + } + ToolChoice::Value(_) => {} + } + } + } + + Ok(()) +} diff --git a/rust/src/server/src/routes/openai/chat_completions/validate.rs b/rust/src/server/src/routes/openai/chat_completions/validate.rs new file mode 100644 index 00000000000..fbd10eea0cb --- /dev/null +++ b/rust/src/server/src/routes/openai/chat_completions/validate.rs @@ -0,0 +1,410 @@ +use super::types::ChatCompletionRequest; +use crate::error::{ApiError, bail_invalid_request}; +use crate::routes::openai::utils::types::{ChatMessage, Tool, ToolChoice, ToolChoiceValue}; + +/// Enforce the minimal compatibility contract for the Rust OpenAI server. +pub(super) fn validate_request_compat( + request: &ChatCompletionRequest, + served_model_names: &[String], +) -> Result<(), ApiError> { + if !served_model_names.iter().any(|n| n == &request.model) { + return Err(ApiError::model_not_found(request.model.clone())); + } + + if request.stream_options.is_some() && !request.stream { + bail_invalid_request!( + param = "stream_options", + "stream_options are only supported when stream=true." + ); + } + + if request.n.unwrap_or(1) > 1 { + bail_invalid_request!(param = "n", "Only n=1 is supported."); + } + + if request.top_logprobs.is_some() && !request.logprobs { + bail_invalid_request!( + param = "top_logprobs", + "top_logprobs can only be used when logprobs=true." + ); + } + + if let Some(prompt_logprobs) = request.prompt_logprobs { + if prompt_logprobs < 0 && prompt_logprobs != -1 { + bail_invalid_request!( + param = "prompt_logprobs", + "prompt_logprobs must be a non-negative value or -1." + ); + } + + if request.stream && (prompt_logprobs > 0 || prompt_logprobs == -1) { + bail_invalid_request!( + param = "prompt_logprobs", + "prompt_logprobs are not available when stream=true." + ); + } + } + + if let Some(tools) = request.tools.as_ref() { + validate_function_tools(tools, "tools")?; + } + + for message in &request.messages { + if let ChatMessage::Developer { + tools: Some(tools), .. + } = message + { + validate_function_tools(tools, "messages[].tools")?; + } + } + + if let Some(tool_choice) = &request.tool_choice { + match tool_choice { + ToolChoice::Value(ToolChoiceValue::Auto | ToolChoiceValue::None) => {} + ToolChoice::Value(ToolChoiceValue::Required) => { + bail_invalid_request!( + param = "tool_choice", + "tool_choice=required is not supported yet." + ); + } + ToolChoice::Function { .. } => { + bail_invalid_request!( + param = "tool_choice", + "Named function tool_choice is not supported yet." + ); + } + ToolChoice::AllowedTools { .. } => { + bail_invalid_request!( + param = "tool_choice", + "allowed_tools tool_choice is not supported yet." + ); + } + } + } + + if request.use_beam_search { + bail_invalid_request!( + param = "use_beam_search", + "use_beam_search is not supported." + ); + } + + // ---- Reject parameters that are accepted for deserialization but not yet + // implemented ---- + + if request.parallel_tool_calls.is_some() { + bail_invalid_request!( + param = "parallel_tool_calls", + "parallel_tool_calls is not supported." + ); + } + + reject_non_default( + request.length_penalty.as_ref(), + "length_penalty", + "length_penalty is not supported.", + )?; + if !request.spaces_between_special_tokens { + bail_invalid_request!( + param = "spaces_between_special_tokens", + "spaces_between_special_tokens is not supported." + ); + } + reject_non_default( + request.truncate_prompt_tokens.as_ref(), + "truncate_prompt_tokens", + "truncate_prompt_tokens is not supported.", + )?; + reject_non_default( + request.thinking_token_budget.as_ref(), + "thinking_token_budget", + "thinking_token_budget is not supported.", + )?; + if !request.include_reasoning { + bail_invalid_request!( + param = "include_reasoning", + "include_reasoning is not supported." + ); + } + reject_non_default( + request.media_io_kwargs.as_ref(), + "media_io_kwargs", + "media_io_kwargs is not supported.", + )?; + reject_non_default( + request.mm_processor_kwargs.as_ref(), + "mm_processor_kwargs", + "mm_processor_kwargs is not supported.", + )?; + reject_non_default( + request.repetition_detection.as_ref(), + "repetition_detection", + "repetition_detection is not supported.", + )?; + + if let Some(options) = &request.stream_options + && options.continuous_usage_stats.is_some() + { + bail_invalid_request!( + param = "stream_options", + "continuous_usage_stats is not supported." + ); + } + + Ok(()) +} + +/// Reject one option unless it is entirely absent. +fn reject_non_default( + value: Option<&T>, + param: &'static str, + message: &str, +) -> Result<(), ApiError> { + if value.is_some() { + bail_invalid_request!(param = param, "{}", message); + } + Ok(()) +} + +fn validate_function_tools(tools: &[Tool], param: &'static str) -> Result<(), ApiError> { + for tool in tools { + if tool.tool_type != "function" { + bail_invalid_request!(param = param, "Only function tools are supported."); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use serde_json::json; + use vllm_chat::ReasoningEffort; + + use super::validate_request_compat; + use crate::routes::openai::chat_completions::types::ChatCompletionRequest; + use crate::routes::openai::utils::structured_outputs::ResponseFormat; + use crate::routes::openai::utils::types::{ + ChatMessage, Function, FunctionChoice, MessageContent, StringOrArray, Tool, ToolChoice, + ToolChoiceValue, ToolReference, + }; + + fn served(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + + fn base_request() -> ChatCompletionRequest { + ChatCompletionRequest { + model: "Qwen/Qwen1.5-0.5B-Chat".to_string(), + messages: vec![ChatMessage::User { + content: MessageContent::Text("hello".to_string()), + name: None, + }], + stream: true, + ..Default::default() + } + } + + #[test] + fn validate_request_compat_accepts_stop() { + let request = ChatCompletionRequest { + stop: Some(StringOrArray::String("stop".to_string())), + ..base_request() + }; + + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("stop strings should be accepted"); + } + + #[test] + fn validate_request_compat_accepts_non_zero_penalties_and_function_tools() { + let request = ChatCompletionRequest { + frequency_penalty: Some(0.5), + presence_penalty: Some(0.25), + min_p: Some(0.2), + repetition_penalty: Some(1.1), + seed: Some(7), + ..base_request() + }; + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("sampling fields should be accepted"); + + let request = ChatCompletionRequest { + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "tool".to_string(), + description: None, + parameters: json!({}), + strict: None, + }, + }]), + ..base_request() + }; + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("function tools should be accepted"); + + let request = ChatCompletionRequest { + messages: vec![ChatMessage::Developer { + content: MessageContent::Text("policy".to_string()), + tools: Some(vec![Tool { + tool_type: "function".to_string(), + function: Function { + name: "tool".to_string(), + description: None, + parameters: json!({}), + strict: None, + }, + }]), + name: None, + }], + ..base_request() + }; + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("developer function tools should be accepted"); + } + + #[test] + fn validate_request_compat_rejects_non_function_developer_tools() { + let request = ChatCompletionRequest { + messages: vec![ChatMessage::Developer { + content: MessageContent::Text("policy".to_string()), + tools: Some(vec![Tool { + tool_type: "mcp".to_string(), + function: Function { + name: "tool".to_string(), + description: None, + parameters: json!({}), + strict: None, + }, + }]), + name: None, + }], + ..base_request() + }; + + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); + } + + #[test] + fn validate_request_compat_accepts_output_logprobs() { + let request = ChatCompletionRequest { + logprobs: true, + ..base_request() + }; + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("logprobs should be accepted"); + } + + #[test] + fn validate_request_compat_accepts_reasoning_effort() { + let request = ChatCompletionRequest { + reasoning_effort: Some(ReasoningEffort::Max), + chat_template_kwargs: Some(HashMap::from([( + "reasoning_effort".to_string(), + json!("low"), + )])), + ..base_request() + }; + + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("reasoning_effort should be accepted"); + } + + #[test] + fn validate_request_compat_rejects_top_logprobs_without_logprobs() { + let request = ChatCompletionRequest { + top_logprobs: Some(0), + ..base_request() + }; + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); + } + + #[test] + fn validate_request_compat_rejects_streaming_prompt_logprobs_requests() { + let request = ChatCompletionRequest { + prompt_logprobs: Some(1), + ..base_request() + }; + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); + + let request = ChatCompletionRequest { + prompt_logprobs: Some(-1), + ..base_request() + }; + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); + } + + #[test] + fn validate_request_compat_rejects_invalid_prompt_logprobs_value() { + let request = ChatCompletionRequest { + stream: false, + prompt_logprobs: Some(-2), + ..base_request() + }; + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); + } + + #[test] + fn validate_request_compat_accepts_response_format() { + let request = ChatCompletionRequest { + response_format: Some(ResponseFormat::Text), + ..base_request() + }; + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("response_format=text should be accepted"); + + let request = ChatCompletionRequest { + response_format: Some(ResponseFormat::JsonObject), + ..base_request() + }; + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("response_format=json_object should be accepted"); + } + + #[test] + fn validate_request_compat_accepts_noop_tool_choice_none() { + let request = ChatCompletionRequest { + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::None)), + ..base_request() + }; + + validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])) + .expect("tool_choice=none is ok"); + } + + #[test] + fn validate_request_compat_rejects_required_and_named_tool_choices() { + let required = ChatCompletionRequest { + tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Required)), + ..base_request() + }; + assert!(validate_request_compat(&required, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); + + let named = ChatCompletionRequest { + tool_choice: Some(ToolChoice::Function { + tool_type: "function".to_string(), + function: FunctionChoice { + name: "tool".to_string(), + }, + }), + ..base_request() + }; + assert!(validate_request_compat(&named, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); + + let allowed_tools = ChatCompletionRequest { + tool_choice: Some(ToolChoice::AllowedTools { + tool_type: "allowed_tools".to_string(), + mode: "auto".to_string(), + tools: vec![ToolReference::Function { + name: "tool".to_string(), + }], + }), + ..base_request() + }; + assert!( + validate_request_compat(&allowed_tools, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err() + ); + } +} diff --git a/rust/src/server/src/routes/openai/completions.rs b/rust/src/server/src/routes/openai/completions.rs new file mode 100644 index 00000000000..33813e67687 --- /dev/null +++ b/rust/src/server/src/routes/openai/completions.rs @@ -0,0 +1,571 @@ +mod convert; +mod types; +mod validate; + +use std::convert::Infallible; +use std::result::Result; +use std::sync::Arc; + +use asynk_strim_attr::{TryYielder, try_stream}; +use axum::Json; +use axum::extract::State; +use axum::http::HeaderMap; +use axum::response::sse::{Event, Sse}; +use axum::response::{IntoResponse, Response}; +use futures::{Stream, StreamExt as _, pin_mut}; +use thiserror_ext::AsReport as _; +use tracing::{debug, error, info, trace}; +use tracing_futures::Instrument as _; +use vllm_text::{DecodedTextEvent, FinishReason, TextOutputStream, TextOutputStreamExt as _}; + +use super::utils::logprobs::{ + collected_logprobs_to_openai, decoded_logprobs_to_openai, decoded_prompt_logprobs_to_maps, + text_len, +}; +use super::utils::types::Usage; +use crate::error::{ApiError, bail_server_error, server_error}; +use crate::routes::openai::completions::convert::prepare_completion_request; +use crate::routes::openai::completions::types::{ + CompletionChoice, CompletionRequest, CompletionResponse, CompletionSseChunk, + CompletionStreamChoice, CompletionStreamResponse, +}; +use crate::routes::openai::utils::types::LogProbs; +use crate::routes::openai::utils::validated_json::ValidatedJson; +use crate::state::AppState; +use crate::utils::{resolve_request_context, unix_timestamp}; + +/// Validate one completions request and proxy it into the shared `vllm-text` +/// stack. +pub async fn completions( + State(state): State>, + headers: HeaderMap, + ValidatedJson(body): ValidatedJson, +) -> Response { + let stream = body.stream; + let logprobs = body.logprobs; + let request_context = resolve_request_context(&headers, body.request_id.as_deref()); + + let prepared = + match prepare_completion_request(body, state.served_model_names(), request_context) { + Ok(prepared) => prepared, + Err(error) => return error.into_response(), + }; + let request_span = tracing::info_span!( + "completions", + request_id = %prepared.request_id, + engine_request_id = tracing::field::Empty, + ); + + let created = unix_timestamp(); + let include_prompt_logprobs = prepared.text_request.sampling_params.prompt_logprobs.is_some(); + let log_request = state.enable_log_requests; + + let text_stream = match state + .chat + .text() + .generate(prepared.text_request) + .instrument(request_span.clone()) + .await + { + Ok(stream) => stream, + Err(error) => { + return server_error!( + "failed to submit completion request: {}", + error.to_report_string() + ) + .into_response(); + } + }; + + if stream { + let chunk_stream = completion_chunk_stream( + text_stream, + prepared.request_id, + prepared.response_model, + created, + log_request, + prepared.include_usage, + prepared.echo, + logprobs, + prepared.return_token_ids, + prepared.return_tokens_as_token_ids, + ); + let sse_stream = completion_sse_stream(chunk_stream).instrument(request_span); + + Sse::new(sse_stream).into_response() + } else { + let response = match collect_completion( + text_stream, + prepared.request_id, + prepared.response_model, + created, + prepared.echo, + logprobs, + include_prompt_logprobs, + prepared.return_token_ids, + prepared.return_tokens_as_token_ids, + ) + .instrument(request_span.clone()) + .await + { + Ok(response) => response, + Err(error) => return error.into_response(), + }; + + if log_request { + let usage = response.usage.as_ref(); + info!( + parent: &request_span, + model = %response.model, + prompt_tokens = usage.map_or(0, |u| u.prompt_tokens), + output_tokens = usage.and_then(|u| u.completion_tokens).unwrap_or(0), + finish_reason = response.choices.first().and_then(|c| c.finish_reason.as_deref()).unwrap_or("unknown"), + "completion finished" + ); + } + + Json(response).into_response() + } +} + +async fn collect_completion( + stream: impl TextOutputStream, + request_id: String, + response_model: String, + created: u64, + echo: Option, + requested_logprobs: Option, + include_prompt_logprobs: bool, + return_token_ids: bool, + return_tokens_as_token_ids: bool, +) -> Result { + let collected = stream + .collect_output() + .await + .map_err(|error| server_error!("completion stream failed: {}", error.to_report_string()))?; + let finish_reason = collected.finish_reason.clone(); + let stop_reason = finish_reason + .as_stop_reason() + .map(|sr| serde_json::to_value(sr).expect("StopReason must serialize to JSON")); + + let prompt_char_count = echo.as_ref().map(|prompt| text_len(prompt)).unwrap_or_default(); + let prompt_logprobs = if include_prompt_logprobs { + let prompt_logprobs = collected.prompt_logprobs.as_ref().ok_or_else(|| { + server_error!( + "completion response requested prompt_logprobs but generation returned none" + ) + })?; + Some(prompt_logprobs) + } else { + None + }; + let logprobs = if requested_logprobs.is_some() { + Some(collected_logprobs_to_openai( + &collected, + echo.is_some(), + prompt_char_count, + return_tokens_as_token_ids, + )?) + } else { + None + }; + let prompt_logprobs = + prompt_logprobs.map(|lp| decoded_prompt_logprobs_to_maps(lp, return_tokens_as_token_ids)); + let text = match &echo { + None => collected.text, + Some(prompt) => format!("{prompt}{}", collected.text), + }; + + Ok(CompletionResponse { + id: request_id, + object: "text_completion".to_string(), + created, + model: response_model, + choices: vec![CompletionChoice { + index: 0, + text, + logprobs, + finish_reason: Some(completion_finish_reason_to_openai(finish_reason)?.into()), + stop_reason, + prompt_logprobs, + token_ids: return_token_ids.then(|| collected.token_ids.clone()), + prompt_token_ids: return_token_ids.then(|| collected.prompt_token_ids.to_vec()), + }], + usage: Some(Usage::from_counts( + collected.prompt_token_ids.len() as u32, + collected.token_ids.len() as u32, + )), + system_fingerprint: None, + kv_transfer_params: collected.kv_transfer_params, + }) +} + +/// Convert one internal decoded-text stream into OpenAI completions chunks. +#[try_stream] +async fn completion_chunk_stream( + stream: impl TextOutputStream, + request_id: String, + response_model: String, + created: u64, + log_request: bool, + include_usage: bool, + echo: Option, + requested_logprobs: Option, + return_token_ids: bool, + return_tokens_as_token_ids: bool, + mut y: TryYielder, +) -> Result<(), ApiError> { + pin_mut!(stream); + let mut visible_text_len = 0_u32; + let mut first_chunk = true; + + while let Some(next) = stream.next().await { + match next { + Ok(DecodedTextEvent::Start { + prompt_token_ids, .. + }) => { + debug!("completion stream started"); + if let Some(prompt) = echo.as_ref() { + visible_text_len = text_len(prompt); + let mut chunk = + delta_chunk(&request_id, &response_model, created, prompt.clone(), None); + if return_token_ids && first_chunk { + if let Some(choice) = chunk.choices.first_mut() { + choice.prompt_token_ids = Some(prompt_token_ids.to_vec()); + } + first_chunk = false; + } + y.yield_ok(CompletionSseChunk::Chunk(chunk)).await; + } else if return_token_ids { + // Emit a chunk with prompt_token_ids in the first streaming response + let mut chunk = + delta_chunk(&request_id, &response_model, created, String::new(), None); + if let Some(choice) = chunk.choices.first_mut() { + choice.prompt_token_ids = Some(prompt_token_ids.to_vec()); + } + first_chunk = false; + y.yield_ok(CompletionSseChunk::Chunk(chunk)).await; + } + } + Ok(DecodedTextEvent::TextDelta { + delta, + token_ids, + logprobs, + finished, + }) => { + let delta_text_len = text_len(&delta); + let logprobs = if requested_logprobs.is_some() { + let decoded_logprobs = logprobs.as_ref().ok_or_else(|| { + server_error!( + "completion stream requested logprobs but generation returned none" + ) + })?; + Some(decoded_logprobs_to_openai( + decoded_logprobs, + visible_text_len, + return_tokens_as_token_ids, + )?) + } else { + None + }; + let mut chunk = delta_chunk(&request_id, &response_model, created, delta, logprobs); + if return_token_ids && let Some(choice) = chunk.choices.first_mut() { + choice.token_ids = Some(token_ids); + } + y.yield_ok(CompletionSseChunk::Chunk(chunk)).await; + visible_text_len = visible_text_len.saturating_add(delta_text_len); + + if let Some(finished) = finished { + if log_request { + info!( + stream = true, + model = %response_model, + prompt_tokens = finished.prompt_token_count, + output_tokens = finished.output_token_count, + finish_reason = finished.finish_reason.as_str(), + "completion finished" + ); + } + y.yield_ok(CompletionSseChunk::Chunk(final_chunk( + &request_id, + &response_model, + created, + finished.finish_reason, + )?)) + .await; + + if include_usage { + y.yield_ok(CompletionSseChunk::Usage(usage_chunk( + &request_id, + &response_model, + created, + Usage::from_counts( + finished.prompt_token_count as u32, + finished.output_token_count as u32, + ), + ))) + .await; + } + } + } + Err(error) => { + error!( + error = %error.as_report(), + "completion stream failed" + ); + bail_server_error!("{}", error.to_report_string()); + } + } + } + Ok(()) +} + +fn delta_chunk( + request_id: &str, + response_model: &str, + created: u64, + text: String, + logprobs: Option, +) -> CompletionStreamResponse { + let mut chunk = CompletionStreamResponse::new(request_id, response_model, created); + chunk.choices.push(CompletionStreamChoice { + text, + logprobs, + ..Default::default() + }); + chunk +} + +fn final_chunk( + request_id: &str, + response_model: &str, + created: u64, + finish_reason: FinishReason, +) -> Result { + let finish_reason = completion_finish_reason_to_openai(finish_reason)?; + + let mut chunk = CompletionStreamResponse::new(request_id, response_model, created); + chunk.choices.push(CompletionStreamChoice { + finish_reason: Some(finish_reason.to_string()), + ..Default::default() + }); + Ok(chunk) +} + +fn completion_finish_reason_to_openai( + finish_reason: FinishReason, +) -> Result<&'static str, ApiError> { + match finish_reason { + FinishReason::Stop(_) | FinishReason::Repetition => Ok("stop"), + FinishReason::Length => Ok("length"), + FinishReason::Abort => Ok("abort"), + FinishReason::Error => { + bail_server_error!("Internal server error"); + } + } +} + +fn usage_chunk( + request_id: &str, + response_model: &str, + created: u64, + usage: Usage, +) -> CompletionStreamResponse { + let mut chunk = CompletionStreamResponse::new(request_id, response_model, created); + chunk.usage = Some(usage); + chunk +} + +/// Convert one chunk stream into OpenAI-style SSE events. +/// +/// OpenAI-style streaming errors are encoded as ordinary `data: {"error": ...}` +/// events followed by `data: [DONE]`, so the transport stream itself stays +/// infallible even when generation fails after the HTTP response has started. +#[try_stream] +async fn completion_sse_stream( + stream: impl Stream>, + mut y: TryYielder, +) -> Result<(), Infallible> { + pin_mut!(stream); + + while let Some(next) = stream.next().await { + match next { + Ok(chunk) => y.yield_ok(to_sse_event(&chunk)).await, + Err(error) => { + y.yield_ok(to_error_sse_event(&error)).await; + break; + } + } + } + + y.yield_ok(done_sse_event()).await; + Ok(()) +} + +/// Serialize one OpenAI chunk payload into one SSE `data:` event. +fn to_sse_event(chunk: &CompletionSseChunk) -> Event { + let payload = serde_json::to_string(chunk).expect("completion chunk must serialize to JSON"); + trace!(payload, "completion emitting chunk"); + Event::default().data(payload) +} + +/// Serialize one OpenAI error payload into one SSE `data:` event. +fn to_error_sse_event(error: &ApiError) -> Event { + let payload = serde_json::to_string(&error.to_error_response()) + .expect("ErrorResponse must serialize to JSON"); + trace!(payload, "completion emitting error"); + Event::default().data(payload) +} + +/// Build the terminal OpenAI SSE sentinel event. +fn done_sse_event() -> Event { + trace!("completion emitting done"); + Event::default().data("[DONE]") +} + +#[cfg(test)] +mod tests { + use futures::{StreamExt as _, stream}; + use itertools::Itertools as _; + use vllm_text::{ + DecodedLogprobs, DecodedPositionLogprobs, DecodedTextEvent, DecodedTokenLogprob, + FinishReason, Finished, + }; + + use super::{CompletionSseChunk, completion_chunk_stream, final_chunk}; + + #[test] + fn final_chunk_maps_stop_finish_reason() { + let chunk = final_chunk("cmpl-1", "model", 1, FinishReason::stop_eos()) + .expect("finish reason valid"); + assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("stop")); + assert_eq!(chunk.choices[0].text, ""); + } + + #[test] + fn final_chunk_maps_length_finish_reason() { + let chunk = + final_chunk("cmpl-1", "model", 1, FinishReason::Length).expect("finish reason valid"); + assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("length")); + } + + #[test] + fn final_chunk_maps_abort_finish_reason() { + let chunk = + final_chunk("cmpl-1", "model", 1, FinishReason::Abort).expect("finish reason valid"); + assert_eq!(chunk.choices[0].finish_reason.as_deref(), Some("abort")); + } + + #[test] + fn final_chunk_rejects_error_finish_reason() { + assert!(final_chunk("cmpl-1", "model", 1, FinishReason::Error).is_err()); + } + + #[tokio::test] + async fn completion_chunk_stream_maps_streaming_logprobs() { + let stream = stream::iter(vec![ + Ok(DecodedTextEvent::Start { + prompt_token_ids: vec![1, 2, 3, 4, 5].into(), + prompt_logprobs: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "h".to_string(), + token_ids: vec![b'h' as u32], + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![ + DecodedTokenLogprob { + token_id: 0, + token: "h".to_string(), + logprob: -0.1, + rank: 1, + }, + DecodedTokenLogprob { + token_id: 0, + token: "H".to_string(), + logprob: -0.2, + rank: 1, + }, + ], + }], + }), + finished: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: String::new(), + token_ids: vec![b'!' as u32], + logprobs: Some(DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![ + DecodedTokenLogprob { + token_id: 0, + token: "!".to_string(), + logprob: -0.3, + rank: 1, + }, + DecodedTokenLogprob { + token_id: 0, + token: "?".to_string(), + logprob: -0.4, + rank: 1, + }, + ], + }], + }), + finished: Some(Finished { + prompt_token_count: 5, + output_token_count: 2, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + }), + ]); + + let chunks = completion_chunk_stream( + stream, + "cmpl-1".to_string(), + "model".to_string(), + 1, + false, + false, + None, + Some(1), + false, + false, + ) + .collect::>() + .await; + + let chunks: Vec<_> = chunks.into_iter().try_collect().expect("stream should succeed"); + + match &chunks[0] { + CompletionSseChunk::Chunk(chunk) => { + assert_eq!(chunk.choices[0].text, "h"); + assert_eq!( + chunk.choices[0].logprobs.as_ref().expect("logprobs").tokens, + vec!["h".to_string()] + ); + assert_eq!( + chunk.choices[0].logprobs.as_ref().expect("logprobs").text_offset, + vec![0] + ); + } + CompletionSseChunk::Usage(_) => panic!("expected regular chunk"), + } + + match &chunks[1] { + CompletionSseChunk::Chunk(chunk) => { + assert_eq!(chunk.choices[0].text, ""); + assert_eq!( + chunk.choices[0].logprobs.as_ref().expect("logprobs").tokens, + vec!["!".to_string()] + ); + assert_eq!( + chunk.choices[0].logprobs.as_ref().expect("logprobs").text_offset, + vec![1] + ); + } + CompletionSseChunk::Usage(_) => panic!("expected regular chunk"), + } + } +} diff --git a/rust/src/server/src/routes/openai/completions/convert.rs b/rust/src/server/src/routes/openai/completions/convert.rs new file mode 100644 index 00000000000..066c4c046f4 --- /dev/null +++ b/rust/src/server/src/routes/openai/completions/convert.rs @@ -0,0 +1,352 @@ +use vllm_text::{SamplingParams, TextDecodeOptions, TextRequest}; + +use super::types::CompletionRequest; +use crate::error::ApiError; +use crate::routes::openai::completions::validate; +use crate::routes::openai::utils::structured_outputs::convert_from_response_format_value; +use crate::utils::{ResolvedRequestContext, convert_logit_bias, merge_kv_transfer_params}; + +/// Lowered completion request plus the public response metadata carried by +/// every SSE chunk. +#[derive(Debug, Clone, PartialEq)] +pub struct PreparedRequest { + /// Stable OpenAI-style request ID, reused as the external text request ID. + pub request_id: String, + /// Public model ID echoed back to the client. + pub response_model: String, + /// Whether the caller asked for the final streamed usage chunk. + pub include_usage: bool, + /// Lowered text request for the shared `vllm-text` facade. + pub text_request: TextRequest, + /// Original text prompt that should be echoed back northbound when + /// `echo=true`. + pub echo: Option, + /// Whether to include token IDs alongside generated text. + pub return_token_ids: bool, + /// Whether to format logprob tokens as `token_id:{id}`. + pub return_tokens_as_token_ids: bool, +} + +/// Validate and lower one OpenAI completions request into the internal +/// text-generation format. +/// +/// `served_model_names` must be non-empty; the first entry is used as the +/// `model` field in responses. +pub(crate) fn prepare_completion_request( + request: CompletionRequest, + served_model_names: &[String], + ctx: ResolvedRequestContext, +) -> Result { + validate::validate_request_compat(&request, served_model_names)?; + + let request_id = format!("cmpl-{}", ctx.request_id); + + let logprobs = match request.logprobs { + Some(logprobs) => Some(i32::try_from(logprobs).map_err(|_| { + ApiError::invalid_request( + "`logprobs` must fit within a signed 32-bit integer.".to_string(), + Some("logprobs"), + ) + })?), + None => None, + }; + let prompt_logprobs = request.prompt_logprobs.or(if request.echo && !request.stream { + logprobs + } else { + None + }); + let include_usage = (request.stream_options.as_ref()) + .and_then(|options| options.include_usage) + .unwrap_or(false); + let echo = request.echo.then(|| request.prompt.as_text().cloned()).flatten(); + + let structured_outputs = + convert_from_response_format_value(&request.response_format, &request.structured_outputs)?; + + let text_request = TextRequest { + request_id: request_id.clone(), + prompt: request.prompt, + mm_features: None, + sampling_params: SamplingParams { + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + seed: request.seed, + max_tokens: request.max_tokens, + min_tokens: request.min_tokens, + logprobs, + prompt_logprobs, + min_p: request.min_p, + frequency_penalty: request.frequency_penalty, + presence_penalty: request.presence_penalty, + repetition_penalty: request.repetition_penalty, + stop_token_ids: request.stop_token_ids, + ignore_eos: request.ignore_eos, + logit_bias: convert_logit_bias(request.logit_bias)?, + allowed_token_ids: request.allowed_token_ids, + bad_words: None, + logprob_token_ids: None, + structured_outputs, + skip_reading_prefix_cache: None, + vllm_xargs: merge_kv_transfer_params( + request.vllm_xargs, + request.kv_transfer_params.as_ref(), + ), + }, + decode_options: TextDecodeOptions { + skip_special_tokens: request.skip_special_tokens, + include_stop_str_in_output: request.include_stop_str_in_output, + stop_strings: request.stop.map(|stop| stop.into_vec()), + min_tokens: request.min_tokens.unwrap_or(0), + }, + intermediate: request.stream, + priority: request.priority.unwrap_or(0), + cache_salt: request.cache_salt, + add_special_tokens: request.add_special_tokens, + data_parallel_rank: ctx.data_parallel_rank, + }; + + Ok(PreparedRequest { + request_id, + response_model: served_model_names.first().cloned().unwrap_or_default(), + include_usage, + text_request, + echo, + return_token_ids: request.return_token_ids.unwrap_or(false), + return_tokens_as_token_ids: request.return_tokens_as_token_ids.unwrap_or(false), + }) +} + +#[cfg(test)] +mod tests { + use axum::http::HeaderMap; + use serde_json::json; + use vllm_text::Prompt; + + use super::prepare_completion_request; + use crate::routes::openai::completions::types::CompletionRequest; + use crate::utils::{ResolvedRequestContext, resolve_request_context}; + + fn request_context(headers: &HeaderMap, request_id: Option<&str>) -> ResolvedRequestContext { + resolve_request_context(headers, request_id) + } + + fn served(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + + fn base_request_json() -> serde_json::Value { + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": true + }) + } + + #[test] + fn completion_http_request_deserializes_text_prompt() { + let request: CompletionRequest = + serde_json::from_value(base_request_json()).expect("parse request"); + + assert_eq!(request.prompt, Prompt::Text("hello".to_string())); + assert_eq!(request.model, "Qwen/Qwen1.5-0.5B-Chat"); + } + + #[test] + fn completion_http_request_deserializes_token_id_prompt() { + let request: CompletionRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": [11, 22, 33], + "stream": true, + "ignore_eos": true, + "max_tokens": 7 + })) + .expect("parse request"); + + assert_eq!(request.prompt, Prompt::TokenIds(vec![11, 22, 33])); + assert_eq!(request.max_tokens, Some(7)); + assert!(request.ignore_eos); + } + + #[test] + fn prepare_completion_request_maps_sampling_fields() { + let request: CompletionRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": [11, 22, 33], + "stream": true, + "stream_options": {"include_usage": true}, + "max_tokens": 7, + "logprobs": 2, + "top_p": 0.9, + "top_k": 42, + "min_p": 0.1, + "frequency_penalty": 0.2, + "presence_penalty": 0.3, + "repetition_penalty": 1.1, + "ignore_eos": true, + "skip_special_tokens": false + })) + .expect("parse request"); + + let prepared = prepare_completion_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("prepare"); + + assert!(prepared.include_usage); + assert_eq!( + prepared.text_request.prompt, + Prompt::TokenIds(vec![11, 22, 33]) + ); + assert_eq!(prepared.text_request.sampling_params.max_tokens, Some(7)); + assert_eq!(prepared.text_request.sampling_params.logprobs, Some(2)); + assert_eq!(prepared.text_request.sampling_params.top_p, Some(0.9)); + assert_eq!(prepared.text_request.sampling_params.top_k, Some(42)); + assert_eq!(prepared.text_request.sampling_params.min_p, Some(0.1)); + assert_eq!( + prepared.text_request.sampling_params.frequency_penalty, + Some(0.2) + ); + assert_eq!( + prepared.text_request.sampling_params.presence_penalty, + Some(0.3) + ); + assert_eq!( + prepared.text_request.sampling_params.repetition_penalty, + Some(1.1) + ); + assert!(prepared.text_request.sampling_params.ignore_eos); + assert!(!prepared.text_request.decode_options.skip_special_tokens); + } + + #[test] + fn prepare_completion_request_accepts_text_echo() { + let request: CompletionRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": true, + "echo": true, + "max_tokens": 7 + })) + .expect("parse request"); + + let prepared = prepare_completion_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("prepare"); + + assert_eq!(prepared.echo, Some("hello".to_string())); + assert_eq!(prepared.text_request.sampling_params.max_tokens, Some(7)); + } + + #[test] + fn prepare_completion_request_enables_prompt_logprobs_for_non_stream_echo() { + let request: CompletionRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "echo": true, + "stream": false, + "logprobs": 3 + })) + .expect("parse request"); + + let prepared = prepare_completion_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("prepare"); + + assert_eq!(prepared.text_request.sampling_params.logprobs, Some(3)); + assert_eq!( + prepared.text_request.sampling_params.prompt_logprobs, + Some(3) + ); + } + + #[test] + fn prepare_completion_request_rejects_token_id_prompt_echo() { + let request: CompletionRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": [11, 22, 33], + "stream": true, + "echo": true + })) + .expect("parse request"); + + assert!( + prepare_completion_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .is_err() + ); + } + + #[test] + fn prepare_completion_request_accepts_logprobs_fields() { + let request: CompletionRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + "logprobs": 1, + "prompt_logprobs": 2 + })) + .expect("parse request"); + + let prepared = prepare_completion_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("prepare"); + assert_eq!(prepared.text_request.sampling_params.logprobs, Some(1)); + assert_eq!( + prepared.text_request.sampling_params.prompt_logprobs, + Some(2) + ); + } + + #[test] + fn prepare_completion_request_threads_data_parallel_rank() { + let request: CompletionRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + })) + .expect("parse request"); + + let mut headers = HeaderMap::new(); + headers.insert("X-data-parallel-rank", "3".parse().unwrap()); + let prepared = prepare_completion_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + request_context(&headers, None), + ) + .expect("prepare"); + assert_eq!(prepared.text_request.data_parallel_rank, Some(3)); + } + + #[test] + fn prepare_completion_request_leaves_data_parallel_rank_none_when_absent() { + let request: CompletionRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + })) + .expect("parse request"); + + let prepared = prepare_completion_request( + request, + &served(&["Qwen/Qwen1.5-0.5B-Chat"]), + ResolvedRequestContext::default(), + ) + .expect("prepare"); + assert_eq!(prepared.text_request.data_parallel_rank, None); + } +} diff --git a/rust/src/server/src/routes/openai/completions/types.rs b/rust/src/server/src/routes/openai/completions/types.rs new file mode 100644 index 00000000000..adc8a7ba7cb --- /dev/null +++ b/rust/src/server/src/routes/openai/completions/types.rs @@ -0,0 +1,253 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; +use validator::Validate; +use vllm_text::Prompt; + +use crate::routes::openai::utils::types::{ + LogProbs, Normalizable, StreamOptions, StringOrArray, Usage, default_true, validate_stop, +}; + +/// Serde default for `CompletionRequest::max_tokens`, matching the Python vLLM +/// / OpenAI default. +fn default_completion_max_tokens() -> Option { + Some(16) +} + +/// vLLM-compatible request type for the Completions API. +/// +/// Mirrors the Python vLLM `CompletionRequest` class. The local copy keeps the +/// request type route-owned so we can accept token-id prompts via +/// [`vllm_text::Prompt`] and add vLLM-only fields directly instead of layering +/// wrapper deserializers on top. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize, Validate)] +pub struct CompletionRequest { + // -------- Standard OpenAI API Parameters -------- + /// ID of the model to use + pub model: String, + + /// The prompt(s) to generate completions for. + /// + /// We use [`vllm_text::Prompt`] here to support token-id input. + pub prompt: Prompt, + + /// Echo back the prompt in addition to the completion + #[serde(default)] + pub echo: bool, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based + /// on their existing frequency in the text so far + pub frequency_penalty: Option, + + /// Modify the likelihood of specified tokens appearing in the completion + pub logit_bias: Option>, + + /// Include the log probabilities on the logprobs most likely tokens + pub logprobs: Option, + + /// The maximum number of tokens to generate (defaults to 16 when absent, + /// matching the Python vLLM / OpenAI API convention) + #[serde(default = "default_completion_max_tokens")] + pub max_tokens: Option, + + /// How many completions to generate for each prompt + pub n: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based + /// on whether they appear in the text so far + pub presence_penalty: Option, + + /// If specified, our system will make a best effort to sample + /// deterministically + pub seed: Option, + + /// Up to 4 sequences where the API will stop generating further tokens + #[validate(custom(function = "validate_stop"))] + pub stop: Option, + + /// Whether to stream back partial progress + #[serde(default)] + pub stream: bool, + + /// The suffix that comes after a completion of inserted text + pub suffix: Option, + + /// What sampling temperature to use, between 0 and 2 + pub temperature: Option, + + /// An alternative to sampling with temperature (nucleus sampling) + pub top_p: Option, + + /// A unique identifier representing your end-user + pub user: Option, + + // -------- vLLM Sampling Parameters -------- + /// Options for streaming response + pub stream_options: Option, + + /// Use beam search instead of sampling + #[serde(default)] + pub use_beam_search: bool, + + /// Top-k sampling parameter + pub top_k: Option, + + /// Min-p nucleus sampling parameter + pub min_p: Option, + + /// Repetition penalty for reducing repetitive text + pub repetition_penalty: Option, + + /// Length penalty for beam search + pub length_penalty: Option, + + /// Specific token IDs to use as stop conditions + pub stop_token_ids: Option>, + + /// Include stop string in output + #[serde(default)] + pub include_stop_str_in_output: bool, + + /// Ignore end-of-sequence tokens during generation + #[serde(default)] + pub ignore_eos: bool, + + /// Minimum number of tokens to generate + pub min_tokens: Option, + + /// Skip special tokens during detokenization + #[serde(default = "default_true")] + pub skip_special_tokens: bool, + + /// Add spaces between special tokens during detokenization + #[serde(default = "default_true")] + pub spaces_between_special_tokens: bool, + + /// Truncate prompt tokens to this length + pub truncate_prompt_tokens: Option, + + /// Restrict output to these token IDs only + pub allowed_token_ids: Option>, + + /// Number of prompt logprobs to return + pub prompt_logprobs: Option, + + // -------- Extra vLLM Parameters -------- + /// Whether to add special tokens (e.g. BOS) to the prompt + #[serde(default = "default_true")] + pub add_special_tokens: bool, + + /// Format specification for structured output (JSON mode, JSON schema, + /// etc.) + pub response_format: Option, + + /// Additional kwargs for structured outputs + pub structured_outputs: Option, + + /// Request scheduling priority (lower means earlier; default 0) + pub priority: Option, + + /// External request ID used for response correlation. + pub request_id: Option, + + /// Tokens represented as strings of the form 'token_id:{token_id}' in + /// logprobs + pub return_tokens_as_token_ids: Option, + + /// Include token IDs alongside generated text + pub return_token_ids: Option, + + /// Salt for prefix cache isolation in multi-user environments + pub cache_salt: Option, + + /// KV transfer parameters for disaggregated serving + pub kv_transfer_params: Option>, + + /// Additional request parameters with string or numeric values for custom + /// extensions + pub vllm_xargs: Option>, + + /// Additional fields + #[serde(flatten)] + pub other: Map, +} + +impl Normalizable for CompletionRequest {} + +/// Mirrors the Python vLLM `CompletionResponse` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct CompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Option, + pub system_fingerprint: Option, + pub kv_transfer_params: Option, +} + +/// Mirrors the Python vLLM `CompletionResponseChoice` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct CompletionChoice { + pub index: u32, + pub text: String, + pub logprobs: Option, + pub finish_reason: Option, + pub stop_reason: Option, + pub prompt_logprobs: Option>>>, + pub token_ids: Option>, + pub prompt_token_ids: Option>, +} + +/// Mirrors the Python vLLM `CompletionStreamResponse` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct CompletionStreamResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +impl CompletionStreamResponse { + /// Create a stream response with the standard envelope fields pre-filled. + pub fn new(id: &str, model: &str, created: u64) -> Self { + Self { + id: id.to_string(), + object: "text_completion".to_string(), + created, + model: model.to_string(), + choices: Vec::new(), + usage: None, + } + } +} + +/// Mirrors the Python vLLM `CompletionResponseStreamChoice` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Default, Serialize)] +pub(super) struct CompletionStreamChoice { + pub index: u32, + pub text: String, + pub logprobs: Option, + pub finish_reason: Option, + pub stop_reason: Option, + pub token_ids: Option>, + pub prompt_token_ids: Option>, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub(super) enum CompletionSseChunk { + /// Ordinary OpenAI completions delta/final chunk. + Chunk(CompletionStreamResponse), + /// Final usage chunk emitted before `[DONE]` when `include_usage=true`. + Usage(CompletionStreamResponse), +} diff --git a/rust/src/server/src/routes/openai/completions/validate.rs b/rust/src/server/src/routes/openai/completions/validate.rs new file mode 100644 index 00000000000..a53609234b6 --- /dev/null +++ b/rust/src/server/src/routes/openai/completions/validate.rs @@ -0,0 +1,178 @@ +use vllm_text::Prompt; + +use super::types::CompletionRequest; +use crate::error::{ApiError, bail_invalid_request}; + +/// Enforce the minimal compatibility contract for the Rust OpenAI server. +pub(super) fn validate_request_compat( + request: &CompletionRequest, + served_model_names: &[String], +) -> Result<(), ApiError> { + // This path is intentionally scoped to the minimum surface needed by + // `vllm-bench` random workload compatibility, so unsupported legacy + // completions features fail early here. + if !served_model_names.iter().any(|n| n == &request.model) { + return Err(ApiError::model_not_found(request.model.clone())); + } + + if request.stream_options.is_some() && !request.stream { + bail_invalid_request!( + param = "stream_options", + "stream_options are only supported when stream=true." + ); + } + + if request.n.unwrap_or(1) > 1 { + bail_invalid_request!(param = "n", "Only n=1 is supported."); + } + + if request.max_tokens == Some(0) { + bail_invalid_request!(param = "max_tokens", "max_tokens must be greater than 0."); + } + + if request.echo && matches!(request.prompt, Prompt::TokenIds(_)) { + bail_invalid_request!( + param = "echo", + "echo is not supported with token-ID prompts." + ); + } + + if request.suffix.is_some() { + bail_invalid_request!(param = "suffix", "suffix is not supported."); + } + + if let Some(logprobs) = request.logprobs + && logprobs > i32::MAX as u32 + { + bail_invalid_request!( + param = "logprobs", + "`logprobs` must fit within a signed 32-bit integer." + ); + } + + if let Some(prompt_logprobs) = request.prompt_logprobs { + if request.stream && (prompt_logprobs > 0 || prompt_logprobs == -1) { + bail_invalid_request!( + param = "prompt_logprobs", + "`prompt_logprobs` are not available when `stream=true`." + ); + } + + if prompt_logprobs < 0 && prompt_logprobs != -1 { + bail_invalid_request!( + param = "prompt_logprobs", + "`prompt_logprobs` must be a non-negative value or -1." + ); + } + } + + if request.use_beam_search { + bail_invalid_request!( + param = "use_beam_search", + "use_beam_search is not supported." + ); + } + + // ---- Reject parameters that are accepted for deserialization but not yet + // implemented ---- + + if request.length_penalty.is_some() { + bail_invalid_request!(param = "length_penalty", "length_penalty is not supported."); + } + if !request.spaces_between_special_tokens { + bail_invalid_request!( + param = "spaces_between_special_tokens", + "spaces_between_special_tokens is not supported." + ); + } + if request.truncate_prompt_tokens.is_some() { + bail_invalid_request!( + param = "truncate_prompt_tokens", + "truncate_prompt_tokens is not supported." + ); + } + + if let Some(options) = &request.stream_options + && options.continuous_usage_stats.is_some() + { + bail_invalid_request!( + param = "stream_options", + "continuous_usage_stats is not supported." + ); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::validate_request_compat; + use crate::routes::openai::completions::types::CompletionRequest; + + fn base_request() -> CompletionRequest { + serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": true, + })) + .expect("parse request") + } + + fn served_names(names: &[&str]) -> Vec { + names.iter().map(|s| s.to_string()).collect() + } + + #[test] + fn validate_request_compat_accepts_logprobs() { + let request = CompletionRequest { + logprobs: Some(1), + ..base_request() + }; + assert!( + validate_request_compat(&request, &served_names(&["Qwen/Qwen1.5-0.5B-Chat"])).is_ok() + ); + } + + #[test] + fn validate_request_compat_accepts_any_served_name() { + let request = base_request(); + assert!( + validate_request_compat( + &request, + &served_names(&["other-alias", "Qwen/Qwen1.5-0.5B-Chat"]) + ) + .is_ok() + ); + } + + #[test] + fn validate_request_compat_rejects_unknown_model() { + let request = base_request(); + assert!(validate_request_compat(&request, &served_names(&["other-model"])).is_err()); + } + + #[test] + fn validate_request_compat_rejects_streaming_prompt_logprobs() { + let request = CompletionRequest { + prompt_logprobs: Some(1), + ..base_request() + }; + assert!( + validate_request_compat(&request, &served_names(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err() + ); + } + + #[test] + fn validate_request_compat_accepts_non_stream_prompt_logprobs() { + let request = CompletionRequest { + stream: false, + prompt_logprobs: Some(-1), + ..base_request() + }; + assert!( + validate_request_compat(&request, &served_names(&["Qwen/Qwen1.5-0.5B-Chat"])).is_ok() + ); + } +} diff --git a/rust/src/server/src/routes/openai/mod.rs b/rust/src/server/src/routes/openai/mod.rs new file mode 100644 index 00000000000..0cc8887dfab --- /dev/null +++ b/rust/src/server/src/routes/openai/mod.rs @@ -0,0 +1,8 @@ +pub mod chat_completions; +mod completions; +mod models; +pub(crate) mod utils; + +pub use chat_completions::chat_completions; +pub use completions::completions; +pub use models::list_models; diff --git a/rust/src/server/src/routes/openai/models.rs b/rust/src/server/src/routes/openai/models.rs new file mode 100644 index 00000000000..42e3098fc9e --- /dev/null +++ b/rust/src/server/src/routes/openai/models.rs @@ -0,0 +1,24 @@ +use std::sync::Arc; + +use axum::Json; +use axum::extract::State; + +use crate::routes::openai::utils::types::{ListModelsResponse, ModelObject}; +use crate::state::AppState; + +/// Return all configured served model names in OpenAI `list models` format. +pub async fn list_models(State(state): State>) -> Json { + Json(ListModelsResponse { + object: "list".to_string(), + data: state + .served_model_names() + .iter() + .map(|name| ModelObject { + id: name.clone(), + object: "model".to_string(), + created: 0, + owned_by: "vllm-frontend-rs".to_string(), + }) + .collect(), + }) +} diff --git a/rust/src/server/src/routes/openai/utils/logprobs.rs b/rust/src/server/src/routes/openai/utils/logprobs.rs new file mode 100644 index 00000000000..aab02e46ec6 --- /dev/null +++ b/rust/src/server/src/routes/openai/utils/logprobs.rs @@ -0,0 +1,256 @@ +use std::collections::HashMap; + +use itertools::Itertools as _; +use vllm_text::{ + CollectedTextOutput, DecodedLogprobs, DecodedPositionLogprobs, DecodedPromptLogprobs, + DecodedTokenLogprob, +}; + +use super::types::{ChatLogProbs, ChatLogProbsContent, LogProbs, TopLogProb}; +use crate::error::{ApiError, server_error}; + +/// Convert decoded token-position logprobs into the OpenAI completions +/// `logprobs` shape. +pub fn decoded_logprobs_to_openai( + logprobs: &DecodedLogprobs, + initial_text_offset: u32, + return_tokens_as_token_ids: bool, +) -> Result { + let mut text_offset = Vec::with_capacity(logprobs.positions.len()); + let mut token_logprobs = Vec::with_capacity(logprobs.positions.len()); + let mut tokens = Vec::with_capacity(logprobs.positions.len()); + let mut top_logprobs = Vec::with_capacity(logprobs.positions.len()); + let mut current_offset = initial_text_offset; + + for position in &logprobs.positions { + let chosen = position.entries.first().ok_or_else(|| { + server_error!("decoded logprobs position unexpectedly had no token candidates") + })?; + + let token_str = format_token(chosen, return_tokens_as_token_ids); + text_offset.push(current_offset); + token_logprobs.push(Some(clamp_logprob(chosen.logprob))); + current_offset = current_offset.saturating_add(text_len(&token_str)); + tokens.push(token_str); + top_logprobs.push(Some(position_top_logprobs_map( + position, + return_tokens_as_token_ids, + ))); + } + + Ok(LogProbs { + tokens, + token_logprobs, + top_logprobs, + text_offset, + }) +} + +/// Convert decoded prompt logprobs into the OpenAI completions `logprobs` +/// shape. +/// +/// The first prompt token is included with `None` logprob metadata, matching +/// Python vLLM's echoed completions behavior. +pub fn decoded_prompt_logprobs_to_openai( + prompt_logprobs: &DecodedPromptLogprobs, + initial_text_offset: u32, + return_tokens_as_token_ids: bool, +) -> Result { + let mut text_offset = Vec::with_capacity(prompt_logprobs.scored_positions.len() + 1); + let mut token_logprobs = Vec::with_capacity(prompt_logprobs.scored_positions.len() + 1); + let mut tokens = Vec::with_capacity(prompt_logprobs.scored_positions.len() + 1); + let mut top_logprobs = Vec::with_capacity(prompt_logprobs.scored_positions.len() + 1); + let mut current_offset = initial_text_offset; + + let first_token_str = if return_tokens_as_token_ids { + format!("token_id:{}", prompt_logprobs.first_token_id) + } else { + prompt_logprobs.first_token.clone() + }; + text_offset.push(current_offset); + token_logprobs.push(None); + current_offset = current_offset.saturating_add(text_len(&first_token_str)); + tokens.push(first_token_str); + top_logprobs.push(None); + + for position in &prompt_logprobs.scored_positions { + let chosen = position.entries.first().ok_or_else(|| { + server_error!("decoded prompt logprobs position unexpectedly had no token candidates") + })?; + + let token_str = format_token(chosen, return_tokens_as_token_ids); + text_offset.push(current_offset); + token_logprobs.push(Some(clamp_logprob(chosen.logprob))); + current_offset = current_offset.saturating_add(text_len(&token_str)); + tokens.push(token_str); + top_logprobs.push(Some(position_top_logprobs_map( + position, + return_tokens_as_token_ids, + ))); + } + + Ok(LogProbs { + tokens, + token_logprobs, + top_logprobs, + text_offset, + }) +} + +/// Convert decoded prompt logprobs into the vLLM-style prompt-logprobs response +/// shape. +pub fn decoded_prompt_logprobs_to_maps( + prompt_logprobs: &DecodedPromptLogprobs, + return_tokens_as_token_ids: bool, +) -> Vec>> { + std::iter::once(None) + .chain(prompt_logprobs.scored_positions.iter().map(|position| { + Some(position_top_logprobs_map( + position, + return_tokens_as_token_ids, + )) + })) + .collect() +} + +/// Convert decoded token-position logprobs into the OpenAI chat `logprobs` +/// shape. +pub fn decoded_logprobs_to_openai_chat( + logprobs: &DecodedLogprobs, + return_tokens_as_token_ids: bool, +) -> Result { + let content = logprobs + .positions + .iter() + .map(|pos| position_to_chat_logprobs_content(pos, return_tokens_as_token_ids)) + .try_collect()?; + + Ok(ChatLogProbs { + content: Some(content), + }) +} + +/// Count visible text positions using OpenAI completions' character-offset +/// convention. +pub fn text_len(text: &str) -> u32 { + u32::try_from(text.chars().count()).unwrap_or(u32::MAX) +} + +/// Concatenate two OpenAI-style completion logprobs payloads in token order. +pub fn append_openai_logprobs(mut prefix: LogProbs, suffix: LogProbs) -> LogProbs { + prefix.tokens.extend(suffix.tokens); + prefix.token_logprobs.extend(suffix.token_logprobs); + prefix.top_logprobs.extend(suffix.top_logprobs); + prefix.text_offset.extend(suffix.text_offset); + prefix +} + +/// Build the non-stream completions `logprobs` payload from collected text +/// output. +/// +/// When `echoed_prompt` is true, the returned payload matches Python vLLM's +/// echoed completions behavior by concatenating prompt and completion logprobs +/// into one OpenAI `LogProbs` object. +pub fn collected_logprobs_to_openai( + collected: &CollectedTextOutput, + echoed_prompt: bool, + initial_completion_offset: u32, + return_tokens_as_token_ids: bool, +) -> Result { + if echoed_prompt { + let prompt_logprobs = collected.prompt_logprobs.as_ref().ok_or_else(|| { + server_error!( + "echoed completion logprobs require prompt logprobs but generation returned none" + ) + })?; + let prompt_logprobs = + decoded_prompt_logprobs_to_openai(prompt_logprobs, 0, return_tokens_as_token_ids)?; + let completion_start = prompt_logprobs + .text_offset + .last() + .zip(prompt_logprobs.tokens.last()) + .map(|(&offset, token)| offset.saturating_add(text_len(token))) + .unwrap_or(0); + return match collected.logprobs.as_ref() { + Some(completion_logprobs) => Ok(append_openai_logprobs( + prompt_logprobs, + decoded_logprobs_to_openai( + completion_logprobs, + completion_start, + return_tokens_as_token_ids, + )?, + )), + None => Ok(prompt_logprobs), + }; + } + + let completion_logprobs = collected.logprobs.as_ref().ok_or_else(|| { + server_error!("completion response requested logprobs but generation returned none") + })?; + decoded_logprobs_to_openai( + completion_logprobs, + initial_completion_offset, + return_tokens_as_token_ids, + ) +} + +/// Format a token entry as either its decoded string or `token_id:{id}`. +fn format_token(entry: &DecodedTokenLogprob, as_token_id: bool) -> String { + if as_token_id { + format!("token_id:{}", entry.token_id) + } else { + entry.token.clone() + } +} + +fn position_top_logprobs_map( + position: &DecodedPositionLogprobs, + return_tokens_as_token_ids: bool, +) -> HashMap { + position + .entries + .iter() + .map(|entry| { + ( + format_token(entry, return_tokens_as_token_ids), + clamp_logprob(entry.logprob), + ) + }) + .collect() +} + +fn position_to_chat_logprobs_content( + position: &DecodedPositionLogprobs, + return_tokens_as_token_ids: bool, +) -> Result { + let chosen = position.entries.first().ok_or_else(|| { + server_error!("decoded chat logprobs position unexpectedly had no token candidates") + })?; + + let token_str = format_token(chosen, return_tokens_as_token_ids); + Ok(ChatLogProbsContent { + token: token_str.clone(), + logprob: clamp_logprob(chosen.logprob), + bytes: Some(token_bytes(&token_str)), + top_logprobs: position + .entries + .iter() + .map(|entry| { + let t = format_token(entry, return_tokens_as_token_ids); + TopLogProb { + logprob: clamp_logprob(entry.logprob), + bytes: Some(token_bytes(&t)), + token: t, + } + }) + .collect(), + }) +} + +fn token_bytes(token: &str) -> Vec { + token.as_bytes().to_vec() +} + +pub fn clamp_logprob(logprob: f32) -> f32 { + logprob.max(-9999.0) +} diff --git a/rust/src/server/src/routes/openai/utils/mod.rs b/rust/src/server/src/routes/openai/utils/mod.rs new file mode 100644 index 00000000000..57b1d99690d --- /dev/null +++ b/rust/src/server/src/routes/openai/utils/mod.rs @@ -0,0 +1,4 @@ +pub mod logprobs; +pub mod structured_outputs; +pub mod types; +pub mod validated_json; diff --git a/rust/src/server/src/routes/openai/utils/structured_outputs.rs b/rust/src/server/src/routes/openai/utils/structured_outputs.rs new file mode 100644 index 00000000000..e974c836bb5 --- /dev/null +++ b/rust/src/server/src/routes/openai/utils/structured_outputs.rs @@ -0,0 +1,132 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use vllm_engine_core_client::protocol::StructuredOutputsParams; + +use crate::error::ApiError; + +/// JSON schema specification nested inside a `json_schema` response format. +/// +/// Mirrors the Python vLLM `JsonSchemaResponseFormat` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct JsonSchemaFormat { + pub name: String, + #[serde(default)] + pub description: Option, + /// The actual JSON schema object. + #[serde(alias = "json_schema")] + pub schema: Value, + #[serde(default)] + pub strict: Option, +} + +/// Supported `response_format` types for chat and completion requests. +/// +/// This is our own definition (rather than the `openai-protocol` crate's) so +/// that we can support the vLLM-specific `structural_tag` variant. +/// +/// Original Python definitions: +/// +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponseFormat { + Text, + JsonObject, + JsonSchema { + json_schema: JsonSchemaFormat, + }, + /// vLLM-specific structural tag format. The entire object (including the + /// `type` field) is JSON-serialized and passed as + /// `StructuredOutputsParams.structural_tag`. + /// + /// We capture the payload as a catch-all map so both the legacy + /// (`structures`/`triggers`) and current (`format`) shapes are + /// preserved without needing typed structs. + StructuralTag { + #[serde(flatten)] + extra: serde_json::Map, + }, +} + +/// Convert an explicit `structured_outputs` JSON blob into +/// [`StructuredOutputsParams`]. +fn deserialize_structured_outputs( + raw: &serde_json::Value, +) -> Result { + serde_json::from_value(raw.clone()).map_err(|e| { + ApiError::invalid_request( + format!("invalid structured_outputs: {e}"), + Some("structured_outputs"), + ) + }) +} + +/// Convert a typed [`ResponseFormat`] and/or raw `structured_outputs` blob into +/// engine-core [`StructuredOutputsParams`]. +/// +/// Mirrors the Python vLLM conversion in +/// `ChatCompletionRequest.to_sampling_params()`: +pub fn convert_from_response_format( + response_format: Option<&ResponseFormat>, + structured_outputs: &Option, +) -> Result, ApiError> { + if let Some(raw) = structured_outputs { + return Ok(Some(deserialize_structured_outputs(raw)?)); + } + + let Some(fmt) = response_format else { + return Ok(None); + }; + match fmt { + ResponseFormat::Text => Ok(None), + ResponseFormat::JsonObject => Ok(Some(StructuredOutputsParams { + json_object: Some(true), + ..Default::default() + })), + ResponseFormat::JsonSchema { json_schema } => Ok(Some(StructuredOutputsParams { + json: Some(json_schema.schema.clone()), + ..Default::default() + })), + ResponseFormat::StructuralTag { .. } => { + // The Python frontend dumps the entire response_format object (including the + // `type` field) as a JSON string for the engine-core backend. + let tag_json = serde_json::to_string(fmt).map_err(|e| { + ApiError::invalid_request( + format!("failed to serialize structural_tag: {e}"), + Some("response_format"), + ) + })?; + Ok(Some(StructuredOutputsParams { + structural_tag: Some(tag_json), + ..Default::default() + })) + } + } +} + +/// Convert raw `response_format` and/or `structured_outputs` JSON blobs into +/// engine-core [`StructuredOutputsParams`]. +/// +/// Used by the completions endpoint which keeps both fields as opaque +/// `serde_json::Value`. +pub fn convert_from_response_format_value( + response_format: &Option, + structured_outputs: &Option, +) -> Result, ApiError> { + if let Some(raw) = structured_outputs { + return Ok(Some(deserialize_structured_outputs(raw)?)); + } + + let Some(raw) = response_format else { + return Ok(None); + }; + + // Deserialize into our typed enum and delegate. + let fmt: ResponseFormat = serde_json::from_value(raw.clone()).map_err(|e| { + ApiError::invalid_request( + format!("invalid response_format: {e}"), + Some("response_format"), + ) + })?; + convert_from_response_format(Some(&fmt), &None) +} diff --git a/rust/src/server/src/routes/openai/utils/types.rs b/rust/src/server/src/routes/openai/utils/types.rs new file mode 100644 index 00000000000..ff747a5daf5 --- /dev/null +++ b/rust/src/server/src/routes/openai/utils/types.rs @@ -0,0 +1,425 @@ +use std::collections::HashMap; +use std::slice; + +use llm_multimodal::ImageDetail; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +// ============================================================================ +// Constants +// ============================================================================ + +/// Default model identifier used when no model is specified. +pub const UNKNOWN_MODEL_ID: &str = "unknown"; + +// ============================================================================ +// Default value helpers +// ============================================================================ + +/// Helper function for serde default value (returns true). +pub fn default_true() -> bool { + true +} + +// ============================================================================ +// String/Array Utilities +// ============================================================================ + +/// A type that can be either a single string or an array of strings. +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] +pub enum StringOrArray { + String(String), + Array(Vec), +} + +impl StringOrArray { + pub fn as_slice(&self) -> &[String] { + match self { + StringOrArray::String(s) => slice::from_ref(s), + StringOrArray::Array(arr) => arr, + } + } + + #[allow(unused)] + pub fn into_vec(self) -> Vec { + match self { + StringOrArray::String(s) => vec![s], + StringOrArray::Array(arr) => arr, + } + } +} + +/// Validates stop sequences (non-empty strings) +pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> { + if stop.as_slice().iter().any(|s| s.is_empty()) { + return Err(validator::ValidationError::new( + "stop strings cannot be empty", + )); + } + Ok(()) +} + +// ============================================================================ +// Validation helpers +// ============================================================================ + +/// Validates top_p: 0.0 < top_p <= 1.0. +pub fn validate_top_p_value(top_p: f32) -> Result<(), validator::ValidationError> { + if !(top_p > 0.0 && top_p <= 1.0) { + return Err(validator::ValidationError::new( + "top_p must be in (0, 1] - greater than 0.0 and at most 1.0", + )); + } + Ok(()) +} + +// ============================================================================ +// Content Parts (for multimodal messages) +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { + image_url: ImageUrl, + #[serde(skip_serializing_if = "Option::is_none")] + uuid: Option, + }, + #[serde(rename = "video_url")] + VideoUrl { video_url: VideoUrl }, +} + +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct ImageUrl { + pub url: String, + pub detail: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub struct VideoUrl { + pub url: String, +} + +// ============================================================================ +// Streaming +// ============================================================================ + +/// Mirrors the Python vLLM `StreamOptions` class. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct StreamOptions { + pub include_usage: Option, + pub continuous_usage_stats: Option, +} + +// ============================================================================ +// Tools and Function Calling +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, + pub function: Function, +} + +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Function { + pub name: String, + pub description: Option, + pub parameters: Value, + /// Whether to enable strict schema adherence (OpenAI structured outputs). + pub strict: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub tool_type: String, + pub function: FunctionCallResponse, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallResponse { + pub name: String, + #[serde(default)] + pub arguments: Option, +} + +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCallDelta { + pub index: u32, + pub id: Option, + #[serde(rename = "type")] + pub tool_type: Option, + pub function: Option, +} + +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallDelta { + pub name: Option, + pub arguments: Option, +} + +/// Tool choice value for simple string options. +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoiceValue { + Auto, + Required, + None, +} + +/// Tool choice for the Chat Completion API. +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ToolChoice { + Value(ToolChoiceValue), + Function { + #[serde(rename = "type")] + tool_type: String, + function: FunctionChoice, + }, + AllowedTools { + #[serde(rename = "type")] + tool_type: String, + mode: String, + tools: Vec, + }, +} + +impl Default for ToolChoice { + fn default() -> Self { + Self::Value(ToolChoiceValue::Auto) + } +} + +/// Function choice specification for `ToolChoice::Function`. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionChoice { + pub name: String, +} + +/// Tool reference for `ToolChoice::AllowedTools`. +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ToolReference { + #[serde(rename = "function")] + Function { name: String }, + #[serde(rename = "mcp")] + Mcp { + server_label: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + #[serde(rename = "file_search")] + FileSearch, + #[serde(rename = "web_search_preview")] + WebSearchPreview, + #[serde(rename = "computer_use_preview")] + ComputerUsePreview, + #[serde(rename = "code_interpreter")] + CodeInterpreter, + #[serde(rename = "image_generation")] + ImageGeneration, +} + +impl ToolReference { + /// Get a unique identifier for this tool reference. + pub fn identifier(&self) -> String { + match self { + ToolReference::Function { name } => format!("function:{name}"), + ToolReference::Mcp { + server_label, + name: Some(n), + } => format!("mcp:{server_label}:{n}"), + ToolReference::Mcp { + server_label, + name: _, + } => format!("mcp:{server_label}"), + ToolReference::FileSearch => "file_search".to_string(), + ToolReference::WebSearchPreview => "web_search_preview".to_string(), + ToolReference::ComputerUsePreview => "computer_use_preview".to_string(), + ToolReference::CodeInterpreter => "code_interpreter".to_string(), + ToolReference::ImageGeneration => "image_generation".to_string(), + } + } +} + +// ============================================================================ +// Chat Messages +// ============================================================================ + +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "role")] +pub enum ChatMessage { + #[serde(rename = "system")] + System { + content: MessageContent, + name: Option, + }, + #[serde(rename = "user")] + User { + content: MessageContent, + name: Option, + }, + #[serde(rename = "assistant")] + Assistant { + content: Option, + name: Option, + tool_calls: Option>, + /// Reasoning content for reasoning-capable models. + #[serde(alias = "reasoning_content")] + #[serde(alias = "thinking")] + reasoning: Option, + }, + #[serde(rename = "tool")] + Tool { + content: MessageContent, + tool_call_id: String, + }, + #[serde(rename = "function")] + Function { content: String, name: String }, + #[serde(rename = "developer")] + Developer { + content: MessageContent, + tools: Option>, + name: Option, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + Text(String), + Parts(Vec), +} + +// ============================================================================ +// Usage and Logging +// ============================================================================ + +/// Mirrors the Python vLLM `UsageInfo` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub total_tokens: u32, + pub completion_tokens: Option, + pub prompt_tokens_details: Option, +} + +impl Usage { + /// Create a Usage from prompt and completion token counts. + pub fn from_counts(prompt_tokens: u32, completion_tokens: u32) -> Self { + Self { + prompt_tokens, + total_tokens: prompt_tokens + completion_tokens, + completion_tokens: Some(completion_tokens), + prompt_tokens_details: None, + } + } +} + +/// Mirrors the Python vLLM `PromptTokenUsageInfo` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub struct PromptTokenUsageInfo { + pub cached_tokens: Option, +} + +/// OpenAI completions-style logprobs. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct LogProbs { + pub tokens: Vec, + pub token_logprobs: Vec>, + pub top_logprobs: Vec>>, + pub text_offset: Vec, +} + +/// Mirrors the Python vLLM `ChatCompletionLogProbs` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub struct ChatLogProbs { + pub content: Option>, +} + +/// Mirrors the Python vLLM `ChatCompletionLogProbsContent` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub struct ChatLogProbsContent { + pub token: String, + pub logprob: f32, + pub bytes: Option>, + pub top_logprobs: Vec, +} + +/// Mirrors the Python vLLM `ChatCompletionLogProb` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub struct TopLogProb { + pub token: String, + pub logprob: f32, + pub bytes: Option>, +} + +// ============================================================================ +// Error Types +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorResponse { + pub error: ErrorDetail, +} + +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorDetail { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, + pub param: Option, + pub code: Option, +} + +// ============================================================================ +// Model types +// ============================================================================ + +/// A single model entry in the `/v1/models` response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelObject { + pub id: String, + pub object: String, + pub created: i64, + pub owned_by: String, +} + +/// Response body for `GET /v1/models`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ListModelsResponse { + pub object: String, + pub data: Vec, +} + +// ============================================================================ +// Normalizable trait +// ============================================================================ + +/// Trait for request types that need post-deserialization normalization. +pub trait Normalizable { + /// Normalize the request by applying defaults and transformations. + fn normalize(&mut self) { + // Default: no-op + } +} diff --git a/rust/src/server/src/routes/openai/utils/validated_json.rs b/rust/src/server/src/routes/openai/utils/validated_json.rs new file mode 100644 index 00000000000..d07158cc2e0 --- /dev/null +++ b/rust/src/server/src/routes/openai/utils/validated_json.rs @@ -0,0 +1,55 @@ +//! Validated JSON extractor for automatic request validation. +//! Variation of https://github.com/lightseekorg/smg/blob/main/crates/protocols/src/validated.rs + +use axum::Json; +use axum::extract::rejection::JsonRejection; +use axum::extract::{FromRequest, Request}; +use serde::de::DeserializeOwned; +use validator::Validate; + +use super::types::Normalizable; +use crate::error::{ApiError, invalid_request}; + +/// A JSON extractor that automatically validates and normalizes the request +/// body. +/// +/// This extractor deserializes the request body and automatically calls +/// `.validate()` on types that implement the `Validate` trait. If validation +/// fails, it returns [`ApiError::InvalidRequest`] with details about the +/// validation errors. +pub struct ValidatedJson(pub T); + +impl FromRequest for ValidatedJson +where + T: DeserializeOwned + Validate + Normalizable + Send, + S: Send + Sync, +{ + type Rejection = ApiError; + + async fn from_request(req: Request, state: &S) -> Result { + let Json(mut data) = Json::::from_request(req, state) + .await + .map_err(|err: JsonRejection| ApiError::json_parse_error(err.body_text()))?; + + data.normalize(); + + data.validate() + .map_err(|validation_errors| invalid_request!("{}", validation_errors))?; + + Ok(ValidatedJson(data)) + } +} + +impl std::ops::Deref for ValidatedJson { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for ValidatedJson { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/rust/src/server/src/routes/sleep.rs b/rust/src/server/src/routes/sleep.rs new file mode 100644 index 00000000000..d7b279699b3 --- /dev/null +++ b/rust/src/server/src/routes/sleep.rs @@ -0,0 +1,78 @@ +use std::sync::Arc; + +use axum::Json; +use axum::extract::{Query, State}; +use axum::http::StatusCode; +use serde::{Deserialize, Serialize}; + +use crate::error::ApiError; +use crate::state::AppState; +use crate::utils::utility_call_error; + +#[derive(Serialize)] +pub(crate) struct IsSleepingResponse { + is_sleeping: bool, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct SleepParams { + #[serde(default = "default_sleep_level")] + level: u32, + #[serde(default = "default_sleep_mode")] + mode: String, +} + +#[derive(Debug, Default, Deserialize)] +pub(crate) struct WakeUpParams { + #[serde(default)] + tags: Option>, +} + +const fn default_sleep_level() -> u32 { + 1 +} + +fn default_sleep_mode() -> String { + "abort".to_string() +} + +/// Put the engine to sleep. +pub async fn sleep( + State(state): State>, + Query(params): Query, +) -> Result { + state + .engine_core_client() + .sleep(params.level, ¶ms.mode) + .await + .map_err(|error| utility_call_error("sleep", error))?; + + Ok(StatusCode::OK) +} + +/// Wake the engine from sleep mode. +pub async fn wake_up( + State(state): State>, + Query(params): Query, +) -> Result { + state + .engine_core_client() + .wake_up(params.tags) + .await + .map_err(|error| utility_call_error("wake_up", error))?; + + Ok(StatusCode::OK) +} + +/// Return whether the engine is currently sleeping at any level. +pub async fn is_sleeping( + State(state): State>, +) -> Result, ApiError> { + let is_sleeping = state + .engine_core_client() + .is_sleeping() + .await + .map_err(|error| utility_call_error("is_sleeping", error))?; + + Ok(Json(IsSleepingResponse { is_sleeping })) +} diff --git a/rust/src/server/src/routes/tests.rs b/rust/src/server/src/routes/tests.rs new file mode 100644 index 00000000000..dbf904cd1f7 --- /dev/null +++ b/rust/src/server/src/routes/tests.rs @@ -0,0 +1,3679 @@ +// Route tests should use `Service::call` rather than `ServiceExt::oneshot`. +// `oneshot` consumes the router and can drop `AppState` before a streaming +// response body is fully drained, which closes the mock engine connection too +// early and causes flaky `closed unexpectedly` failures. + +use std::collections::BTreeSet; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use std::{fmt, fs}; + +use axum::body::{Body, to_bytes}; +use axum::http::{Request, StatusCode}; +use bytes::Bytes; +use futures::StreamExt as _; +use rmpv::Value; +use serde_json::json; +use serial_test::serial; +use tower::{Service as _, ServiceExt as _}; +use vllm_chat::{ + ChatBackend, ChatContent, ChatContentPart, ChatEvent, ChatLlm, ChatMessage, ChatRenderer, + ChatRequest, ChatRole, ChatTextBackend, DefaultChatOutputProcessor, DynChatOutputProcessor, + DynChatRenderer, NewChatOutputProcessorOptions, SamplingParams, +}; +use vllm_engine_core_client::protocol::logprobs::{ + Logprobs, MaybeWireLogprobs, PositionLogprobs, TokenLogprob, +}; +use vllm_engine_core_client::protocol::{ + EngineCoreFinishReason, EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, StopReason, + UtilityOutput, UtilityResultEnvelope, decode_value, +}; +use vllm_engine_core_client::test_utils::{IpcNamespace, spawn_mock_engine_task}; +use vllm_engine_core_client::{ + ENGINE_CORE_DEAD_SENTINEL, EngineCoreClient, EngineCoreClientConfig, EngineId, +}; +use vllm_llm::Llm; +use vllm_metrics::METRICS; +use vllm_text::tokenizer::{DynTokenizer, Tokenizer}; +use vllm_text::{Prompt, TextBackend}; +use zeromq::prelude::{SocketRecv, SocketSend}; +use zeromq::{DealerSocket, PushSocket, ZmqMessage}; + +use super::{build_router, build_router_with_dev_mode}; +use crate::routes::openai::chat_completions::convert::prepare_chat_request; +use crate::state::AppState; + +fn request_output( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, +) -> EngineCoreOutput { + request_output_with_stop_reason(request_id, new_token_ids, finish_reason, None) +} + +fn request_output_with_stop_reason( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, + stop_reason: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: None, + new_prompt_logprobs_tensors: None, + pooling_output: None, + finish_reason, + stop_reason, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn request_output_with_logprobs( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, + stop_reason: Option, + new_logprobs: Option, + new_prompt_logprobs_tensors: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: new_logprobs.map(MaybeWireLogprobs::Direct), + new_prompt_logprobs_tensors: new_prompt_logprobs_tensors.map(MaybeWireLogprobs::Direct), + pooling_output: None, + finish_reason, + stop_reason, + events: None, + kv_transfer_params: None, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn request_output_with_logprobs_and_kv( + request_id: &str, + new_token_ids: Vec, + finish_reason: Option, + stop_reason: Option, + new_logprobs: Option, + new_prompt_logprobs_tensors: Option, + kv_transfer_params: Option, +) -> EngineCoreOutput { + EngineCoreOutput { + request_id: request_id.to_string(), + new_token_ids, + new_logprobs: new_logprobs.map(MaybeWireLogprobs::Direct), + new_prompt_logprobs_tensors: new_prompt_logprobs_tensors.map(MaybeWireLogprobs::Direct), + pooling_output: None, + finish_reason, + stop_reason, + events: None, + kv_transfer_params, + trace_headers: None, + prefill_stats: None, + routed_experts: None, + num_nans_in_logits: 0, + } +} + +fn bytes_to_token_ids(bytes: &[u8]) -> Vec { + bytes.iter().map(|byte| u32::from(*byte)).collect() +} + +fn default_stream_output_specs() -> Vec<(Vec, Option)> { + vec![ + (vec![b'h' as u32], None), + (vec![b'i' as u32], None), + (vec![b'!' as u32], Some(EngineCoreFinishReason::Stop)), + ] +} + +fn sse_data_payloads(text: &str) -> Vec<&str> { + text.lines().filter_map(|line| line.strip_prefix("data: ")).collect() +} + +type TestFuture<'a> = Pin + Send + 'a>>; + +fn boxed_test_future<'a>(future: impl Future + Send + 'a) -> TestFuture<'a> { + Box::pin(future) +} + +struct MockEngineTask { + shutdown_tx: Option>, + join_handle: Option>, +} + +impl MockEngineTask { + fn new( + (shutdown_tx, join_handle): ( + tokio::sync::oneshot::Sender<()>, + tokio::task::JoinHandle<()>, + ), + ) -> Self { + Self { + shutdown_tx: Some(shutdown_tx), + join_handle: Some(join_handle), + } + } + + async fn finish(self) { + self.await.expect("mock engine task"); + } + + fn abort(&self) { + if let Some(join_handle) = &self.join_handle { + join_handle.abort(); + } + } + + async fn abort_and_join(mut self) { + if let Some(join_handle) = self.join_handle.take() { + join_handle.abort(); + let _ = join_handle.await; + } + } +} + +impl Future for MockEngineTask { + type Output = Result<(), tokio::task::JoinError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + match self.join_handle.as_mut() { + Some(join_handle) => Pin::new(join_handle).poll(cx), + None => Poll::Ready(Ok(())), + } + } +} + +impl Drop for MockEngineTask { + fn drop(&mut self) { + if let Some(join_handle) = &self.join_handle { + join_handle.abort(); + } + } +} + +fn engine_outputs_for_request( + request_id: &str, + output_specs: Vec<(Vec, Option)>, +) -> EngineCoreOutputs { + EngineCoreOutputs { + engine_index: 0, + outputs: output_specs + .into_iter() + .map(|(token_ids, finish_reason)| request_output(request_id, token_ids, finish_reason)) + .collect(), + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + } +} + +fn test_llm(client: EngineCoreClient) -> Llm { + Llm::new(client).with_request_id_randomization(false) +} + +fn sample_logprobs_for_token(token_id: u32, alternate_token_id: u32) -> Logprobs { + Logprobs { + positions: vec![PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id, + logprob: -0.1, + rank: 1, + }, + TokenLogprob { + token_id: alternate_token_id, + logprob: -0.2, + rank: 1, + }, + ], + }], + } +} + +fn sample_logprobs_for_tokens(token_ids: &[u32]) -> Logprobs { + Logprobs { + positions: token_ids + .iter() + .map(|&token_id| PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id, + logprob: -0.1, + rank: 1, + }, + TokenLogprob { + token_id: token_id.saturating_add(1), + logprob: -0.2, + rank: 2, + }, + ], + }) + .collect(), + } +} + +fn prompt_logprobs_for_hello() -> Logprobs { + Logprobs { + positions: vec![ + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: b'e' as u32, + logprob: -0.3, + rank: 1, + }, + TokenLogprob { + token_id: b'a' as u32, + logprob: -0.5, + rank: 1, + }, + ], + }, + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: b'l' as u32, + logprob: -0.4, + rank: 1, + }, + TokenLogprob { + token_id: b'r' as u32, + logprob: -0.6, + rank: 1, + }, + ], + }, + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: b'l' as u32, + logprob: -0.45, + rank: 1, + }, + TokenLogprob { + token_id: b'i' as u32, + logprob: -0.65, + rank: 1, + }, + ], + }, + PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: b'o' as u32, + logprob: -0.5, + rank: 1, + }, + TokenLogprob { + token_id: b'u' as u32, + logprob: -0.7, + rank: 1, + }, + ], + }, + ], + } +} + +fn prompt_logprobs_for_tokens(token_ids: &[u32]) -> Logprobs { + Logprobs { + positions: token_ids + .iter() + .skip(1) + .map(|&token_id| PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id, + logprob: -0.3, + rank: 1, + }, + TokenLogprob { + token_id: token_id.saturating_add(1), + logprob: -0.5, + rank: 2, + }, + ], + }) + .collect(), + } +} + +fn utility_result_value(value: T) -> UtilityResultEnvelope +where + T: serde::Serialize, +{ + UtilityResultEnvelope::without_type_info(rmpv::ext::to_value(value).expect("encode result")) +} + +fn utility_none_result() -> UtilityResultEnvelope { + UtilityResultEnvelope::without_type_info(Value::Nil) +} + +fn utility_outputs(call_id: i64, result: UtilityResultEnvelope) -> EngineCoreOutputs { + EngineCoreOutputs { + utility_output: Some(UtilityOutput { + call_id, + failure_message: None, + result: Some(result), + }), + ..Default::default() + } +} + +async fn send_outputs(push: &mut PushSocket, outputs: EngineCoreOutputs) { + push.send(ZmqMessage::from( + rmp_serde::to_vec_named(&outputs).expect("encode outputs"), + )) + .await + .expect("send outputs"); +} + +async fn recv_engine_message(dealer: &mut DealerSocket) -> Vec { + dealer.recv().await.expect("recv engine message").into_vec() +} + +#[derive(Clone)] +struct FakeChatBackend { + model_id: String, + multimodal_model_info: Option, +} + +#[derive(Debug)] +struct FakeChatTokenizer; + +impl Tokenizer for FakeChatTokenizer { + fn encode( + &self, + text: &str, + _add_special_tokens: bool, + ) -> vllm_text::tokenizer::Result> { + let mut token_ids = Vec::new(); + let mut rest = text; + while !rest.is_empty() { + if let Some(stripped) = rest.strip_prefix("") { + token_ids.push(999); + rest = stripped; + continue; + } + + let ch = rest.chars().next().expect("rest is not empty"); + let mut buf = [0; 4]; + token_ids.extend(ch.encode_utf8(&mut buf).bytes().map(u32::from)); + rest = &rest[ch.len_utf8()..]; + } + Ok(token_ids) + } + + fn decode( + &self, + token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_text::tokenizer::Result { + Ok( + String::from_utf8_lossy(&token_ids.iter().map(|id| *id as u8).collect::>()) + .into_owned(), + ) + } + + fn token_to_id(&self, token: &str) -> Option { + match token { + "" => Some(999), + "<|image_pad|>" => Some(151655), + "" => Some(0xF001), + "" => Some(0xF002), + "<|START_THINKING|>" => Some(0xF003), + "<|END_THINKING|>" => Some(0xF004), + "◁think▷" => Some(0xF005), + "◁/think▷" => Some(0xF006), + _ => None, + } + } + + fn id_to_token(&self, id: u32) -> Option { + match id { + 999 => Some("".to_string()), + 151655 => Some("<|image_pad|>".to_string()), + 0xF001 => Some("".to_string()), + 0xF002 => Some("".to_string()), + 0xF003 => Some("<|START_THINKING|>".to_string()), + 0xF004 => Some("<|END_THINKING|>".to_string()), + 0xF005 => Some("◁think▷".to_string()), + 0xF006 => Some("◁/think▷".to_string()), + _ => None, + } + } +} + +impl FakeChatBackend { + fn new() -> Self { + Self { + model_id: "test-model".to_string(), + multimodal_model_info: None, + } + } + + fn with_model_id(model_id: impl Into) -> Self { + Self { + model_id: model_id.into(), + multimodal_model_info: None, + } + } + + fn with_multimodal_model_info( + multimodal_model_info: vllm_chat::multimodal::MultimodalModelInfo, + ) -> Self { + Self { + model_id: "test-model".to_string(), + multimodal_model_info: Some(multimodal_model_info), + } + } +} + +impl fmt::Debug for FakeChatBackend { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FakeChatBackend") + .field("model_id", &self.model_id) + .finish_non_exhaustive() + } +} + +impl TextBackend for FakeChatBackend { + fn tokenizer(&self) -> DynTokenizer { + Arc::new(FakeChatTokenizer) + } + + fn model_id(&self) -> &str { + &self.model_id + } +} + +impl ChatBackend for FakeChatBackend { + fn chat_renderer(&self) -> DynChatRenderer { + Arc::new(self.clone()) + } + + fn multimodal_model_info(&self) -> Option<&vllm_chat::multimodal::MultimodalModelInfo> { + self.multimodal_model_info.as_ref() + } + + fn new_chat_output_processor( + &self, + request: &mut ChatRequest, + options: NewChatOutputProcessorOptions<'_>, + ) -> vllm_chat::Result { + Ok(Box::new(DefaultChatOutputProcessor::new( + request, + &self.model_id, + self.tokenizer(), + options.tool_call_parser, + options.reasoning_parser, + )?)) + } +} + +impl ChatRenderer for FakeChatBackend { + fn render(&self, request: &ChatRequest) -> vllm_chat::Result { + let mut prompt = String::new(); + for message in &request.messages { + prompt.push_str(message.role().as_str()); + prompt.push_str(": "); + prompt.push_str(&render_fake_message_content(message)?); + prompt.push('\n'); + } + if request.chat_options.add_generation_prompt() { + prompt.push_str("assistant:"); + } + Ok(vllm_chat::RenderedPrompt { + prompt: Prompt::Text(prompt), + }) + } +} + +fn render_fake_message_content(message: &ChatMessage) -> vllm_chat::Result { + match message { + ChatMessage::System { content } + | ChatMessage::Developer { content, .. } + | ChatMessage::User { content } + | ChatMessage::ToolResponse { content, .. } => render_fake_content(content), + ChatMessage::Assistant { .. } => message.text_content(), + } +} + +fn render_fake_content(content: &ChatContent) -> vllm_chat::Result { + Ok(match content { + ChatContent::Text(text) => text.clone(), + ChatContent::Parts(parts) => { + let mut out = String::new(); + for part in parts { + match part { + ChatContentPart::Text { text } => out.push_str(text), + ChatContentPart::ImageUrl { .. } => out.push_str(""), + } + } + out + } + }) +} + +fn qwen_multimodal_model_info() -> vllm_chat::multimodal::MultimodalModelInfo { + let config_path = std::env::temp_dir().join(format!( + "vllm-server-qwen-config-{}.json", + uuid::Uuid::new_v4() + )); + fs::write( + &config_path, + r#"{"model_type":"qwen2_vl","vision_token_id":151655}"#, + ) + .expect("write qwen test config"); + let info = vllm_chat::multimodal::MultimodalModelInfo::from_paths( + "qwen2-vl-test".to_string(), + Some("qwen2_vl".to_string()), + Some(&config_path), + None, + Arc::new(FakeChatTokenizer), + ) + .expect("load multimodal info") + .expect("qwen multimodal info is registered"); + let _ = fs::remove_file(config_path); + info +} + +#[derive(Clone, Debug)] +struct FailingDecodeChatBackend; + +#[derive(Debug)] +struct FailingDecodeTokenizer; + +impl Tokenizer for FailingDecodeTokenizer { + fn encode( + &self, + text: &str, + add_special_tokens: bool, + ) -> vllm_text::tokenizer::Result> { + FakeChatTokenizer.encode(text, add_special_tokens) + } + + fn decode( + &self, + token_ids: &[u32], + skip_special_tokens: bool, + ) -> vllm_text::tokenizer::Result { + if token_ids.contains(&(b'i' as u32)) { + return Err(vllm_text::tokenizer::TokenizerError( + "forced decode failure for streaming test".to_string(), + )); + } + + FakeChatTokenizer.decode(token_ids, skip_special_tokens) + } + + fn token_to_id(&self, token: &str) -> Option { + FakeChatTokenizer.token_to_id(token) + } +} + +impl TextBackend for FailingDecodeChatBackend { + fn tokenizer(&self) -> DynTokenizer { + Arc::new(FailingDecodeTokenizer) + } + + fn model_id(&self) -> &str { + "test-model" + } +} + +impl ChatBackend for FailingDecodeChatBackend { + fn chat_renderer(&self) -> DynChatRenderer { + Arc::new(self.clone()) + } + + fn new_chat_output_processor( + &self, + _request: &mut ChatRequest, + _options: NewChatOutputProcessorOptions<'_>, + ) -> vllm_chat::Result { + Ok(Box::new(DefaultChatOutputProcessor::plain_text_only())) + } +} + +impl ChatRenderer for FailingDecodeChatBackend { + fn render(&self, request: &ChatRequest) -> vllm_chat::Result { + FakeChatBackend::new().render(request) + } +} + +async fn test_models_with_engine_outputs_and_backend_inner( + engine_id: impl Into, + output_specs: Vec<(Vec, Option)>, + expected_prompt_token_ids: Option>, + backend: Arc, +) -> (ChatLlm, MockEngineTask) { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = engine_id.into(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + move |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + if let Some(expected_prompt_token_ids) = expected_prompt_token_ids { + assert_eq!( + request.prompt_token_ids.as_deref(), + Some(expected_prompt_token_ids.as_slice()) + ); + } + send_outputs( + push, + engine_outputs_for_request(&request.request_id, output_specs), + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + + ( + ChatLlm::from_shared_backend(test_llm(client), backend), + engine_task, + ) +} + +async fn test_models_with_engine_outputs_and_backend( + engine_id: impl Into, + output_specs: Vec<(Vec, Option)>, + backend: Arc, +) -> (ChatLlm, MockEngineTask) { + test_models_with_engine_outputs_and_backend_inner(engine_id, output_specs, None, backend).await +} + +async fn test_chat_with_engine_outputs( + engine_id: impl Into, + output_specs: Vec<(Vec, Option)>, +) -> (ChatLlm, MockEngineTask) { + test_models_with_engine_outputs_and_backend( + engine_id, + output_specs, + Arc::new(FakeChatBackend::new()), + ) + .await +} + +async fn test_app() -> axum::Router { + let (chat, _engine_task) = test_models_with_engine_outputs_and_backend( + b"engine-openai", + default_stream_output_specs(), + Arc::new(FakeChatBackend::new()), + ) + .await; + build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))) +} + +async fn test_health_app_with_engine_script( + script: F, +) -> (axum::Router, Arc, MockEngineTask) +where + F: for<'a> FnOnce(&'a mut PushSocket) -> TestFuture<'a> + Send + 'static, +{ + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-health".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + move |_dealer, push| script(push), + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + + let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); + let state = Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + )); + (build_router(state.clone()), state, engine_task) +} + +async fn test_admin_app_with_engine_script(script: F) -> (axum::Router, MockEngineTask) +where + F: for<'a> FnOnce(&'a mut DealerSocket, &'a mut PushSocket) -> TestFuture<'a> + Send + 'static, +{ + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-admin".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + move |dealer, push| script(dealer, push), + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + + let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); + ( + build_router_with_dev_mode( + Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + )), + true, + ), + engine_task, + ) +} + +async fn test_app_with_engine_handle() -> (axum::Router, MockEngineTask) { + test_app_with_stream_output_specs(default_stream_output_specs()).await +} + +async fn test_app_with_stream_output_specs( + output_specs: Vec<(Vec, Option)>, +) -> (axum::Router, MockEngineTask) { + let (chat, engine_task) = test_models_with_engine_outputs_and_backend( + b"engine-openai", + output_specs, + Arc::new(FakeChatBackend::new()), + ) + .await; + ( + build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))), + engine_task, + ) +} + +async fn test_app_with_backend_and_stream_output_specs( + backend: Arc, + output_specs: Vec<(Vec, Option)>, +) -> (axum::Router, MockEngineTask) { + let (chat, engine_task) = + test_models_with_engine_outputs_and_backend(b"engine-openai", output_specs, backend).await; + ( + build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))), + engine_task, + ) +} + +async fn test_app_with_backend_and_engine_request_check( + backend: Arc, + check_request: F, +) -> (axum::Router, MockEngineTask) +where + F: FnOnce(&EngineCoreRequest) + Send + 'static, +{ + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-check-request".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + move |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + check_request(&request); + send_outputs( + push, + engine_outputs_for_request(&request.request_id, default_stream_output_specs()), + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + + let chat = ChatLlm::from_shared_backend(test_llm(client), backend); + ( + build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))), + engine_task, + ) +} + +async fn test_chat_with_engine_handle() -> (ChatLlm, MockEngineTask) { + test_chat_with_engine_outputs(b"engine-openai-chat", default_stream_output_specs()).await +} + +async fn server_load(app: &axum::Router) -> u64 { + let response = app + .clone() + .call( + Request::builder() + .method("GET") + .uri("/load") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + let value: serde_json::Value = serde_json::from_slice(&body).expect("json body"); + value["server_load"].as_u64().expect("server_load") +} + +async fn health_status(app: &axum::Router) -> (StatusCode, Bytes) { + let response = app + .clone() + .call( + Request::builder() + .method("GET") + .uri("/health") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + let status = response.status(); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + (status, body) +} + +fn metric_value(rendered: &str, metric: &str, labels: Option<&str>) -> Option { + rendered.lines().find_map(|line| { + let rest = line.strip_prefix(metric)?; + + match labels { + Some(labels) => { + let (encoded_labels, value) = rest.split_once("} ")?; + if !encoded_labels.starts_with('{') { + return None; + } + let expected_parts = labels.split(','); + if expected_parts.into_iter().all(|part| encoded_labels.contains(part)) { + value.parse::().ok() + } else { + None + } + } + None => rest.strip_prefix(' ').and_then(|value| value.parse::().ok()), + } + }) +} + +fn metric_delta( + rendered_before: &str, + rendered_after: &str, + metric: &str, + labels: Option<&str>, +) -> f64 { + metric_value(rendered_after, metric, labels).unwrap_or(0.0) + - metric_value(rendered_before, metric, labels).unwrap_or(0.0) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn list_models_returns_configured_model() { + let mut app = test_app().await; + let response = app + .call(Request::builder().uri("/v1/models").body(Body::empty()).expect("build request")) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + assert_eq!(json["data"][0]["id"], "Qwen/Qwen1.5-0.5B-Chat"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn http_metrics_record_list_models_requests() { + let mut app = test_app().await; + let before = METRICS.render().unwrap(); + + let response = app + .call( + Request::builder() + .method("GET") + .uri("/v1/models") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let after = METRICS.render().unwrap(); + assert_eq!( + metric_delta( + &before, + &after, + "http_requests_total", + Some("method=\"GET\",status=\"2xx\",handler=\"/v1/models\""), + ), + 1.0 + ); + assert_eq!( + metric_delta( + &before, + &after, + "http_request_duration_seconds_count", + Some("method=\"GET\",handler=\"/v1/models\""), + ), + 1.0 + ); + assert_eq!( + metric_delta( + &before, + &after, + "http_request_duration_highr_seconds_count", + None, + ), + 1.0 + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn wrong_model_returns_not_found() { + let mut app = test_app().await; + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "wrong-model", + "stream": true, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn invalid_request_returns_openai_error() { + let mut app = test_app().await; + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": false, + "stream_options": {"include_usage": true}, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + assert_eq!(json["error"]["type"], "invalid_request_error"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_chat_returns_json_response() { + let (app, engine_task) = test_app_with_engine_handle().await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": false, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + assert!( + response + .headers() + .get("content-type") + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.starts_with("application/json")) + ); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["object"], "chat.completion"); + assert_eq!(json["choices"][0]["message"]["role"], "assistant"); + assert_eq!(json["choices"][0]["message"]["content"], "hi"); + assert_eq!(json["choices"][0]["finish_reason"], "stop"); + assert_eq!(json["usage"]["prompt_tokens"], 22); + assert_eq!(json["usage"]["completion_tokens"], 3); + assert_eq!(json["usage"]["total_tokens"], 25); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_chat_image_url_reaches_engine_mm_features() { + let (app, engine_task) = test_app_with_backend_and_engine_request_check( + Arc::new(FakeChatBackend::with_multimodal_model_info( + qwen_multimodal_model_info(), + )), + |request| { + let prompt_token_ids = request.prompt_token_ids.as_ref().expect("prompt token ids"); + assert!(prompt_token_ids.contains(&151655)); + + let features = request.mm_features.as_ref().expect("multimodal features"); + assert_eq!(features.len(), 1); + assert_eq!(features[0].modality, "image"); + assert_eq!(features[0].identifier, "image-1"); + assert!(features[0].mm_position.length > 0); + assert!(features[0].mm_position.is_embed.is_some()); + + let data = features[0].data.as_ref().expect("feature data"); + assert!(data.contains_key("pixel_values")); + assert!(data.contains_key("image_grid_thw")); + }, + ) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": false, + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "describe "}, + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" + }, + "uuid": "image-1" + } + ] + }] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["object"], "chat.completion"); + assert_eq!(json["choices"][0]["message"]["content"], "hi"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_chat_includes_logprobs_and_prompt_logprobs() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-chat-logprobs".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + let prompt_token_ids = request.prompt_token_ids.clone().expect("prompt token ids"); + + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output_with_logprobs( + &request.request_id, + bytes_to_token_ids(b"hi"), + Some(EngineCoreFinishReason::Stop), + None, + Some(sample_logprobs_for_tokens(&bytes_to_token_ids(b"hi"))), + Some(prompt_logprobs_for_tokens(&prompt_token_ids)), + )], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + }, + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": false, + "logprobs": true, + "prompt_logprobs": 1, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!( + json["choices"][0]["logprobs"]["content"][0]["token"], + json!("h") + ); + assert_eq!( + json["choices"][0]["logprobs"]["content"][1]["token"], + json!("i") + ); + assert_eq!(json["prompt_logprobs"][0], serde_json::Value::Null); + assert!(json["prompt_logprobs"][1].is_object()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn happy_path_returns_sse_stream() { + let (app, engine_task) = test_app_with_engine_handle().await; + let before = METRICS.render().unwrap(); + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").and_then(|value| value.to_str().ok()), + Some("text/event-stream") + ); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let after = METRICS.render().unwrap(); + + assert!(text.contains("\"role\":\"assistant\""), "{text}"); + assert!(text.starts_with("data: "), "{text}"); + assert_eq!( + metric_delta( + &before, + &after, + "http_requests_total", + Some("method=\"POST\",status=\"2xx\",handler=\"/v1/chat/completions\""), + ), + 1.0 + ); + assert_eq!( + metric_delta( + &before, + &after, + "http_request_duration_seconds_count", + Some("method=\"POST\",handler=\"/v1/chat/completions\""), + ), + 1.0 + ); + assert_eq!( + metric_delta( + &before, + &after, + "http_request_duration_highr_seconds_count", + None, + ), + 1.0 + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn http_metrics_exclude_metrics_route() { + let mut app = test_app().await; + let before = METRICS.render().unwrap(); + + let response = app + .call( + Request::builder() + .method("GET") + .uri("/metrics") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let after = METRICS.render().unwrap(); + assert_eq!( + metric_delta( + &before, + &after, + "http_request_duration_highr_seconds_count", + None, + ), + 0.0 + ); + assert_eq!( + metric_value( + &after, + "http_requests_total", + Some("method=\"GET\",status=\"2xx\",handler=\"/metrics\""), + ), + None + ); + assert_eq!( + metric_value( + &after, + "http_request_duration_seconds_count", + Some("method=\"GET\",handler=\"/metrics\""), + ), + None + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn http_metrics_group_error_statuses() { + let mut app = test_app().await; + let before = METRICS.render().unwrap(); + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": false, + "stream_options": {"include_usage": true}, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let after = METRICS.render().unwrap(); + assert_eq!( + metric_delta( + &before, + &after, + "http_requests_total", + Some("method=\"POST\",status=\"4xx\",handler=\"/v1/chat/completions\""), + ), + 1.0 + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn load_endpoint_tracks_chat_stream_lifecycle() { + let (app, engine_task) = test_app_with_engine_handle().await; + + assert_eq!(server_load(&app).await, 0); + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(server_load(&app).await, 1); + + let _body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + + assert_eq!(server_load(&app).await, 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn health_endpoint_returns_ok_with_empty_body_when_client_is_healthy() { + let (app, _state, engine_task) = + test_health_app_with_engine_script(|_push| boxed_test_future(async move {})).await; + + let (status, body) = health_status(&app).await; + assert_eq!(status, StatusCode::OK); + assert!(body.is_empty(), "expected empty body, got {:?}", body); + + engine_task.await.expect("mock engine task"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn health_endpoint_returns_503_after_engine_core_dead_sentinel() { + let (app, state, engine_task) = test_health_app_with_engine_script(|push| { + boxed_test_future(async move { + push.send(ZmqMessage::from(ENGINE_CORE_DEAD_SENTINEL.to_vec())) + .await + .expect("send sentinel"); + }) + }) + .await; + + tokio::time::timeout(Duration::from_secs(2), async { + while state.chat.engine_core_client().is_healthy() { + tokio::task::yield_now().await; + } + }) + .await + .expect("wait for unhealthy client"); + + let (status, body) = health_status(&app).await; + assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); + assert!(body.is_empty(), "expected empty body, got {:?}", body); + assert!(matches!( + state.chat.engine_core_client().health_error().as_deref(), + Some(vllm_engine_core_client::Error::EngineCoreDead) + )); + + engine_task.await.expect("mock engine task"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn load_endpoint_resets_when_stream_response_is_dropped() { + let (app, engine_task) = test_app_with_engine_handle().await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": true + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(server_load(&app).await, 1); + + drop(response); + tokio::task::yield_now().await; + engine_task.await.expect("mock engine task"); + + assert_eq!(server_load(&app).await, 0); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn stream_error_is_returned_as_openai_error_sse() { + let (app, engine_task) = test_app_with_backend_and_stream_output_specs( + Arc::new(FailingDecodeChatBackend), + default_stream_output_specs(), + ) + .await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "stream_options": {"include_usage": true}, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + + assert!(text.contains("\"role\":\"assistant\""), "{text}"); + assert!(text.contains("\"type\":\"server_error\""), "{text}"); + assert!( + text.contains("forced decode failure for streaming test"), + "{text}" + ); + assert!(!text.contains("\"usage\":"), "{text}"); + assert!(text.trim_end().ends_with("data: [DONE]"), "{text}"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn invalid_terminal_finish_reason_is_returned_as_openai_error_sse() { + let (app, engine_task) = + test_app_with_stream_output_specs(vec![(vec![], Some(EngineCoreFinishReason::Error))]) + .await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "stream_options": {"include_usage": true}, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + + assert!(text.contains("\"role\":\"assistant\""), "{text}"); + assert!(text.contains("\"type\":\"server_error\""), "{text}"); + assert!(text.contains("Internal server error"), "{text}"); + assert!(!text.contains("\"usage\":"), "{text}"); + assert!(text.trim_end().ends_with("data: [DONE]"), "{text}"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn include_usage_adds_final_usage_chunk_before_done() { + let (app, engine_task) = test_app_with_stream_output_specs(default_stream_output_specs()).await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "stream_options": {"include_usage": true}, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + + let payloads = sse_data_payloads(&text); + let finish_index = payloads + .iter() + .position(|payload| payload.contains("\"finish_reason\":\"stop\"")) + .expect("finish chunk"); + let usage_index = payloads + .iter() + .position(|payload| payload.contains("\"usage\":")) + .expect("usage chunk"); + let done_index = + payloads.iter().position(|payload| *payload == "[DONE]").expect("done sentinel"); + + assert!(finish_index < usage_index, "{text}"); + assert!(usage_index < done_index, "{text}"); + + let usage_chunk: serde_json::Value = + serde_json::from_str(payloads[usage_index]).expect("usage chunk json"); + assert_eq!(usage_chunk["choices"], json!([])); + assert_eq!(usage_chunk["usage"]["prompt_tokens"], 22); + assert_eq!(usage_chunk["usage"]["completion_tokens"], 3); + assert_eq!(usage_chunk["usage"]["total_tokens"], 25); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn stream_without_include_usage_keeps_existing_shape() { + let (app, engine_task) = test_app_with_stream_output_specs(default_stream_output_specs()).await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + + assert!(!text.contains("\"usage\":"), "{text}"); + assert!(text.contains("\"finish_reason\":\"stop\""), "{text}"); + assert!(text.trim_end().ends_with("data: [DONE]"), "{text}"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn completions_invalid_request_returns_openai_error() { + let mut app = test_app().await; + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + "stream_options": {"include_usage": true} + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + assert_eq!(json["error"]["type"], "invalid_request_error"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_completions_return_json_response() { + let (app, engine_task) = test_app_with_engine_handle().await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + assert!( + response + .headers() + .get("content-type") + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.starts_with("application/json")) + ); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["object"], "text_completion"); + assert_eq!(json["choices"][0]["text"], "hi"); + assert_eq!(json["choices"][0]["finish_reason"], "stop"); + assert_eq!(json["usage"]["completion_tokens"], 3); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_completions_echo_prepends_prompt_text() { + let (app, engine_task) = test_app_with_engine_handle().await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "echo": true, + "stream": false + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["choices"][0]["text"], "hellohi"); + assert_eq!(json["usage"]["prompt_tokens"], 5); + assert_eq!(json["usage"]["completion_tokens"], 3); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_completions_include_logprobs() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-completion-logprobs".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![ + request_output_with_logprobs( + &request.request_id, + vec![b'h' as u32], + None, + None, + Some(sample_logprobs_for_token(b'h' as u32, b'H' as u32)), + None, + ), + request_output_with_logprobs( + &request.request_id, + vec![b'i' as u32], + Some(EngineCoreFinishReason::Stop), + None, + Some(sample_logprobs_for_token(b'i' as u32, b'I' as u32)), + None, + ), + ], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: Some(BTreeSet::from([request.request_id.clone()])), + wave_complete: None, + start_wave: None, + }, + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + "logprobs": 1 + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["choices"][0]["logprobs"]["tokens"], json!(["h", "i"])); + assert_eq!( + json["choices"][0]["logprobs"]["token_logprobs"], + json!([-0.1, -0.1]) + ); + assert_eq!(json["choices"][0]["logprobs"]["text_offset"], json!([0, 1])); + assert_eq!( + json["choices"][0]["logprobs"]["top_logprobs"][0], + json!({"h": -0.1, "H": -0.2}) + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_completions_include_prompt_logprobs() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-completion-prompt-logprobs".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output_with_logprobs( + &request.request_id, + vec![b'h' as u32, b'i' as u32, b'!' as u32], + Some(EngineCoreFinishReason::Stop), + None, + Some(Logprobs { + positions: vec![ + sample_logprobs_for_token(b'h' as u32, b'H' as u32).positions + [0] + .clone(), + sample_logprobs_for_token(b'i' as u32, b'I' as u32).positions + [0] + .clone(), + sample_logprobs_for_token(b'!' as u32, b'?' as u32).positions + [0] + .clone(), + ], + }), + Some(prompt_logprobs_for_hello()), + )], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + }, + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + "echo": true, + "logprobs": 1 + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["choices"][0]["text"], "hellohi"); + assert_eq!( + json["choices"][0]["logprobs"]["tokens"], + json!(["h", "e", "l", "l", "o", "h", "i", "!"]) + ); + assert_eq!( + json["choices"][0]["logprobs"]["text_offset"], + json!([0, 1, 2, 3, 4, 5, 6, 7]) + ); + assert_eq!( + json["choices"][0]["logprobs"]["token_logprobs"], + json!([null, -0.3, -0.4, -0.45, -0.5, -0.1, -0.1, -0.1]) + ); + assert_eq!( + json["choices"][0]["prompt_logprobs"][0], + serde_json::Value::Null + ); + assert_eq!( + json["choices"][0]["prompt_logprobs"][1], + json!({"a": -0.5, "e": -0.3}) + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_chat_completions_still_succeed() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-chat-non-stream".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + send_outputs( + push, + engine_outputs_for_request(&request.request_id, default_stream_output_specs()), + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": false, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + engine_task.await.expect("mock engine task"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_completions_still_succeed() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-completion-non-stream".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + send_outputs( + push, + engine_outputs_for_request(&request.request_id, default_stream_output_specs()), + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend(test_llm(client), Arc::new(FakeChatBackend::new())); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + engine_task.await.expect("mock engine task"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn chat_completions_header_request_id_takes_precedence() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-chat-request-id-precedence".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + assert_eq!( + request.external_req_id.as_deref(), + Some("chatcmpl-header-req") + ); + assert!(request.request_id.starts_with("chatcmpl-header-req-")); + assert_ne!(request.request_id, "chatcmpl-header-req"); + + send_outputs( + push, + engine_outputs_for_request(&request.request_id, default_stream_output_specs()), + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend(Llm::new(client), Arc::new(FakeChatBackend::new())); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .header("X-Request-Id", "header-req") + .body(Body::from( + json!({ + "request_id": "body-req", + "model": "Qwen/Qwen1.5-0.5B-Chat", + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["id"], "chatcmpl-header-req"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_raw_generate_returns_token_output_envelope() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-raw-generate-non-stream".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + assert_eq!(request.prompt_token_ids.as_deref(), Some(&[11, 22][..])); + assert_eq!(request.external_req_id.as_deref(), Some("raw-req")); + assert!(request.request_id.starts_with("raw-req-")); + assert_ne!(request.request_id, "raw-req"); + + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![ + request_output_with_logprobs( + &request.request_id, + vec![33], + None, + None, + Some(sample_logprobs_for_token(33, 34)), + Some(prompt_logprobs_for_tokens(&[11, 22])), + ), + request_output_with_logprobs_and_kv( + &request.request_id, + vec![44], + Some(EngineCoreFinishReason::Stop), + None, + Some(sample_logprobs_for_token(44, 45)), + None, + Some(json!({"connector": "x"})), + ), + ], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + }, + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend(Llm::new(client), Arc::new(FakeChatBackend::new())); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/inference/v1/generate") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "request_id": "raw-req", + "model": "Qwen/Qwen1.5-0.5B-Chat", + "token_ids": [11, 22], + "stream": false, + "sampling_params": { + "max_tokens": 2, + "logprobs": 1, + "prompt_logprobs": 1 + } + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["request_id"], "raw-req"); + assert_eq!(json["choices"][0]["index"], 0); + assert_eq!(json["choices"][0]["token_ids"], json!([33, 44])); + assert_eq!(json["choices"][0]["finish_reason"], "stop"); + assert_eq!( + json["choices"][0]["logprobs"]["content"][0]["token"], + "token_id:33" + ); + assert_eq!( + json["choices"][0]["logprobs"]["content"][1]["top_logprobs"][0]["token"], + "token_id:44" + ); + assert_eq!(json["prompt_logprobs"][0], serde_json::Value::Null); + assert_eq!( + json["prompt_logprobs"][1]["22"]["decoded_token"], + "token_id:22" + ); + assert_eq!(json["kv_transfer_params"], json!({"connector": "x"})); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn raw_generate_rejects_streaming() { + let mut app = test_app().await; + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/inference/v1/generate") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "token_ids": [11, 22], + "stream": true, + "sampling_params": {} + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + assert_eq!(json["error"]["param"], "stream"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn raw_generate_rejects_empty_token_ids() { + let mut app = test_app().await; + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/inference/v1/generate") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "token_ids": [], + "sampling_params": {} + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + assert_eq!(json["error"]["param"], "token_ids"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn raw_generate_rejects_wrong_model() { + let mut app = test_app().await; + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/inference/v1/generate") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "wrong-model", + "token_ids": [11, 22], + "sampling_params": {} + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn completions_happy_path_returns_sse_stream() { + let (app, engine_task) = test_app_with_engine_handle().await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": true, + "stream_options": {"include_usage": true} + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").and_then(|value| value.to_str().ok()), + Some("text/event-stream") + ); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let payloads = sse_data_payloads(&text); + let usage_index = payloads + .iter() + .position(|payload| payload.contains("\"usage\":")) + .expect("usage chunk"); + let done_index = + payloads.iter().position(|payload| *payload == "[DONE]").expect("done sentinel"); + + assert!( + payloads.iter().any(|payload| payload.contains("\"text\":\"h\"")), + "{text}" + ); + assert!( + payloads.iter().any(|payload| payload.contains("\"finish_reason\":\"stop\"")), + "{text}" + ); + assert!(usage_index < done_index, "{text}"); + + let usage_chunk: serde_json::Value = + serde_json::from_str(payloads[usage_index]).expect("usage chunk json"); + assert_eq!(usage_chunk["choices"], json!([])); + assert_eq!(usage_chunk["usage"]["completion_tokens"], 3); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn completions_echo_stream_emits_separate_prompt_chunk() { + let (app, engine_task) = test_app_with_engine_handle().await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "echo": true, + "stream": true, + "stream_options": {"include_usage": true} + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let payloads = sse_data_payloads(&text); + let hello_index = payloads + .iter() + .position(|payload| payload.contains("\"text\":\"hello\"")) + .expect("prompt echo chunk"); + let h_index = payloads + .iter() + .position(|payload| payload.contains("\"text\":\"h\"")) + .expect("first generation chunk"); + + assert!(hello_index < h_index, "{text}"); + assert!( + payloads.iter().any(|payload| payload.contains("\"text\":\"i\"")), + "{text}" + ); + + let usage_chunk: serde_json::Value = serde_json::from_str( + payloads + .iter() + .find(|payload| payload.contains("\"usage\":")) + .expect("usage chunk"), + ) + .expect("usage chunk json"); + assert_eq!(usage_chunk["usage"]["prompt_tokens"], 5); + assert_eq!(usage_chunk["usage"]["completion_tokens"], 3); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn chat_harness_streams_text_events() { + let (chat, engine_task) = test_chat_with_engine_handle().await; + let mut stream = chat + .chat(ChatRequest { + messages: vec![ChatMessage::text(ChatRole::User, "hello")], + sampling_params: SamplingParams { + max_tokens: Some(8), + ..Default::default() + }, + request_id: "chat-harness".to_string(), + ..ChatRequest::for_test() + }) + .await + .expect("submit chat request"); + + let mut saw_text = false; + let mut saw_done = false; + while let Some(event) = stream.next().await { + match event.expect("chat event") { + ChatEvent::BlockDelta { .. } => saw_text = true, + ChatEvent::Done { .. } => { + saw_done = true; + break; + } + ChatEvent::Start { .. } + | ChatEvent::LogprobsDelta { .. } + | ChatEvent::BlockStart { .. } + | ChatEvent::BlockEnd { .. } + | ChatEvent::ToolCallStart { .. } + | ChatEvent::ToolCallArgumentsDelta { .. } + | ChatEvent::ToolCallEnd { .. } => {} + } + } + engine_task.await.expect("mock engine task"); + + assert!(saw_text); + assert!(saw_done); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn prepared_openai_request_streams_text_events() { + let (chat, engine_task) = test_chat_with_engine_handle().await; + let prepared = prepare_chat_request( + serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "messages": [{"role": "user", "content": "hello"}] + })) + .expect("decode request"), + &["Qwen/Qwen1.5-0.5B-Chat".to_string()], + crate::utils::ResolvedRequestContext::default(), + ) + .expect("prepare request"); + + let mut stream = chat.chat(prepared.chat_request).await.expect("submit chat request"); + + let mut saw_text = false; + let mut saw_done = false; + while let Some(event) = stream.next().await { + match event.expect("chat event") { + ChatEvent::BlockDelta { .. } => saw_text = true, + ChatEvent::Done { .. } => { + saw_done = true; + break; + } + ChatEvent::Start { .. } + | ChatEvent::LogprobsDelta { .. } + | ChatEvent::BlockStart { .. } + | ChatEvent::BlockEnd { .. } + | ChatEvent::ToolCallStart { .. } + | ChatEvent::ToolCallArgumentsDelta { .. } + | ChatEvent::ToolCallEnd { .. } => {} + } + } + engine_task.await.expect("mock engine task"); + + assert!(saw_text); + assert!(saw_done); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn reasoning_blocks_are_mapped_to_reasoning_sse_chunks() { + let (app, engine_task) = test_app_with_backend_and_stream_output_specs( + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")), + vec![ + (bytes_to_token_ids(b""), None), + (bytes_to_token_ids(b"think "), None), + (bytes_to_token_ids(b"more"), None), + ( + bytes_to_token_ids(b"answer"), + Some(EngineCoreFinishReason::Length), + ), + ], + ) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + + assert!(text.contains("\"reasoning\":\"think \""), "{text}"); + assert!(text.contains("\"reasoning\":\"more\""), "{text}"); + assert!(text.contains("\"content\":\"answer\""), "{text}"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn tool_calls_are_mapped_to_tool_call_sse_chunks() { + let (app, engine_task) = test_app_with_backend_and_stream_output_specs( + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")), + vec![ + (bytes_to_token_ids(b"Need tool."), None), + ( + bytes_to_token_ids(b"\n{\"name\":\"get_weather\", "), + None, + ), + ( + bytes_to_token_ids(b"\"arguments\":{\"city\":\"Paris\"}}\n"), + Some(EngineCoreFinishReason::Stop), + ), + ], + ) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "messages": [{"role": "user", "content": "hello"}], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}} + } + } + }] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + + assert!(text.contains("\"tool_calls\":"), "{text}"); + assert!(text.contains("\"name\":\"get_weather\""), "{text}"); + assert!( + text.contains("\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\""), + "{text}" + ); + assert!(text.contains("\"finish_reason\":\"tool_calls\""), "{text}"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn tool_call_sse_chunks_can_carry_logprobs() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-openai-chat-tools-logprobs".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + Box::pin(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output_with_logprobs( + &request.request_id, + bytes_to_token_ids(b"Need tool."), + None, + None, + Some(sample_logprobs_for_tokens(&bytes_to_token_ids( + b"Need tool.", + ))), + None, + )], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + }, + ) + .await; + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output_with_logprobs( + &request.request_id, + bytes_to_token_ids(b"\n{\"name\":\"get_weather\", "), + None, + None, + Some(sample_logprobs_for_tokens(&bytes_to_token_ids( + b"\n{\"name\":\"get_weather\", ", + ))), + None, + )], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + }, + ) + .await; + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![request_output_with_logprobs( + &request.request_id, + bytes_to_token_ids( + b"\"arguments\":{\"city\":\"Paris\"}}\n", + ), + Some(EngineCoreFinishReason::Stop), + None, + Some(sample_logprobs_for_tokens(&bytes_to_token_ids( + b"\"arguments\":{\"city\":\"Paris\"}}\n", + ))), + None, + )], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: Some(BTreeSet::from([request.request_id.clone()])), + wave_complete: None, + start_wave: None, + }, + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend( + test_llm(client), + Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")), + ); + let app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "logprobs": true, + "messages": [{"role": "user", "content": "hello"}], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}} + } + } + }] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.finish().await; + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + + assert!(text.contains("\"tool_calls\":"), "{text}"); + assert!(text.contains("\"logprobs\":{\"content\":"), "{text}"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn streaming_chat_prompt_logprobs_are_rejected() { + let (app, engine_task) = test_app_with_engine_handle().await; + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "stream": true, + "prompt_logprobs": 1, + "messages": [{"role": "user", "content": "hello"}] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + engine_task.abort(); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn reset_prefix_cache_route_sends_expected_utility_call() { + let (app, engine_task) = test_admin_app_with_engine_script(|dealer, push| { + boxed_test_future(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + + let payload = decode_value(&utility[1]).expect("decode utility payload"); + let array = payload.as_array().expect("utility payload array"); + let call_id = array[1].as_i64().expect("call id"); + + assert_eq!(array[2], Value::from("reset_prefix_cache")); + assert_eq!( + array[3], + Value::Array(vec![Value::from(true), Value::from(true)]) + ); + + send_outputs(push, utility_outputs(call_id, utility_result_value(true))).await; + }) + }) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/reset_prefix_cache?reset_running_requests=true&reset_external=true") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + let status = response.status(); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); + assert!(body.is_empty()); + engine_task.await.expect("mock engine task"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn reset_mm_cache_route_sends_expected_utility_call() { + let (app, engine_task) = test_admin_app_with_engine_script(|dealer, push| { + boxed_test_future(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + + let payload = decode_value(&utility[1]).expect("decode utility payload"); + let array = payload.as_array().expect("utility payload array"); + let call_id = array[1].as_i64().expect("call id"); + + assert_eq!(array[2], Value::from("reset_mm_cache")); + assert_eq!(array[3], Value::Array(Vec::new())); + + send_outputs(push, utility_outputs(call_id, utility_none_result())).await; + }) + }) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/reset_mm_cache") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + let status = response.status(); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); + assert!(body.is_empty()); + engine_task.await.expect("mock engine task"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn reset_encoder_cache_route_sends_expected_utility_call() { + let (app, engine_task) = test_admin_app_with_engine_script(|dealer, push| { + boxed_test_future(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + + let payload = decode_value(&utility[1]).expect("decode utility payload"); + let array = payload.as_array().expect("utility payload array"); + let call_id = array[1].as_i64().expect("call id"); + + assert_eq!(array[2], Value::from("reset_encoder_cache")); + assert_eq!(array[3], Value::Array(Vec::new())); + + send_outputs(push, utility_outputs(call_id, utility_none_result())).await; + }) + }) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/reset_encoder_cache") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + let status = response.status(); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); + assert!(body.is_empty()); + engine_task.await.expect("mock engine task"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn collective_rpc_route_sends_expected_utility_call_and_returns_results() { + let (app, engine_task) = test_admin_app_with_engine_script(|dealer, push| { + boxed_test_future(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + + let payload = decode_value(&utility[1]).expect("decode utility payload"); + let array = payload.as_array().expect("utility payload array"); + let call_id = array[1].as_i64().expect("call id"); + + assert_eq!(array[2], Value::from("collective_rpc")); + assert_eq!( + array[3], + Value::Array(vec![ + Value::from("echo_args_kwargs"), + Value::from(1.5_f64), + Value::Array(vec![Value::from("arg1"), Value::from("arg2")]), + Value::Map(vec![ + (Value::from("key1"), Value::from("value1")), + (Value::from("key2"), Value::from("value2")), + ]), + ]) + ); + + send_outputs( + push, + utility_outputs( + call_id, + UtilityResultEnvelope::without_type_info(Value::Array(vec![Value::Map(vec![ + ( + Value::from("args"), + Value::Array(vec![Value::from("arg1"), Value::from("arg2")]), + ), + ( + Value::from("kwargs"), + Value::Map(vec![ + (Value::from("key1"), Value::from("value1")), + (Value::from("key2"), Value::from("value2")), + ]), + ), + (Value::from("total_items"), Value::from(4_u64)), + ])])), + ), + ) + .await; + }) + }) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/collective_rpc") + .header("content-type", "application/json") + .body(Body::from( + r#"{"method":"echo_args_kwargs","args":["arg1","arg2"],"kwargs":{"key1":"value1","key2":"value2"},"timeout":1.5}"#, + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + + assert_eq!( + serde_json::from_slice::(&body).expect("decode json"), + json!({ + "results": [{ + "args": ["arg1", "arg2"], + "kwargs": { + "key1": "value1", + "key2": "value2" + }, + "total_items": 4 + }] + }) + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn sleep_route_uses_python_compatible_default_query_values() { + let (app, engine_task) = test_admin_app_with_engine_script(|dealer, push| { + boxed_test_future(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + + let payload = decode_value(&utility[1]).expect("decode utility payload"); + let array = payload.as_array().expect("utility payload array"); + let call_id = array[1].as_i64().expect("call id"); + + assert_eq!(array[2], Value::from("sleep")); + assert_eq!( + array[3], + Value::Array(vec![Value::from(1_u64), Value::from("abort")]) + ); + + send_outputs(push, utility_outputs(call_id, utility_none_result())).await; + }) + }) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/sleep") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + let status = response.status(); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); + assert!(body.is_empty()); + engine_task.await.expect("mock engine task"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn wake_up_route_without_tags_sends_none() { + let (app, engine_task) = test_admin_app_with_engine_script(|dealer, push| { + boxed_test_future(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + + let payload = decode_value(&utility[1]).expect("decode utility payload"); + let array = payload.as_array().expect("utility payload array"); + let call_id = array[1].as_i64().expect("call id"); + + assert_eq!(array[2], Value::from("wake_up")); + assert_eq!(array[3], Value::Array(vec![Value::Nil])); + + send_outputs(push, utility_outputs(call_id, utility_none_result())).await; + }) + }) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("POST") + .uri("/wake_up") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + let status = response.status(); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + assert_eq!(status, StatusCode::OK, "{}", String::from_utf8_lossy(&body)); + assert!(body.is_empty()); + engine_task.await.expect("mock engine task"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn is_sleeping_route_returns_json_payload() { + let (app, engine_task) = test_admin_app_with_engine_script(|dealer, push| { + boxed_test_future(async move { + let utility = recv_engine_message(dealer).await; + assert_eq!(utility[0].as_ref(), &[0x03]); + + let payload = decode_value(&utility[1]).expect("decode utility payload"); + let array = payload.as_array().expect("utility payload array"); + let call_id = array[1].as_i64().expect("call id"); + + assert_eq!(array[2], Value::from("is_sleeping")); + assert_eq!(array[3], Value::Array(Vec::new())); + + send_outputs(push, utility_outputs(call_id, utility_result_value(true))).await; + }) + }) + .await; + + let response = app + .clone() + .call( + Request::builder() + .method("GET") + .uri("/is_sleeping") + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + + assert_eq!( + serde_json::from_slice::(&body).expect("decode json"), + json!({ "is_sleeping": true }) + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn admin_routes_are_hidden_when_dev_mode_is_disabled() { + let (chat, engine_task) = test_chat_with_engine_handle().await; + let app = build_router_with_dev_mode( + Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + )), + false, + ); + + for (method, uri) in [ + ("GET", "/is_sleeping"), + ("POST", "/sleep"), + ("POST", "/wake_up"), + ("POST", "/collective_rpc"), + ("POST", "/reset_prefix_cache"), + ("POST", "/reset_mm_cache"), + ("POST", "/reset_encoder_cache"), + ] { + let response = app + .clone() + .call( + Request::builder() + .method(method) + .uri(uri) + .body(Body::empty()) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND, "{method} {uri}"); + } + + engine_task.abort_and_join().await; +} + +// ========================= Stop string tests ========================= + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_completions_stop_string_excluded_from_output() { + // Engine generates "say world" but stop string "wor" truncates output to "say + // ". + let output_specs = vec![ + (bytes_to_token_ids(b"say"), None), + ( + bytes_to_token_ids(b" world"), + Some(EngineCoreFinishReason::Length), + ), + ]; + let (app, engine_task) = test_app_with_stream_output_specs(output_specs).await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + "stop": ["wor"] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["choices"][0]["text"], "say "); + assert_eq!(json["choices"][0]["finish_reason"], "stop"); + assert_eq!(json["choices"][0]["stop_reason"], "wor"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_completions_stop_string_included_in_output() { + // Same tokens but include_stop_str_in_output=true includes the stop string in + // the output. + let output_specs = vec![ + (bytes_to_token_ids(b"say"), None), + ( + bytes_to_token_ids(b" world"), + Some(EngineCoreFinishReason::Length), + ), + ]; + let (app, engine_task) = test_app_with_stream_output_specs(output_specs).await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + "stop": ["wor"], + "include_stop_str_in_output": true + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + assert_eq!(json["choices"][0]["text"], "say wor"); + assert_eq!(json["choices"][0]["finish_reason"], "stop"); + assert_eq!(json["choices"][0]["stop_reason"], "wor"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn stream_completions_stop_string_excluded_from_output() { + let output_specs = vec![ + (bytes_to_token_ids(b"say"), None), + ( + bytes_to_token_ids(b" world"), + Some(EngineCoreFinishReason::Length), + ), + ]; + let (app, engine_task) = test_app_with_stream_output_specs(output_specs).await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": true, + "stop": ["wor"] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let payloads = sse_data_payloads(&text); + + // Collect all text deltas from the SSE chunks. + let mut full_text = String::new(); + for payload in &payloads { + if *payload == "[DONE]" { + continue; + } + let chunk: serde_json::Value = serde_json::from_str(payload).expect("json chunk"); + if let Some(text) = chunk["choices"][0]["text"].as_str() { + full_text.push_str(text); + } + } + + // The concatenated text deltas should equal "say " (stop string excluded). + assert_eq!(full_text, "say ", "full streamed text: {text}"); + + // The final chunk should have finish_reason "stop". + assert!( + payloads.iter().any(|p| p.contains("\"finish_reason\":\"stop\"")), + "{text}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn stream_completions_stop_string_included_in_output() { + let output_specs = vec![ + (bytes_to_token_ids(b"say"), None), + ( + bytes_to_token_ids(b" world"), + Some(EngineCoreFinishReason::Length), + ), + ]; + let (app, engine_task) = test_app_with_stream_output_specs(output_specs).await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": true, + "stop": ["wor"], + "include_stop_str_in_output": true + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let payloads = sse_data_payloads(&text); + + let mut full_text = String::new(); + for payload in &payloads { + if *payload == "[DONE]" { + continue; + } + let chunk: serde_json::Value = serde_json::from_str(payload).expect("json chunk"); + if let Some(text) = chunk["choices"][0]["text"].as_str() { + full_text.push_str(text); + } + } + + // With include_stop_str_in_output, the stop string "wor" should be included. + assert_eq!(full_text, "say wor", "full streamed text: {text}"); + + assert!( + payloads.iter().any(|p| p.contains("\"finish_reason\":\"stop\"")), + "{text}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_completions_no_stop_string_match_preserves_original_finish_reason() { + // Stop string "xyz" does not appear in "hi!" so the original finish reason is + // preserved. + let (app, engine_task) = test_app_with_engine_handle().await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + "stop": ["xyz"] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + // Default output is "hi" (stop token '!' suppressed), finish_reason remains + // "stop" from EOS. + assert_eq!(json["choices"][0]["text"], "hi"); + assert_eq!(json["choices"][0]["finish_reason"], "stop"); + // No text stop string matched — stop_reason should be absent. + assert!(json["choices"][0]["stop_reason"].is_null()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn non_stream_completions_stop_string_array_matches_first_occurrence() { + // Multiple stop strings: "rl" appears in "world" but " wo" appears earlier. + let output_specs = vec![( + bytes_to_token_ids(b"say world"), + Some(EngineCoreFinishReason::Length), + )]; + let (app, engine_task) = test_app_with_stream_output_specs(output_specs).await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + "stop": [" wo", "rl"] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); + + // " wo" is detected first (at byte 3), so output is truncated to "say". + assert_eq!(json["choices"][0]["text"], "say"); + assert_eq!(json["choices"][0]["finish_reason"], "stop"); + assert_eq!(json["choices"][0]["stop_reason"], " wo"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn completions_empty_stop_string_returns_validation_error() { + let (app, _engine_task) = test_app_with_engine_handle().await; + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "prompt": "hello", + "stream": false, + "stop": [""] + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); +} diff --git a/rust/src/server/src/state.rs b/rust/src/server/src/state.rs new file mode 100644 index 00000000000..04d37f1a5d4 --- /dev/null +++ b/rust/src/server/src/state.rs @@ -0,0 +1,120 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use tokio::time::{Duration, Instant, sleep_until}; +use tracing::warn; +use vllm_chat::ChatLlm; +use vllm_engine_core_client::EngineCoreClient; + +const SHUTDOWN_REFCOUNT_POLL_INTERVAL: Duration = Duration::from_millis(100); + +/// Shared router state for the minimal single-model OpenAI server. +pub struct AppState { + /// All public model IDs served by this frontend. The first entry is the + /// primary ID used in responses; all entries are valid in requests. + served_model_names: Vec, + /// Shared chat facade used by all requests. + pub chat: ChatLlm, + /// Whether to log a summary line for each completed request. + pub enable_log_requests: bool, + /// Number of in-flight inference requests currently owned by this frontend. + server_load: AtomicU64, +} + +impl AppState { + /// Construct one application state instance. + /// + /// `served_model_names` must be non-empty; the first entry is the primary + /// model ID returned in API responses. + /// + /// # Panics + /// + /// Panics if `served_model_names` is empty. + pub fn new(served_model_names: Vec, chat: ChatLlm) -> Self { + assert!( + !served_model_names.is_empty(), + "served_model_names must not be empty" + ); + Self { + served_model_names, + chat, + enable_log_requests: false, + server_load: AtomicU64::new(0), + } + } + + /// Enable per-request completion logging. + pub fn with_log_requests(mut self, enabled: bool) -> Self { + self.enable_log_requests = enabled; + self + } + + /// The primary model name echoed back in API responses (the first served + /// name). + pub fn primary_model_name(&self) -> &str { + self.served_model_names.first().map(String::as_str).unwrap_or_default() + } + + /// All model names served by this frontend. + pub fn served_model_names(&self) -> &[String] { + &self.served_model_names + } + + /// Return a reference to the underlying engine core client for utility + /// calls. + pub(crate) fn engine_core_client(&self) -> &EngineCoreClient { + self.chat.engine_core_client() + } + + /// Return the current in-flight inference request count for the `/load` + /// endpoint. + pub fn server_load(&self) -> u64 { + self.server_load.load(Ordering::Relaxed) + } + + /// Increment the in-flight inference request count, called by the load + /// tracking middleware. + pub(crate) fn increment_server_load(&self) { + self.server_load.fetch_add(1, Ordering::Relaxed); + } + + /// Decrement the in-flight inference request count, called by the load + /// tracking middleware. + pub(crate) fn decrement_server_load(&self) { + self.server_load.fetch_sub(1, Ordering::Relaxed); + } + + /// Wait until all request-owned references are dropped, then shut down the + /// engine client. + /// + /// If the deadline elapses while request/connection tasks still hold state + /// references, skip the clean engine-client shutdown and let process + /// teardown reclaim the remaining resources. + pub async fn shutdown(mut self: Arc, deadline: Instant) -> anyhow::Result<()> { + loop { + match Arc::try_unwrap(self) { + Ok(state) => { + state.chat.shutdown().await?; + return Ok(()); + } + Err(state) => self = state, + } + let ref_count = Arc::strong_count(&self); + + let now = Instant::now(); + if now >= deadline { + warn!( + ref_count, + "shutdown deadline elapsed before app state became idle; skipping engine-client shutdown" + ); + return Ok(()); + } + + sleep_until(std::cmp::min( + deadline, + now + SHUTDOWN_REFCOUNT_POLL_INTERVAL, + )) + .await; + } + } +} diff --git a/rust/src/server/src/utils.rs b/rust/src/server/src/utils.rs new file mode 100644 index 00000000000..13fa0dfaeec --- /dev/null +++ b/rust/src/server/src/utils.rs @@ -0,0 +1,106 @@ +use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; + +use axum::http::HeaderMap; +use serde_json::Value; +use thiserror_ext::AsReport; +use uuid::Uuid; + +use crate::error::ApiError; + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ResolvedRequestContext { + pub request_id: String, + pub data_parallel_rank: Option, +} + +/// Return the current Unix timestamp in seconds for OpenAI response objects. +pub fn unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or_default() +} + +/// Construct an API error for a failed utility call to the engine core. +pub fn utility_call_error(method: &str, error: impl AsReport) -> ApiError { + ApiError::server_error(format!("failed to call {method}: {}", error.as_report())) +} + +/// Merge `kv_transfer_params` into the `vllm_xargs` map, mirroring the Python +/// vLLM behavior where `kv_transfer_params` is injected into `extra_args` for +/// engine-core consumption. +pub fn merge_kv_transfer_params( + mut xargs: Option>, + kv_transfer_params: Option<&HashMap>, +) -> Option> { + if let Some(kv_params) = kv_transfer_params { + let map = xargs.get_or_insert_with(HashMap::new); + map.insert( + "kv_transfer_params".to_string(), + // This is safe because we know that `kv_params` is already valid JSON. + serde_json::to_value(kv_params).unwrap(), + ); + } + xargs +} + +/// Convert OpenAI-style `logit_bias` with string token-ID keys into the +/// internal `HashMap` representation, validating that every key +/// parses as a `u32`. +pub fn convert_logit_bias( + logit_bias: Option>, +) -> Result>, ApiError> { + logit_bias + .map(|bias| { + bias.into_iter() + .map(|(key, value)| { + key.parse().map(|k| (k, value)).map_err(|_| { + ApiError::invalid_request( + format!( + "Invalid key in 'logit_bias': '{key}' is not a valid token ID. \ + Token IDs must be non-negative integers." + ), + Some("logit_bias"), + ) + }) + }) + .collect() + }) + .transpose() +} + +/// Extract common request metadata from HTTP headers: the external request ID +/// and the optional data-parallel rank used for engine routing. +pub fn resolve_request_context( + headers: &HeaderMap, + request_id: Option<&str>, +) -> ResolvedRequestContext { + // `None` when the header is absent or cannot be parsed as a `u32`. + let data_parallel_rank = headers + .get("X-data-parallel-rank") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.trim().parse().ok()); + + // Extract request id from header. + let request_id_header = headers.get("X-Request-Id").and_then(|value| value.to_str().ok()); + let request_id = resolve_base_request_id(request_id_header, request_id); + + ResolvedRequestContext { + request_id, + data_parallel_rank, + } +} + +/// Resolve the base external request ID before API-specific prefixes such as +/// `chatcmpl-`. +pub fn resolve_base_request_id( + request_id_header: Option<&str>, + request_id: Option<&str>, +) -> String { + request_id_header.or(request_id).map(ToOwned::to_owned).unwrap_or_else(|| { + let mut id = Uuid::new_v4().simple().to_string(); + id.truncate(8); + id + }) +} diff --git a/rust/src/text/Cargo.toml b/rust/src/text/Cargo.toml new file mode 100644 index 00000000000..04202348182 --- /dev/null +++ b/rust/src/text/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "vllm-text" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +anyhow.workspace = true +asynk-strim-attr.workspace = true +easy-ext.workspace = true +enum-as-inner.workspace = true +futures.workspace = true +hf-hub.workspace = true +itertools.workspace = true +serde.workspace = true +serde_json.workspace = true +serde_with.workspace = true +thiserror.workspace = true +thiserror-ext.workspace = true +tracing.workspace = true +trait-set.workspace = true +vllm-engine-core-client.workspace = true +vllm-llm.workspace = true +vllm-tokenizer.workspace = true + +[dev-dependencies] +expect-test.workspace = true +futures.workspace = true +tempfile.workspace = true +tokio.workspace = true +vllm-llm = { workspace = true, features = ["test-util"] } + +[lints] +workspace = true diff --git a/rust/src/text/src/backend/hf/config.rs b/rust/src/text/src/backend/hf/config.rs new file mode 100644 index 00000000000..5f2ecf8ba60 --- /dev/null +++ b/rust/src/text/src/backend/hf/config.rs @@ -0,0 +1,373 @@ +use std::collections::BTreeSet; +use std::fs; +use std::path::Path; + +use serde::{Deserialize, Serialize}; +use thiserror_ext::AsReport as _; + +use crate::error::{Error, Result}; + +/// Minimal subset of `tokenizer_config.json` needed by chat/EOS handling. +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +pub struct HfTokenizerConfig { + #[serde(flatten)] + pub special_tokens: HfSpecialTokens, + pub chat_template: Option, + /// The `tokenizer_class` field from HuggingFace tokenizer configs. Some + /// tiktoken-based models (e.g. DeepSeek, Kimi K2) set this to a value + /// containing "Tiktoken" which can be used as a hint for backend + /// selection. + pub tokenizer_class: Option, +} + +/// Hugging Face named special tokens may be serialized as a string or an +/// object carrying the token content. +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub enum NamedSpecialToken { + Text(String), + WithContent { content: String }, +} + +impl Serialize for NamedSpecialToken { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.as_str()) + } +} + +impl From for String { + fn from(value: NamedSpecialToken) -> Self { + match value { + NamedSpecialToken::Text(string) => string, + NamedSpecialToken::WithContent { content } => content, + } + } +} + +impl NamedSpecialToken { + pub fn as_str(&self) -> &str { + match self { + Self::Text(value) => value, + Self::WithContent { content } => content, + } + } +} + +/// Minimal set of special-token entries needed by chat/EOS handling. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +#[serde(default)] +pub struct HfSpecialTokens { + pub bos_token: Option, + pub eos_token: Option, + pub unk_token: Option, + pub pad_token: Option, +} + +impl HfSpecialTokens { + /// Returns true if we don't discover any special tokens in the config. + pub fn is_empty(&self) -> bool { + self.bos_token.is_none() + && self.eos_token.is_none() + && self.unk_token.is_none() + && self.pad_token.is_none() + } +} + +/// Minimal subset of `config.json` (the model's main HF config). +/// +/// This intentionally supports only the two layouts we currently care about in +/// the Rust frontend: +/// - pure text models that keep text metadata at the top level +/// - composite models that expose a single nested `text_config` +/// +/// We do not support additional entry points such as `decoder`, `generator`, or +/// `text_encoder`. +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +pub struct ModelConfig { + model_type: Option, + max_position_embeddings: Option, + num_attention_heads: Option, + num_experts: Option, + moe_num_experts: Option, + n_routed_experts: Option, + num_local_experts: Option, + block_configs: Vec, + text_config: Option>, +} + +/// Minimal subset of `generation_config.json`. +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +pub(super) struct GenerationConfig { + pub eos_token_id: Option, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub min_p: Option, + pub repetition_penalty: Option, + pub max_new_tokens: Option, +} + +/// HF generation configs allow either one EOS id or a list of EOS ids. +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub(super) enum OneOrManyTokenIds { + One(u32), + Many(Vec), +} + +impl OneOrManyTokenIds { + pub(super) fn into_set(self) -> BTreeSet { + match self { + Self::One(id) => BTreeSet::from([id]), + Self::Many(ids) => ids.into_iter().collect(), + } + } +} + +/// Hugging Face configs may expose the expert count either as one integer or +/// as a list of repeated integers. +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub(super) enum OneOrManyExpertCount { + One(u32), + Many(Vec), +} + +impl OneOrManyExpertCount { + fn first_value(&self) -> u32 { + match self { + Self::One(value) => *value, + // Python currently takes the first value for list[int] expert + // counts in remote-code configs. + Self::Many(values) => values.first().copied().unwrap_or(0), + } + } +} + +/// Heterogeneous block-level MoE metadata used as a fallback when no top-level +/// expert-count field is available. +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +pub(super) struct BlockConfig { + pub block_type: String, + pub n_routed_experts: u32, +} + +impl ModelConfig { + /// Return the config that the Rust frontend treats as the text/LLM config. + /// + /// This is deliberately narrower than Python/transformers: we only support + /// either the top-level config itself or a single nested `text_config`. + fn effective_text_config(&self) -> &Self { + self.text_config.as_deref().unwrap_or(self) + } + + /// Return the effective Hugging Face `model_type` used by the Rust + /// frontend. + /// + /// This follows the same simplified text-config selection as the rest of + /// this type: the top-level config wins, otherwise a single nested + /// `text_config` may provide the value. + pub fn model_type(&self) -> Option<&str> { + self.model_type.as_deref().or_else(|| self.text_config.as_deref()?.model_type()) + } + + /// Reject partially nested `text_config` payloads that are unlikely to be + /// valid LLM configs for our current use. + /// + /// This keeps the simplified Rust-side parsing honest: if a model declares + /// `text_config`, it must at least look like a real text model config. + fn validate_text_config_selection(&self) -> Result<()> { + if let Some(text_config) = self.text_config.as_deref() + && text_config.num_attention_heads.is_none() + { + return Err(Error::Tokenizer( + "the text config extracted from the model config does not have `num_attention_heads`" + .to_string(), + )); + } + + Ok(()) + } + + /// Match Python's current expert-count priority on the selected text + /// config. + /// + /// The only intentional simplification here is how we pick the text config: + /// Rust only looks at the top level or `text_config`, not the broader + /// transformers composite-config surface. + fn num_experts_from_block_configs(&self) -> u32 { + self.effective_text_config() + .block_configs + .iter() + .filter(|block| block.block_type == "moe") + .map(|block| block.n_routed_experts) + .max() + .unwrap_or(0) + } + + pub(super) fn num_experts(&self) -> u32 { + let config = self.effective_text_config(); + let direct = [ + config.num_experts.as_ref(), + config.moe_num_experts.as_ref(), + config.n_routed_experts.as_ref(), + config.num_local_experts.as_ref(), + ] + .into_iter() + .flatten() + .map(OneOrManyExpertCount::first_value) + .next() + .unwrap_or(0); + + if direct > 0 { + direct + } else { + self.num_experts_from_block_configs() + } + } + + pub(super) fn is_moe(&self) -> bool { + self.num_experts() > 0 + } + + pub(super) fn max_position_embeddings(&self) -> Option { + self.effective_text_config().max_position_embeddings + } +} + +/// Load the tokenizer-side EOS metadata if a config file is present. +pub fn load_tokenizer_config(path: Option<&Path>) -> Result { + read_json_file(path) +} + +/// Load the generation-side EOS metadata if a config file is present. +pub(super) fn load_generation_config(path: Option<&Path>) -> Result { + read_json_file(path) +} + +/// Load the model-side config (`config.json`) if present. +pub fn load_model_config(path: Option<&Path>) -> Result { + let config: ModelConfig = read_json_file(path)?; + config.validate_text_config_selection()?; + Ok(config) +} + +fn read_json_file(path: Option<&Path>) -> Result +where + T: for<'de> Deserialize<'de> + Default, +{ + let Some(path) = path else { + return Ok(T::default()); + }; + let content = fs::read_to_string(path).map_err(|error| { + Error::Tokenizer(format!( + "failed to read {}: {}", + path.display(), + error.as_report() + )) + })?; + serde_json::from_str(&content).map_err(|error| { + Error::Tokenizer(format!( + "failed to parse {}: {}", + path.display(), + error.as_report() + )) + }) +} + +#[cfg(test)] +mod tests { + use super::ModelConfig; + + #[test] + fn model_config_detects_moe_from_named_expert_fields() { + let field_names = [ + "num_experts", + "moe_num_experts", + "n_routed_experts", + "num_local_experts", + ]; + + for field_name in field_names { + let config: ModelConfig = + serde_json::from_str(&format!(r#"{{"{field_name}": 64}}"#)).unwrap(); + assert_eq!(config.num_experts(), 64, "field_name={field_name}"); + assert!(config.is_moe(), "field_name={field_name}"); + } + } + + #[test] + fn model_config_uses_first_value_for_list_expert_counts() { + let config: ModelConfig = serde_json::from_str(r#"{"num_experts":[16,16]}"#).unwrap(); + + assert_eq!(config.num_experts(), 16); + assert!(config.is_moe()); + } + + #[test] + fn model_config_falls_back_to_block_configs_maximum() { + let config: ModelConfig = serde_json::from_str( + r#"{ + "block_configs": [ + {"block_type":"attention","n_routed_experts":9}, + {"block_type":"moe","n_routed_experts":32}, + {"block_type":"moe","n_routed_experts":64} + ] + }"#, + ) + .unwrap(); + + assert_eq!(config.num_experts(), 64); + assert!(config.is_moe()); + } + + #[test] + fn model_config_prefers_nested_text_config_like_python_hf_text_config() { + let config: ModelConfig = serde_json::from_str( + r#"{ + "model_type": "top_level", + "num_experts": 64, + "max_position_embeddings": 8192, + "text_config": { + "model_type": "nested", + "num_attention_heads": 32, + "num_local_experts": 8, + "max_position_embeddings": 4096 + } + }"#, + ) + .unwrap(); + + assert_eq!(config.num_experts(), 8); + assert_eq!(config.model_type(), Some("top_level")); + assert_eq!(config.max_position_embeddings(), Some(4096)); + assert!(config.is_moe()); + } + + #[test] + fn model_config_defaults_to_non_moe_when_no_expert_metadata_exists() { + let config: ModelConfig = + serde_json::from_str(r#"{"max_position_embeddings":4096}"#).unwrap(); + + assert_eq!(config.num_experts(), 0); + assert!(!config.is_moe()); + assert_eq!(config.max_position_embeddings(), Some(4096)); + } + + #[test] + fn model_config_rejects_nested_text_config_without_attention_heads() { + let config: ModelConfig = + serde_json::from_str(r#"{"text_config":{"max_position_embeddings":4096}}"#).unwrap(); + + let error = config.validate_text_config_selection().unwrap_err(); + assert!(error.to_string().contains("does not have `num_attention_heads`"),); + } +} diff --git a/rust/src/text/src/backend/hf/mod.rs b/rust/src/text/src/backend/hf/mod.rs new file mode 100644 index 00000000000..a5d07dd8fc0 --- /dev/null +++ b/rust/src/text/src/backend/hf/mod.rs @@ -0,0 +1,120 @@ +mod config; +mod model_files; + +use std::collections::BTreeSet; +use std::sync::Arc; + +use tracing::info; +use vllm_tokenizer::{DynTokenizer, HuggingFaceTokenizer, TekkenTokenizer, TiktokenTokenizer}; + +use self::config::{GenerationConfig, load_generation_config}; +pub use self::config::{ + HfSpecialTokens, HfTokenizerConfig, ModelConfig, NamedSpecialToken, load_model_config, + load_tokenizer_config, +}; +pub use self::model_files::{ResolvedModelFiles, TokenizerSource}; +use crate::backend::{SamplingHints, TextBackend}; +use crate::error::Result; + +fn load_tokenizer(tokenizer: &TokenizerSource) -> Result { + match tokenizer { + TokenizerSource::HuggingFace(path) => Ok(Arc::new(HuggingFaceTokenizer::new(path)?)), + TokenizerSource::Tiktoken(path) => Ok(Arc::new(TiktokenTokenizer::new(path)?)), + TokenizerSource::Tekken(path) => Ok(Arc::new(TekkenTokenizer::new(path)?)), + } +} + +/// [`TextBackend`] implementation built on Hugging Face model files. +pub struct HfTextBackend { + model_id: String, + files: ResolvedModelFiles, + tokenizer: DynTokenizer, + /// Primary EOS handled by engine-core's dedicated EOS path. + primary_eos_token_id: Option, + /// Additional EOS ids that should flow through stop-token handling. + extra_eos_token_ids: BTreeSet, + /// Generation-config for sampling defaults that may be inherited when the + /// user does not explicitly override them. + generation_config: GenerationConfig, + /// Model config (`config.json`). + model_config: ModelConfig, +} + +impl HfTextBackend { + /// Load the text backend with the given model id. + pub async fn from_model(model_id: &str) -> Result { + let files = ResolvedModelFiles::new(model_id).await?; + Self::from_resolved_model_files(files, model_id.to_string()) + } + + /// Load the text backend from resolved Hugging Face model files. + pub fn from_resolved_model_files(files: ResolvedModelFiles, model_id: String) -> Result { + let tokenizer_config = load_tokenizer_config(files.tokenizer_config_path.as_deref())?; + let tokenizer = load_tokenizer(&files.tokenizer)?; + let primary_eos_token_id = tokenizer_config + .special_tokens + .eos_token + .as_ref() + .and_then(|token| tokenizer.token_to_id(token.as_str())); + + let model_config = load_model_config(files.config_path.as_deref())?; + let generation_config = load_generation_config(files.generation_config_path.as_deref())?; + let mut extra_eos_token_ids = generation_config + .eos_token_id + .clone() + .map(|value| value.into_set()) + .unwrap_or_default(); + if let Some(primary_eos_token_id) = primary_eos_token_id { + extra_eos_token_ids.remove(&primary_eos_token_id); + } + + info!( + model_id, + "loaded text backend with Hugging Face model files" + ); + + Ok(Self { + model_id, + files, + tokenizer, + primary_eos_token_id, + extra_eos_token_ids, + generation_config, + model_config, + }) + } + + /// Expose the resolved model files for use by the chat backend to load the + /// chat template. + pub fn resolved_model_files(&self) -> &ResolvedModelFiles { + &self.files + } +} + +impl TextBackend for HfTextBackend { + fn tokenizer(&self) -> DynTokenizer { + self.tokenizer.clone() + } + + fn is_moe(&self) -> bool { + self.model_config.is_moe() + } + + fn model_id(&self) -> &str { + &self.model_id + } + + fn sampling_hints(&self) -> Result { + Ok(SamplingHints { + primary_eos_token_id: self.primary_eos_token_id, + extra_eos_token_ids: self.extra_eos_token_ids.clone(), + default_temperature: self.generation_config.temperature, + default_top_p: self.generation_config.top_p, + default_top_k: self.generation_config.top_k, + default_min_p: self.generation_config.min_p, + default_repetition_penalty: self.generation_config.repetition_penalty, + default_max_tokens: self.generation_config.max_new_tokens, + max_model_len: self.model_config.max_position_embeddings(), + }) + } +} diff --git a/rust/src/text/src/backend/hf/model_files.rs b/rust/src/text/src/backend/hf/model_files.rs new file mode 100644 index 00000000000..f4a66d30dae --- /dev/null +++ b/rust/src/text/src/backend/hf/model_files.rs @@ -0,0 +1,459 @@ +use std::path::{Path, PathBuf}; + +use hf_hub::Cache; +use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo}; +use thiserror_ext::AsReport as _; + +use super::config::{HfTokenizerConfig, load_tokenizer_config}; +use crate::error::{Error, Result}; + +const HF_TOKEN_ENV: &str = "HF_TOKEN"; + +/// The tokenizer source selected for a model. +#[derive(Debug, Clone)] +pub enum TokenizerSource { + /// Path to `tokenizer.json` in HuggingFace format. + HuggingFace(PathBuf), + /// Path to `tiktoken.model` or `*.tiktoken` file for tiktoken-based models. + Tiktoken(PathBuf), + /// Path to `tekken.json` when present (Mistral native tokenizer format). + /// + /// When set, the Tekken tokenizer should be preferred over the Hugging Face + /// tokenizer because the HuggingFace `tokenizer.json` for Mistral + /// models has a known regex bug that produces incorrect token IDs for + /// some inputs. + Tekken(PathBuf), +} + +impl TokenizerSource { + pub fn path(&self) -> &Path { + match self { + Self::HuggingFace(path) | Self::Tiktoken(path) | Self::Tekken(path) => path, + } + } +} + +/// Concrete tokenizer/config file locations resolved for one HF model id. +#[derive(Debug, Clone)] +pub struct ResolvedModelFiles { + /// The selected tokenizer source for this model. + pub tokenizer: TokenizerSource, + pub tokenizer_config_path: Option, + pub generation_config_path: Option, + pub preprocessor_config_path: Option, + pub chat_template_path: Option, + pub config_path: Option, +} + +impl ResolvedModelFiles { + /// Resolve tokenizer/config files from a local model directory first when + /// `model_id` points to one, otherwise consult the local HF cache and + /// finally the Hub. + pub async fn new(model_id: &str) -> Result { + if Path::new(model_id).is_dir() { + return resolve_local_model_files(Path::new(model_id)); + } + if let Some(files) = resolve_cached_model_files(model_id)? { + return Ok(files); + } + resolve_remote_model_files(model_id).await + } +} + +fn resolve_local_model_files(model_dir: &Path) -> Result { + let tokenizer_config_path = local_file_if_exists(model_dir, "tokenizer_config.json"); + let tokenizer_config = load_tokenizer_config(tokenizer_config_path.as_deref())?; + let tokenizer = resolve_local_tokenizer_source(model_dir, &tokenizer_config)?; + + Ok(ResolvedModelFiles { + tokenizer, + tokenizer_config_path, + generation_config_path: local_file_if_exists(model_dir, "generation_config.json"), + preprocessor_config_path: local_file_if_exists(model_dir, "preprocessor_config.json"), + chat_template_path: discover_chat_template_in_dir(model_dir), + config_path: local_file_if_exists(model_dir, "config.json"), + }) +} + +async fn resolve_remote_model_files(model_id: &str) -> Result { + let api = build_api().map_err(|error| Error::Tokenizer(error.to_report_string()))?; + let repo = api.model(model_id.to_string()); + let info = repo.info().await.map_err(|error| { + Error::Tokenizer(format!( + "failed to fetch model '{model_id}': {}", + error.as_report() + )) + })?; + + let siblings = info + .siblings + .iter() + .map(|sibling| sibling.rfilename.as_str()) + .collect::>(); + + let tokenizer_config_path = + download_if_present(&repo, model_id, &siblings, "tokenizer_config.json").await?; + let tokenizer_config = load_tokenizer_config(tokenizer_config_path.as_deref())?; + + let tokenizer = resolve_remote_tokenizer_source( + &repo, + model_id, + &siblings, + tokenizer_config.tokenizer_class.as_deref(), + ) + .await?; + + let generation_config_path = + download_if_present(&repo, model_id, &siblings, "generation_config.json").await?; + let preprocessor_config_path = + download_if_present(&repo, model_id, &siblings, "preprocessor_config.json").await?; + let chat_template_name = siblings + .contains("chat_template.json") + .then_some("chat_template.json") + .or_else(|| siblings.contains("chat_template.jinja").then_some("chat_template.jinja")) + .or_else(|| siblings.iter().copied().find(|name| name.ends_with(".jinja"))); + let chat_template_path = match chat_template_name { + Some(name) => Some(download_known_file(&repo, model_id, name).await?), + None => None, + }; + let config_path = download_if_present(&repo, model_id, &siblings, "config.json").await?; + + Ok(ResolvedModelFiles { + tokenizer, + tokenizer_config_path, + generation_config_path, + preprocessor_config_path, + chat_template_path, + config_path, + }) +} + +fn resolve_cached_model_files(model_id: &str) -> Result> { + let cache_repo = Cache::from_env().model(model_id.to_string()); + + let tokenizer_config_path = cache_repo.get("tokenizer_config.json"); + let tokenizer_config = load_tokenizer_config(tokenizer_config_path.as_deref())?; + let tokenizer = match resolve_cached_tokenizer_source(&cache_repo, &tokenizer_config)? { + Some(tokenizer) => tokenizer, + None => return Ok(None), + }; + + let model_dir = tokenizer.path().parent().ok_or_else(|| { + Error::Tokenizer("resolved tokenizer file has no parent directory".to_string()) + })?; + let generation_config_path = cache_repo.get("generation_config.json"); + let preprocessor_config_path = cache_repo.get("preprocessor_config.json"); + let chat_template_path = discover_chat_template_in_dir(model_dir); + let config_path = cache_repo.get("config.json"); + + Ok(Some(ResolvedModelFiles { + tokenizer, + tokenizer_config_path, + generation_config_path, + preprocessor_config_path, + chat_template_path, + config_path, + })) +} + +async fn resolve_remote_tokenizer_source( + repo: &ApiRepo, + model_id: &str, + siblings: &std::collections::BTreeSet<&str>, + tokenizer_class: Option<&str>, +) -> Result { + if let Some(tekken_path) = download_if_present(repo, model_id, siblings, "tekken.json").await? { + return Ok(TokenizerSource::Tekken(tekken_path)); + } + + let tokenizer_path = if siblings.contains("tokenizer.json") { + download_known_file(repo, model_id, "tokenizer.json").await? + } else if let Some(tiktoken_name) = find_tiktoken_sibling(siblings) { + download_known_file(repo, model_id, tiktoken_name).await? + } else { + return Err(Error::Tokenizer(format!( + "model '{model_id}' does not expose a supported tokenizer file \ + (tokenizer.json, tiktoken.model, or *.tiktoken) on Hugging Face" + ))); + }; + + Ok(resolve_tokenizer_source( + tokenizer_path, + tokenizer_class, + None, + )) +} + +fn resolve_cached_tokenizer_source( + cache_repo: &hf_hub::CacheRepo, + tokenizer_config: &HfTokenizerConfig, +) -> Result> { + let tekken_path = cache_repo.get("tekken.json"); + + if let Some(tekken_path) = tekken_path { + return Ok(Some(TokenizerSource::Tekken(tekken_path))); + } + + let Some(tokenizer_path) = cache_repo.get("tokenizer.json").or_else(|| { + // tiktoken.model is the most common name, try it first. + cache_repo.get("tiktoken.model").or_else(|| { + // Scan for any *.tiktoken file in the cache snapshot directory. + let snapshot_dir = cache_repo.get("config.json")?.parent()?.to_path_buf(); + discover_tiktoken_in_dir(&snapshot_dir) + }) + }) else { + return Ok(None); + }; + + Ok(Some(resolve_tokenizer_source( + tokenizer_path, + tokenizer_config.tokenizer_class.as_deref(), + None, + ))) +} + +fn resolve_local_tokenizer_source( + model_dir: &Path, + tokenizer_config: &HfTokenizerConfig, +) -> Result { + let tekken_path = local_file_if_exists(model_dir, "tekken.json"); + if let Some(tekken_path) = tekken_path { + return Ok(TokenizerSource::Tekken(tekken_path)); + } + + let tokenizer_path = local_file_if_exists(model_dir, "tokenizer.json") + .or_else(|| local_file_if_exists(model_dir, "tiktoken.model")) + .or_else(|| discover_tiktoken_in_dir(model_dir)) + .ok_or_else(|| { + Error::Tokenizer(format!( + "local model directory '{}' does not contain a supported tokenizer file \ + (tokenizer.json, tiktoken.model, or *.tiktoken)", + model_dir.display() + )) + })?; + + Ok(resolve_tokenizer_source( + tokenizer_path, + tokenizer_config.tokenizer_class.as_deref(), + None, + )) +} + +/// Choose the tokenizer. +/// +/// Selection order: +/// 1. `tekken.json` — Mistral native tokenizer (preferred over HF `tokenizer.json` because the HF +/// version has a known regex bug for Mistral models). +/// 2. File extension — `.tiktoken` / `tiktoken.model` files use tiktoken from BPE data. +/// 3. `tokenizer_class` in `tokenizer_config.json` — classes containing "Tiktoken" (case- +/// insensitive) trigger tiktoken loading from a sibling BPE file. +/// 4. Default — `tokenizer.json` in HuggingFace format. +fn resolve_tokenizer_source( + tokenizer_path: PathBuf, + tokenizer_class: Option<&str>, + tekken_path: Option, +) -> TokenizerSource { + if let Some(tekken_path) = tekken_path { + return TokenizerSource::Tekken(tekken_path); + } + + if is_tiktoken_file(&tokenizer_path) { + return TokenizerSource::Tiktoken(tokenizer_path); + } + + if tokenizer_class.is_some_and(|cls| cls.to_ascii_lowercase().contains("tiktoken")) + && let Some(dir) = tokenizer_path.parent() + && let Some(tiktoken_path) = discover_tiktoken_in_dir(dir) + { + return TokenizerSource::Tiktoken(tiktoken_path); + } + + TokenizerSource::HuggingFace(tokenizer_path) +} + +/// Download `filename` only if it exists in `siblings`. +async fn download_if_present( + repo: &ApiRepo, + model_id: &str, + siblings: &std::collections::BTreeSet<&str>, + filename: &str, +) -> Result> { + match siblings.contains(filename) { + true => download_known_file(repo, model_id, filename).await.map(Some), + false => Ok(None), + } +} + +async fn download_known_file(repo: &ApiRepo, model_id: &str, filename: &str) -> Result { + repo.get(filename).await.map_err(|error| { + Error::Tokenizer(format!( + "failed to download '{filename}' for model '{model_id}': {}", + error.as_report() + )) + }) +} + +fn build_api() -> anyhow::Result { + let mut builder = ApiBuilder::from_env().with_progress(true); + if let Ok(token) = std::env::var(HF_TOKEN_ENV) + && !token.is_empty() + { + builder = builder.with_token(Some(token)); + } + Ok(builder.build()?) +} + +fn local_file_if_exists(dir: &Path, filename: &str) -> Option { + let path = dir.join(filename); + path.is_file().then_some(path) +} + +/// Find a tiktoken file name among repo siblings, preferring `tiktoken.model`. +fn find_tiktoken_sibling<'a>(siblings: &std::collections::BTreeSet<&'a str>) -> Option<&'a str> { + if siblings.contains("tiktoken.model") { + return Some("tiktoken.model"); + } + siblings.iter().copied().find(|name| name.ends_with(".tiktoken")) +} + +/// Discover a tiktoken model file in a local directory. +pub(super) fn discover_tiktoken_in_dir(dir: &std::path::Path) -> Option { + let tiktoken_model = dir.join("tiktoken.model"); + if tiktoken_model.exists() { + return Some(tiktoken_model); + } + std::fs::read_dir(dir).ok()?.flatten().find_map(|entry| { + let path = entry.path(); + if path + .file_name() + .and_then(|n| n.to_str()) + .is_some_and(|n| n.ends_with(".tiktoken")) + { + Some(path) + } else { + None + } + }) +} + +/// Returns `true` if `path` points to a tiktoken-format file (by name). +pub(super) fn is_tiktoken_file(path: &std::path::Path) -> bool { + path.file_name() + .and_then(|n| n.to_str()) + .is_some_and(|name| name == "tiktoken.model" || name.ends_with(".tiktoken")) +} + +/// Chat templates are sometimes stored as dedicated .jinja files rather than as +/// a fixed-name config entry, so we scan the cached model dir. +fn discover_chat_template_in_dir(dir: &std::path::Path) -> Option { + let json_template_path = dir.join("chat_template.json"); + if json_template_path.exists() { + return Some(json_template_path); + } + + let jinja_path = dir.join("chat_template.jinja"); + if jinja_path.exists() { + return Some(jinja_path); + } + + std::fs::read_dir(dir).ok()?.flatten().map(|entry| entry.path()).find(|path| { + path.file_name() + .and_then(|name| name.to_str()) + .is_some_and(|name| name.ends_with(".jinja")) + }) +} + +#[cfg(test)] +mod tests { + use std::fs; + + use tempfile::tempdir; + use vllm_tokenizer::{TiktokenTokenizer, Tokenizer}; + + use super::{ResolvedModelFiles, TokenizerSource}; + + #[tokio::test] + async fn resolved_model_files_prefers_absolute_local_model_dir() { + let dir = tempdir().expect("create temp dir"); + fs::write(dir.path().join("tokenizer.json"), "{}").expect("write tokenizer"); + fs::write( + dir.path().join("tokenizer_config.json"), + r#"{"tokenizer_class":"PreTrainedTokenizerFast"}"#, + ) + .expect("write tokenizer config"); + fs::write(dir.path().join("config.json"), "{}").expect("write config"); + + let files = ResolvedModelFiles::new(dir.path().to_str().expect("utf8 path")) + .await + .expect("resolve local model files"); + + match files.tokenizer { + TokenizerSource::HuggingFace(path) => { + assert_eq!(path, dir.path().join("tokenizer.json")); + } + other => panic!("expected HuggingFace tokenizer, got {other:?}"), + } + assert_eq!(files.config_path, Some(dir.path().join("config.json"))); + assert_eq!( + files.tokenizer_config_path, + Some(dir.path().join("tokenizer_config.json")) + ); + } + + #[tokio::test] + #[ignore = "requires network access to Hugging Face and downloads the real Kimi K2.5 tokenizer"] + async fn tiktoken_real_kimi_k25_tokenizer_files_load_and_handle_special_tokens() { + let files = ResolvedModelFiles::new("moonshotai/Kimi-K2.5") + .await + .expect("resolve real Kimi K2.5 model files"); + + let tokenizer_path = match &files.tokenizer { + TokenizerSource::Tiktoken(path) => path.clone(), + other => panic!("expected tiktoken tokenizer source, got {other:?}"), + }; + + for backend in [ + TiktokenTokenizer::new_riptoken(&tokenizer_path).expect("load riptoken backend"), + TiktokenTokenizer::new_tiktoken_rs(&tokenizer_path).expect("load tiktoken-rs backend"), + ] { + let think_id = backend.token_to_id("").expect("resolve "); + let end_think_id = backend.token_to_id("").expect("resolve "); + let tool_section_id = backend + .token_to_id("<|tool_calls_section_begin|>") + .expect("resolve tool call section marker"); + let contraction_heavy_text = + "I'm sure it's fine, but I can't say I'd trust that it's what we'd ship."; + let contraction_heavy_ids = backend.encode(contraction_heavy_text, false).unwrap(); + + assert_eq!( + (think_id, end_think_id, tool_section_id), + (163606, 163607, 163595) + ); + assert_eq!(backend.decode(&[think_id], true).unwrap(), ""); + assert_eq!(backend.decode(&[end_think_id], true).unwrap(), ""); + assert_eq!( + backend.decode(&[tool_section_id], true).unwrap(), + "<|tool_calls_section_begin|>" + ); + + // This demonstrates that we're using Kimi's custom BPE pattern. + // With CL100K this will be 23 tokens instead. + assert_eq!( + contraction_heavy_ids, + vec![ + 17172, 3287, 4643, 8201, 11, 996, 374, 8971, 3637, 20020, 8173, 473, 4643, + 1573, 56229, 13922, 13, + ] + ); + assert_eq!(contraction_heavy_ids.len(), 17); + assert_eq!( + backend.decode(&contraction_heavy_ids, false).unwrap(), + contraction_heavy_text + ); + + // Special-looking text that is not actually registered should fail gracefully. + assert_eq!(backend.token_to_id("◁think▷"), None); + assert_eq!(backend.token_to_id("<|definitely_not_registered|>"), None); + } + } +} diff --git a/rust/src/text/src/backend/mod.rs b/rust/src/text/src/backend/mod.rs new file mode 100644 index 00000000000..4f2d7093a75 --- /dev/null +++ b/rust/src/text/src/backend/mod.rs @@ -0,0 +1,47 @@ +pub mod hf; + +use std::sync::Arc; + +use vllm_tokenizer::DynTokenizer; + +use crate::error::Result; + +/// Tokenizer/model-derived hints used to enrich text-generation requests before +/// they are lowered into engine-core. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct SamplingHints { + pub primary_eos_token_id: Option, + pub extra_eos_token_ids: std::collections::BTreeSet, + pub default_temperature: Option, + pub default_top_p: Option, + pub default_top_k: Option, + pub default_min_p: Option, + pub default_repetition_penalty: Option, + pub default_max_tokens: Option, + /// Model context window size (`max_position_embeddings` from + /// `config.json`). + pub max_model_len: Option, +} + +/// Minimal text-processing backend needed by `vllm-text`. +pub trait TextBackend: Send + Sync { + /// Return the tokenizer used by this backend. + fn tokenizer(&self) -> DynTokenizer; + + /// Return whether the loaded model is a mixture-of-experts model. + fn is_moe(&self) -> bool { + false + } + + /// Return the backend model ID. + fn model_id(&self) -> &str; + + /// Return tokenizer/model-derived hints used to enrich southbound sampling + /// parameters. + fn sampling_hints(&self) -> Result { + Ok(SamplingHints::default()) + } +} + +/// Shared trait-object form of [`TextBackend`]. +pub type DynTextBackend = Arc; diff --git a/rust/src/text/src/error.rs b/rust/src/text/src/error.rs new file mode 100644 index 00000000000..62e8e2ae98a --- /dev/null +++ b/rust/src/text/src/error.rs @@ -0,0 +1,30 @@ +use thiserror::Error; +use vllm_engine_core_client::Error as EngineCoreError; +use vllm_llm::Error as LlmError; + +#[derive(Debug, Error)] +pub enum Error { + #[error("tokenizer error: {0}")] + Tokenizer(String), + #[error("text request `{request_id}` must contain at least one prompt token ID")] + EmptyPromptTokenIds { request_id: String }, + #[error( + "this model's maximum context length is {max_model_len} tokens, \ + but the prompt contains {prompt_len} input tokens" + )] + PromptTooLong { max_model_len: u32, prompt_len: u32 }, + #[error("text request stream `{request_id}` closed before terminal output")] + StreamClosedBeforeTerminalOutput { request_id: String }, + #[error(transparent)] + Llm(#[from] LlmError), + #[error(transparent)] + EngineCore(#[from] EngineCoreError), +} + +pub type Result = std::result::Result; + +impl From for Error { + fn from(error: vllm_tokenizer::TokenizerError) -> Self { + Self::Tokenizer(error.0) + } +} diff --git a/rust/src/text/src/lib.rs b/rust/src/text/src/lib.rs new file mode 100644 index 00000000000..ef5615ec6d4 --- /dev/null +++ b/rust/src/text/src/lib.rs @@ -0,0 +1,149 @@ +//! Shared text-generation support used by chat and future raw completions. +//! +//! This crate intentionally stays below chat semantics: +//! prompt text handling, tokenizer/model loading, incremental detokenization, +//! and the thin generate-facing backend interface live here. + +use std::mem::take; + +pub use backend::{DynTextBackend, SamplingHints, TextBackend}; +pub use error::{Error, Result}; +use futures::Stream; +pub use lower::{ + PreparedTextRequest, lower_sampling_params, lower_text_request, resolve_max_tokens, +}; +pub use output::{ + CollectedTextOutput, DecodedLogprobs, DecodedPositionLogprobs, DecodedPromptLogprobs, + DecodedTextEvent, DecodedTokenLogprob, Finished, TextDecodeOptions, TextOutputStreamExt, +}; +pub use request::{Prompt, SamplingParams, TextRequest}; +use trait_set::trait_set; +use vllm_engine_core_client::EngineCoreClient; +pub use vllm_llm::FinishReason; +use vllm_llm::{GenerateOutputStream, Llm}; +use vllm_tokenizer::DynTokenizer; + +pub mod backend; +mod error; +mod lower; +pub mod output; +mod request; +pub use vllm_tokenizer as tokenizer; + +trait_set! { + /// Shared streamed text output type used by raw completions and other text-only northbound paths. + pub trait TextOutputStream = Stream> + Send + 'static; +} + +/// Raw text facade above [`Llm`]. +/// +/// This layer stays below chat semantics: prompt text or prompt token IDs flow +/// in, decoded text deltas and terminal metadata flow out. +pub struct TextLlm { + /// Generate-only client owned by this text facade. + llm: Llm, + /// Tokenizer/model metadata backend responsible for prompt encode/decode + /// and sampling hints. + backend: DynTextBackend, + /// Context window size derived by the backend or from engine startup + /// handshake, with optional override from config. + max_model_len: Option, +} + +impl TextLlm { + /// Create a new text-generation facade from a shared LLM client plus a text + /// backend. + pub fn new(llm: Llm, backend: DynTextBackend) -> Self { + // Prefer the engine-reported max_model_len because it reflects the + // post-profiling, auto-fitted KV cache limit rather than static + // frontend metadata. + let max_model_len = llm.engine_core_client().max_model_len(); + + Self { + llm, + backend, + max_model_len, + } + } + + /// Override the maximum model context length explicitly. + /// + /// This takes priority over both the engine-reported default and any + /// tokenizer/model metadata exposed by the backend. + pub fn with_max_model_len(mut self, max_model_len: u32) -> Self { + self.max_model_len = Some(max_model_len); + self + } + + /// Return the backend model ID. + pub fn model_id(&self) -> &str { + self.backend.model_id() + } + + /// Expose the underlying engine-core client for low-level utility/admin + /// calls. + pub fn engine_core_client(&self) -> &EngineCoreClient { + self.llm.engine_core_client() + } + + /// Return the tokenizer used by this text backend. + pub fn tokenizer(&self) -> DynTokenizer { + self.backend.tokenizer() + } + + /// Tokenize if needed, lower to a generate request, and return the raw + /// token stream. + pub async fn generate_raw(&self, request: TextRequest) -> Result { + let (_, raw_stream) = self.generate_inner(request).await?; + Ok(raw_stream) + } + + /// Tokenize if needed, lower to a generate request, and stream + /// incrementally decoded text. + pub async fn generate(&self, request: TextRequest) -> Result { + let (text_request, raw_stream) = self.generate_inner(request).await?; + let tokenizer = self.backend.tokenizer(); + let decoded_stream = output::decoded_text_event_stream( + text_request.request_id, + tokenizer, + raw_stream, + text_request.decode_options, + text_request.intermediate, + ); + + Ok(decoded_stream) + } + + async fn generate_inner( + &self, + mut request: TextRequest, + ) -> Result<(TextRequest, GenerateOutputStream)> { + request.validate()?; + + let tokenizer = self.backend.tokenizer(); + let prompt_token_ids = match take(&mut request.prompt) { + Prompt::Text(text) => tokenizer.encode(&text, request.add_special_tokens)?, + // Pre-tokenized prompts are the main completions-side escape hatch that lets benchmark + // and infra workloads bypass chat rendering and tokenizer overhead entirely. + Prompt::TokenIds(token_ids) => token_ids, + }; + + let mut sampling_hints = self.backend.sampling_hints()?; + if let Some(max_model_len) = self.max_model_len { + sampling_hints.max_model_len = Some(max_model_len); + } + let PreparedTextRequest { + text_request, + generate_request, + } = lower_text_request(request, prompt_token_ids, sampling_hints, &*tokenizer)?; + + let raw_stream = self.llm.generate(generate_request).await?; + Ok((text_request, raw_stream)) + } + + /// Shut down the underlying LLM client and its background tasks. + pub async fn shutdown(self) -> Result<()> { + self.llm.shutdown().await?; + Ok(()) + } +} diff --git a/rust/src/text/src/lower.rs b/rust/src/text/src/lower.rs new file mode 100644 index 00000000000..7cbbd53dc6a --- /dev/null +++ b/rust/src/text/src/lower.rs @@ -0,0 +1,739 @@ +use std::collections::BTreeSet; + +use vllm_engine_core_client::protocol::EngineCoreSamplingParams; +use vllm_llm::GenerateRequest; +use vllm_tokenizer::Tokenizer; + +use crate::backend::SamplingHints; +use crate::error::{Error, Result}; +use crate::request::{SamplingParams, TextRequest}; + +/// One text request after it has been lowered into the raw generate boundary. +#[derive(Debug)] +pub struct PreparedTextRequest { + /// The original high-level request, preserved for response-side metadata + /// and decoding options. + pub text_request: TextRequest, + /// The southbound request ready to be sent to `vllm-llm`. + pub generate_request: GenerateRequest, +} + +/// Convert a high-level [`TextRequest`] into one lower-level +/// [`GenerateRequest`] ready for the `llm` crate. +pub fn lower_text_request( + request: TextRequest, + prompt_token_ids: Vec, + sampling_hints: SamplingHints, + tokenizer: &dyn Tokenizer, +) -> Result { + let prompt_len = prompt_token_ids.len() as u32; + let generate_request = GenerateRequest { + request_id: request.request_id.clone(), + prompt_token_ids, + mm_features: request.mm_features.clone(), + sampling_params: lower_sampling_params( + request.sampling_params.clone(), + sampling_hints, + prompt_len, + tokenizer, + )?, + cache_salt: request.cache_salt.clone(), + priority: request.priority, + data_parallel_rank: request.data_parallel_rank, + // Fields below are currently placeholders. + arrival_time: None, + trace_headers: None, + reasoning_ended: None, + lora_request: None, + }; + + Ok(PreparedTextRequest { + text_request: request, + generate_request, + }) +} + +/// Convert [`SamplingParams`] into [`EngineCoreSamplingParams`], enriching +/// omitted user values with tokenizer/model-derived hints when available. +pub fn lower_sampling_params( + sampling_params: SamplingParams, + SamplingHints { + primary_eos_token_id, + extra_eos_token_ids, + default_temperature, + default_top_p, + default_top_k, + default_min_p, + default_repetition_penalty, + default_max_tokens, + max_model_len, + }: SamplingHints, + prompt_len: u32, + tokenizer: &dyn Tokenizer, +) -> Result { + let SamplingParams { + temperature, + top_p, + top_k, + seed, + max_tokens, + min_tokens, + logprobs, + prompt_logprobs, + min_p, + frequency_penalty, + presence_penalty, + repetition_penalty, + stop_token_ids, + ignore_eos, + logit_bias, + allowed_token_ids, + bad_words, + logprob_token_ids, + structured_outputs, + skip_reading_prefix_cache, + vllm_xargs, + } = sampling_params; + + // Mirrors the model-generation-config inheritance used by vLLM's OpenAI chat + // path: https://github.com/vllm-project/vllm/blob/bc2c0c86efb28e77677a3cfb8687e976914a313a/vllm/entrypoints/openai/chat_completion/protocol.py#L424-L450 + // If neither the caller nor the model provides a value, fall back to 1.0 — the + // default used by the Python vLLM OpenAI-compatible API (via + // `_DEFAULT_SAMPLING_PARAMS`). + let temperature = temperature.or(default_temperature).unwrap_or(1.0); + let top_p = top_p.or(default_top_p).unwrap_or(1.0); + let top_k = top_k.or(default_top_k).unwrap_or(0); + let min_p = min_p.or(default_min_p).unwrap_or(0.0); + let repetition_penalty = repetition_penalty.or(default_repetition_penalty).unwrap_or(1.0); + let max_tokens = resolve_max_tokens(max_tokens, default_max_tokens, max_model_len, prompt_len)?; + let min_tokens = min_tokens.unwrap_or(0); + let frequency_penalty = frequency_penalty.unwrap_or(0.0); + let presence_penalty = presence_penalty.unwrap_or(0.0); + + let mut stop_token_ids = stop_token_ids.unwrap_or_default(); + let mut all_stop_token_ids = BTreeSet::from_iter(stop_token_ids.iter().copied()); + if let Some(primary_eos_token_id) = primary_eos_token_id { + all_stop_token_ids.insert(primary_eos_token_id); + } + all_stop_token_ids.extend(extra_eos_token_ids.iter().copied()); + + if !ignore_eos { + merge_unique_token_ids(&mut stop_token_ids, extra_eos_token_ids.iter().copied()); + } + + Ok(EngineCoreSamplingParams { + temperature, + top_p, + top_k, + seed, + max_tokens, + min_tokens, + logprobs, + prompt_logprobs, + min_p, + frequency_penalty, + presence_penalty, + repetition_penalty, + stop_token_ids, + eos_token_id: (!ignore_eos).then_some(primary_eos_token_id).flatten(), + all_stop_token_ids, + logit_bias, + allowed_token_ids, + bad_words_token_ids: tokenize_bad_words(bad_words.as_deref(), tokenizer)?, + structured_outputs, + logprob_token_ids, + skip_reading_prefix_cache, + extra_args: vllm_xargs, + }) +} + +/// Convert bad-word strings into token-ID sequences, following the Python vLLM +/// logic in `SamplingParams.update_from_tokenizer()`. +/// +/// Each word is encoded both with and without a leading space so that the ban +/// applies regardless of whether the word appears at the beginning or in the +/// middle of generated text (this accounts for tokenizers that use an +/// `add_prefix_space` convention). +/// +/// Reference: +fn tokenize_bad_words( + bad_words: Option<&[String]>, + tokenizer: &dyn Tokenizer, +) -> Result>>> { + let bad_words = bad_words.filter(|w| !w.is_empty()); + let mut all_token_ids = Vec::new(); + + for bad_word in bad_words.into_iter().flatten() { + // Without a leading space we always keep the encoding. + // With a leading space we only keep it when the prefix-space variant produces a + // distinct first token but the same sequence length — this mirrors the Python + // dedup condition that avoids redundant entries. + let without_space = tokenizer.encode(bad_word, false)?; + let with_space = tokenizer.encode(&format!(" {}", bad_word.trim_start()), false)?; + + if !without_space.is_empty() { + all_token_ids.push(without_space); + } + if !with_space.is_empty() + && all_token_ids.last().is_some_and(|prev: &Vec| { + with_space[0] != prev[0] && with_space.len() == prev.len() + }) + { + all_token_ids.push(with_space); + } + } + + Ok((!all_token_ids.is_empty()).then_some(all_token_ids)) +} + +/// Resolve the effective `max_tokens` for generation, mirroring vLLM Python's +/// `get_max_tokens()` in `vllm/entrypoints/utils.py`. +/// +/// Takes the minimum of all available limits (user-specified, generation-config +/// default, and `max_model_len - prompt_len`). When nothing is known, falls +/// back to `u32::MAX` so the engine-core can apply its own context-window +/// limit. +pub fn resolve_max_tokens( + user_max_tokens: Option, + default_max_tokens: Option, + max_model_len: Option, + prompt_len: u32, +) -> Result { + let model_max_tokens = match max_model_len { + Some(max_model_len) if prompt_len >= max_model_len => { + return Err(Error::PromptTooLong { + max_model_len, + prompt_len, + }); + } + Some(max_model_len) => Some(max_model_len - prompt_len), + None => None, + }; + + let fallback_max_tokens = user_max_tokens.or(default_max_tokens); + Ok([fallback_max_tokens, model_max_tokens] + .into_iter() + .flatten() + .min() + .unwrap_or(u32::MAX /* TODO: a reasonable fallback? */)) +} + +fn merge_unique_token_ids( + stop_token_ids: &mut Vec, + extra_token_ids: impl Iterator, +) { + // Keep user-provided ordering stable while still folding in backend-derived EOS + // aliases. + for token_id in extra_token_ids { + if !stop_token_ids.contains(&token_id) { + stop_token_ids.push(token_id); + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use super::*; + use crate::backend::hf::HfTextBackend; + use crate::backend::{SamplingHints, TextBackend as _}; + use crate::request::{Prompt, TextRequest}; + + /// Stub tokenizer that returns empty token IDs — sufficient for tests that + /// don't exercise bad-words tokenization. + struct StubTokenizer; + + impl Tokenizer for StubTokenizer { + fn encode( + &self, + _text: &str, + _add_special_tokens: bool, + ) -> vllm_tokenizer::Result> { + Ok(vec![]) + } + + fn decode( + &self, + _token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_tokenizer::Result { + Ok(String::new()) + } + + fn token_to_id(&self, _token: &str) -> Option { + None + } + } + + fn stub_tokenizer() -> StubTokenizer { + StubTokenizer + } + + fn sample_request() -> TextRequest { + TextRequest { + prompt: Prompt::TokenIds(vec![1, 2, 3]), + request_id: "text-1".to_string(), + ..TextRequest::for_test() + } + } + + fn sample_sampling_hints() -> SamplingHints { + SamplingHints { + primary_eos_token_id: Some(99), + extra_eos_token_ids: BTreeSet::from([77]), + default_temperature: None, + default_top_p: None, + default_top_k: None, + default_min_p: None, + default_repetition_penalty: None, + default_max_tokens: None, + max_model_len: None, + } + } + + #[test] + fn lower_text_request_applies_python_style_eos_hints() { + let prepared = lower_text_request( + sample_request(), + vec![1, 2, 3], + sample_sampling_hints(), + &stub_tokenizer(), + ) + .unwrap(); + + let params = prepared.generate_request.sampling_params; + expect_test::expect![[r#" + EngineCoreSamplingParams { + temperature: 1.0, + top_p: 1.0, + top_k: 0, + seed: None, + max_tokens: 4294967295, + min_tokens: 0, + logprobs: None, + prompt_logprobs: None, + min_p: 0.0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + repetition_penalty: 1.0, + stop_token_ids: [ + 77, + ], + eos_token_id: Some( + 99, + ), + all_stop_token_ids: { + 77, + 99, + }, + logit_bias: None, + allowed_token_ids: None, + bad_words_token_ids: None, + structured_outputs: None, + logprob_token_ids: None, + skip_reading_prefix_cache: None, + extra_args: None, + } + "#]] + .assert_debug_eq(¶ms); + } + + #[test] + fn lower_text_request_respects_ignore_eos_for_stop_token_ids() { + let mut request = sample_request(); + request.sampling_params.ignore_eos = true; + + let prepared = lower_text_request( + request, + vec![1, 2, 3], + sample_sampling_hints(), + &stub_tokenizer(), + ) + .unwrap(); + + let params = prepared.generate_request.sampling_params; + expect_test::expect![[r#" + EngineCoreSamplingParams { + temperature: 1.0, + top_p: 1.0, + top_k: 0, + seed: None, + max_tokens: 4294967295, + min_tokens: 0, + logprobs: None, + prompt_logprobs: None, + min_p: 0.0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + repetition_penalty: 1.0, + stop_token_ids: [], + eos_token_id: None, + all_stop_token_ids: { + 77, + 99, + }, + logit_bias: None, + allowed_token_ids: None, + bad_words_token_ids: None, + structured_outputs: None, + logprob_token_ids: None, + skip_reading_prefix_cache: None, + extra_args: None, + } + "#]] + .assert_debug_eq(¶ms); + } + + #[tokio::test] + #[ignore = "requires network access to Hugging Face"] + async fn lower_text_request_uses_real_qwen_generation_defaults() { + let backend = HfTextBackend::from_model("Qwen/Qwen3-0.6B") + .await + .expect("load qwen tokenizer and generation config"); + let hints = backend.sampling_hints().expect("collect sampling hints"); + + expect_test::expect![[r#" + SamplingHints { + primary_eos_token_id: Some( + 151645, + ), + extra_eos_token_ids: { + 151643, + }, + default_temperature: Some( + 0.6, + ), + default_top_p: Some( + 0.95, + ), + default_top_k: Some( + 20, + ), + default_min_p: Some( + 0.1, + ), + default_repetition_penalty: Some( + 1.2, + ), + default_max_tokens: None, + max_model_len: Some( + 40960, + ), + } + "#]] + .assert_debug_eq(&hints); + + let prepared = + lower_text_request(sample_request(), vec![1, 2, 3], hints, &stub_tokenizer()) + .expect("lower request"); + let params = prepared.generate_request.sampling_params; + + expect_test::expect![[r#" + EngineCoreSamplingParams { + temperature: 0.6, + top_p: 0.95, + top_k: 20, + seed: None, + max_tokens: 40957, + min_tokens: 0, + logprobs: None, + prompt_logprobs: None, + min_p: 0.1, + frequency_penalty: 0.0, + presence_penalty: 0.0, + repetition_penalty: 1.2, + stop_token_ids: [ + 151643, + ], + eos_token_id: Some( + 151645, + ), + all_stop_token_ids: { + 151643, + 151645, + }, + } + "#]] + .assert_debug_eq(¶ms); + } + + #[test] + fn lower_sampling_params_preserves_explicit_stop_token_ids_in_all_stop_set() { + let sampling_params = SamplingParams { + stop_token_ids: Some(vec![11, 77]), + ..SamplingParams::default() + }; + + let params = lower_sampling_params( + sampling_params, + SamplingHints { + primary_eos_token_id: Some(99), + extra_eos_token_ids: BTreeSet::from([77, 88]), + default_temperature: None, + default_top_p: None, + default_top_k: None, + default_min_p: None, + default_repetition_penalty: None, + default_max_tokens: None, + max_model_len: None, + }, + 3, + &stub_tokenizer(), + ) + .unwrap(); + + expect_test::expect![[r#" + EngineCoreSamplingParams { + temperature: 1.0, + top_p: 1.0, + top_k: 0, + seed: None, + max_tokens: 4294967295, + min_tokens: 0, + logprobs: None, + prompt_logprobs: None, + min_p: 0.0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + repetition_penalty: 1.0, + stop_token_ids: [ + 11, + 77, + 88, + ], + eos_token_id: Some( + 99, + ), + all_stop_token_ids: { + 11, + 77, + 88, + 99, + }, + logit_bias: None, + allowed_token_ids: None, + bad_words_token_ids: None, + structured_outputs: None, + logprob_token_ids: None, + skip_reading_prefix_cache: None, + extra_args: None, + } + "#]] + .assert_debug_eq(¶ms); + } + + #[test] + fn lower_sampling_params_prefers_user_values_over_generation_defaults() { + let sampling_params = SamplingParams { + temperature: Some(0.2), + top_p: Some(0.3), + top_k: Some(4), + max_tokens: Some(32), + min_tokens: Some(2), + ..Default::default() + }; + + let params = lower_sampling_params( + sampling_params, + SamplingHints { + primary_eos_token_id: None, + extra_eos_token_ids: BTreeSet::new(), + default_temperature: Some(0.8), + default_top_p: Some(0.9), + default_top_k: Some(12), + default_min_p: Some(0.1), + default_repetition_penalty: Some(1.2), + default_max_tokens: Some(128), + max_model_len: None, + }, + 3, + &stub_tokenizer(), + ) + .unwrap(); + + expect_test::expect![[r#" + EngineCoreSamplingParams { + temperature: 0.2, + top_p: 0.3, + top_k: 4, + seed: None, + max_tokens: 32, + min_tokens: 2, + logprobs: None, + prompt_logprobs: None, + min_p: 0.1, + frequency_penalty: 0.0, + presence_penalty: 0.0, + repetition_penalty: 1.2, + stop_token_ids: [], + eos_token_id: None, + all_stop_token_ids: {}, + logit_bias: None, + allowed_token_ids: None, + bad_words_token_ids: None, + structured_outputs: None, + logprob_token_ids: None, + skip_reading_prefix_cache: None, + extra_args: None, + } + "#]] + .assert_debug_eq(¶ms); + } + + #[test] + fn lower_sampling_params_passes_logprobs_fields_through() { + let sampling_params = SamplingParams { + logprobs: Some(3), + prompt_logprobs: Some(-1), + ..Default::default() + }; + + let params = lower_sampling_params( + sampling_params, + SamplingHints { + primary_eos_token_id: None, + extra_eos_token_ids: BTreeSet::new(), + default_temperature: None, + default_top_p: None, + default_top_k: None, + default_min_p: None, + default_repetition_penalty: None, + default_max_tokens: None, + max_model_len: None, + }, + 3, + &stub_tokenizer(), + ) + .unwrap(); + + assert_eq!(params.logprobs, Some(3)); + assert_eq!(params.prompt_logprobs, Some(-1)); + } + + #[test] + fn lower_sampling_params_uses_generation_defaults_when_user_omits_values() { + let params = lower_sampling_params( + SamplingParams::default(), + SamplingHints { + primary_eos_token_id: None, + extra_eos_token_ids: BTreeSet::new(), + default_temperature: Some(0.8), + default_top_p: Some(0.9), + default_top_k: Some(12), + default_min_p: Some(0.1), + default_repetition_penalty: Some(1.2), + default_max_tokens: Some(128), + max_model_len: None, + }, + 3, + &stub_tokenizer(), + ) + .unwrap(); + + expect_test::expect![[r#" + EngineCoreSamplingParams { + temperature: 0.8, + top_p: 0.9, + top_k: 12, + seed: None, + max_tokens: 128, + min_tokens: 0, + logprobs: None, + prompt_logprobs: None, + min_p: 0.1, + frequency_penalty: 0.0, + presence_penalty: 0.0, + repetition_penalty: 1.2, + stop_token_ids: [], + eos_token_id: None, + all_stop_token_ids: {}, + logit_bias: None, + allowed_token_ids: None, + bad_words_token_ids: None, + structured_outputs: None, + logprob_token_ids: None, + skip_reading_prefix_cache: None, + extra_args: None, + } + "#]] + .assert_debug_eq(¶ms); + } + + #[test] + fn resolve_max_tokens_caps_by_model_len() { + let result = resolve_max_tokens(Some(150), None, Some(200), 100); + assert_eq!(result.unwrap(), 100); + } + + #[test] + fn lower_text_request_preserves_non_streaming_request_metadata() { + let mut request = sample_request(); + request.intermediate = false; + + let prepared = lower_text_request( + request, + vec![1, 2, 3], + sample_sampling_hints(), + &stub_tokenizer(), + ) + .unwrap(); + + assert!(!prepared.text_request.intermediate); + assert_eq!(prepared.generate_request.request_id, "text-1"); + } + + #[test] + fn resolve_max_tokens_user_smaller_than_model_limit() { + let result = resolve_max_tokens(Some(50), None, Some(200), 100); + assert_eq!(result.unwrap(), 50); + } + + #[test] + fn resolve_max_tokens_uses_default_when_user_omits() { + let result = resolve_max_tokens(None, Some(64), Some(200), 100); + assert_eq!(result.unwrap(), 64); + } + + #[test] + fn resolve_max_tokens_default_capped_by_model_len() { + let result = resolve_max_tokens(None, Some(256), Some(200), 100); + assert_eq!(result.unwrap(), 100); + } + + #[test] + fn resolve_max_tokens_no_model_len_falls_back() { + let result = resolve_max_tokens(Some(9999), None, None, 100); + assert_eq!(result.unwrap(), 9999); + } + + #[test] + fn resolve_max_tokens_no_limits_known_falls_back_to_u32_max() { + let result = resolve_max_tokens(None, None, None, 100); + assert_eq!(result.unwrap(), u32::MAX); + } + + #[test] + fn resolve_max_tokens_prompt_too_long() { + let result = resolve_max_tokens(Some(10), None, Some(100), 100); + assert!(matches!( + result, + Err(Error::PromptTooLong { + max_model_len: 100, + prompt_len: 100, + }) + )); + } + + #[test] + fn resolve_max_tokens_prompt_exceeds_model_len() { + let result = resolve_max_tokens(Some(10), None, Some(100), 200); + assert!(matches!( + result, + Err(Error::PromptTooLong { + max_model_len: 100, + prompt_len: 200, + }) + )); + } +} diff --git a/rust/src/text/src/output/decoded.rs b/rust/src/text/src/output/decoded.rs new file mode 100644 index 00000000000..2ebc6f38532 --- /dev/null +++ b/rust/src/text/src/output/decoded.rs @@ -0,0 +1,607 @@ +use std::sync::Arc; + +use asynk_strim_attr::{TryYielder, try_stream}; +use futures::{Stream, StreamExt}; +use serde::{Deserialize, Serialize}; +use tracing::{Level, debug, trace}; +use vllm_engine_core_client::AbortCause; +use vllm_engine_core_client::protocol::StopReason; +use vllm_llm::{FinishReason, GenerateOutput}; +use vllm_tokenizer::{DynTokenizer, IncrementalDecoder}; + +use super::logprobs::{ + DecodedLogprobs, DecodedPromptLogprobs, decode_logprobs, decode_prompt_logprobs, +}; +use crate::error::Error; + +/// Request-neutral options for incremental text decoding. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TextDecodeOptions { + pub skip_special_tokens: bool, + pub include_stop_str_in_output: bool, + pub stop_strings: Option>, + /// Minimum number of tokens to generate before stop-string checking kicks + /// in. Stop strings found within the first `min_tokens` tokens are + /// ignored. + pub min_tokens: u32, +} + +impl Default for TextDecodeOptions { + fn default() -> Self { + Self { + skip_special_tokens: true, + include_stop_str_in_output: false, + stop_strings: None, + min_tokens: 0, + } + } +} + +/// Terminal metadata carried on the final [`DecodedTextEvent`]. +#[derive(Debug, Clone, PartialEq)] +pub struct Finished { + pub prompt_token_count: usize, + pub output_token_count: usize, + pub finish_reason: FinishReason, + /// Connector-specific KV transfer parameters for disaggregated serving. + pub kv_transfer_params: Option, +} + +/// Internal decoded-text event emitted before higher-level assistant +/// adaptation. +#[derive(Debug, Clone, PartialEq)] +pub enum DecodedTextEvent { + /// The request has reached the point where prompt-scoped decoding metadata + /// is ready. + Start { + /// The actual prompt token IDs for this request. + prompt_token_ids: Arc<[u32]>, + /// Once-only prompt logprobs metadata, when requested. + /// + /// The first prompt token is carried separately because it has no left + /// context to score against; `scored_positions` covers the + /// remaining prompt positions. + prompt_logprobs: Option, + }, + /// A delta of text has been decoded, optionally alongside token-position + /// logprobs. + /// + /// `delta` is the newly visible decoded text fragment for this update. + /// + /// `logprobs` covers the newly generated token positions from the same + /// update, but is not guaranteed to align with `delta` by character + /// span. One update may carry token logprobs but no newly visible text + /// yet, and one visible text fragment may reflect multiple token + /// positions becoming decodable together. + /// + /// Upper-level may further parse `delta` as reasoning or tool calls. + /// + /// When `finished` is `Some`, this is the terminal event for the request. + TextDelta { + delta: String, + token_ids: Vec, + logprobs: Option, + finished: Option, + }, +} + +/// Convert the output token stream from the `vllm_llm` layer into incrementally +/// decoded text. +#[try_stream] +pub async fn decoded_text_event_stream( + request_id: String, + tokenizer: DynTokenizer, + mut raw_stream: impl Stream> + Unpin, + mut decode_options: TextDecodeOptions, + intermediate: bool, + mut y: TryYielder, +) -> crate::Result<()> { + let mut decoder: Option> = None; + let mut prompt_token_count = 0_usize; + let mut token_ids = Vec::new(); + let mut output_token_count: usize = 0; + let mut logprobs: Option = None; + + while let Some(next) = raw_stream.next().await { + let output = next?; + + // If it's the first output, init states and yield `Start` event. + if decoder.is_none() { + let prompt_token_ids = + output.prompt_token_ids().expect("first llm output must carry prompt token ids"); + prompt_token_count = prompt_token_ids.len(); + + let dec = tokenizer.create_decode_stream( + prompt_token_ids, + decode_options.skip_special_tokens, + // If we are excluding stop strings from output, we need to buffer + // the output so that we don't return the beginning of a stop string + // when streaming the outputs. + match decode_options.include_stop_str_in_output { + true => 0, + false => { + decode_options + .stop_strings + .as_ref() + .and_then(|stops| stops.iter().map(|ss| ss.len()).max()) + .unwrap_or(1) + - 1 + } + }, + ); + decoder = Some(dec); + + y.yield_ok(DecodedTextEvent::Start { + prompt_token_ids: prompt_token_ids.clone(), + prompt_logprobs: output + .prompt_logprobs() + .map(|logprobs| { + decode_prompt_logprobs( + tokenizer.as_ref(), + prompt_token_ids, + logprobs, + decode_options.skip_special_tokens, + ) + }) + .transpose()?, + }) + .await; + }; + let decoder = decoder.as_mut().unwrap(); + + let kv_transfer_params = output.kv_transfer_params; + let mut finish_reason = output.finish_reason; + let mut stop_str_matched = false; + let suppress_terminal_stop_token = finish_reason.as_ref().is_some_and(|r| r.is_stop()) + && !decode_options.include_stop_str_in_output; + let decodable_token_ids = if suppress_terminal_stop_token { + // Match Python V1 token-stop detokenization by keeping the stop token + // in metadata while excluding it from user-visible text. + output.token_ids.split_last().map(|(_, rest)| rest).unwrap_or(&[]) + } else { + &output.token_ids + }; + + let mut delta: Option = None; + let mut truncate_output_to = None; + let mut truncate_tokens_to = None; + for (tok_idx, &token_id) in decodable_token_ids.iter().enumerate() { + let new_bytes = decoder.push_token(token_id)?; + if output_token_count + tok_idx + 1 > decode_options.min_tokens as usize + && let Some(stops) = decode_options.stop_strings.as_mut() + && let Some((idx, off)) = matches_stop_string(stops, decoder.output(), new_bytes) + { + let stop_str = stops.swap_remove(idx); + truncate_output_to = match decode_options.include_stop_str_in_output { + true => Some(off + stop_str.len()), + false => Some(off), + }; + finish_reason = Some(FinishReason::Stop(Some(StopReason::Text(stop_str)))); + truncate_tokens_to = Some(tok_idx + 1); + stop_str_matched = true; + + break; + } + + if intermediate && let Some(chunk) = decoder.next_chunk() { + if let Some(delta_str) = delta.as_mut() { + delta_str.push_str(&chunk); + } else { + delta = Some(chunk); + } + } + } + + let mut new_token_ids = output.token_ids; + let mut new_logprobs = output.logprobs; + + // Trim tokens and logprobs if we matched stop string. + if let Some(num_tokens) = truncate_tokens_to { + new_token_ids.truncate(num_tokens); + if let Some(logprobs) = &mut new_logprobs { + logprobs.positions.truncate(num_tokens); + } + } + + output_token_count += new_token_ids.len(); + + let decoded_logprobs = new_logprobs + .as_ref() + .map(|logprobs| { + decode_logprobs( + tokenizer.as_ref(), + logprobs, + decode_options.skip_special_tokens, + ) + }) + .transpose()?; + + if !intermediate { + token_ids.extend(&new_token_ids); + if let Some(dlp) = decoded_logprobs.as_ref() { + logprobs + .get_or_insert_with(|| DecodedLogprobs { positions: vec![] }) + .positions + .extend_from_slice(&dlp.positions); + } + } + + if let Some(reason) = finish_reason { + // Flush any remaining buffered text. + let (last_chunk, mut text) = decoder.flush(truncate_output_to)?; + let text_len = text.len(); + let full_text = tracing::enabled!(Level::TRACE).then(|| text.clone()); + + if intermediate { + if let Some(chunk) = last_chunk { + if let Some(delta_str) = delta.as_mut() { + delta_str.push_str(&chunk); + } else { + delta = Some(chunk); + } + } + token_ids = new_token_ids; + logprobs = decoded_logprobs; + text = delta.unwrap_or_default(); + } + + debug!( + finish_reason = ?reason, + text_length_bytes = text_len, + output_token_count = output_token_count, + "request finished with terminal output" + ); + if let Some(full_text) = full_text { + trace!(full_text, "request finished with terminal decoded text"); + } + + // Intentionally drop the stream with explicit cause, so that the engine core + // can distinguish between such normal completion vs an unexpected + // early drop. + if stop_str_matched { + AbortCause::StopStringMatched.drop_as(raw_stream); + } + + y.yield_ok(DecodedTextEvent::TextDelta { + delta: text, + token_ids, + logprobs, + finished: Some(Finished { + prompt_token_count, + output_token_count, + finish_reason: reason, + kv_transfer_params, + }), + }) + .await; + return Ok(()); + } + + if intermediate { + y.yield_ok(DecodedTextEvent::TextDelta { + delta: delta.unwrap_or_default(), + token_ids: new_token_ids, + logprobs: decoded_logprobs, + finished: None, + }) + .await; + } + } + + Err(Error::StreamClosedBeforeTerminalOutput { request_id }) +} + +/// If stop string matches, returns tuple +/// (index into stop string vec, byte index of first byte of stop string in +/// output) +fn matches_stop_string(stops: &[String], output: &str, new_bytes: usize) -> Option<(usize, usize)> { + // We compare byte subslices to avoid utf8 boundary problem + let output = output.as_bytes(); + let next_off = (output.len() + 1) - new_bytes; + stops + .iter() + .map(|ss| (ss.as_bytes(), ss.len(), next_off.saturating_sub(ss.len()))) + .enumerate() + .find_map(|(ss_idx, (ss, len, start_off))| { + output[start_off..] + .windows(len) + .rposition(|w| w == ss) + .map(|pos| (ss_idx, start_off + pos)) + }) +} + +#[cfg(test)] +mod tests { + use std::pin::Pin; + use std::sync::{Arc, Mutex}; + use std::task::{Context, Poll}; + + use futures::{Stream, stream}; + use vllm_engine_core_client::AbortCause; + use vllm_llm::GenerateOutput; + use vllm_tokenizer::Tokenizer; + + use super::*; + use crate::output::TextOutputStreamExt as _; + + /// Backend that treats each token ID as a raw byte, producing lossy UTF-8. + struct ByteTokenizer; + + impl Tokenizer for ByteTokenizer { + fn encode( + &self, + _text: &str, + _add_special_tokens: bool, + ) -> vllm_tokenizer::Result> { + unreachable!() + } + + fn decode( + &self, + token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_tokenizer::Result { + let bytes = token_ids.iter().map(|id| *id as u8).collect::>(); + Ok(String::from_utf8_lossy(&bytes).into_owned()) + } + + fn token_to_id(&self, _token: &str) -> Option { + unreachable!() + } + } + + /// Helper: run `decoded_text_event_stream` to completion and return the + /// collected output. + async fn run_to_completion( + token_ids: Vec, + decode_options: TextDecodeOptions, + ) -> crate::output::CollectedTextOutput { + let prompt: Arc<[u32]> = Arc::from([]); + let raw_stream = stream::iter(vec![Ok(GenerateOutput::for_test( + Some(prompt), + token_ids, + Some(FinishReason::Length), + ))]); + let tokenizer: DynTokenizer = Arc::new(ByteTokenizer); + decoded_text_event_stream("test".into(), tokenizer, raw_stream, decode_options, false) + .collect_output() + .await + .unwrap() + } + + /// Convert ASCII string to token IDs (one byte per token). + fn ascii_tokens(s: &str) -> Vec { + s.bytes().map(u32::from).collect() + } + + fn opts(stop: &[&str], min_tokens: u32) -> TextDecodeOptions { + TextDecodeOptions { + stop_strings: Some(stop.iter().map(|s| s.to_string()).collect()), + min_tokens, + ..Default::default() + } + } + + struct DropRecordingStream { + next: Option>, + dropped_cause: Arc>>, + } + + impl Stream for DropRecordingStream { + type Item = vllm_llm::Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.next.take()) + } + } + + impl Drop for DropRecordingStream { + fn drop(&mut self) { + *self.dropped_cause.lock().unwrap() = Some(AbortCause::current()); + } + } + + // --- stop string stream tests --- + + #[tokio::test] + async fn stream_stop_string_sets_task_local_abort_cause_on_raw_stream_drop() { + let prompt: Arc<[u32]> = Arc::from([]); + let dropped_cause = Arc::new(Mutex::new(None)); + let raw_stream = DropRecordingStream { + next: Some(Ok(GenerateOutput::for_test( + Some(prompt), + ascii_tokens("hello"), + Some(FinishReason::Length), + ))), + dropped_cause: Arc::clone(&dropped_cause), + }; + let tokenizer: DynTokenizer = Arc::new(ByteTokenizer); + + let output = decoded_text_event_stream( + "test".into(), + tokenizer, + raw_stream, + opts(&["ll"], 0), + false, + ) + .collect_output() + .await + .unwrap(); + + assert_eq!(output.text, "he"); + assert!(output.finish_reason.is_stop()); + assert_eq!( + *dropped_cause.lock().unwrap(), + Some(AbortCause::StopStringMatched) + ); + } + + #[tokio::test] + async fn stream_stop_string_truncates_at_match() { + let output = run_to_completion(ascii_tokens("hello"), opts(&["e"], 0)).await; + assert_eq!(output.text, "h"); + assert!(output.finish_reason.is_stop()); + } + + #[tokio::test] + async fn stream_stop_string_at_end() { + let output = run_to_completion(ascii_tokens("abcxyz"), opts(&["xyz"], 0)).await; + assert_eq!(output.text, "abc"); + assert!(output.finish_reason.is_stop()); + } + + #[tokio::test] + async fn stream_stop_string_first_token() { + let output = run_to_completion(ascii_tokens("xhello"), opts(&["x"], 0)).await; + assert_eq!(output.text, ""); + assert!(output.finish_reason.is_stop()); + } + + #[tokio::test] + async fn stream_stop_string_no_match_runs_to_completion() { + let output = run_to_completion(ascii_tokens("hello"), opts(&["z"], 0)).await; + assert_eq!(output.text, "hello"); + assert_eq!(output.finish_reason, FinishReason::Length); + } + + #[tokio::test] + async fn stream_stop_string_multi_char() { + let output = run_to_completion(ascii_tokens("say hello world"), opts(&["lo"], 0)).await; + assert_eq!(output.text, "say hel"); + assert!(output.finish_reason.is_stop()); + } + + #[tokio::test] + async fn stream_stop_string_first_of_multiple_wins() { + // Both "ll" and "lo" are present; "ll" appears first in the output. + let output = run_to_completion(ascii_tokens("hello"), opts(&["ll", "lo"], 0)).await; + assert_eq!(output.text, "he"); + assert!(output.finish_reason.is_stop()); + } + + #[tokio::test] + async fn stream_stop_string_include_in_output() { + let output = run_to_completion( + ascii_tokens("hello"), + TextDecodeOptions { + stop_strings: Some(vec!["ll".to_string()]), + include_stop_str_in_output: true, + ..Default::default() + }, + ) + .await; + assert_eq!(output.text, "hell"); + assert!(output.finish_reason.is_stop()); + } + + // --- min_tokens + stop string interaction --- + + #[tokio::test] + async fn min_tokens_suppresses_early_stop_string() { + // stop="e", min_tokens=3: the 'e' at token 2 is within the first 3 tokens, + // so it should be skipped. No later 'e' exists, so output runs to completion. + let output = run_to_completion(ascii_tokens("hello"), opts(&["e"], 3)).await; + assert_eq!(output.text, "hello"); + assert_eq!(output.finish_reason, FinishReason::Length); + } + + #[tokio::test] + async fn min_tokens_allows_stop_string_after_threshold() { + // stop="e", min_tokens=2: the first 'e' at token 3 is past the threshold. + let output = run_to_completion(ascii_tokens("greet"), opts(&["e"], 2)).await; + assert_eq!(output.text, "gr"); + assert!(output.finish_reason.is_stop()); + } + + #[tokio::test] + async fn min_tokens_zero_behaves_like_absent() { + let output = run_to_completion(ascii_tokens("hello"), opts(&["e"], 0)).await; + assert_eq!(output.text, "h"); + assert!(output.finish_reason.is_stop()); + } + + #[test] + fn stop_string_matches_at_end() { + let stops = vec!["wor".to_string()]; + // Output: "say wor", last byte 'r' was just added (new_bytes=1) + let result = matches_stop_string(&stops, "say wor", 1); + assert_eq!(result, Some((0, 4))); + } + + #[test] + fn stop_string_no_match() { + let stops = vec!["xyz".to_string()]; + let result = matches_stop_string(&stops, "say wor", 1); + assert_eq!(result, None); + } + + #[test] + fn stop_string_matches_first_of_multiple() { + let stops = vec!["wor".to_string(), "say".to_string()]; + // "say" appears earlier but "wor" is checked first (index 0) + let result = matches_stop_string(&stops, "say wor", 1); + assert_eq!(result, Some((0, 4))); + } + + #[test] + fn stop_string_matches_second_of_multiple() { + let stops = vec!["xyz".to_string(), "wor".to_string()]; + let result = matches_stop_string(&stops, "say wor", 1); + assert_eq!(result, Some((1, 4))); + } + + #[test] + fn stop_string_matches_with_multiple_new_bytes() { + let stops = vec!["wor".to_string()]; + // "say wor" where last 3 bytes "wor" were added at once + let result = matches_stop_string(&stops, "say wor", 3); + assert_eq!(result, Some((0, 4))); + } + + #[test] + fn stop_string_matches_at_beginning() { + let stops = vec!["say".to_string()]; + let result = matches_stop_string(&stops, "say wor", 7); + assert_eq!(result, Some((0, 0))); + } + + #[test] + fn stop_string_exact_output() { + let stops = vec!["abc".to_string()]; + let result = matches_stop_string(&stops, "abc", 3); + assert_eq!(result, Some((0, 0))); + } + + #[test] + fn stop_string_single_char() { + let stops = vec!["!".to_string()]; + let result = matches_stop_string(&stops, "hello!", 1); + assert_eq!(result, Some((0, 5))); + } + + #[test] + fn stop_string_not_in_new_bytes_region() { + let stops = vec!["say".to_string()]; + // "say" is in the output but before the new byte region. + // new_bytes=1 means only 'r' was added; "say" ended at byte 3, + // but the search window starts at next_off - stop_len = 7+1-1 - 3 = 4. + let result = matches_stop_string(&stops, "say wor", 1); + assert_eq!(result, None); + } + + #[test] + fn stop_string_empty_list() { + let stops: Vec = vec![]; + let result = matches_stop_string(&stops, "hello", 1); + assert_eq!(result, None); + } + + #[test] + fn stop_string_multibyte_utf8() { + let stops = vec!["世界".to_string()]; + // "你好世界" is 12 bytes: 你(3) + 好(3) + 世(3) + 界(3) + // "世界" starts at byte 6 + let result = matches_stop_string(&stops, "你好世界", 3); + assert_eq!(result, Some((0, 6))); + } +} diff --git a/rust/src/text/src/output/logprobs.rs b/rust/src/text/src/output/logprobs.rs new file mode 100644 index 00000000000..7024c52b779 --- /dev/null +++ b/rust/src/text/src/output/logprobs.rs @@ -0,0 +1,236 @@ +use itertools::Itertools as _; +use serde::{Deserialize, Serialize}; +use vllm_llm::{Logprobs, PositionLogprobs}; +use vllm_tokenizer::Tokenizer; + +use crate::error::Error; + +/// One decoded token candidate and its logprob metadata. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct DecodedTokenLogprob { + /// Original vocabulary token ID for this candidate. + pub token_id: u32, + /// Best-effort decoded token string for this candidate. + pub token: String, + /// Log probability of this token candidate. + pub logprob: f32, + /// Vocabulary rank of this token candidate. + pub rank: u32, +} + +/// One position's decoded token candidates and their logprobs. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct DecodedPositionLogprobs { + /// Candidate tokens for this position. + pub entries: Vec, +} + +/// Decoded sample logprobs for generated token positions. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct DecodedLogprobs { + /// Generated token positions covered by this payload. + pub positions: Vec, +} + +/// Decoded prompt logprobs for prompt token positions. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct DecodedPromptLogprobs { + /// Original vocabulary token ID for the first prompt token. + pub first_token_id: u32, + /// Best-effort decoded string for the first prompt token. + /// + /// The first prompt token has no left context to score against, so it is + /// stored separately instead of appearing in `scored_positions`. + pub first_token: String, + /// Scored prompt positions after the first prompt token. + /// + /// `scored_positions[i]` corresponds to the prompt token at position `i + + /// 1`. + pub scored_positions: Vec, +} + +/// Decode generated-token logprobs from the raw `llm` token-ID shape into the +/// text-layer decoded-token representation. +/// +/// Each returned position corresponds to one generated token position from the +/// same `llm` update. +pub(super) fn decode_logprobs( + tokenizer: &T, + logprobs: &Logprobs, + skip_special_tokens: bool, +) -> Result { + Ok(DecodedLogprobs { + positions: logprobs + .positions + .iter() + .map(|position| decode_position_logprobs(tokenizer, position, skip_special_tokens)) + .try_collect()?, + }) +} + +/// Decode prompt logprobs from the raw `llm` token-ID shape into the text-layer +/// decoded-token representation. +/// +/// The returned payload stores the first prompt token separately and decodes +/// the remaining scored prompt positions into `scored_positions`, matching +/// vLLM's prompt-logprobs semantics. +pub(super) fn decode_prompt_logprobs( + tokenizer: &T, + prompt_token_ids: &[u32], + logprobs: &Logprobs, + skip_special_tokens: bool, +) -> Result { + let first_token_id = prompt_token_ids + .first() + .copied() + .expect("prompt logprobs require at least one prompt token"); + let first_token = tokenizer.decode(&[first_token_id], skip_special_tokens)?; + let scored_positions = logprobs + .positions + .iter() + .map(|position| decode_position_logprobs(tokenizer, position, skip_special_tokens)) + .try_collect()?; + + Ok(DecodedPromptLogprobs { + first_token_id, + first_token, + scored_positions, + }) +} + +/// Decode one token position's raw candidate set into decoded token strings +/// plus logprob metadata. +/// +/// This decodes every candidate token ID independently through the active text +/// backend. +fn decode_position_logprobs( + tokenizer: &T, + position: &PositionLogprobs, + skip_special_tokens: bool, +) -> Result { + Ok(DecodedPositionLogprobs { + entries: position + .entries + .iter() + .map(|entry| { + tokenizer.decode(&[entry.token_id], skip_special_tokens).map(|token| { + DecodedTokenLogprob { + token_id: entry.token_id, + token, + logprob: entry.logprob, + rank: entry.rank, + } + }) + }) + .try_collect()?, + }) +} + +#[cfg(test)] +mod tests { + use vllm_llm::{Logprobs, PositionLogprobs, TokenLogprob}; + + use super::*; + + #[derive(Debug)] + struct ByteTokenizer; + + impl vllm_tokenizer::Tokenizer for ByteTokenizer { + fn encode( + &self, + _text: &str, + _add_special_tokens: bool, + ) -> vllm_tokenizer::Result> { + unreachable!() + } + + fn decode( + &self, + token_ids: &[u32], + _skip_special_tokens: bool, + ) -> vllm_tokenizer::Result { + Ok(String::from_utf8_lossy( + &token_ids.iter().map(|token_id| *token_id as u8).collect::>(), + ) + .into_owned()) + } + + fn token_to_id(&self, _token: &str) -> Option { + unreachable!() + } + } + + #[test] + fn decode_logprobs_decodes_every_candidate_token() { + let tokenizer = ByteTokenizer; + let logprobs = Logprobs { + positions: vec![PositionLogprobs { + entries: vec![ + TokenLogprob { + token_id: b'a' as u32, + logprob: -0.1, + rank: 3, + }, + TokenLogprob { + token_id: b'b' as u32, + logprob: -0.2, + rank: 1, + }, + ], + }], + }; + + assert_eq!( + decode_logprobs(&tokenizer, &logprobs, false).unwrap(), + DecodedLogprobs { + positions: vec![DecodedPositionLogprobs { + entries: vec![ + DecodedTokenLogprob { + token_id: b'a' as u32, + token: "a".to_string(), + logprob: -0.1, + rank: 3, + }, + DecodedTokenLogprob { + token_id: b'b' as u32, + token: "b".to_string(), + logprob: -0.2, + rank: 1, + }, + ], + }], + } + ); + } + + #[test] + fn decode_prompt_logprobs_separates_first_prompt_token() { + let tokenizer = ByteTokenizer; + let logprobs = Logprobs { + positions: vec![PositionLogprobs { + entries: vec![TokenLogprob { + token_id: b'x' as u32, + logprob: -0.4, + rank: 1, + }], + }], + }; + + assert_eq!( + decode_prompt_logprobs(&tokenizer, &[b'p' as u32, b'x' as u32], &logprobs, false) + .unwrap(), + DecodedPromptLogprobs { + first_token_id: b'p' as u32, + first_token: "p".to_string(), + scored_positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: b'x' as u32, + token: "x".to_string(), + logprob: -0.4, + rank: 1, + }], + }], + } + ); + } +} diff --git a/rust/src/text/src/output/mod.rs b/rust/src/text/src/output/mod.rs new file mode 100644 index 00000000000..064b820d57f --- /dev/null +++ b/rust/src/text/src/output/mod.rs @@ -0,0 +1,323 @@ +//! Output processing helpers shared by text and chat layers. + +pub use decoded::{DecodedTextEvent, Finished, TextDecodeOptions, decoded_text_event_stream}; +pub use logprobs::{ + DecodedLogprobs, DecodedPositionLogprobs, DecodedPromptLogprobs, DecodedTokenLogprob, +}; + +mod decoded; +mod logprobs; + +use std::sync::Arc; + +use futures::{StreamExt as _, pin_mut}; + +use crate::{Error, FinishReason, Result, TextOutputStream}; + +/// Final decoded text plus terminal stream metadata. +#[derive(Debug, Clone, PartialEq)] +pub struct CollectedTextOutput { + pub text: String, + pub prompt_token_ids: Arc<[u32]>, + pub prompt_logprobs: Option, + pub logprobs: Option, + pub token_ids: Vec, + pub finish_reason: FinishReason, + /// Connector-specific KV transfer parameters for disaggregated serving. + pub kv_transfer_params: Option, +} + +#[allow(clippy::manual_async_fn, reason = "specify `Send` bound")] +#[easy_ext::ext(TextOutputStreamExt)] +impl T { + /// Collect the stream to completion and return the final decoded text plus + /// terminal metadata. + pub fn collect_output(self) -> impl Future> + Send { + async move { + let stream = self; + pin_mut!(stream); + let mut prompt_logprobs = None; + let mut prompt_token_ids: Arc<[u32]> = Arc::from([]); + let mut collected: Option = None; + + while let Some(event) = stream.next().await.transpose()? { + match event { + DecodedTextEvent::Start { + prompt_logprobs: start_prompt_logprobs, + prompt_token_ids: start_prompt_token_ids, + .. + } => { + prompt_logprobs = start_prompt_logprobs; + prompt_token_ids = start_prompt_token_ids; + } + DecodedTextEvent::TextDelta { + delta, + token_ids: delta_token_ids, + logprobs: mut delta_logprobs, + finished, + } => { + if let Some(c) = collected.as_mut() { + c.text.push_str(&delta); + c.token_ids.extend(delta_token_ids); + if let Some(dlp) = delta_logprobs.as_mut() { + if let Some(lp) = c.logprobs.as_mut() { + lp.positions.extend_from_slice(&dlp.positions); + } else { + c.logprobs = delta_logprobs; + } + } + } else { + collected = Some(CollectedTextOutput { + text: delta, + prompt_token_ids: Arc::clone(&prompt_token_ids), + prompt_logprobs: prompt_logprobs.take(), + logprobs: delta_logprobs, + token_ids: delta_token_ids, + finish_reason: FinishReason::Error, + kv_transfer_params: None, + }) + }; + + if let Some(finished) = finished { + let mut collected = collected.unwrap(); + collected.finish_reason = finished.finish_reason; + collected.kv_transfer_params = finished.kv_transfer_params; + return Ok(collected); + } + } + } + } + + // Note: this is actually unreachable, as the underlying stream always emit an + // error on unexpected close. + Err(Error::StreamClosedBeforeTerminalOutput { + request_id: "unknown".to_string(), + }) + } + } +} + +#[cfg(test)] +mod tests { + use futures::stream; + use vllm_llm::FinishReason; + + use super::*; + + #[tokio::test] + async fn collect_output_retains_prompt_and_sample_logprobs() { + let stream = stream::iter(vec![ + Ok(DecodedTextEvent::Start { + prompt_token_ids: vec![10, 11].into(), + prompt_logprobs: Some(DecodedPromptLogprobs { + first_token_id: 0, + first_token: "o".to_string(), + scored_positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "p".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }), + }), + Ok(DecodedTextEvent::TextDelta { + delta: "bc".to_string(), + token_ids: vec![1, 2], + logprobs: Some(DecodedLogprobs { + positions: vec![ + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "a".to_string(), + logprob: -0.2, + rank: 1, + }], + }, + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "bc".to_string(), + logprob: -0.3, + rank: 1, + }], + }, + ], + }), + finished: Some(Finished { + prompt_token_count: 2, + output_token_count: 2, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + }), + ]); + + let collected = stream.collect_output().await.unwrap(); + assert_eq!(collected.text, "bc"); + assert_eq!( + collected.prompt_logprobs, + Some(DecodedPromptLogprobs { + first_token_id: 0, + first_token: "o".to_string(), + scored_positions: vec![DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "p".to_string(), + logprob: -0.1, + rank: 1, + }], + }], + }) + ); + assert_eq!( + collected.logprobs, + Some(DecodedLogprobs { + positions: vec![ + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "a".to_string(), + logprob: -0.2, + rank: 1, + }], + }, + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "bc".to_string(), + logprob: -0.3, + rank: 1, + }], + }, + ], + }) + ); + } + + #[tokio::test] + async fn collect_output_accumulates_intermediate_deltas() { + let stream = stream::iter(vec![ + Ok(DecodedTextEvent::Start { + prompt_token_ids: vec![10, 11].into(), + prompt_logprobs: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "he".to_string(), + token_ids: vec![1, 2], + logprobs: Some(DecodedLogprobs { + positions: vec![ + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "h".to_string(), + logprob: -0.1, + rank: 1, + }], + }, + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "e".to_string(), + logprob: -0.2, + rank: 1, + }], + }, + ], + }), + finished: None, + }), + Ok(DecodedTextEvent::TextDelta { + delta: "llo".to_string(), + token_ids: vec![3, 4, 5], + logprobs: Some(DecodedLogprobs { + positions: vec![ + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "l".to_string(), + logprob: -0.3, + rank: 1, + }], + }, + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "l".to_string(), + logprob: -0.4, + rank: 1, + }], + }, + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "o".to_string(), + logprob: -0.5, + rank: 1, + }], + }, + ], + }), + finished: Some(Finished { + prompt_token_count: 2, + output_token_count: 5, + finish_reason: FinishReason::stop_eos(), + kv_transfer_params: None, + }), + }), + ]); + + let collected = stream.collect_output().await.unwrap(); + assert_eq!(collected.text, "hello"); + assert_eq!(collected.prompt_logprobs, None); + assert_eq!(collected.token_ids, vec![1, 2, 3, 4, 5]); + assert_eq!( + collected.logprobs, + Some(DecodedLogprobs { + positions: vec![ + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "h".to_string(), + logprob: -0.1, + rank: 1, + }], + }, + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "e".to_string(), + logprob: -0.2, + rank: 1, + }], + }, + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "l".to_string(), + logprob: -0.3, + rank: 1, + }], + }, + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "l".to_string(), + logprob: -0.4, + rank: 1, + }], + }, + DecodedPositionLogprobs { + entries: vec![DecodedTokenLogprob { + token_id: 0, + token: "o".to_string(), + logprob: -0.5, + rank: 1, + }], + }, + ], + }) + ); + } +} diff --git a/rust/src/text/src/request.rs b/rust/src/text/src/request.rs new file mode 100644 index 00000000000..9e2464f14af --- /dev/null +++ b/rust/src/text/src/request.rs @@ -0,0 +1,197 @@ +use std::collections::HashMap; + +use enum_as_inner::EnumAsInner; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use vllm_engine_core_client::protocol::StructuredOutputsParams; +use vllm_engine_core_client::protocol::multimodal::MmFeatures; + +use crate::error::{Error, Result}; +use crate::output::TextDecodeOptions; + +/// One raw text-generation prompt. +/// +/// This supports either ordinary text that still needs tokenization or +/// already-tokenized prompt IDs that should bypass tokenizer work entirely. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, EnumAsInner)] +#[serde(untagged)] +pub enum Prompt { + /// Untokenized prompt text that still needs tokenizer work before + /// generation. + Text(String), + /// Pre-tokenized prompt IDs that should be forwarded southbound without + /// re-encoding. + TokenIds(Vec), +} + +impl Default for Prompt { + fn default() -> Self { + Self::Text(String::new()) // placeholder + } +} + +/// User-facing sampling parameters accepted by `vllm-text`. +/// +/// This intentionally keeps only the subset that the current Rust text layer +/// supports as northbound request semantics. Engine-core-specific normalized +/// fields are derived later during lowering. +/// +/// Original Python definition: +/// +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(default)] +pub struct SamplingParams { + /// Controls randomness. Lower values are more deterministic; zero means + /// greedy sampling. `None` means no explicit user override. + pub temperature: Option, + /// Cumulative probability threshold for nucleus sampling. + pub top_p: Option, + /// Maximum number of top tokens to consider. `Some(0)` means all tokens. + pub top_k: Option, + /// Random seed used by the sampler when present. + pub seed: Option, + /// Maximum number of tokens to generate. `None` means no explicit user + /// override. + pub max_tokens: Option, + /// Minimum number of tokens to generate before EOS or stop-token handling. + pub min_tokens: Option, + /// Number of log probabilities to return per generated token. + /// + /// `None` disables sample logprobs. `-1` requests the full vocabulary. + pub logprobs: Option, + /// Number of log probabilities to return per prompt token. + /// + /// `None` disables prompt logprobs. `-1` requests the full vocabulary. + pub prompt_logprobs: Option, + /// Minimum probability threshold for token sampling. `None` means no + /// explicit user override. + pub min_p: Option, + /// Frequency penalty applied by the sampler. `None` means no explicit user + /// override. + pub frequency_penalty: Option, + /// Presence penalty applied by the sampler. `None` means no explicit user + /// override. + pub presence_penalty: Option, + /// Repetition penalty applied by the sampler. `None` means no explicit user + /// override. + pub repetition_penalty: Option, + /// Explicit stop token IDs provided by the caller. `None` means no explicit + /// user override. + pub stop_token_ids: Option>, + /// If true, do not stop on the model's primary EOS token. + pub ignore_eos: bool, + /// Modify the likelihood of specified tokens appearing in the completion. + /// Keys are token IDs. + pub logit_bias: Option>, + /// Restrict output to these token IDs only. + pub allowed_token_ids: Option>, + /// Words to avoid during generation (tokenized to IDs during lowering). + pub bad_words: Option>, + /// Specific token IDs for which log probabilities should be returned at + /// each position. + /// + /// When set, the engine returns logprobs for exactly these tokens in + /// addition to the sampled/scored token. Mutually exclusive with + /// `logprobs` in practice. + pub logprob_token_ids: Option>, + /// Parameters for configuring structured outputs (guided decoding). + pub structured_outputs: Option, + /// If true, bypass reads from the prefix cache for this request (the prompt + /// will not reuse cached KV blocks from earlier requests, though newly + /// computed blocks may still populate the cache). `None` defers to + /// engine-core defaults. + pub skip_reading_prefix_cache: Option, + /// Additional request parameters for custom extensions. + pub vllm_xargs: Option>, +} + +#[allow(clippy::derivable_impls)] // more explicit +impl Default for SamplingParams { + fn default() -> Self { + Self { + temperature: None, + top_p: None, + top_k: None, + seed: None, + max_tokens: None, + min_tokens: None, + logprobs: None, + prompt_logprobs: None, + min_p: None, + frequency_penalty: None, + presence_penalty: None, + repetition_penalty: None, + stop_token_ids: None, + ignore_eos: false, + logit_bias: None, + allowed_token_ids: None, + bad_words: None, + logprob_token_ids: None, + structured_outputs: None, + skip_reading_prefix_cache: None, + vllm_xargs: None, + } + } +} + +/// One raw text-generation request ready to be tokenized or sent directly to +/// the engine. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct TextRequest { + /// Stable caller-supplied request ID. + pub request_id: String, + /// Prompt text or prompt token IDs for this request. + pub prompt: Prompt, + /// Multimodal features prepared by a higher-level frontend. Raw text + /// requests keep this empty; multimodal chat uses it with pre-tokenized + /// prompt IDs. + pub mm_features: Option, + /// User-facing sampling parameters accepted by `vllm-text`. + pub sampling_params: SamplingParams, + /// Incremental detokenization options for the response path. + pub decode_options: TextDecodeOptions, + /// Whether to emit intermediate northbound deltas before the terminal + /// result. + /// + /// If `false`, callers only observe the terminal accumulated output. If + /// `true`, callers may receive zero or more incremental decoded updates + /// before the final terminal event. + pub intermediate: bool, + /// Request scheduling priority (lower means earlier handling; default 0). + pub priority: i32, + /// Salt for prefix cache isolation in multi-user environments. + pub cache_salt: Option, + /// Whether to add special tokens (e.g. BOS) during prompt tokenization. + pub add_special_tokens: bool, + /// Override data parallel rank. + #[serde(default)] + pub data_parallel_rank: Option, +} + +impl TextRequest { + /// Return one minimal valid request fixture for tests. + pub fn for_test() -> Self { + Self { + request_id: "test-request".to_string(), + prompt: Prompt::Text("test".to_string()), + mm_features: None, + sampling_params: SamplingParams::default(), + decode_options: TextDecodeOptions::default(), + intermediate: true, + priority: 0, + cache_salt: None, + add_special_tokens: false, + data_parallel_rank: None, + } + } + + /// Validate the minimum invariants before tokenization or request lowering. + pub fn validate(&self) -> Result<()> { + if matches!(&self.prompt, Prompt::TokenIds(ids) if ids.is_empty()) { + return Err(Error::EmptyPromptTokenIds { + request_id: self.request_id.clone(), + }); + } + Ok(()) + } +} diff --git a/rust/src/tokenizer/Cargo.toml b/rust/src/tokenizer/Cargo.toml new file mode 100644 index 00000000000..786c46f4031 --- /dev/null +++ b/rust/src/tokenizer/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "vllm-tokenizer" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +base64.workspace = true +fastokens.workspace = true +riptoken.workspace = true +rustc-hash.workspace = true +serde.workspace = true +serde_json.workspace = true +tekken.workspace = true +thiserror.workspace = true +thiserror-ext.workspace = true +tiktoken-rs.workspace = true +tokenizers.workspace = true +tracing.workspace = true + +[dev-dependencies] +criterion.workspace = true +hf-hub.workspace = true +tempfile.workspace = true + +[[bench]] +name = "hf" +harness = false + +[[bench]] +name = "tiktoken" +harness = false + +[lints] +workspace = true diff --git a/rust/src/tokenizer/benches/hf.rs b/rust/src/tokenizer/benches/hf.rs new file mode 100644 index 00000000000..9bf37778089 --- /dev/null +++ b/rust/src/tokenizer/benches/hf.rs @@ -0,0 +1,118 @@ +use criterion::{Criterion, Throughput, black_box, criterion_group, criterion_main}; +use hf_hub::api::sync::ApiBuilder; +use vllm_tokenizer::{HuggingFaceTokenizer, Tokenizer}; + +const MODEL_ID: &str = "Qwen/Qwen3.5-0.8B"; +const SAMPLE_TEXT: &str = "\ +<|im_start|>system +You are Qwen3.5, a helpful assistant. +<|im_end|> +<|im_start|>user +请用中英混合总结以下需求,并给出一个简短的 JSON 示例。 +The service should stop cleanly at EOS, avoid leaking the next template turn, and keep decode latency low. +Input: 4 concurrent requests, 10240 prompt tokens, 16 generated tokens. +<|im_end|> +<|im_start|>assistant +"; + +struct BenchFixture { + fastokens: HuggingFaceTokenizer, + hf: HuggingFaceTokenizer, + text: String, + token_ids: Vec, +} + +impl BenchFixture { + fn load() -> Self { + let path = tokenizer_json(); + let fastokens = + HuggingFaceTokenizer::new_fastokens(&path).expect("load fastokens tokenizer"); + let hf = HuggingFaceTokenizer::new_hf(&path).expect("load huggingface tokenizer"); + + let text = SAMPLE_TEXT.repeat(32); + let hf_token_ids = + hf.encode(text.as_str(), false).expect("encode sample text with hf tokenizer"); + let fastokens_token_ids = fastokens + .encode(text.as_str(), false) + .expect("encode sample text with fastokens"); + assert_eq!(fastokens_token_ids, hf_token_ids); + + let hf_decoded = hf + .decode(hf_token_ids.as_slice(), false) + .expect("decode sample token ids with hf tokenizer"); + let fastokens_decoded = fastokens + .decode(hf_token_ids.as_slice(), false) + .expect("decode sample token ids with fastokens"); + assert_eq!(fastokens_decoded, hf_decoded); + + Self { + fastokens, + hf, + text, + token_ids: hf_token_ids, + } + } +} + +fn tokenizer_json() -> std::path::PathBuf { + ApiBuilder::from_env() + .with_progress(false) + .build() + .expect("build hf-hub api") + .model(MODEL_ID.to_string()) + .get("tokenizer.json") + .expect("fetch tokenizer.json from hf-hub") +} + +fn bench_encode(c: &mut Criterion) { + let fixture = BenchFixture::load(); + let mut group = c.benchmark_group("tokenizer_encode"); + group.throughput(Throughput::Bytes(fixture.text.len() as u64)); + + group.bench_function("fastokens", |b| { + b.iter(|| { + fixture + .fastokens + .encode(black_box(fixture.text.as_str()), black_box(false)) + .expect("encode sample text with fastokens") + }) + }); + group.bench_function("hf_tokenizers", |b| { + b.iter(|| { + fixture + .hf + .encode(black_box(fixture.text.as_str()), black_box(false)) + .expect("encode sample text with hf tokenizer") + }) + }); + + group.finish(); +} + +fn bench_decode(c: &mut Criterion) { + let fixture = BenchFixture::load(); + let mut group = c.benchmark_group("tokenizer_decode"); + group.throughput(Throughput::Elements(fixture.token_ids.len() as u64)); + + group.bench_function("fastokens", |b| { + b.iter(|| { + fixture + .fastokens + .decode(black_box(fixture.token_ids.as_slice()), black_box(false)) + .expect("decode sample token ids with fastokens") + }) + }); + group.bench_function("hf_tokenizers", |b| { + b.iter(|| { + fixture + .hf + .decode(black_box(fixture.token_ids.as_slice()), black_box(false)) + .expect("decode sample token ids with hf tokenizer") + }) + }); + + group.finish(); +} + +criterion_group!(benches, bench_encode, bench_decode); +criterion_main!(benches); diff --git a/rust/src/tokenizer/benches/tiktoken.rs b/rust/src/tokenizer/benches/tiktoken.rs new file mode 100644 index 00000000000..54b9805f01a --- /dev/null +++ b/rust/src/tokenizer/benches/tiktoken.rs @@ -0,0 +1,117 @@ +use criterion::{Criterion, Throughput, black_box, criterion_group, criterion_main}; +use hf_hub::api::sync::ApiBuilder; +use vllm_tokenizer::{TiktokenTokenizer, Tokenizer}; + +const MODEL_ID: &str = "moonshotai/Kimi-K2.5"; +const SAMPLE_TEXT: &str = "\ + +I'm sure it's fine, but I can't say I'd trust that it's what we'd ship. + +请用中英混合总结以下需求,并保留 tool-call marker: +<|tool_calls_section_begin|>{\"name\":\"summarize\",\"arguments\":{\"style\":\"brief\"}}<|tool_calls_section_end|> +The service should stop cleanly at EOS, avoid leaking the next template turn, and keep decode latency low. +"; + +struct BenchFixture { + riptoken: TiktokenTokenizer, + tiktoken_rs: TiktokenTokenizer, + text: String, + token_ids: Vec, +} + +impl BenchFixture { + fn load() -> Self { + let path = tiktoken_model(); + let riptoken = TiktokenTokenizer::new_riptoken(&path).expect("load riptoken tokenizer"); + let tiktoken_rs = + TiktokenTokenizer::new_tiktoken_rs(&path).expect("load tiktoken-rs tokenizer"); + + let text = SAMPLE_TEXT.repeat(32); + let riptoken_token_ids = + riptoken.encode(text.as_str(), false).expect("encode sample text with riptoken"); + let tiktoken_rs_token_ids = tiktoken_rs + .encode(text.as_str(), false) + .expect("encode sample text with tiktoken-rs"); + assert_eq!(riptoken_token_ids, tiktoken_rs_token_ids); + + let riptoken_decoded = riptoken + .decode(riptoken_token_ids.as_slice(), false) + .expect("decode sample token ids with riptoken"); + let tiktoken_rs_decoded = tiktoken_rs + .decode(riptoken_token_ids.as_slice(), false) + .expect("decode sample token ids with tiktoken-rs"); + assert_eq!(riptoken_decoded, tiktoken_rs_decoded); + + Self { + riptoken, + tiktoken_rs, + text, + token_ids: riptoken_token_ids, + } + } +} + +fn tiktoken_model() -> std::path::PathBuf { + let repo = ApiBuilder::from_env() + .with_progress(false) + .build() + .expect("build hf-hub api") + .model(MODEL_ID.to_string()); + repo.get("config.json").expect("fetch config.json from hf-hub"); + repo.get("tokenizer_config.json") + .expect("fetch tokenizer_config.json from hf-hub"); + repo.get("tiktoken.model").expect("fetch tiktoken.model from hf-hub") +} + +fn bench_encode(c: &mut Criterion) { + let fixture = BenchFixture::load(); + let mut group = c.benchmark_group("tiktoken_encode"); + group.throughput(Throughput::Bytes(fixture.text.len() as u64)); + + group.bench_function("riptoken", |b| { + b.iter(|| { + fixture + .riptoken + .encode(black_box(fixture.text.as_str()), black_box(false)) + .expect("encode sample text with riptoken") + }) + }); + group.bench_function("tiktoken_rs", |b| { + b.iter(|| { + fixture + .tiktoken_rs + .encode(black_box(fixture.text.as_str()), black_box(false)) + .expect("encode sample text with tiktoken-rs") + }) + }); + + group.finish(); +} + +fn bench_decode(c: &mut Criterion) { + let fixture = BenchFixture::load(); + let mut group = c.benchmark_group("tiktoken_decode"); + group.throughput(Throughput::Elements(fixture.token_ids.len() as u64)); + + group.bench_function("riptoken", |b| { + b.iter(|| { + fixture + .riptoken + .decode(black_box(fixture.token_ids.as_slice()), black_box(false)) + .expect("decode sample token ids with riptoken") + }) + }); + group.bench_function("tiktoken_rs", |b| { + b.iter(|| { + fixture + .tiktoken_rs + .decode(black_box(fixture.token_ids.as_slice()), black_box(false)) + .expect("decode sample token ids with tiktoken-rs") + }) + }); + + group.finish(); +} + +criterion_group!(benches, bench_encode, bench_decode); +criterion_main!(benches); diff --git a/rust/src/tokenizer/src/byte_level_decode.rs b/rust/src/tokenizer/src/byte_level_decode.rs new file mode 100644 index 00000000000..8208e8e8d48 --- /dev/null +++ b/rust/src/tokenizer/src/byte_level_decode.rs @@ -0,0 +1,119 @@ +//! Fast GPT-2 byte-level detokenization that writes into a single `Vec`, +//! avoiding the `Vec` / `String::join` assembly in fastokens' generic +//! `Decoder::decode` pipeline. + +/// Reverse GPT-2 byte-to-unicode mapping: codepoint → original byte. The GPT-2 +/// table only emits codepoints in U+0000..U+0143, so a flat array suffices. +const CHAR_TO_BYTE: [u8; 324] = build_char_to_byte(); + +const fn is_nice(b: u8) -> bool { + (b >= b'!' && b <= b'~') || (b >= 0xA1 && b <= 0xAC) || b >= 0xAE +} + +const fn build_char_to_byte() -> [u8; 324] { + let mut table = [0u8; 324]; + let mut b: u16 = 0; + while b < 256 { + let cp = if is_nice(b as u8) { + b as u32 + } else { + 256 + nice_offset(b as u8) + }; + table[cp as usize] = b as u8; + b += 1; + } + table +} + +const fn nice_offset(b: u8) -> u32 { + let mut i: u16 = 0; + let mut n: u32 = 0; + while i < b as u16 { + if !is_nice(i as u8) { + n += 1; + } + i += 1; + } + n +} + +/// Decode byte-level encoded token strings into a single UTF-8 string, +/// matching `fastokens::decoders::ByteLevelDecoder`. +pub fn decode_byte_level<'a, I: IntoIterator>(tokens: I) -> String { + let iter = tokens.into_iter(); + let (lower, _) = iter.size_hint(); + let mut bytes: Vec = Vec::with_capacity(lower.saturating_mul(4)); + for token in iter { + for c in token.chars() { + let cp = c as usize; + if cp < CHAR_TO_BYTE.len() { + bytes.push(CHAR_TO_BYTE[cp]); + } else { + // Non-GPT2 codepoints (e.g. DeepSeek's U+FF5C, U+2581) pass through. + let mut buf = [0u8; 4]; + bytes.extend_from_slice(c.encode_utf8(&mut buf).as_bytes()); + } + } + } + String::from_utf8(bytes).unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned()) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn build_byte_to_char_ref() -> [char; 256] { + let mut table = ['\0'; 256]; + let mut next: u32 = 256; + for b in 0..=255u8 { + let cp = if is_nice(b) { + b as u32 + } else { + let cp = next; + next += 1; + cp + }; + table[b as usize] = char::from_u32(cp).unwrap(); + } + table + } + + #[test] + fn char_to_byte_roundtrips_every_byte() { + let byte_to_char = build_byte_to_char_ref(); + for b in 0..=255u8 { + let cp = byte_to_char[b as usize] as usize; + assert!(cp < CHAR_TO_BYTE.len()); + assert_eq!(CHAR_TO_BYTE[cp], b, "mismatch for byte {b:#x}"); + } + } + + #[test] + fn decode_ascii() { + assert_eq!(decode_byte_level(["Hello"]), "Hello"); + } + + #[test] + fn decode_space_marker() { + // GPT-2 maps 0x20 → Ġ (U+0120). + assert_eq!( + decode_byte_level(["\u{120}Hello", "\u{120}world"]), + " Hello world", + ); + } + + #[test] + fn decode_multibyte_euro() { + // € → 0xE2 0x82 0xAC, each mapped to a specific GPT-2 char. + let byte_to_char = build_byte_to_char_ref(); + let encoded: String = + [0xE2u8, 0x82, 0xAC].iter().map(|&b| byte_to_char[b as usize]).collect(); + assert_eq!(decode_byte_level([encoded.as_str()]), "€"); + } + + #[test] + fn decode_preserves_non_gpt2_chars() { + let tok = "<\u{FF5C}begin\u{2581}of\u{2581}sentence\u{FF5C}>"; + assert_eq!(decode_byte_level([tok]), "<|begin▁of▁sentence|>"); + } +} diff --git a/rust/src/tokenizer/src/error.rs b/rust/src/tokenizer/src/error.rs new file mode 100644 index 00000000000..cfe8253ffe2 --- /dev/null +++ b/rust/src/tokenizer/src/error.rs @@ -0,0 +1,9 @@ +use thiserror::Error; +use thiserror_ext::Macro; + +pub type Result = std::result::Result; + +#[derive(Debug, Error, Macro)] +#[thiserror_ext(macro(path = "crate::error"))] +#[error("tokenizer error: {0}")] +pub struct TokenizerError(#[message] pub String); diff --git a/rust/src/tokenizer/src/hf.rs b/rust/src/tokenizer/src/hf.rs new file mode 100644 index 00000000000..93b48545a24 --- /dev/null +++ b/rust/src/tokenizer/src/hf.rs @@ -0,0 +1,344 @@ +use std::path::Path; +use std::sync::Arc; + +use fastokens::Tokenizer as FastokensTokenizer; +use fastokens::decoders::Decoder as FastokensDecoder; +use thiserror_ext::AsReport as _; +use tokenizers::Tokenizer as HfTokenizer; +use tracing::{info, warn}; + +use crate::byte_level_decode::decode_byte_level; +use crate::{Result, Tokenizer}; + +enum Backend { + Hf(Box), + Fastokens(Box), + /// Fastokens tokenizer whose decoder is pure GPT-2 byte-level, so we can + /// bypass `Decoder::decode`'s `Vec`/`join("")` assembly. + FastokensByteLevel(Box), +} + +/// True if `dec` is effectively a single `ByteLevel` stage — one `ByteLevel` +/// leaf in a tree of `Sequence`s (fastokens represents `Fuse` as an empty +/// `Sequence`, which is a no-op for our purposes). +fn is_byte_level_only(dec: &FastokensDecoder) -> bool { + fn count_byte_level(dec: &FastokensDecoder) -> usize { + match dec { + FastokensDecoder::ByteLevel(_) => 1, + FastokensDecoder::Sequence(steps) => steps.iter().map(count_byte_level).sum(), + } + } + count_byte_level(dec) == 1 +} + +fn decode_fastokens_byte_level( + t: &FastokensTokenizer, + token_ids: &[u32], + skip_special_tokens: bool, +) -> Result { + let tokens: Vec<&str> = token_ids + .iter() + .filter(|&&id| !(skip_special_tokens && t.is_special_token(id))) + .map(|&id| { + t.id_to_token(id) + .ok_or_else(|| tokenizer_error!("decoding failed: unknown token ID: {id}")) + }) + .collect::>()?; + Ok(decode_byte_level(tokens)) +} + +/// Tokenizer from `tokenizer.json` in HuggingFace format. +/// +/// This tries to load with `fastokens` first for better performance, then falls +/// back to HuggingFace's `tokenizers` if the former fails (e.g. due to +/// unsupported tokenizer features or file formats). +pub struct HuggingFaceTokenizer { + backend: Backend, + special_token_ids: Arc<[u32]>, +} + +impl HuggingFaceTokenizer { + fn from_hf_backend(tokenizer: HfTokenizer) -> Self { + let special_token_ids = { + let mut ids: Vec = tokenizer + .get_added_tokens_decoder() + .iter() + .filter(|(_id, token)| token.special) + .map(|(id, _token)| *id) + .collect(); + ids.sort_unstable(); + ids.dedup(); + Arc::from(ids) + }; + Self { + backend: Backend::Hf(Box::new(tokenizer)), + special_token_ids, + } + } + + fn from_fastokens_backend(tokenizer: FastokensTokenizer) -> Self { + let special_token_ids = { + let mut ids: Vec = tokenizer + .added_tokens() + .into_iter() + .flat_map(|added_tokens| added_tokens.iter()) + .filter(|token| token.special) + .map(|token| token.id) + .collect(); + ids.sort_unstable(); + ids.dedup(); + Arc::from(ids) + }; + let byte_level = tokenizer.decoder().is_some_and(is_byte_level_only); + let backend = if byte_level { + Backend::FastokensByteLevel(Box::new(tokenizer)) + } else { + Backend::Fastokens(Box::new(tokenizer)) + }; + Self { + backend, + special_token_ids, + } + } + + /// Load from `tokenizer.json` with `fastokens`. + pub fn new_fastokens(path: &Path) -> Result { + info!(path = %path.display(), "loading tokenizer with fastokens"); + let t = FastokensTokenizer::from_file(path) + .map_err(|error| tokenizer_error!("failed to load tokenizer: {}", error.as_report()))?; + Ok(Self::from_fastokens_backend(t)) + } + + /// Load from `tokenizer.json` with Hugging Face `tokenizers`. + pub fn new_hf(path: &Path) -> Result { + info!(path = %path.display(), "loading tokenizer with huggingface tokenizers"); + let t = HfTokenizer::from_file(path) + .map_err(|error| tokenizer_error!("failed to load tokenizer: {}", error.as_report()))?; + Ok(Self::from_hf_backend(t)) + } + + /// Load from `tokenizer.json` via fastokens or HuggingFace tokenizers. + pub fn new(path: &Path) -> Result { + match Self::new_fastokens(path) { + Ok(tokenizer) => Ok(tokenizer), + Err(error) => { + warn!( + path = %path.display(), + error = %error.as_report(), + "failed to load tokenizer with fastokens; falling back to HuggingFace tokenizers" + ); + Self::new_hf(path) + } + } + } +} + +impl Tokenizer for HuggingFaceTokenizer { + fn encode(&self, text: &str, add_special_tokens: bool) -> Result> { + match &self.backend { + Backend::Hf(t) => { + let encoding = t + .encode(text, add_special_tokens) + .map_err(|error| tokenizer_error!("encoding failed: {}", error.as_report()))?; + Ok(encoding.get_ids().to_vec()) + } + Backend::Fastokens(t) | Backend::FastokensByteLevel(t) => t + .encode_with_special_tokens(text, add_special_tokens) + .map_err(|error| tokenizer_error!("encoding failed: {}", error.as_report())), + } + } + + fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result { + match &self.backend { + Backend::Hf(t) => t + .decode(token_ids, skip_special_tokens) + .map_err(|error| tokenizer_error!("decoding failed: {}", error.as_report())), + Backend::Fastokens(t) => t + .decode(token_ids, skip_special_tokens) + .map_err(|error| tokenizer_error!("decoding failed: {}", error.as_report())), + Backend::FastokensByteLevel(t) => { + decode_fastokens_byte_level(t, token_ids, skip_special_tokens) + } + } + } + + fn token_to_id(&self, token: &str) -> Option { + match &self.backend { + Backend::Hf(t) => t.token_to_id(token), + Backend::Fastokens(t) | Backend::FastokensByteLevel(t) => t.token_to_id(token), + } + } + + fn id_to_token(&self, id: u32) -> Option { + match &self.backend { + Backend::Hf(t) => t.id_to_token(id), + Backend::Fastokens(t) | Backend::FastokensByteLevel(t) => { + t.id_to_token(id).map(ToOwned::to_owned) + } + } + } + + fn is_special_id(&self, token_id: u32) -> bool { + self.special_token_ids.binary_search(&token_id).is_ok() + } +} + +#[cfg(test)] +mod tests { + use tempfile::tempdir; + use tokenizers::models::bpe::BPE; + use tokenizers::{AddedToken, Tokenizer as HfTokenizer}; + + use super::{HuggingFaceTokenizer, Tokenizer}; + + fn tiny_bpe_tokenizer() -> HfTokenizer { + let vocab = [ + ("".to_string(), 0), + ("h".to_string(), 1), + ("e".to_string(), 2), + ("l".to_string(), 3), + ("o".to_string(), 4), + ("he".to_string(), 5), + ("ll".to_string(), 6), + ("hell".to_string(), 7), + ("hello".to_string(), 8), + ]; + let merges = vec![ + ("h".to_string(), "e".to_string()), + ("l".to_string(), "l".to_string()), + ("he".to_string(), "ll".to_string()), + ("hell".to_string(), "o".to_string()), + ]; + let model = BPE::builder() + .vocab_and_merges(vocab, merges) + .unk_token("".to_string()) + .build() + .expect("build bpe tokenizer"); + HfTokenizer::new(model) + } + + #[test] + fn hf_constructor_resolves_added_token_ids() { + let mut tokenizer = tiny_bpe_tokenizer(); + tokenizer.add_special_tokens(&[AddedToken::from("<|im_end|>", true)]); + + let dir = tempdir().expect("create temp dir"); + let path = dir.path().join("tokenizer.json"); + tokenizer.save(&path, false).expect("save tokenizer json"); + + let wrapper = HuggingFaceTokenizer::new_hf(&path).expect("load hf wrapper"); + let special_id = wrapper.token_to_id("<|im_end|>").expect("resolve added special token id"); + assert!(wrapper.is_special_id(special_id)); + } + + #[test] + fn new_fastokens_preserves_special_ids_from_fastokens_metadata() { + let mut tokenizer = tiny_bpe_tokenizer(); + tokenizer.add_special_tokens(&[AddedToken::from("<|im_end|>", true)]); + + let dir = tempdir().expect("create temp dir"); + let path = dir.path().join("tokenizer.json"); + tokenizer.save(&path, false).expect("save tokenizer json"); + + let wrapper = HuggingFaceTokenizer::new_fastokens(&path) + .expect("load wrapper with fastokens backend"); + assert!(matches!( + wrapper.backend, + super::Backend::Fastokens(_) | super::Backend::FastokensByteLevel(_), + )); + let special_id = wrapper.token_to_id("<|im_end|>").expect("resolve added special token id"); + assert!(wrapper.is_special_id(special_id)); + } + + /// BPE tokenizer that round-trips through fastokens with a genuine + /// `ByteLevel` decoder; vocab covers both GPT-2 (Ġ U+0120) and non-GPT-2 + /// (| U+FF5C) codepoints. + fn tiny_byte_level_bpe() -> fastokens::Tokenizer { + let raw = r#"{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + {"id": 0, "content": "<|endoftext|>", "single_word": false, + "lstrip": false, "rstrip": false, "normalized": false, "special": true} + ], + "normalizer": null, + "pre_tokenizer": {"type": "ByteLevel", "add_prefix_space": false, + "trim_offsets": true, "use_regex": true}, + "post_processor": null, + "decoder": {"type": "ByteLevel", "add_prefix_space": false, + "trim_offsets": true, "use_regex": true}, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": false, + "vocab": { + "<|endoftext|>": 0, + "H": 1, "e": 2, "l": 3, "o": 4, "w": 5, "r": 6, "d": 7, + "Ġ": 8, "!": 9, + "|": 10 + }, + "merges": [] + } + }"#; + let value: serde_json::Value = serde_json::from_str(raw).expect("parse tokenizer json"); + fastokens::Tokenizer::from_json(value).expect("build fastokens tokenizer") + } + + #[test] + fn byte_level_detected_direct() { + let t = tiny_byte_level_bpe(); + assert!(super::is_byte_level_only(t.decoder().expect("decoder"))); + } + + #[test] + fn byte_level_detected_inside_sequence() { + let raw = r#"{ + "type": "Sequence", + "decoders": [ + {"type": "ByteLevel", "add_prefix_space": false, + "trim_offsets": true, "use_regex": true}, + {"type": "Fuse"} + ] + }"#; + let config: fastokens::DecoderConfig = + serde_json::from_str(raw).expect("parse decoder config"); + let dec = + fastokens::decoders::Decoder::from_config(config).expect("build decoder from config"); + assert!(super::is_byte_level_only(&dec)); + } + + /// Fast path must produce byte-identical output to fastokens' own decode. + #[test] + fn fast_byte_level_matches_fastokens_decode() { + let t = tiny_byte_level_bpe(); + let cases: &[&[u32]] = &[ + &[], + &[1, 2, 3, 3, 4], // "Hello" + &[1, 2, 3, 3, 4, 8, 5, 4, 6, 3, 7], // "Hello world" + &[0, 1, 2, 3, 3, 4, 0, 9, 0], // specials interleaved + &[10, 1, 2, 3, 3, 4, 10], // |Hello| (non-GPT2 chars) + ]; + for ids in cases { + for &skip in &[false, true] { + let expected = t.decode(ids, skip).expect("fastokens decode"); + let got = + super::decode_fastokens_byte_level(&t, ids, skip).expect("fast-path decode"); + assert_eq!(got, expected, "ids={ids:?} skip={skip}"); + } + } + } + + #[test] + fn fast_byte_level_errors_on_unknown_id() { + let t = tiny_byte_level_bpe(); + let err = super::decode_fastokens_byte_level(&t, &[999], false) + .expect_err("unknown id must error"); + assert!(format!("{err:?}").contains("999")); + } +} diff --git a/rust/src/tokenizer/src/incremental.rs b/rust/src/tokenizer/src/incremental.rs new file mode 100644 index 00000000000..7a025d35e5c --- /dev/null +++ b/rust/src/tokenizer/src/incremental.rs @@ -0,0 +1,359 @@ +use std::mem::take; + +use crate::{Result, Tokenizer}; + +/// Stateful incremental decoder that emits text chunks one token at a time. +pub trait IncrementalDecoder: Send { + /// Push one generated token and return how many new string bytes were + /// added. + fn push_token(&mut self, token_id: u32) -> Result; + + /// Consume any text which is currently ready. + fn next_chunk(&mut self) -> Option; + + /// Flush any remaining buffered text that has not yet been emitted. + /// + /// Called after the final generated token to force out buffered/incomplete + /// fragments. + fn flush(&mut self, truncate_output_to: Option) -> Result<(Option, String)>; + + /// Return cumulative decoded text so far. + fn output(&self) -> &str; +} + +/// [`IncrementalDecoder`] built on [`Tokenizer::decode()`] with prefix-diffing. +/// +/// This is the same sliding-window algorithm used by `tokenizers::DecodeStream` +pub(crate) struct DecodeStream<'a, T: Tokenizer + ?Sized> { + tokenizer: &'a T, + skip_special_tokens: bool, + min_bytes_to_buffer: usize, + // mutated state + ids: Vec, + prefix: String, + prefix_index: usize, + cumulative_output: String, + output_index: usize, +} + +impl<'a, T: Tokenizer + ?Sized> DecodeStream<'a, T> { + pub(crate) fn new( + tokenizer: &'a T, + prompt_token_ids: &[u32], + skip_special_tokens: bool, + min_bytes_to_buffer: usize, + ) -> Self { + Self { + tokenizer, + skip_special_tokens, + min_bytes_to_buffer, + ids: prompt_token_ids.to_vec(), + prefix: String::new(), + prefix_index: 0, + cumulative_output: String::new(), + output_index: 0, + } + } +} + +/// Try a short tail suffix first (covers a CJK glyph straddling 1-2 token +/// boundaries); beyond 6 tokens the fallback full-prompt decode is no worse +/// than baseline so widening the sweep just adds overhead. +const SAFE_SUFFIX_MIN: usize = 4; +const SAFE_SUFFIX_MAX: usize = 6; + +impl DecodeStream<'_, T> { + /// Seed `self.prefix` from the shortest trailing suffix whose decoded text + /// has no U+FFFD — a clean decode means the suffix starts and ends at + /// valid UTF-8/token boundaries, so priming from it is equivalent to + /// priming from the full prompt. + fn seed_prefix(&mut self) -> Result<()> { + let prompt_len = self.ids.len(); + if prompt_len > SAFE_SUFFIX_MIN { + let max_try = SAFE_SUFFIX_MAX.min(prompt_len - 1); + for suffix_len in SAFE_SUFFIX_MIN..=max_try { + let start = prompt_len - suffix_len; + let decoded = + self.tokenizer.decode(&self.ids[start..], self.skip_special_tokens)?; + if !decoded.contains('\u{FFFD}') { + self.prefix = decoded; + self.ids.drain(..start); + self.prefix_index = self.ids.len(); + return Ok(()); + } + } + } + let decoded = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?; + if !decoded.ends_with('\u{FFFD}') { + self.prefix = decoded; + self.prefix_index = self.ids.len(); + } + Ok(()) + } +} + +impl IncrementalDecoder for DecodeStream<'_, T> { + fn push_token(&mut self, token_id: u32) -> Result { + if self.prefix.is_empty() && !self.ids.is_empty() { + self.seed_prefix()?; + } + + self.ids.push(token_id); + let string = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?; + let prefix_len = self.prefix.len(); + if string.len() <= prefix_len || string.ends_with('\u{FFFD}') { + return Ok(0); + } + // Ensure we split at a utf-8 char boundary. + let new_chunk = &string[string.floor_char_boundary(prefix_len)..]; + self.cumulative_output.push_str(new_chunk); + self.ids.drain(..self.prefix_index); + self.prefix = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?; + self.prefix_index = self.ids.len(); + Ok(new_chunk.len()) + } + + fn next_chunk(&mut self) -> Option { + let cutoff = self.cumulative_output.len().saturating_sub(self.min_bytes_to_buffer); + (cutoff > self.output_index).then(|| { + let chunk = self.cumulative_output[self.output_index..cutoff].to_string(); + self.output_index = cutoff; + chunk + }) + } + + fn flush(&mut self, truncate_output_to: Option) -> Result<(Option, String)> { + if !self.ids.is_empty() { + let string = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?; + let prefix_len = self.prefix.len(); + self.ids.clear(); + self.prefix.clear(); + self.prefix_index = 0; + // Ensure we split at a utf-8 char boundary. + self.cumulative_output + .push_str(&string[string.floor_char_boundary(prefix_len)..]); + } + if let Some(truncate_output_to) = truncate_output_to { + self.cumulative_output.truncate(truncate_output_to); + } + let last_chunk = (self.output_index < self.cumulative_output.len()) + .then(|| self.cumulative_output[self.output_index..].to_string()); + self.output_index = 0; + Ok((last_chunk, take(&mut self.cumulative_output))) + } + + fn output(&self) -> &str { + &self.cumulative_output + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Backend that treats each token ID as a raw byte, producing lossy UTF-8. + #[derive(Debug)] + struct Utf8Backend; + + impl Tokenizer for Utf8Backend { + fn encode(&self, _text: &str, _add_special_tokens: bool) -> Result> { + unreachable!() + } + + fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result { + let bytes = token_ids.iter().map(|id| *id as u8).collect::>(); + Ok(String::from_utf8_lossy(&bytes).into_owned()) + } + + fn token_to_id(&self, _token: &str) -> Option { + unreachable!() + } + } + + #[test] + fn holds_incomplete_utf8_until_complete() { + let backend = Utf8Backend; + let mut decoder = backend.create_decode_stream(&[], false, 0); + + // 你 = U+4F60 = 0xE4 0xBD 0xA0 + assert_eq!(decoder.push_token(0xe4).unwrap(), 0); + assert_eq!(decoder.push_token(0xbd).unwrap(), 0); + assert_eq!(decoder.push_token(0xa0).unwrap(), 3); // "你" is 3 bytes + assert_eq!(decoder.output(), "你"); + } + + #[test] + fn emits_ascii_immediately() { + let backend = Utf8Backend; + let mut decoder = backend.create_decode_stream(&[], false, 0); + + assert_eq!(decoder.push_token(b'o' as u32).unwrap(), 1); + assert_eq!(decoder.push_token(b'k' as u32).unwrap(), 1); + assert_eq!(decoder.output(), "ok"); + } + + #[test] + fn flush_returns_none_when_fully_consumed() { + let backend = Utf8Backend; + let mut decoder = backend.create_decode_stream(&[], false, 0); + + assert_eq!(decoder.push_token(b'o' as u32).unwrap(), 1); + assert_eq!(decoder.next_chunk().as_deref(), Some("o")); + assert_eq!(decoder.push_token(b'k' as u32).unwrap(), 1); + assert_eq!(decoder.next_chunk().as_deref(), Some("k")); + // All text already consumed via next_chunk + let (last_chunk, full_text) = decoder.flush(None).unwrap(); + assert_eq!(last_chunk, None); + assert_eq!(full_text, "ok"); + } + + #[test] + fn flush_emits_buffered_incomplete_utf8() { + let backend = Utf8Backend; + let mut decoder = backend.create_decode_stream(&[], false, 0); + + // Push incomplete multi-byte sequence — step returns 0 bytes. + assert_eq!(decoder.push_token(0xe4).unwrap(), 0); + assert_eq!(decoder.push_token(0xbd).unwrap(), 0); + + // Flush forces out whatever the decoder can produce (lossy replacement). + let (last_chunk, _full_text) = decoder.flush(None).unwrap(); + assert!(last_chunk.is_some()); + } + + /// Backend where token 0 is a special token. + #[derive(Debug)] + struct SpecialTokenBackend; + + impl Tokenizer for SpecialTokenBackend { + fn encode(&self, _text: &str, _add_special_tokens: bool) -> Result> { + unreachable!() + } + + fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result { + let mut text = String::new(); + for &token_id in token_ids { + match token_id { + 0 if !skip_special_tokens => text.push_str(""), + 0 => {} + 1 => text.push('a'), + _ => {} + } + } + Ok(text) + } + + fn token_to_id(&self, _token: &str) -> Option { + unreachable!() + } + } + + #[test] + fn respects_skip_special_tokens() { + let backend = SpecialTokenBackend; + let mut skip_decoder = backend.create_decode_stream(&[], true, 0); + let mut keep_decoder = backend.create_decode_stream(&[], false, 0); + + assert_eq!(skip_decoder.push_token(0).unwrap(), 0); + assert_eq!(keep_decoder.push_token(0).unwrap(), 9); // "" is 9 bytes + assert_eq!(keep_decoder.output(), ""); + } + + #[test] + fn prompt_tokens_provide_context_without_re_emission() { + let backend = Utf8Backend; + let prompt = &[b'H' as u32, b'i' as u32]; + let mut decoder = backend.create_decode_stream(prompt, false, 0); + + // First generated token should not re-emit "Hi". + let added = decoder.push_token(b'!' as u32).unwrap(); + assert_eq!(added, 1); + assert_eq!(decoder.output(), "!"); + } + + #[test] + fn chunks_concatenate_to_full_text() { + let backend = Utf8Backend; + let mut decoder = backend.create_decode_stream(&[], false, 0); + + let input = b"Hello, world!"; + let mut full = String::new(); + for &byte in input { + decoder.push_token(byte as u32).unwrap(); + if let Some(chunk) = decoder.next_chunk() { + full.push_str(&chunk); + } + } + let (last_chunk, full_text) = decoder.flush(None).unwrap(); + assert_eq!(last_chunk, None); // all consumed via next_chunk + assert_eq!(full, "Hello, world!"); + assert_eq!(full_text, "Hello, world!"); + } + + /// Backend simulating non-monotonic decode where adding a token changes how + /// earlier tokens decode (context-dependent normalization), causing + /// prefix_len to land mid-UTF-8. Reproduces the class of bug from + /// vllm-project/vllm#17448. + #[derive(Debug)] + struct NonMonotonicBackend; + + impl Tokenizer for NonMonotonicBackend { + fn encode(&self, _text: &str, _add_special_tokens: bool) -> Result> { + unreachable!() + } + + fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result { + match token_ids { + [1] => Ok("abc".into()), + [1, 2] => Ok("ab".into()), + // Token 3 triggers a normalization change: "ab" becomes emoji + "d". + // prefix_len=3 ("abc") lands inside the 4-byte emoji 🎉. + [1, 2, 3] => Ok("🎉d".into()), // 🎉 is 4 bytes + d = 5 bytes + [2, 3] => Ok("🎉d".into()), // prefix recompute after drain + [3] => Ok("d".into()), // after drain + _ => panic!("unexpected decode: {:?}", token_ids), + } + } + + fn token_to_id(&self, _token: &str) -> Option { + unreachable!() + } + } + + /// Without the char-boundary fix, this panics slicing mid-emoji. + #[test] + fn non_monotonic_decode_does_not_panic() { + let backend = NonMonotonicBackend; + let mut decoder = backend.create_decode_stream(&[], false, 0); + + // Token 1: "abc", prefix="abc" + assert_eq!(decoder.push_token(1).unwrap(), 3); + // Token 2: "ab" (shorter), no emit + assert_eq!(decoder.push_token(2).unwrap(), 0); + // Token 3: "🎉d" — prefix_len=3 is mid-emoji. Without fix this panics. + let added = decoder.push_token(3).unwrap(); + assert!(added > 0); + } + + #[test] + fn next_chunk_with_hold_back() { + let backend = Utf8Backend; + // hold_back_bytes: 3 means we buffer the last 3 bytes + let mut decoder = backend.create_decode_stream(&[], false, 3); + + let input = b"Hello!"; + let mut chunks = String::new(); + for &byte in input { + decoder.push_token(byte as u32).unwrap(); + if let Some(chunk) = decoder.next_chunk() { + chunks.push_str(&chunk); + } + } + // With hold_back_bytes=3, last 3 bytes ("lo!") are held back + assert_eq!(chunks, "Hel"); + // Flush returns the rest + let (last_chunk, full_text) = decoder.flush(None).unwrap(); + assert_eq!(last_chunk.as_deref(), Some("lo!")); + assert_eq!(full_text, "Hello!"); + } +} diff --git a/rust/src/tokenizer/src/lib.rs b/rust/src/tokenizer/src/lib.rs new file mode 100644 index 00000000000..6a512a5a620 --- /dev/null +++ b/rust/src/tokenizer/src/lib.rs @@ -0,0 +1,62 @@ +use std::sync::Arc; + +use crate::incremental::DecodeStream; + +mod byte_level_decode; +#[macro_use] +mod error; +mod hf; +mod incremental; +mod tekken; +mod tiktoken; + +pub use error::{Result, TokenizerError}; +pub use hf::HuggingFaceTokenizer; +pub use incremental::IncrementalDecoder; +pub use tekken::TekkenTokenizer; +pub use tiktoken::TiktokenTokenizer; + +pub trait Tokenizer: Send + Sync { + /// Encode one prompt string into token IDs. + fn encode(&self, text: &str, add_special_tokens: bool) -> Result>; + + /// Decode one token sequence into text. + fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result; + + /// Convert one token string into a token ID, returning `None` if the token + /// is not in the tokenizer vocabulary. + fn token_to_id(&self, token: &str) -> Option; + + /// Convert one token ID into the tokenizer's raw token string. + fn id_to_token(&self, _id: u32) -> Option { + // TODO: remove default impl and require this to be implemented by all + // tokenizers + None + } + + /// Return whether the given token ID is special. + fn is_special_id(&self, _token_id: u32) -> bool { + false + } + + /// Create a stateful incremental decoder primed with the given prompt + /// tokens. + /// + /// The prompt tokens provide left context for the first generated token; + /// the decoder does not re-emit prompt text. + fn create_decode_stream( + &self, + prompt_token_ids: &[u32], + skip_special_tokens: bool, + min_bytes_to_buffer: usize, + ) -> Box { + Box::new(DecodeStream::new( + self, + prompt_token_ids, + skip_special_tokens, + min_bytes_to_buffer, + )) + } +} + +pub type DynTokenizer = Arc; diff --git a/rust/src/tokenizer/src/tekken.rs b/rust/src/tokenizer/src/tekken.rs new file mode 100644 index 00000000000..e8560c65a30 --- /dev/null +++ b/rust/src/tokenizer/src/tekken.rs @@ -0,0 +1,62 @@ +use std::path::Path; + +use tekken::Tekkenizer; +use tracing::info; + +use crate::{Result, Tokenizer}; + +/// Mistral Tekken tokenizer from a `tekken.json` file. +pub struct TekkenTokenizer { + inner: Tekkenizer, +} + +impl TekkenTokenizer { + /// Load a Mistral Tekken tokenizer from a `tekken.json` file. + pub fn new(path: &Path) -> Result { + info!(path = %path.display(), "loading tokenizer with Mistral Tekken"); + + let inner = Tekkenizer::from_file(path).map_err(|error| { + tokenizer_error!( + "failed to load tekken tokenizer from {}: {error}", + path.display() + ) + })?; + Ok(Self { inner }) + } +} + +impl Tokenizer for TekkenTokenizer { + fn encode(&self, text: &str, add_special_tokens: bool) -> Result> { + self.inner + .encode(text, add_special_tokens, false) + .map_err(|error| tokenizer_error!("encoding failed: {error}")) + } + + fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result { + let policy = if skip_special_tokens { + tekken::SpecialTokenPolicy::Ignore + } else { + tekken::SpecialTokenPolicy::Keep + }; + self.inner + .decode(token_ids, policy) + .map_err(|error| tokenizer_error!("decoding failed: {error}")) + } + + fn token_to_id(&self, token: &str) -> Option { + // tekken-rs exposes `get_control_token` for special tokens. Try that first, + // then fall back to encoding. + self.inner.get_control_token(token).ok().or_else(|| { + let ids = self.inner.encode(token, false, false).ok()?; + if ids.len() == 1 { Some(ids[0]) } else { None } + }) + } + + fn id_to_token(&self, id: u32) -> Option { + self.inner.id_to_piece(id).ok() + } + + fn is_special_id(&self, token_id: u32) -> bool { + self.inner.is_special_token(token_id) + } +} diff --git a/rust/src/tokenizer/src/tiktoken.rs b/rust/src/tokenizer/src/tiktoken.rs new file mode 100644 index 00000000000..0c57ff5f6b6 --- /dev/null +++ b/rust/src/tokenizer/src/tiktoken.rs @@ -0,0 +1,1013 @@ +use std::collections::HashSet; +use std::path::Path; +use std::sync::Mutex; + +use base64::Engine as _; +use rustc_hash::{FxHashMap, FxHashSet}; +use serde::Deserialize; +use thiserror_ext::AsReport as _; +use tracing::{info, warn}; + +use crate::{Result, Tokenizer}; + +/// Default regex pattern used when loading tiktoken from a BPE file. This is +/// the same `cl100k_base` pattern that HuggingFace transformers uses as its +/// default in `TikTokenConverter`. +const CL100K_BASE_PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; + +/// Kimi BPE pattern from `moonshotai/Kimi-K2-Instruct/tokenization_kimi.py`. +const KIMI_PATTERN: &str = r"[\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; + +/// Fallback number of reserved special-token slots to assume when the model's +/// `config.json` is not available (so we cannot read `vocab_size` directly). +/// +/// 256 is the value used by Kimi K2 / K2.5 (`tokenization_kimi.py`'s +/// `num_reserved_special_tokens`) and by Llama 3, and it appears to be the most +/// common convention among modern tiktoken-based HF tokenizers. When +/// `config.json` *is* present we honour the model's actual `vocab_size` instead +/// of this fallback — see `Self::new`. +const FALLBACK_NUM_RESERVED_SPECIAL_TOKENS: u32 = 256; +const DISABLE_RIPTOKEN_ENV: &str = "VLLM_RS_DISABLE_RIPTOKEN"; + +/// Parsed entry from `tokenizer_config.json`'s `added_tokens_decoder`. +#[derive(Debug, Clone, Deserialize)] +struct AddedToken { + content: String, + /// HuggingFace `added_tokens_decoder` entries can be marked `"special": + /// true|false`. Special tokens are dropped from output when `decode` is + /// called with `skip_special_tokens = true`. Defaults to `false` when + /// the field is omitted, matching HuggingFace's `AddedToken` default — + /// so only tokens explicitly marked special are stripped during normal + /// decode (where `skip_special_tokens` itself defaults to true). + #[serde(default)] + special: bool, +} + +/// Minimal subset of `tokenizer_config.json` needed by the tiktoken loader. +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct TiktokenTokenizerConfig { + /// Format: + /// `{ "added_tokens_decoder": { "163584": { "content": "[BOS]", "special": + /// true }, ... } }` + #[serde(default)] + added_tokens_decoder: FxHashMap, +} + +/// Minimal subset of model `config.json` needed by the tiktoken loader. +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct TiktokenModelConfig { + model_type: Option, + vocab_size: Option, + text_config: Option>, +} + +impl TiktokenModelConfig { + /// Read `model_type` from a model `config.json` value, falling back to a + /// single-level nested `text_config.model_type` for composite (e.g. + /// multimodal) configs that keep text metadata under a `text_config` + /// object. + fn effective_model_type(&self) -> Option<&str> { + self.model_type + .as_deref() + .or_else(|| self.text_config.as_deref()?.effective_model_type()) + } + + /// Read `vocab_size` from a model `config.json` value, falling back to a + /// single-level nested `text_config.vocab_size` for composite (e.g. + /// multimodal) configs that keep text metadata under a `text_config` + /// object — matching the same shape `ModelConfig` parses. + fn effective_vocab_size(&self) -> Option { + self.vocab_size.or_else(|| self.text_config.as_deref()?.effective_vocab_size()) + } +} + +/// Tiktoken tokenizer from `tiktoken.model` or `*.tiktoken` BPE files. +pub struct TiktokenTokenizer { + backend: Backend, + metadata: TokenMetadata, +} + +enum Backend { + Riptoken(RiptokenBackend), + TiktokenRs(TiktokenRsBackend), +} + +struct RiptokenBackend { + inner: Box, + allowed_special_tokens: Vec, +} + +struct TiktokenRsBackend { + inner: Box, + /// Reverse map for special / added token strings populated from the + /// reserved range. This lets `token_to_id` answer special-token lookups + /// directly without round-tripping through `tiktoken-rs`'s encoder, + /// which can panic for unknown special-looking strings. + special_token_ids_by_text: FxHashMap, + /// Set of out-of-vocab token IDs we have already warned about. The + /// reserved-slot population in the constructor should keep this empty + /// under normal operation; it only fills up if a model emits ids at or + /// above `vocab_upper_bound` (e.g. an engine sampling bug). We dedupe + /// so streaming decode (which calls `decode` repeatedly on the same prefix) + /// does not spam. + warned_unknown_ids: Mutex>, +} + +struct TokenMetadata { + /// Number of regular BPE tokens. Token ids in `[0, num_base_tokens)` are + /// BPE tokens that always decode to text; ids in `[num_base_tokens, + /// vocab_upper_bound)` live in the special-token slots and are subject + /// to `skip_special_tokens` filtering. + num_base_tokens: u32, + /// Exclusive upper bound on token IDs that `inner` is guaranteed to know + /// how to decode. + /// + /// The constructor registers every id in `[num_base_tokens, + /// vocab_upper_bound)` with the inner `CoreBPE` as a (named or + /// `<|reserved_token_{id}|>`) special token, and the BPE + /// encoder densely covers `[0, num_base_tokens)`. So any id below this + /// bound is in one of the inner `CoreBPE`'s decoder maps and + /// `_decode_native_and_split` will not panic on it. `decode` filters + /// out ids at or above this bound to keep that guarantee. + vocab_upper_bound: u32, + /// Ids in `[num_base_tokens, vocab_upper_bound)` whose + /// `added_tokens_decoder` entry was explicitly marked `"special": + /// false` — i.e. tokens that should still appear in output + /// even when `skip_special_tokens = true`. For Kimi K2 / K2.5 this + /// typically holds the tool-call markers and `` / ``. + /// Reserved-slot placeholders are not in this set (they default to + /// special and get skipped). + non_special_added_ids: FxHashSet, + /// Raw token string by token id. Base BPE tokens are represented with + /// lossy UTF-8, matching decode behavior for byte sequences that are not + /// valid UTF-8 on their own. + token_by_id: FxHashMap, +} + +impl TokenMetadata { + fn filter_special_tokens(&self, token_ids: &[u32]) -> Vec { + token_ids + .iter() + .copied() + .filter(|&id| { + id < self.num_base_tokens + || id >= self.vocab_upper_bound + || self.non_special_added_ids.contains(&id) + }) + .collect() + } + + fn is_special_id(&self, token_id: u32) -> bool { + token_id >= self.num_base_tokens + && token_id < self.vocab_upper_bound + && !self.non_special_added_ids.contains(&token_id) + } + + fn id_to_token(&self, token_id: u32) -> Option { + self.token_by_id.get(&token_id).cloned() + } +} + +impl RiptokenBackend { + fn encode(&self, text: &str) -> Vec { + // TODO: avoid collecting `allowed_special` every time this method is called. + let allowed_special: HashSet<&str> = + self.allowed_special_tokens.iter().map(String::as_str).collect(); + self.inner.encode(text, &allowed_special) + } + + fn decode(&self, token_ids: &[u32]) -> String { + let bytes = self.inner.decode_bytes(token_ids); + // TODO: use `from_utf8_lossy_owned` once it's stabilized. + String::from_utf8_lossy(&bytes).into_owned() + } + + fn token_to_id(&self, token: &str) -> Option { + self.inner.encode_single_token(token.as_bytes()) + } +} + +impl TiktokenRsBackend { + fn encode(&self, text: &str) -> Vec { + self.inner.encode_with_special_tokens(text) + } + + fn decode(&self, token_ids: &[u32], metadata: &TokenMetadata) -> String { + let safe_ids: Vec = token_ids + .iter() + .copied() + .filter(|&id| { + if id >= metadata.vocab_upper_bound { + self.warn_unknown_id(id); + return false; + } + true + }) + .collect(); + let bytes: Vec = self.inner._decode_native_and_split(safe_ids).flatten().collect(); + // TODO: use `from_utf8_lossy_owned` once it's stabilized. + String::from_utf8_lossy(&bytes).into_owned() + } + + fn token_to_id(&self, token: &str) -> Option { + if let Some(&token_id) = self.special_token_ids_by_text.get(token) { + return Some(token_id); + } + + // Fall back to ordinary encoding for regular vocabulary items. This + // deliberately avoids `encode_with_special_tokens`: older `tiktoken-rs` + // versions can panic if the input text merely *looks* like a special + // token but is not registered in `special_tokens_encoder`. + let ids = self.inner.encode_ordinary(token); + if ids.len() == 1 { Some(ids[0]) } else { None } + } + + /// Log a warning the first time an unknown token id is seen during decode, + /// deduped across calls so streaming decode does not spam the log for + /// the same id. + fn warn_unknown_id(&self, token_id: u32) { + let newly_inserted = self + .warned_unknown_ids + .lock() + .map(|mut set| set.insert(token_id)) + .unwrap_or(false); + if newly_inserted { + warn!( + token_id, + "tiktoken-rs decode encountered token id not in the vocabulary; skipping. \ + This typically indicates a sparse-vocab model whose `added_tokens_decoder` \ + does not list every reserved id in the special-token range." + ); + } + } +} + +impl TiktokenTokenizer { + /// Load a tiktoken tokenizer from a `.tiktoken` / `tiktoken.model` BPE + /// file. + /// + /// The BPE file format is one ` ` pair per line, + /// the same format used by OpenAI's tiktoken and by HuggingFace model + /// repos that ship tiktoken files (e.g. DeepSeek, Kimi K2). + /// + /// Special / added tokens are read from `tokenizer_config.json` in the same + /// directory when present. The `cl100k_base` regex pattern is used as a + /// reasonable default. + pub fn new(path: &Path) -> Result { + if std::env::var_os(DISABLE_RIPTOKEN_ENV).is_some() { + return Self::new_tiktoken_rs(path); + } + + match Self::new_riptoken(path) { + Ok(tokenizer) => Ok(tokenizer), + Err(error) => { + warn!( + path = %path.display(), + error = %error.as_report(), + "failed to load tokenizer with riptoken; falling back to tiktoken-rs" + ); + Self::new_tiktoken_rs(path) + } + } + } + + /// Load from `tiktoken.model` / `*.tiktoken` with riptoken. + pub fn new_riptoken(path: &Path) -> Result { + info!(path = %path.display(), "loading tokenizer with riptoken (BPE file)"); + + let config = LoadedTiktokenConfig::load(path)?; + let allowed_special_tokens = config.special_tokens_encoder.keys().cloned().collect(); + let inner = riptoken::CoreBPE::new( + config.encoder.into_iter().collect(), + config.special_tokens_encoder.into_iter().collect(), + config.pattern, + ) + .map_err(|error| { + tokenizer_error!( + "failed to create riptoken tokenizer from {}: {error}", + path.display() + ) + })?; + + Ok(Self { + backend: Backend::Riptoken(RiptokenBackend { + inner: Box::new(inner), + allowed_special_tokens, + }), + metadata: config.metadata, + }) + } + + /// Load from `tiktoken.model` / `*.tiktoken` with tiktoken-rs. + pub fn new_tiktoken_rs(path: &Path) -> Result { + info!(path = %path.display(), "loading tokenizer with tiktoken-rs (BPE file)"); + + let config = LoadedTiktokenConfig::load(path)?; + let special_token_ids_by_text = config.special_tokens_encoder.clone(); + let inner = tiktoken_rs::CoreBPE::new( + config.encoder, + config.special_tokens_encoder, + config.pattern, + ) + .map_err(|error| { + tokenizer_error!( + "failed to create tiktoken-rs tokenizer from {}: {error}", + path.display() + ) + })?; + + Ok(Self { + backend: Backend::TiktokenRs(TiktokenRsBackend { + inner: Box::new(inner), + special_token_ids_by_text, + warned_unknown_ids: Mutex::new(FxHashSet::default()), + }), + metadata: config.metadata, + }) + } +} + +struct LoadedTiktokenConfig { + encoder: FxHashMap, u32>, + special_tokens_encoder: FxHashMap, + metadata: TokenMetadata, + pattern: &'static str, +} + +impl LoadedTiktokenConfig { + fn load(path: &Path) -> Result { + let content = std::fs::read_to_string(path).map_err(|error| { + tokenizer_error!( + "failed to read tiktoken file {}: {}", + path.display(), + error.as_report() + ) + })?; + let mut encoder: FxHashMap, u32> = + FxHashMap::with_capacity_and_hasher(content.lines().count(), Default::default()); + for line in content.lines() { + if line.is_empty() { + continue; + } + let mut parts = line.split_whitespace(); + let token_b64 = + parts.next().ok_or_else(|| tokenizer_error!("missing token in tiktoken file"))?; + let rank_str = + parts.next().ok_or_else(|| tokenizer_error!("missing rank in tiktoken file"))?; + let token_bytes = base64::engine::general_purpose::STANDARD + .decode(token_b64) + .map_err(|error| tokenizer_error!("invalid base64 in tiktoken file: {error}"))?; + let rank: u32 = rank_str + .parse() + .map_err(|error| tokenizer_error!("invalid rank in tiktoken file: {error}"))?; + encoder.insert(token_bytes, rank); + } + + let parent_dir = path.parent(); + + // Read added/special tokens (id -> {name, special}) from + // tokenizer_config.json in the same dir. + let added_tokens_by_id = parent_dir + .map(|dir| dir.join("tokenizer_config.json")) + .filter(|p| p.exists()) + .and_then(|config_path| { + let content = std::fs::read_to_string(&config_path).ok()?; + serde_json::from_str(&content).ok() + }) + .map(|config: TiktokenTokenizerConfig| config.added_tokens_decoder) + .unwrap_or_default(); + + let model_config: Option = parent_dir + .map(|dir| dir.join("config.json")) + .filter(|p| p.exists()) + .and_then(|config_path| { + let content = std::fs::read_to_string(&config_path).ok()?; + serde_json::from_str(&content).ok() + }); + let vocab_size_from_config = model_config.as_ref().and_then(|c| c.effective_vocab_size()); + + // Build the full special-tokens encoder by populating the reserved + // range that follows the BPE vocabulary. Unknown reserved slots get + // Python-compatible placeholder names so sampled ids can still decode. + // + // Note: `*.tiktoken` ranks are token ids, and they are not guaranteed + // to be contiguous. The base-vocab boundary is therefore `max_rank + 1`, + // not `encoder.len()`. + let num_base_tokens = + encoder.values().copied().max().map_or(0, |max_rank| max_rank.saturating_add(1)); + let max_added_id = added_tokens_by_id.keys().copied().max().unwrap_or(0); + let reserved_end = vocab_size_from_config + .unwrap_or_else(|| num_base_tokens.saturating_add(FALLBACK_NUM_RESERVED_SPECIAL_TOKENS)) + .max(num_base_tokens) + .max(max_added_id.saturating_add(1)); + + let mut special_tokens_encoder: FxHashMap = + FxHashMap::with_capacity_and_hasher( + (reserved_end - num_base_tokens) as usize, + Default::default(), + ); + let mut non_special_added_ids: FxHashSet = FxHashSet::default(); + for id in num_base_tokens..reserved_end { + let name = match added_tokens_by_id.get(&id) { + Some(token) => { + if !token.special { + non_special_added_ids.insert(id); + } + token.content.clone() + } + None => format!("<|reserved_token_{id}|>"), + }; + special_tokens_encoder.insert(name, id); + } + + let mut token_by_id: FxHashMap = FxHashMap::with_capacity_and_hasher( + encoder.len() + special_tokens_encoder.len(), + Default::default(), + ); + for (token_bytes, &id) in &encoder { + token_by_id.insert(id, String::from_utf8_lossy(token_bytes).into_owned()); + } + for (token, &id) in &special_tokens_encoder { + token_by_id.insert(id, token.clone()); + } + + let pattern = model_config.as_ref().map_or(CL100K_BASE_PATTERN, detect_bpe_pattern); + + Ok(Self { + encoder, + special_tokens_encoder, + metadata: TokenMetadata { + num_base_tokens, + vocab_upper_bound: reserved_end, + non_special_added_ids, + token_by_id, + }, + pattern, + }) + } +} + +impl Tokenizer for TiktokenTokenizer { + fn encode(&self, text: &str, _add_special_tokens: bool) -> Result> { + // Tiktoken does not have a separate add_special_tokens toggle; both + // backends recognize registered special tokens in the input. + Ok(match &self.backend { + Backend::Riptoken(backend) => backend.encode(text), + Backend::TiktokenRs(backend) => backend.encode(text), + }) + } + + fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result { + // Filter passes: + // + // 1. The constructor registers every id in `[num_base_tokens, vocab_upper_bound)` as a + // special token (named or `<|reserved_token_{id}|>` placeholder, matching + // `tokenization_kimi.py`). The tiktoken-rs backend additionally drops ids at or above + // that bound so `_decode_native_and_split` cannot panic; riptoken's `decode_bytes` + // already skips unknown ids. + // + // 2. When `skip_special_tokens = true`, ids in `[num_base_tokens, vocab_upper_bound)` are + // dropped *unless* they were marked `"special": false` in `added_tokens_decoder`. This + // matches HuggingFace's tokenizer semantics: tool-call markers and `` / + // `` (which Kimi K2 / K2.5 declare as non-special) stay in the output, while + // BOS/EOS/header tokens and reserved-slot placeholders are stripped. + // + // Lossy UTF-8 decoding (instead of strict `String::from_utf8`) is used so + // partial multi-byte sequences become `\u{FFFD}`, which `DecodeStream` + // relies on to detect incomplete characters during streaming. + let ids = if skip_special_tokens { + &self.metadata.filter_special_tokens(token_ids) + } else { + token_ids + }; + + Ok(match &self.backend { + Backend::Riptoken(backend) => backend.decode(ids), + Backend::TiktokenRs(backend) => backend.decode(ids, &self.metadata), + }) + } + + fn token_to_id(&self, token: &str) -> Option { + match &self.backend { + Backend::Riptoken(backend) => backend.token_to_id(token), + Backend::TiktokenRs(backend) => backend.token_to_id(token), + } + } + + fn id_to_token(&self, id: u32) -> Option { + self.metadata.id_to_token(id) + } + + fn is_special_id(&self, token_id: u32) -> bool { + self.metadata.is_special_id(token_id) + } +} + +/// Select the BPE regex pattern for a tiktoken model based on `config.json`. +/// +/// Most tiktoken models use the `cl100k_base` regex. Kimi models ship a custom +/// regex in their Python tokenizer implementation; we mirror the explicit +/// `model_type` switch used by Dynamo instead of heuristically parsing Python +/// source files. +fn detect_bpe_pattern(config: &TiktokenModelConfig) -> &'static str { + let model_type = config.effective_model_type(); + + match model_type { + Some("kimi" | "kimi_k2" | "kimi_k25" | "deepseek_v3") => KIMI_PATTERN, + _ => CL100K_BASE_PATTERN, + } +} + +#[cfg(test)] +mod tests { + use std::fs; + use std::path::{Path, PathBuf}; + + use base64::Engine as _; + use tempfile::TempDir; + + use super::{ + CL100K_BASE_PATTERN, KIMI_PATTERN, TiktokenModelConfig, TiktokenTokenizer, + TiktokenTokenizerConfig, detect_bpe_pattern, + }; + use crate::Tokenizer; + + macro_rules! config_json { + ($($json:tt)+) => { + serde_json::from_value::(serde_json::json!($($json)+)).unwrap() + }; + } + + /// Write a minimal `*.tiktoken` BPE file (one token per byte 0..=255) into + /// `dir` and return its path. The single-byte vocab is enough to + /// exercise the multi-byte / streaming UTF-8 paths without depending on + /// any pretrained tokenizer asset. + fn write_synthetic_bpe_file(dir: &std::path::Path) -> PathBuf { + let mut content = String::new(); + for byte in 0u8..=255 { + let b64 = base64::engine::general_purpose::STANDARD.encode([byte]); + content.push_str(&format!("{b64} {}\n", byte as u32)); + } + let path = dir.join("test.tiktoken"); + fs::write(&path, content).expect("write tiktoken file"); + path + } + + /// Write a synthetic `*.tiktoken` file whose base-vocab ranks are + /// sparse/non-contiguous. + /// + /// This reproduces the important edge case for `num_base_tokens`: it must + /// be derived from `max_rank + 1`, not `encoder.len()`, otherwise + /// high-rank base tokens get misclassified as reserved/special ids. + fn write_sparse_rank_bpe_file(dir: &std::path::Path) -> PathBuf { + let mut content = String::new(); + for byte in 0u8..=255 { + let b64 = base64::engine::general_purpose::STANDARD.encode([byte]); + content.push_str(&format!("{b64} {}\n", byte as u32)); + } + + let high_rank_token = base64::engine::general_purpose::STANDARD.encode(b"SPARSE"); + content.push_str(&format!("{high_rank_token} 1000\n")); + + let path = dir.join("sparse-rank.tiktoken"); + fs::write(&path, content).expect("write sparse-rank tiktoken file"); + path + } + + /// Build a `TiktokenTokenizer` from the synthetic BPE file with no sibling + /// config files, so the constructor takes the + /// `FALLBACK_NUM_RESERVED_SPECIAL_TOKENS` (256) path. + fn explicit_backends(path: &Path) -> Vec { + vec![ + TiktokenTokenizer::new_riptoken(path).expect("load riptoken backend"), + TiktokenTokenizer::new_tiktoken_rs(path).expect("load tiktoken-rs backend"), + ] + } + + fn tiktoken_backends() -> (Vec, TempDir) { + let dir = tempfile::tempdir().expect("create temp dir"); + let path = write_synthetic_bpe_file(dir.path()); + (explicit_backends(&path), dir) + } + + /// Verify that tiktoken decode uses lossy UTF-8 (producing `\u{FFFD}`) + /// rather than returning an error for incomplete multi-byte sequences. + /// This is critical for streaming decode — `DecodeStream` relies on + /// `\u{FFFD}` to detect incomplete characters. + #[test] + fn tiktoken_decode_incomplete_utf8_produces_replacement_char() { + let (backends, _dir) = tiktoken_backends(); + + for backend in backends { + let ids = backend.encode("你", false).unwrap(); + let full = backend.decode(&ids, false).unwrap(); + assert_eq!(full, "你"); + + let text_with_multibyte = "Hello你好World"; + let all_ids = backend.encode(text_with_multibyte, false).unwrap(); + for &id in &all_ids { + let result = backend.decode(&[id], false); + assert!(result.is_ok(), "decode of token {id} should not error"); + } + } + } + + /// When `config.json` exposes a `vocab_size`, the reserved-token range must + /// be sized to it rather than to the 256-slot fallback. This is the + /// general (non-Kimi-specific) path: any tiktoken model whose own + /// `config.json` says e.g. `vocab_size = 280` should populate + /// reserved slots for `[num_base_tokens, 280)` and nothing beyond. + #[test] + fn tiktoken_reserved_range_uses_vocab_size_from_config_json() { + let dir = tempfile::tempdir().expect("create temp dir"); + let bpe_path = write_synthetic_bpe_file(dir.path()); + // num_base_tokens = 256, vocab_size = 280 → reserved range = [256, 280) (24 + // slots, smaller than the 256 fallback so we can prove the config value + // is honoured). + fs::write(dir.path().join("config.json"), r#"{"vocab_size": 280}"#) + .expect("write config.json"); + + for backend in explicit_backends(&bpe_path) { + // Inside the configured range: reserved placeholder, round-trips both ways. + let in_range_id: u32 = 270; + let placeholder = format!("<|reserved_token_{in_range_id}|>"); + assert_eq!(backend.decode(&[in_range_id], false).unwrap(), placeholder); + assert_eq!( + backend.encode(&placeholder, false).unwrap(), + vec![in_range_id] + ); + assert_eq!( + backend.id_to_token(in_range_id).as_deref(), + Some(placeholder.as_str()) + ); + + // Outside the configured range: not registered as a reserved slot — falls + // through to the backend's unknown-id behavior. The point is that we *don't* + // over-populate beyond what the model actually exposes. + let out_of_range_id: u32 = 290; + let out_of_range_placeholder = format!("<|reserved_token_{out_of_range_id}|>"); + assert_eq!(backend.decode(&[out_of_range_id], false).unwrap(), ""); + assert_eq!(backend.token_to_id(&out_of_range_placeholder), None); + assert_eq!(backend.id_to_token(out_of_range_id), None); + } + } + + /// Sparse/non-contiguous BPE ranks must still count as base-vocab ids. + /// + /// Regression shape: + /// - base vocabulary contains ids 0..=255 and also a normal BPE token at id 1000 + /// - if `num_base_tokens` were computed as `encoder.len()` (257), id 1000 would be + /// misclassified as special/reserved and disappear under `skip_special_tokens = true` + #[test] + fn tiktoken_sparse_base_ranks_are_not_misclassified_as_special() { + let dir = tempfile::tempdir().expect("create temp dir"); + let bpe_path = write_sparse_rank_bpe_file(dir.path()); + fs::write(dir.path().join("config.json"), r#"{"vocab_size": 1002}"#) + .expect("write config.json"); + + for backend in explicit_backends(&bpe_path) { + let sparse_id = backend.token_to_id("SPARSE"); + assert_eq!(sparse_id, Some(1000)); + assert_eq!(backend.id_to_token(1000).as_deref(), Some("SPARSE")); + assert!(!backend.is_special_id(1000)); + assert_eq!(backend.decode(&[1000], false).unwrap(), "SPARSE"); + assert_eq!(backend.decode(&[1000], true).unwrap(), "SPARSE"); + } + } + + /// `skip_special_tokens` must: + /// * keep regular BPE token text unchanged, + /// * drop ids whose `added_tokens_decoder` entry says `"special": true`, + /// * drop reserved-slot placeholder ids (which default to special), + /// * keep ids whose `added_tokens_decoder` entry says `"special": false` — this is how Kimi K2 + /// / K2.5 marks tool-call markers and `` / ``. + /// + /// Synthetic backend has `num_base_tokens = 256`. We write a + /// `tokenizer_config.json` that names ids 257 (special) and 258 + /// (non-special), and a `config.json` with `vocab_size` covering both. + /// Id 259 stays a default reserved placeholder (special). + #[test] + fn tiktoken_skip_special_tokens_filters_special_but_keeps_non_special_added_tokens() { + let dir = tempfile::tempdir().expect("create temp dir"); + let bpe_path = write_synthetic_bpe_file(dir.path()); + fs::write( + dir.path().join("tokenizer_config.json"), + r#"{ + "added_tokens_decoder": { + "257": { "content": "<|im_end|>", "special": true }, + "258": { "content": "<|tool_call_begin|>", "special": false } + } + }"#, + ) + .expect("write tokenizer_config.json"); + fs::write(dir.path().join("config.json"), r#"{"vocab_size": 260}"#) + .expect("write config.json"); + + for backend in explicit_backends(&bpe_path) { + // Resolve the BPE ids for "Hi" so we can interleave them with special-token + // ids. + let h = backend.encode("H", false).unwrap()[0]; + let i = backend.encode("i", false).unwrap()[0]; + + let special_id: u32 = 257; // <|im_end|> + let non_special_id: u32 = 258; // <|tool_call_begin|> + let reserved_id: u32 = 259; // default <|reserved_token_259|> placeholder + + let ids = vec![h, special_id, i, non_special_id, reserved_id]; + + // skip_special_tokens = false: everything is rendered as-is. + let kept = backend.decode(&ids, false).unwrap(); + assert_eq!( + kept, + "H<|im_end|>i<|tool_call_begin|><|reserved_token_259|>" + ); + + // skip_special_tokens = true: special token (257) and reserved placeholder + // (259) are dropped; the non-special added token (258) survives. + let stripped = backend.decode(&ids, true).unwrap(); + assert_eq!(stripped, "Hi<|tool_call_begin|>"); + } + } + + /// `vocab_size` may live under `text_config` for composite (e.g. + /// multimodal) configs. + #[test] + fn tiktoken_reserved_range_reads_text_config_vocab_size() { + let dir = tempfile::tempdir().expect("create temp dir"); + let bpe_path = write_synthetic_bpe_file(dir.path()); + fs::write( + dir.path().join("config.json"), + r#"{"text_config": {"vocab_size": 270}}"#, + ) + .expect("write config.json"); + + for backend in explicit_backends(&bpe_path) { + let in_range_id: u32 = 260; + let placeholder = format!("<|reserved_token_{in_range_id}|>"); + assert_eq!(backend.decode(&[in_range_id], false).unwrap(), placeholder); + + // Just outside the nested vocab_size — should not be registered. + assert_eq!(backend.decode(&[270], false).unwrap(), ""); + } + } + + #[test] + fn tiktoken_detects_kimi_pattern_from_model_type() { + let kimi = config_json!({ "model_type": "kimi_k25" }); + let baseten_kimi = config_json!({ "model_type": "deepseek_v3" }); + let nested_kimi = config_json!({ + "model_type": "composite_wrapper", + "text_config": { "model_type": "kimi_k2" } + }); + let generic = config_json!({ "model_type": "gpt2" }); + let nested_generic = config_json!({ + "model_type": "composite_wrapper", + "text_config": { "model_type": "gpt2" } + }); + let missing = config_json!({ "text_config": {} }); + + assert_eq!(detect_bpe_pattern(&kimi), KIMI_PATTERN); + assert_eq!(detect_bpe_pattern(&baseten_kimi), KIMI_PATTERN); + assert_eq!(detect_bpe_pattern(&nested_kimi), CL100K_BASE_PATTERN); + assert_eq!(detect_bpe_pattern(&generic), CL100K_BASE_PATTERN); + assert_eq!(detect_bpe_pattern(&nested_generic), CL100K_BASE_PATTERN); + assert_eq!(detect_bpe_pattern(&missing), CL100K_BASE_PATTERN); + } + + #[test] + fn tiktoken_reads_model_type_from_text_config_when_top_level_missing() { + let nested_only = config_json!({ + "text_config": { "model_type": "kimi_k2" } + }); + let direct_and_nested = config_json!({ + "model_type": "kimi_k25", + "text_config": { "model_type": "kimi_k2" } + }); + let missing = config_json!({ + "text_config": {} + }); + + assert_eq!(nested_only.effective_model_type(), Some("kimi_k2")); + assert_eq!(direct_and_nested.effective_model_type(), Some("kimi_k25")); + assert_eq!(missing.effective_model_type(), None); + } + + #[test] + fn tiktoken_tokenizer_config_models_added_tokens_decoder() { + let config: TiktokenTokenizerConfig = serde_json::from_value(serde_json::json!({ + "added_tokens_decoder": { + "257": { "content": "" }, + "258": { "content": "", "special": true } + } + })) + .unwrap(); + + let added_tokens = config.added_tokens_decoder; + assert_eq!(added_tokens.len(), 2); + assert_eq!( + added_tokens.get(&257).map(|t| t.content.as_str()), + Some("") + ); + assert_eq!(added_tokens.get(&257).map(|t| t.special), Some(false)); + assert_eq!( + added_tokens.get(&258).map(|t| (t.content.as_str(), t.special)), + Some(("", true)) + ); + } + + /// Reserved token ids in `[num_base_tokens, num_base_tokens + 256)` must + /// decode to their placeholder name (matching `tokenization_kimi.py`'s + /// `<|reserved_token_{i}|>` format), even when the source + /// `tokenizer_config.json` does not list them in `added_tokens_decoder`. + /// + /// In our synthetic backend `num_base_tokens = 256` (256 single-byte BPE + /// tokens), so the reserved range is `[256, 512)`. Picking id 300 — + /// well inside that range and absent from any `added_tokens_decoder` — + /// should round-trip both ways. + #[test] + fn tiktoken_reserved_token_round_trip() { + let (backends, _dir) = tiktoken_backends(); + + for backend in backends { + let reserved_id: u32 = 300; + let placeholder = format!("<|reserved_token_{reserved_id}|>"); + + let decoded = backend.decode(&[reserved_id], false).unwrap(); + assert_eq!(decoded, placeholder); + + // The placeholder name should also encode back to the same single id, since + // the constructor registers it as a special token with `CoreBPE`. + let encoded = backend.encode(&placeholder, false).unwrap(); + assert_eq!(encoded, vec![reserved_id]); + + assert_eq!(backend.token_to_id(&placeholder), Some(reserved_id)); + assert_eq!( + backend.id_to_token(reserved_id).as_deref(), + Some(placeholder.as_str()) + ); + } + } + + /// Decoding a token id that is beyond even the reserved range must not + /// panic — it falls through to the warn-and-skip backstop instead of + /// crashing the worker thread. + #[test] + fn tiktoken_rs_decode_unknown_token_id_does_not_panic() { + let dir = tempfile::tempdir().expect("create temp dir"); + let path = write_synthetic_bpe_file(dir.path()); + let backend = TiktokenTokenizer::new_tiktoken_rs(&path).expect("load tiktoken-rs backend"); + + // ID well above num_base_tokens (256) + reserved (256) = 512 — guaranteed + // unknown. + let unknown_id: u32 = 999_999; + let result = backend.decode(&[unknown_id], false); + assert_eq!(result.unwrap(), ""); + + // Mixed: known bytes for "Hi" surrounding an unknown id should yield just "Hi". + let h = backend.encode("H", false).unwrap()[0]; + let i = backend.encode("i", false).unwrap()[0]; + let result = backend.decode(&[h, unknown_id, i], false).unwrap(); + assert_eq!(result, "Hi"); + } + + #[test] + fn riptoken_decode_unknown_token_id_does_not_panic() { + let dir = tempfile::tempdir().expect("create temp dir"); + let path = write_synthetic_bpe_file(dir.path()); + let backend = TiktokenTokenizer::new_riptoken(&path).expect("load riptoken backend"); + + let unknown_id: u32 = 999_999; + assert_eq!(backend.decode(&[unknown_id], false).unwrap(), ""); + + let h = backend.encode("H", false).unwrap()[0]; + let i = backend.encode("i", false).unwrap()[0]; + assert_eq!(backend.decode(&[h, unknown_id, i], false).unwrap(), "Hi"); + } + + /// Streaming decode of CJK text through tiktoken should produce the + /// original text without errors, even though individual tokens may + /// represent partial UTF-8 byte sequences. + #[test] + fn tiktoken_streaming_decode_multibyte() { + let (backends, _dir) = tiktoken_backends(); + for backend in backends { + let text = "你好世界"; // 4 CJK characters + let ids = backend.encode(text, false).unwrap(); + + let mut decoder = backend.create_decode_stream(&[], false, 0); + let mut output = String::new(); + for &id in &ids { + decoder.push_token(id).unwrap(); + if let Some(chunk) = decoder.next_chunk() { + output.push_str(&chunk); + } + } + let (last_chunk, full_text) = decoder.flush(None).unwrap(); + if let Some(chunk) = last_chunk { + output.push_str(&chunk); + } + + assert_eq!(output, text); + assert_eq!(full_text, text); + } + } + + /// Mixed ASCII and multi-byte text should stream correctly through + /// tiktoken. + #[test] + fn tiktoken_streaming_decode_mixed_ascii_and_multibyte() { + let (backends, _dir) = tiktoken_backends(); + for backend in backends { + let text = "Hello 你好 World 🌍"; + let ids = backend.encode(text, false).unwrap(); + + let mut decoder = backend.create_decode_stream(&[], false, 0); + let mut output = String::new(); + for &id in &ids { + decoder.push_token(id).unwrap(); + if let Some(chunk) = decoder.next_chunk() { + output.push_str(&chunk); + } + } + let (last_chunk, full_text) = decoder.flush(None).unwrap(); + if let Some(chunk) = last_chunk { + output.push_str(&chunk); + } + + assert_eq!(output, text); + assert_eq!(full_text, text); + } + } + + #[test] + fn tiktoken_token_to_id_resolves_added_special_tokens() { + let dir = tempfile::tempdir().expect("create temp dir"); + let bpe_path = write_synthetic_bpe_file(dir.path()); + fs::write( + dir.path().join("tokenizer_config.json"), + r#"{ + "added_tokens_decoder": { + "257": { "content": "", "special": false }, + "258": { "content": "", "special": false } + } + }"#, + ) + .expect("write tokenizer_config.json"); + fs::write(dir.path().join("config.json"), r#"{"vocab_size": 259}"#) + .expect("write config.json"); + + for backend in explicit_backends(&bpe_path) { + assert_eq!(backend.token_to_id(""), Some(257)); + assert_eq!(backend.token_to_id(""), Some(258)); + assert_eq!(backend.id_to_token(257).as_deref(), Some("")); + assert_eq!(backend.id_to_token(258).as_deref(), Some("")); + assert_eq!( + backend.decode(&[257, 258], true).unwrap(), + "" + ); + } + } + + #[test] + fn riptoken_token_to_id_uses_encode_single_token_path() { + let dir = tempfile::tempdir().expect("create temp dir"); + let bpe_path = write_synthetic_bpe_file(dir.path()); + fs::write( + dir.path().join("tokenizer_config.json"), + r#"{ + "added_tokens_decoder": { + "257": { "content": "", "special": false } + } + }"#, + ) + .expect("write tokenizer_config.json"); + fs::write(dir.path().join("config.json"), r#"{"vocab_size": 258}"#) + .expect("write config.json"); + let backend = TiktokenTokenizer::new_riptoken(&bpe_path).expect("load riptoken backend"); + + assert_eq!(backend.token_to_id("H"), Some(b'H' as u32)); + assert_eq!(backend.id_to_token(b'H' as u32).as_deref(), Some("H")); + assert_eq!(backend.token_to_id(""), Some(257)); + assert_eq!(backend.id_to_token(257).as_deref(), Some("")); + } + + #[test] + fn tiktoken_rs_token_to_id_handles_unknown_special_like_text_without_panicking() { + let dir = tempfile::tempdir().expect("create temp dir"); + let path = write_synthetic_bpe_file(dir.path()); + let backend = TiktokenTokenizer::new_tiktoken_rs(&path).expect("load tiktoken-rs backend"); + + assert_eq!(backend.token_to_id("<|definitely_not_registered|>"), None); + } + + #[test] + fn riptoken_token_to_id_handles_unknown_special_like_text_without_panicking() { + let dir = tempfile::tempdir().expect("create temp dir"); + let path = write_synthetic_bpe_file(dir.path()); + let backend = TiktokenTokenizer::new_riptoken(&path).expect("load riptoken backend"); + + assert_eq!(backend.token_to_id("<|definitely_not_registered|>"), None); + } +} diff --git a/rust/src/tool-parser/Cargo.toml b/rust/src/tool-parser/Cargo.toml new file mode 100644 index 00000000000..e02c397cabf --- /dev/null +++ b/rust/src/tool-parser/Cargo.toml @@ -0,0 +1,75 @@ +[package] +name = "vllm-tool-parser" +version.workspace = true +edition.workspace = true +license.workspace = true + +[features] +test-util = [] + +[dependencies] +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +thiserror-ext.workspace = true +winnow.workspace = true + +[dev-dependencies] +criterion.workspace = true +expect-test.workspace = true +futures.workspace = true +openai-protocol.workspace = true +tool-parser.workspace = true + +[[bench]] +name = "deepseek_v3" +harness = false +required-features = ["test-util"] + +[[bench]] +name = "deepseek_v31" +harness = false +required-features = ["test-util"] + +[[bench]] +name = "deepseek_v32" +harness = false +required-features = ["test-util"] + +[[bench]] +name = "qwen3_coder" +harness = false +required-features = ["test-util"] + +[[bench]] +name = "qwen3_xml" +harness = false +required-features = ["test-util"] + +[[bench]] +name = "llama3_json" +harness = false +required-features = ["test-util"] + +[[bench]] +name = "minimax_m2" +harness = false +required-features = ["test-util"] + +[[bench]] +name = "glm45_moe" +harness = false +required-features = ["test-util"] + +[[bench]] +name = "kimi_k2" +harness = false +required-features = ["test-util"] + +[[bench]] +name = "gemma4" +harness = false +required-features = ["test-util"] + +[lints] +workspace = true diff --git a/rust/src/tool-parser/benches/deepseek_v3.rs b/rust/src/tool-parser/benches/deepseek_v3.rs new file mode 100644 index 00000000000..75d2e417ace --- /dev/null +++ b/rust/src/tool-parser/benches/deepseek_v3.rs @@ -0,0 +1,130 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use tool_parser::parsers::DeepSeekParser as ExternalDeepSeekParser; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{DeepSeekV3ToolParser, Tool, ToolParser}; + +mod utils; +use utils::{feed_external_parser, feed_parser, openai_tools}; + +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn mixed_fixture() -> String { + concat!( + "I will check two cities before answering.\n", + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>function<|tool▁sep|>get_weather\n", + "```json\n", + "{\"location\":\"Hangzhou\",\"days\":3}", + "\n```<|tool▁call▁end|>", + "<|tool▁call▁begin|>function<|tool▁sep|>get_weather\n", + "```json\n", + "{\"location\":\"San Francisco\",\"days\":2}", + "\n```<|tool▁call▁end|>", + "<|tool▁calls▁end|>", + ) + .to_string() +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no DeepSeek V3 tool markers at all.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn native_parser(tools: &[Tool]) -> Box { + DeepSeekV3ToolParser::create(tools).expect("DeepSeek V3 parser should initialize") +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + tools: &[Tool], + text: &str, + chunk_chars: usize, + expected_normal_text: &str, + expected_native_calls_len: usize, +) { + let chunks = split_by_chars(text, chunk_chars); + let openai_tools = openai_tools(tools); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("native_reuse_parser", |b| { + let mut parser = native_parser(tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }) + }); + + group.bench_function("native_create_parser", |b| { + b.iter_batched( + || native_parser(tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function("external_reuse_parser", |b| { + let mut parser = ExternalDeepSeekParser::new(); + b.iter(|| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + black_box(result); + }) + }); + + group.bench_function("external_create_parser", |b| { + b.iter_batched( + ExternalDeepSeekParser::new, + |mut parser| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_deepseek_v3(c: &mut Criterion) { + let tools = test_tools(); + let mixed_text = mixed_fixture(); + let long_normal_text = long_normal_text_fixture(); + + run_stream_group( + c, + "deepseek_v3/mixed_text_tool_call", + &tools, + &mixed_text, + CHUNK_CHARS, + "I will check two cities before answering.\n", + 2, + ); + + run_stream_group( + c, + "deepseek_v3/long_normal_text", + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_deepseek_v3); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/deepseek_v31.rs b/rust/src/tool-parser/benches/deepseek_v31.rs new file mode 100644 index 00000000000..bb6d029baff --- /dev/null +++ b/rust/src/tool-parser/benches/deepseek_v31.rs @@ -0,0 +1,128 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use tool_parser::parsers::DeepSeek31Parser as ExternalDeepSeek31Parser; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{DeepSeekV31ToolParser, Tool, ToolParser}; + +mod utils; +use utils::{feed_external_parser, feed_parser, openai_tools}; + +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn mixed_fixture() -> String { + concat!( + "I will check two cities before answering.\n", + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>get_weather<|tool▁sep|>", + "{\"location\":\"Hangzhou\",\"days\":3}", + "<|tool▁call▁end|>", + "<|tool▁call▁begin|>get_weather<|tool▁sep|>", + "{\"location\":\"San Francisco\",\"days\":2}", + "<|tool▁call▁end|>", + "<|tool▁calls▁end|>", + ) + .to_string() +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no DeepSeek V3.1 tool markers at all.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn native_parser(tools: &[Tool]) -> Box { + DeepSeekV31ToolParser::create(tools).expect("DeepSeek V3.1 parser should initialize") +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + tools: &[Tool], + text: &str, + chunk_chars: usize, + expected_normal_text: &str, + expected_native_calls_len: usize, +) { + let chunks = split_by_chars(text, chunk_chars); + let openai_tools = openai_tools(tools); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("native_reuse_parser", |b| { + let mut parser = native_parser(tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }) + }); + + group.bench_function("native_create_parser", |b| { + b.iter_batched( + || native_parser(tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function("external_reuse_parser", |b| { + let mut parser = ExternalDeepSeek31Parser::new(); + b.iter(|| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + black_box(result); + }) + }); + + group.bench_function("external_create_parser", |b| { + b.iter_batched( + ExternalDeepSeek31Parser::new, + |mut parser| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_deepseek_v31(c: &mut Criterion) { + let tools = test_tools(); + let mixed_text = mixed_fixture(); + let long_normal_text = long_normal_text_fixture(); + + run_stream_group( + c, + "deepseek_v31/mixed_text_tool_call", + &tools, + &mixed_text, + CHUNK_CHARS, + "I will check two cities before answering.\n", + 2, + ); + + run_stream_group( + c, + "deepseek_v31/long_normal_text", + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_deepseek_v31); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/deepseek_v32.rs b/rust/src/tool-parser/benches/deepseek_v32.rs new file mode 100644 index 00000000000..c7a8346120d --- /dev/null +++ b/rust/src/tool-parser/benches/deepseek_v32.rs @@ -0,0 +1,113 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{DeepSeekV32ToolParser, Tool, ToolParser}; + +mod utils; +use utils::feed_parser; + +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn mixed_fixture() -> String { + concat!( + "I will check two cities before answering.\n", + "<|DSML|function_calls>\n", + "<|DSML|invoke name=\"get_weather\">\n", + "<|DSML|parameter name=\"location\" string=\"true\">Hangzhou\n", + "<|DSML|parameter name=\"date\" string=\"true\">2026-04-28\n", + "<|DSML|parameter name=\"unit\" string=\"true\">celsius\n", + "<|DSML|parameter name=\"days\" string=\"false\">3\n", + "\n", + "<|DSML|invoke name=\"get_weather\">\n", + "<|DSML|parameter name=\"location\" string=\"true\">San Francisco\n", + "<|DSML|parameter name=\"date\" string=\"true\">2026-04-28\n", + "<|DSML|parameter name=\"unit\" string=\"true\">fahrenheit\n", + "<|DSML|parameter name=\"days\" string=\"false\">2\n", + "\n", + "", + ) + .to_string() +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no DSML tool markers at all.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn parser(tools: &[Tool]) -> Box { + DeepSeekV32ToolParser::create(tools).expect("DeepSeek V3.2 parser should initialize") +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + tools: &[Tool], + text: &str, + chunk_chars: usize, + expected_normal_text: &str, + expected_calls_len: usize, +) { + let chunks = split_by_chars(text, chunk_chars); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("reuse_parser", |b| { + let mut parser = parser(tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_calls_len); + black_box(result); + }) + }); + + group.bench_function("create_parser", |b| { + b.iter_batched( + || parser(tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_deepseek_v32(c: &mut Criterion) { + let tools = test_tools(); + let mixed_text = mixed_fixture(); + let long_normal_text = long_normal_text_fixture(); + + run_stream_group( + c, + "deepseek_v32/mixed_text_tool_call", + &tools, + &mixed_text, + CHUNK_CHARS, + "I will check two cities before answering.\n", + 2, + ); + + run_stream_group( + c, + "deepseek_v32/long_normal_text", + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_deepseek_v32); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/gemma4.rs b/rust/src/tool-parser/benches/gemma4.rs new file mode 100644 index 00000000000..f638c14733a --- /dev/null +++ b/rust/src/tool-parser/benches/gemma4.rs @@ -0,0 +1,125 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{Gemma4ToolParser, Tool, ToolParser}; + +mod utils; +use utils::feed_parser; + +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn mixed_fixture() -> String { + concat!( + "I will inspect the data before answering.\n", + "<|tool_call>", + "call:convert{", + "whole:114.514,", + "flag:true,", + "empty:<|\"|><|\"|>,", + "payload:{", + "name:<|\"|>demo<|\"|>,", + "count:42,", + "enabled:false,", + "missing:null,", + "nested:{level:2,label:<|\"|>deep<|\"|>},", + "tags:[<|\"|>red<|\"|>,<|\"|>blue<|\"|>,3,true,null,{kind:<|\"|>leaf<|\"|>}]", + "},", + "items:[", + "<|\"|>alpha<|\"|>,", + "{key:<|\"|>value<|\"|>,score:0.75},", + "[1,2,3]", + "]", + "}", + "", + "<|tool_call>", + "call:update_record{", + "data:{id:7,active:true,notes:[<|\"|>keep<|\"|>,<|\"|>review<|\"|>]}", + "}", + "", + " Finished.", + ) + .to_string() +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no Gemma4 tool markers at all.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn parser(tools: &[Tool]) -> Box { + Gemma4ToolParser::create(tools).expect("Gemma4 parser should initialize") +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + tools: &[Tool], + text: &str, + chunk_chars: usize, + expected_normal_text: &str, + expected_calls_len: usize, +) { + let chunks = split_by_chars(text, chunk_chars); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("reuse_parser", |b| { + let mut parser = parser(tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_calls_len); + black_box(result); + }) + }); + + group.bench_function("create_parser", |b| { + b.iter_batched( + || parser(tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_gemma4(c: &mut Criterion) { + let tools = test_tools(); + let mixed_text = mixed_fixture(); + let long_normal_text = long_normal_text_fixture(); + + run_stream_group( + c, + "gemma4/mixed_complex_tool_call", + &tools, + &mixed_text, + CHUNK_CHARS, + "I will inspect the data before answering.\n Finished.", + 2, + ); + + run_stream_group( + c, + "gemma4/long_normal_text", + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_gemma4); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/glm45_moe.rs b/rust/src/tool-parser/benches/glm45_moe.rs new file mode 100644 index 00000000000..8486885eceb --- /dev/null +++ b/rust/src/tool-parser/benches/glm45_moe.rs @@ -0,0 +1,210 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use tool_parser::parsers::Glm4MoeParser as ExternalGlm4MoeParser; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{Glm45MoeToolParser, Glm47MoeToolParser, Tool, ToolParser}; + +mod utils; +use utils::{feed_external_parser, feed_parser, openai_tools}; + +const GLM45_PARSER_NAME: &str = "glm45"; +const GLM47_PARSER_NAME: &str = "glm47"; +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn glm45_mixed_fixture() -> String { + concat!( + "I will check two cities before answering.\n", + "get_weather\n", + "city\n", + "Hangzhou\n", + "date\n", + "2026-05-07\n", + "unit\n", + "celsius\n", + "days\n", + "3\n", + "\n", + "get_weather\n", + "city\n", + "San Francisco\n", + "date\n", + "2026-05-07\n", + "unit\n", + "fahrenheit\n", + "days\n", + "2\n", + "", + ) + .to_string() +} + +fn glm47_mixed_fixture() -> String { + concat!( + "I will check two cities before answering.\n", + "get_weather", + "city", + "Hangzhou", + "date", + "2026-05-07", + "unit", + "celsius", + "days", + "3", + "", + "get_weather", + "city", + "San Francisco", + "date", + "2026-05-07", + "unit", + "fahrenheit", + "days", + "2", + "", + ) + .to_string() +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no GLM MoE tool markers at all.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn native_parser(name: &str, tools: &[Tool]) -> Box { + match name { + GLM45_PARSER_NAME => Glm45MoeToolParser::create(tools), + GLM47_PARSER_NAME => Glm47MoeToolParser::create(tools), + _ => unreachable!("unexpected GLM parser name"), + } + .expect("GLM MoE parser should initialize") +} + +fn external_parser(name: &str) -> ExternalGlm4MoeParser { + match name { + GLM45_PARSER_NAME => ExternalGlm4MoeParser::glm45(), + GLM47_PARSER_NAME => ExternalGlm4MoeParser::glm47(), + _ => unreachable!("unexpected GLM parser name"), + } +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + parser_name: &str, + tools: &[Tool], + text: &str, + chunk_chars: usize, + expected_normal_text: &str, + expected_native_calls_len: usize, +) { + let chunks = split_by_chars(text, chunk_chars); + let openai_tools = openai_tools(tools); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("native_reuse_parser", |b| { + let mut parser = native_parser(parser_name, tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }) + }); + + group.bench_function("native_create_parser", |b| { + b.iter_batched( + || native_parser(parser_name, tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function("external_reuse_parser", |b| { + let mut parser = external_parser(parser_name); + b.iter(|| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + black_box(result); + }) + }); + + group.bench_function("external_create_parser", |b| { + b.iter_batched( + || external_parser(parser_name), + |mut parser| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_glm45_moe(c: &mut Criterion) { + let tools = test_tools(); + let glm45_mixed_text = glm45_mixed_fixture(); + let glm47_mixed_text = glm47_mixed_fixture(); + let long_normal_text = long_normal_text_fixture(); + + run_stream_group( + c, + "glm45/mixed_text_tool_call", + GLM45_PARSER_NAME, + &tools, + &glm45_mixed_text, + CHUNK_CHARS, + "I will check two cities before answering.\n", + 2, + ); + + run_stream_group( + c, + "glm47/mixed_text_tool_call", + GLM47_PARSER_NAME, + &tools, + &glm47_mixed_text, + CHUNK_CHARS, + "I will check two cities before answering.\n", + 2, + ); + + run_stream_group( + c, + "glm45/long_normal_text", + GLM45_PARSER_NAME, + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); + + run_stream_group( + c, + "glm47/long_normal_text", + GLM47_PARSER_NAME, + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_glm45_moe); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/kimi_k2.rs b/rust/src/tool-parser/benches/kimi_k2.rs new file mode 100644 index 00000000000..5a80f660673 --- /dev/null +++ b/rust/src/tool-parser/benches/kimi_k2.rs @@ -0,0 +1,149 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use tool_parser::parsers::KimiK2Parser as ExternalKimiK2Parser; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{KimiK2ToolParser, Tool, ToolParser}; + +mod utils; +use utils::{feed_external_parser, feed_parser, openai_tools}; + +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn mixed_fixture() -> String { + concat!( + "I will check two cities before answering.\n", + "<|tool_calls_section_begin|>", + "<|tool_call_begin|>functions.get_weather:0", + "<|tool_call_argument_begin|>{\"location\":\"Hangzhou\",\"days\":3}", + "<|tool_call_end|>", + "<|tool_call_begin|>functions.get_weather:1", + "<|tool_call_argument_begin|>{\"location\":\"San Francisco\",\"days\":2}", + "<|tool_call_end|>", + "<|tool_calls_section_end|>", + ) + .to_string() +} + +fn mixed_chunks() -> Vec<&'static str> { + vec![ + "I will check two cities before answering.\n", + "<|tool_calls_section_begin|>", + "<|tool_call_begin|>functions.get_weather:0", + "<|tool_call_argument_begin|>", + "{\"location\":", + "\"Hangzhou\",", + "\"days\":3}", + "<|tool_call_end|>", + "<|tool_call_begin|>functions.get_weather:1", + "<|tool_call_argument_begin|>", + "{\"location\":", + "\"San Francisco\",", + "\"days\":2}", + "<|tool_call_end|>", + "<|tool_calls_section_end|>", + ] +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no Kimi K2 tool markers at all.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn native_parser(tools: &[Tool]) -> Box { + KimiK2ToolParser::create(tools).expect("Kimi K2 parser should initialize") +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + tools: &[Tool], + text: &str, + chunks: &[&str], + expected_normal_text: &str, + expected_native_calls_len: usize, +) { + let openai_tools = openai_tools(tools); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("native_reuse_parser", |b| { + let mut parser = native_parser(tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }) + }); + + group.bench_function("native_create_parser", |b| { + b.iter_batched( + || native_parser(tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function("external_reuse_parser", |b| { + let mut parser = ExternalKimiK2Parser::new(); + b.iter(|| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(chunks)); + black_box(result); + }) + }); + + group.bench_function("external_create_parser", |b| { + b.iter_batched( + ExternalKimiK2Parser::new, + |mut parser| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(chunks)); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_kimi_k2(c: &mut Criterion) { + let tools = test_tools(); + let mixed_text = mixed_fixture(); + let mixed_chunks = mixed_chunks(); + let long_normal_text = long_normal_text_fixture(); + let long_normal_chunks = split_by_chars(&long_normal_text, CHUNK_CHARS); + + run_stream_group( + c, + "kimi_k2/mixed_text_tool_call", + &tools, + &mixed_text, + &mixed_chunks, + "I will check two cities before answering.\n", + 2, + ); + + run_stream_group( + c, + "kimi_k2/long_normal_text", + &tools, + &long_normal_text, + &long_normal_chunks, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_kimi_k2); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/llama3_json.rs b/rust/src/tool-parser/benches/llama3_json.rs new file mode 100644 index 00000000000..03b5b54ee78 --- /dev/null +++ b/rust/src/tool-parser/benches/llama3_json.rs @@ -0,0 +1,128 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use tool_parser::parsers::LlamaParser as ExternalLlamaParser; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{Llama3JsonToolParser, Tool, ToolParser}; + +mod utils; +use utils::{feed_external_parser, feed_parser, openai_tools}; + +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn tool_call(function_name: &str, parameters: &str) -> String { + format!(r#"{{"name":"{function_name}","parameters":{parameters}}}"#) +} + +fn mixed_fixture() -> String { + format!( + "{}; {}", + tool_call("get_weather", r#"{"location":"Hangzhou","days":3}"#), + tool_call( + "convert", + r#"{"whole":42.5,"flag":true,"payload":{"nested":["x",null]},"items":[1,2,3],"empty":""}"# + ), + ) +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no Llama JSON tool call at the root.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn native_parser(tools: &[Tool]) -> Box { + Llama3JsonToolParser::create(tools).expect("Llama JSON parser should initialize") +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + tools: &[Tool], + text: &str, + chunk_chars: usize, + expected_normal_text: &str, + expected_native_calls_len: usize, +) { + let chunks = split_by_chars(text, chunk_chars); + let openai_tools = openai_tools(tools); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("native_reuse_parser", |b| { + let mut parser = native_parser(tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }) + }); + + group.bench_function("native_create_parser", |b| { + b.iter_batched( + || native_parser(tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function("external_reuse_parser", |b| { + let mut parser = ExternalLlamaParser::new(); + b.iter(|| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + black_box(result); + }) + }); + + group.bench_function("external_create_parser", |b| { + b.iter_batched( + ExternalLlamaParser::new, + |mut parser| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_llama3_json(c: &mut Criterion) { + let tools = test_tools(); + let mixed_text = mixed_fixture(); + let long_normal_text = long_normal_text_fixture(); + + run_stream_group( + c, + "llama3_json/mixed_text_tool_call", + &tools, + &mixed_text, + CHUNK_CHARS, + "", + 2, + ); + + run_stream_group( + c, + "llama3_json/long_normal_text", + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_llama3_json); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/minimax_m2.rs b/rust/src/tool-parser/benches/minimax_m2.rs new file mode 100644 index 00000000000..4ad20400934 --- /dev/null +++ b/rust/src/tool-parser/benches/minimax_m2.rs @@ -0,0 +1,136 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use tool_parser::parsers::MinimaxM2Parser as ExternalMinimaxM2Parser; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{MinimaxM2ToolParser, Tool, ToolParser}; + +mod utils; +use utils::{feed_external_parser, feed_parser, openai_tools}; + +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn mixed_fixture() -> String { + concat!( + "I will check two cities before answering.\n", + "", + "", + "Hangzhou", + "2026-04-30", + "celsius", + "3", + "", + "", + "San Francisco", + "2026-04-30", + "fahrenheit", + "2", + "", + "", + ) + .to_string() +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no MiniMax M2 tool markers at all.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn native_parser(tools: &[Tool]) -> Box { + MinimaxM2ToolParser::create(tools).expect("MiniMax M2 parser should initialize") +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + tools: &[Tool], + text: &str, + chunk_chars: usize, + expected_normal_text: &str, + expected_native_calls_len: usize, +) { + let chunks = split_by_chars(text, chunk_chars); + let openai_tools = openai_tools(tools); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("native_reuse_parser", |b| { + let mut parser = native_parser(tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }) + }); + + group.bench_function("native_create_parser", |b| { + b.iter_batched( + || native_parser(tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function("external_reuse_parser", |b| { + let mut parser = ExternalMinimaxM2Parser::new(); + b.iter(|| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + black_box(result); + }) + }); + + group.bench_function("external_create_parser", |b| { + b.iter_batched( + ExternalMinimaxM2Parser::new, + |mut parser| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_minimax_m2(c: &mut Criterion) { + let tools = test_tools(); + let mixed_text = mixed_fixture(); + let long_normal_text = long_normal_text_fixture(); + + run_stream_group( + c, + "minimax_m2/mixed_text_tool_call", + &tools, + &mixed_text, + CHUNK_CHARS, + "I will check two cities before answering.\n", + 2, + ); + + run_stream_group( + c, + "minimax_m2/long_normal_text", + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_minimax_m2); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/qwen3_coder.rs b/rust/src/tool-parser/benches/qwen3_coder.rs new file mode 100644 index 00000000000..850badaac52 --- /dev/null +++ b/rust/src/tool-parser/benches/qwen3_coder.rs @@ -0,0 +1,138 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use tool_parser::parsers::QwenCoderParser as ExternalQwenCoderParser; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{Qwen3CoderToolParser, Tool, ToolParser}; + +mod utils; +use utils::{feed_external_parser, feed_parser, openai_tools}; + +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn mixed_fixture() -> String { + concat!( + "I will check two cities before answering.\n", + "\n", + "\n", + "Hangzhou\n", + "2026-04-29\n", + "celsius\n", + "3\n", + "\n", + "\n", + "\n", + "\n", + "San Francisco\n", + "2026-04-29\n", + "fahrenheit\n", + "2\n", + "\n", + "", + ) + .to_string() +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no Qwen Coder tool markers at all.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn native_parser(tools: &[Tool]) -> Box { + Qwen3CoderToolParser::create(tools).expect("Qwen Coder parser should initialize") +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + tools: &[Tool], + text: &str, + chunk_chars: usize, + expected_normal_text: &str, + expected_native_calls_len: usize, +) { + let chunks = split_by_chars(text, chunk_chars); + let openai_tools = openai_tools(tools); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("native_reuse_parser", |b| { + let mut parser = native_parser(tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }) + }); + + group.bench_function("native_create_parser", |b| { + b.iter_batched( + || native_parser(tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function("external_reuse_parser", |b| { + let mut parser = ExternalQwenCoderParser::new(); + b.iter(|| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + black_box(result); + }) + }); + + group.bench_function("external_create_parser", |b| { + b.iter_batched( + ExternalQwenCoderParser::new, + |mut parser| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_qwen3_coder(c: &mut Criterion) { + let tools = test_tools(); + let mixed_text = mixed_fixture(); + let long_normal_text = long_normal_text_fixture(); + + run_stream_group( + c, + "qwen3_coder/mixed_text_tool_call", + &tools, + &mixed_text, + CHUNK_CHARS, + "I will check two cities before answering.\n", + 2, + ); + + run_stream_group( + c, + "qwen3_coder/long_normal_text", + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_qwen3_coder); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/qwen3_xml.rs b/rust/src/tool-parser/benches/qwen3_xml.rs new file mode 100644 index 00000000000..f2e37551dda --- /dev/null +++ b/rust/src/tool-parser/benches/qwen3_xml.rs @@ -0,0 +1,125 @@ +use std::time::Duration; + +use criterion::{BatchSize, Criterion, Throughput, black_box, criterion_group, criterion_main}; +use tool_parser::parsers::QwenParser as ExternalQwenParser; +use vllm_tool_parser::test_utils::{split_by_chars, test_tools}; +use vllm_tool_parser::{Qwen3XmlToolParser, Tool, ToolParser}; + +mod utils; +use utils::{feed_external_parser, feed_parser, openai_tools}; + +const CHUNK_CHARS: usize = 7; +const LONG_NORMAL_TEXT_REPEATS: usize = 2048; + +fn tool_call(function_name: &str, arguments: &str) -> String { + format!("\n{{\"name\":\"{function_name}\",\"arguments\":{arguments}}}\n") +} + +fn mixed_fixture() -> String { + format!( + "I will check two cities before answering.\n{}{}", + tool_call("get_weather", r#"{"location":"Hangzhou","days":3}"#), + tool_call("get_weather", r#"{"location":"San Francisco","days":2}"#), + ) +} + +fn long_normal_text_fixture() -> String { + let line = "This is ordinary assistant text with no Qwen XML tool markers at all.\n"; + line.repeat(LONG_NORMAL_TEXT_REPEATS) +} + +fn native_parser(tools: &[Tool]) -> Box { + Qwen3XmlToolParser::create(tools).expect("Qwen XML parser should initialize") +} + +fn run_stream_group( + c: &mut Criterion, + name: &str, + tools: &[Tool], + text: &str, + chunk_chars: usize, + expected_normal_text: &str, + expected_native_calls_len: usize, +) { + let chunks = split_by_chars(text, chunk_chars); + let openai_tools = openai_tools(tools); + + let mut group = c.benchmark_group(name); + group.sample_size(50); + group.warm_up_time(Duration::from_millis(300)); + group.measurement_time(Duration::from_secs(2)); + group.throughput(Throughput::Bytes(text.len() as u64)); + + group.bench_function("native_reuse_parser", |b| { + let mut parser = native_parser(tools); + b.iter(|| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }) + }); + + group.bench_function("native_create_parser", |b| { + b.iter_batched( + || native_parser(tools), + |mut parser| { + let result = feed_parser(&mut *parser, black_box(&chunks)); + debug_assert_eq!(result.0, expected_normal_text); + debug_assert_eq!(result.1, expected_native_calls_len); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function("external_reuse_parser", |b| { + let mut parser = ExternalQwenParser::new(); + b.iter(|| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + black_box(result); + }) + }); + + group.bench_function("external_create_parser", |b| { + b.iter_batched( + ExternalQwenParser::new, + |mut parser| { + let result = feed_external_parser(&mut parser, &openai_tools, black_box(&chunks)); + black_box(result); + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +fn bench_qwen3_xml(c: &mut Criterion) { + let tools = test_tools(); + let mixed_text = mixed_fixture(); + let long_normal_text = long_normal_text_fixture(); + + run_stream_group( + c, + "qwen3_xml/mixed_text_tool_call", + &tools, + &mixed_text, + CHUNK_CHARS, + "I will check two cities before answering.\n", + 2, + ); + + run_stream_group( + c, + "qwen3_xml/long_normal_text", + &tools, + &long_normal_text, + CHUNK_CHARS, + &long_normal_text, + 0, + ); +} + +criterion_group!(benches, bench_qwen3_xml); +criterion_main!(benches); diff --git a/rust/src/tool-parser/benches/utils/mod.rs b/rust/src/tool-parser/benches/utils/mod.rs new file mode 100644 index 00000000000..a0ad768f115 --- /dev/null +++ b/rust/src/tool-parser/benches/utils/mod.rs @@ -0,0 +1,49 @@ +#![allow(dead_code)] + +use futures::FutureExt as _; +use openai_protocol::common::{Function as OpenAiFunction, Tool as OpenAiTool}; +use tool_parser::traits::ToolParser as ExternalToolParser; +use vllm_tool_parser::test_utils::collect_stream; +use vllm_tool_parser::{Tool, ToolParser}; + +pub(super) fn openai_tools(tools: &[Tool]) -> Vec { + tools + .iter() + .map(|tool| OpenAiTool { + tool_type: "function".to_string(), + function: OpenAiFunction { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + strict: tool.strict, + }, + }) + .collect() +} + +pub(super) fn feed_parser(parser: &mut dyn ToolParser, chunks: &[&str]) -> (String, usize) { + let result = collect_stream(parser, chunks); + (result.normal_text, result.calls.len()) +} + +pub(super) fn feed_external_parser( + parser: &mut impl ExternalToolParser, + tools: &[OpenAiTool], + chunks: &[&str], +) -> (String, usize) { + ExternalToolParser::reset(parser); + + let mut normal_text = String::new(); + let mut calls_len = 0; + for chunk in chunks { + let delta = parser + .parse_incremental(chunk, tools) + .now_or_never() + .expect("external parser should not suspend") + .expect("chunk should parse"); + normal_text.push_str(&delta.normal_text); + calls_len += delta.calls.len(); + } + calls_len += parser.get_unstreamed_tool_args().unwrap_or_default().len(); + (normal_text, calls_len) +} diff --git a/rust/src/tool-parser/src/deepseek_dsml/deepseek_v32.rs b/rust/src/tool-parser/src/deepseek_dsml/deepseek_v32.rs new file mode 100644 index 00000000000..ddc630ceab4 --- /dev/null +++ b/rust/src/tool-parser/src/deepseek_dsml/deepseek_v32.rs @@ -0,0 +1,439 @@ +use super::{DeepSeekDsmlToolParser, DsmlTokens}; +use crate::{Result, Tool, ToolParseResult, ToolParser}; + +/// Tool parser for DeepSeek V3.2 models. +/// +/// Example tool call content: +/// +/// ```text +/// <|DSML|function_calls> +/// <|DSML|invoke name="get_weather"> +/// <|DSML|parameter name="location" string="true">杭州 +/// <|DSML|parameter name="date" string="true">2024-01-16 +/// +/// <|DSML|invoke name="get_weather"> +/// <|DSML|parameter name="location" string="true">北京 +/// <|DSML|parameter name="date" string="true">2024-01-16 +/// +/// +/// ``` +/// +/// Arguments are emitted only after a full `invoke` block is parsed. +/// +/// DeepSeek V3.2 relies on DSML markers such as `|DSML|`, which are +/// represented as special tokens in the tokenizer and therefore must be +/// preserved during decode for parsing to work. +pub struct DeepSeekV32ToolParser(DeepSeekDsmlToolParser); + +impl DeepSeekV32ToolParser { + /// Create a DeepSeek V3.2 tool parser. + pub(super) fn new(tools: &[Tool]) -> Self { + Self(DeepSeekDsmlToolParser::new(tools, DsmlTokens::V32)) + } +} + +impl ToolParser for DeepSeekV32ToolParser { + /// Create a boxed DeepSeek V3.2 tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Preserve DSML special tokens while decoding. + fn preserve_special_tokens(&self) -> bool { + true + } + + /// Push one decoded text chunk through the DSML parser. + fn push(&mut self, chunk: &str) -> Result { + self.0.push(chunk) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + self.0.finish() + } +} + +#[cfg(test)] +mod tests { + use serde_json::{Value, json}; + use thiserror_ext::AsReport; + + use super::DeepSeekV32ToolParser; + use crate::ToolParser; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + + fn build_tool_call(function_name: &str, params: &[(&str, &str)]) -> String { + let params = params + .iter() + .map(|(name, value)| { + format!( + r#"<|DSML|parameter name="{name}" string="true">{value}"# + ) + }) + .collect::>() + .join("\n"); + format!( + "<|DSML|function_calls>\n<|DSML|invoke name=\"{function_name}\">\n{params}\n\n" + ) + } + + #[test] + fn deepseek_v32_parse_complete_without_tool_call_keeps_text() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn deepseek_v32_parse_complete_extracts_single_tool_call() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_call( + "get_weather", + &[("location", "SF"), ("date", "2024-01-16")], + )) + .unwrap(); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "location": "SF", + "date": "2024-01-16" + }) + ); + } + + #[test] + fn deepseek_v32_parse_complete_preserves_prefix_text() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let output = format!( + "Thinking... {}", + build_tool_call("get_weather", &[("location", "NYC")]) + ); + let result = parser.parse_complete(&output).unwrap(); + + assert_eq!(result.normal_text, "Thinking... "); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn deepseek_v32_parse_complete_converts_schema_types() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = parser + .parse_complete( + "<|DSML|function_calls>\n\ + <|DSML|invoke name=\"convert\">\n\ + <|DSML|parameter name=\"whole\" string=\"false\">5.0\n\ + <|DSML|parameter name=\"flag\" string=\"false\">true\n\ + <|DSML|parameter name=\"payload\" string=\"false\">{\"nested\":true}\n\ + <|DSML|parameter name=\"items\" string=\"false\">[1,2]\n\ + <|DSML|parameter name=\"empty\" string=\"false\">null\n\ + \n\ + ", + ) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "whole": 5.0, + "flag": true, + "payload": { "nested": true }, + "items": [1, 2], + "empty": null, + }) + ); + } + + #[test] + fn deepseek_v32_parse_complete_string_attr_overrides_schema_types() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = parser + .parse_complete( + "<|DSML|function_calls>\n\ + <|DSML|invoke name=\"convert\">\n\ + <|DSML|parameter name=\"whole\" string=\"true\">5.0\n\ + <|DSML|parameter name=\"flag\" string=\"true\">true\n\ + <|DSML|parameter name=\"payload\" string=\"true\">{\"nested\":true}\n\ + <|DSML|parameter name=\"items\" string=\"true\">[1,2]\n\ + <|DSML|parameter name=\"empty\" string=\"true\">null\n\ + \n\ + ", + ) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "whole": "5.0", + "flag": "true", + "payload": "{\"nested\":true}", + "items": "[1,2]", + "empty": "null", + }) + ); + } + + #[test] + fn deepseek_v32_parse_complete_unescapes_literal_closing_tags_in_parameter_value() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_call( + "get_weather", + &[ + ( + "location", + "Hangzhou </|DSML|parameter></|DSML|invoke></|DSML|function_calls>", + ), + ("date", "2026-05-08"), + ], + )) + .unwrap(); + + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "location": "Hangzhou ", + "date": "2026-05-08", + }) + ); + } + + #[test] + fn deepseek_v32_streaming_extracts_single_tool_call() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "<|DSML|function_calls>\n", + "<|DSML|invoke name=\"get_weather\">\n", + "<|DSML|parameter name=\"location\" string=\"true\">SF\n", + "\n", + "", + ], + ); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "location": "SF" }) + ); + } + + #[test] + fn deepseek_v32_streaming_preserves_prefix_text() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "Thinking... ", + "<|DSML|function_calls>\n", + "<|DSML|invoke name=\"get_weather\">\n", + "<|DSML|parameter name=\"location\" string=\"true\">SF\n", + "\n", + "", + ], + ); + + assert_eq!(result.normal_text, "Thinking... "); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn deepseek_v32_streaming_without_tool_call_emits_text_incrementally() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &["Hello, ", "world!"]); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn deepseek_v32_streaming_extracts_multiple_tool_calls_in_order() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[&format!( + "{}\n{}", + build_tool_call("get_weather", &[("location", "SF")]) + .trim_end_matches(""), + "<|DSML|invoke name=\"get_weather\">\n<|DSML|parameter name=\"location\" string=\"true\">NYC\n\n" + )], + ); + + assert_eq!(result.calls.len(), 2); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[1].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[1].tool_index, 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "location": "SF" }) + ); + assert_eq!( + serde_json::from_str::(&result.calls[1].arguments).unwrap(), + json!({ "location": "NYC" }) + ); + } + + #[test] + fn deepseek_v32_streaming_handles_start_token_split_across_chunks() { + let text = build_tool_call("get_weather", &[("location", "SF")]); + let chunks = split_by_chars(&text, 5); + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "location": "SF" }) + ); + } + + #[test] + fn deepseek_v32_streaming_handles_bpe_chunked_dsml_opener() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "<|DSML|", + "function", + "_c", + "all", + "s", + ">\n", + "<|DSML|", + "invoke", + " name=\"", + "get_weather", + "\">\n", + "<|DSML|", + "parameter", + " name=\"location\" string=\"true\">", + "Beijing", + "\n", + "\n", + "", + ], + ); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "location": "Beijing" }) + ); + } + + #[test] + fn deepseek_v32_streaming_truncated_parameter_does_not_leak_eos() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + parser.push("<|DSML|function_calls>\n").unwrap(); + parser.push("<|DSML|invoke name=\"get_weather\">\n").unwrap(); + parser + .push("<|DSML|parameter name=\"location\" string=\"true\">Tokyo") + .unwrap(); + parser.push("<|end▁of▁sentence|>").unwrap(); + + let error = parser.finish().unwrap_err(); + assert!(error.to_report_string().contains("incomplete DeepSeek DSML tool call")); + } + #[test] + fn deepseek_v32_streaming_drops_eos_after_complete_tool_calls() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "<|DSML|function_calls>\n", + "<|DSML|invoke name=\"get_weather\">\n", + "<|DSML|parameter name=\"location\" string=\"true\">SF\n", + "\n", + "<|end▁of▁sentence|>", + ], + ); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + } + + #[test] + fn deepseek_v32_streaming_ignores_text_after_complete_tool_calls() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "<|DSML|function_calls>\n", + "<|DSML|invoke name=\"get_weather\">\n", + "<|DSML|parameter name=\"location\" string=\"true\">SF\n", + "\n", + "", + "trailing text", + ], + ); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn deepseek_v32_streaming_does_not_emit_incomplete_invoke() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + parser.push("<|DSML|function_calls>\n").unwrap(); + parser.push("<|DSML|invoke name=\"get_weather\">\n").unwrap(); + parser + .push("<|DSML|parameter name=\"location\" string=\"true\">SF\n") + .unwrap(); + + let error = parser.finish().unwrap_err(); + assert!(error.to_report_string().contains("incomplete DeepSeek DSML tool call")); + } + #[test] + fn deepseek_v32_parser_state_resets_after_finish() { + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let first = parser + .parse_complete(&build_tool_call("get_weather", &[("location", "SF")])) + .unwrap(); + let second = parser + .parse_complete(&build_tool_call("get_weather", &[("location", "NYC")])) + .unwrap(); + + assert_eq!(first.calls.len(), 1); + assert_eq!(second.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&second.calls[0].arguments).unwrap(), + json!({ "location": "NYC" }) + ); + } + + #[test] + fn deepseek_v32_streaming_matches_parse_complete() { + let full_text = build_tool_call("add", &[("x", "3"), ("y", "4")]); + let chunks = split_by_chars(&full_text, 7); + let mut streaming_parser = DeepSeekV32ToolParser::new(&test_tools()); + let streamed = collect_stream(&mut streaming_parser, &chunks); + + let mut parser = DeepSeekV32ToolParser::new(&test_tools()); + let complete = parser.parse_complete(&full_text).unwrap(); + + assert_eq!(streamed.normal_text, complete.normal_text); + assert_eq!(streamed.calls, complete.calls); + } +} diff --git a/rust/src/tool-parser/src/deepseek_dsml/deepseek_v4.rs b/rust/src/tool-parser/src/deepseek_dsml/deepseek_v4.rs new file mode 100644 index 00000000000..20fe23b95f8 --- /dev/null +++ b/rust/src/tool-parser/src/deepseek_dsml/deepseek_v4.rs @@ -0,0 +1,133 @@ +use super::{DeepSeekDsmlToolParser, DsmlTokens}; +use crate::{Result, Tool, ToolParseResult, ToolParser}; + +/// Tool parser for DeepSeek V4 models. +/// +/// Example tool call content: +/// +/// ```text +/// <|DSML|tool_calls> +/// <|DSML|invoke name="get_weather"> +/// <|DSML|parameter name="location" string="true">杭州 +/// <|DSML|parameter name="date" string="true">2024-01-16 +/// +/// <|DSML|invoke name="get_weather"> +/// <|DSML|parameter name="location" string="true">北京 +/// <|DSML|parameter name="date" string="true">2024-01-16 +/// +/// +/// ``` +/// +/// Arguments are emitted only after a full `invoke` block is parsed. +/// +/// V4 reuses the V3.2 DSML invoke/parameter grammar but wraps calls in +/// `<|DSML|tool_calls>` instead of `<|DSML|function_calls>`. +/// +/// DeepSeek V4 relies on DSML markers such as `|DSML|`, which are +/// represented as special tokens in the tokenizer and therefore must be +/// preserved during decode for parsing to work. +pub struct DeepSeekV4ToolParser(DeepSeekDsmlToolParser); + +impl DeepSeekV4ToolParser { + /// Create a DeepSeek V4 tool parser. + fn new(tools: &[Tool]) -> Self { + Self(DeepSeekDsmlToolParser::new(tools, DsmlTokens::V4)) + } +} + +impl ToolParser for DeepSeekV4ToolParser { + /// Create a boxed DeepSeek V4 tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Preserve DSML special tokens while decoding. + fn preserve_special_tokens(&self) -> bool { + true + } + + /// Push one decoded text chunk through the DSML parser. + fn push(&mut self, chunk: &str) -> Result { + self.0.push(chunk) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + self.0.finish() + } +} + +#[cfg(test)] +mod tests { + use serde_json::{Value, json}; + + use super::{DeepSeekV4ToolParser, ToolParser}; + use crate::test_utils::{collect_stream, test_tools}; + + fn build_tool_call(function_name: &str, params: &[(&str, &str)]) -> String { + let params = params + .iter() + .map(|(name, value)| { + format!( + r#"<|DSML|parameter name="{name}" string="true">{value}"# + ) + }) + .collect::>() + .join("\n"); + format!( + "<|DSML|tool_calls>\n<|DSML|invoke name=\"{function_name}\">\n{params}\n\n" + ) + } + + #[test] + fn deepseek_v4_parse_complete_reuses_dsml_parser_with_tool_calls_token() { + let mut parser = DeepSeekV4ToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_call( + "get_weather", + &[("location", "SF"), ("date", "2024-01-16")], + )) + .unwrap(); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "location": "SF", + "date": "2024-01-16" + }) + ); + } + + #[test] + fn deepseek_v4_streaming_handles_tool_calls_token_split_across_chunks() { + let mut parser = DeepSeekV4ToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "Thinking... ", + "<|DSML|", + "tool", + "_calls>\n", + "<|DSML|invoke name=\"get_weather\">\n", + "<|DSML|parameter name=\"location\" string=\"true\">Beijing\n", + "\n", + "", + ], + ); + + assert_eq!(result.normal_text, "Thinking... "); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "location": "Beijing" }) + ); + } +} diff --git a/rust/src/tool-parser/src/deepseek_dsml/mod.rs b/rust/src/tool-parser/src/deepseek_dsml/mod.rs new file mode 100644 index 00000000000..dda0b2368ce --- /dev/null +++ b/rust/src/tool-parser/src/deepseek_dsml/mod.rs @@ -0,0 +1,278 @@ +use winnow::ascii::{multispace0 as ws0, multispace1 as ws1}; +use winnow::combinator::{alt, delimited, eof, repeat, seq, terminated}; +use winnow::prelude::*; +use winnow::stream::Partial; +use winnow::token::{literal, rest, take_until}; + +use super::parameters::ToolSchemas; +use super::utils::{parse_buffered_event, safe_text_len, xml_unescape}; +use super::{Result, ToolCallDelta, ToolParseResult}; +use crate::Tool; + +mod deepseek_v32; +mod deepseek_v4; + +pub use deepseek_v4::DeepSeekV4ToolParser; +pub use deepseek_v32::DeepSeekV32ToolParser; + +const INVOKE_START: &str = "<|DSML|invoke"; +const INVOKE_END: &str = ""; +const PARAMETER_START: &str = "<|DSML|parameter"; +const PARAMETER_END: &str = ""; + +type DsmlInput<'i> = Partial<&'i str>; + +#[derive(Debug, Clone, Copy)] +struct DsmlTokens { + tool_calls_start: &'static str, + tool_calls_end: &'static str, +} + +impl DsmlTokens { + const V32: Self = Self { + tool_calls_start: "<|DSML|function_calls>", + tool_calls_end: "", + }; + const V4: Self = Self { + tool_calls_start: "<|DSML|tool_calls>", + tool_calls_end: "", + }; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum DsmlMode { + Text, + ToolBlock, + Done, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum DsmlEvent { + Text { + len: usize, + }, + ToolCallsStart, + Invoke { + name: String, + raw_params: Vec, + }, + ToolCallsEnd, + IgnoredRest, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct DsmlParameter { + name: String, + value: String, + is_string: bool, +} + +/// Tool parser core for DeepSeek DSML tool calls. +struct DeepSeekDsmlToolParser { + buffer: String, + mode: DsmlMode, + emitted_invoke_count: usize, + tool_parameters: ToolSchemas, + tokens: DsmlTokens, +} + +impl DeepSeekDsmlToolParser { + /// Create a parser with DSML tokens for one DeepSeek format. + fn new(tools: &[Tool], tokens: DsmlTokens) -> Self { + Self { + buffer: String::new(), + mode: DsmlMode::Text, + emitted_invoke_count: 0, + tool_parameters: ToolSchemas::from_tools(tools), + tokens, + } + } + + /// Apply one parsed DSML event to parser state and output. + fn apply_event(&mut self, event: DsmlEvent, result: &mut ToolParseResult) -> Result<()> { + match event { + DsmlEvent::Text { len: consumed_len } => { + result.normal_text.push_str(&self.buffer[..consumed_len]); + } + DsmlEvent::ToolCallsStart => self.mode = DsmlMode::ToolBlock, + DsmlEvent::Invoke { name, raw_params } => { + let mut arguments = serde_json::Map::with_capacity(raw_params.len()); + for param in raw_params { + let value = if param.is_string { + serde_json::Value::String(param.value) + } else { + self.tool_parameters.convert_param_with_schema( + &name, + ¶m.name, + ¶m.value, + ) + }; + arguments.insert(param.name, value); + } + let arguments = serde_json::to_string(&arguments) + .map_err(|error| parsing_failed!("failed to serialize arguments: {}", error))?; + + result.calls.push(ToolCallDelta { + tool_index: self.emitted_invoke_count, + name: Some(name), + arguments, + }); + self.emitted_invoke_count += 1; + } + DsmlEvent::ToolCallsEnd => self.mode = DsmlMode::Done, + DsmlEvent::IgnoredRest => {} + }; + Ok(()) + } + + /// Reset all streaming state. + fn reset(&mut self) { + self.buffer.clear(); + self.mode = DsmlMode::Text; + self.emitted_invoke_count = 0; + } + + /// Push one decoded text chunk through the DSML parser. + fn push(&mut self, chunk: &str) -> Result { + // Extract tool calls from streaming model output. + // + // Uses a buffer-until-complete-invoke strategy: text is buffered until + // a complete invoke block is available, then parsed and emitted in one + // shot. + self.buffer.push_str(chunk); + let mut result = ToolParseResult::default(); + + while let Some((event, consumed_len)) = parse_buffered_event(&self.buffer, |input| { + parse_next_dsml_event(input, self.mode, self.tokens) + })? { + self.apply_event(event, &mut result)?; + self.buffer.drain(..consumed_len); + } + + Ok(result) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + let mut result = ToolParseResult::default(); + match self.mode { + DsmlMode::Text => result.normal_text.push_str(&self.buffer), + DsmlMode::Done => {} + DsmlMode::ToolBlock => { + self.reset(); + return Err(parsing_failed!("incomplete DeepSeek DSML tool call")); + } + } + self.reset(); + Ok(result) + } +} + +/// Parse a DSML event for the current parser mode. +fn parse_next_dsml_event( + input: &mut DsmlInput<'_>, + mode: DsmlMode, + tokens: DsmlTokens, +) -> ModalResult { + match mode { + DsmlMode::Text => parse_text_event(input, tokens), + DsmlMode::ToolBlock => parse_tool_block_event(input, tokens), + DsmlMode::Done => ignored_rest_event(input), + } +} + +/// Parse a text-mode DSML event. +fn parse_text_event(input: &mut DsmlInput<'_>, tokens: DsmlTokens) -> ModalResult { + alt(( + |input: &mut DsmlInput<'_>| tool_calls_start_event(input, tokens), + |input: &mut DsmlInput<'_>| safe_text_event(input, tokens), + )) + .parse_next(input) +} + +/// Parse a tool-block DSML event. +fn parse_tool_block_event(input: &mut DsmlInput<'_>, tokens: DsmlTokens) -> ModalResult { + ws0.void().parse_next(input)?; + alt((invoke_event, |input: &mut DsmlInput<'_>| { + tool_calls_end_event(input, tokens) + })) + .parse_next(input) +} + +/// Parse a DSML function-calls start marker. +fn tool_calls_start_event(input: &mut DsmlInput<'_>, tokens: DsmlTokens) -> ModalResult { + literal(tokens.tool_calls_start) + .value(DsmlEvent::ToolCallsStart) + .parse_next(input) +} + +/// Parse a DSML function-calls end marker. +fn tool_calls_end_event(input: &mut DsmlInput<'_>, tokens: DsmlTokens) -> ModalResult { + literal(tokens.tool_calls_end).value(DsmlEvent::ToolCallsEnd).parse_next(input) +} + +/// Parse a trailing rest after DSML function calls. +fn ignored_rest_event(input: &mut DsmlInput<'_>) -> ModalResult { + rest.value(DsmlEvent::IgnoredRest).parse_next(input) +} + +/// Parse a safe text run before the next DSML marker. +fn safe_text_event(input: &mut DsmlInput<'_>, tokens: DsmlTokens) -> ModalResult { + safe_text_len(input, tokens.tool_calls_start).map(|len| DsmlEvent::Text { len }) +} + +/// Parse a DSML invoke block. +fn invoke_event(input: &mut DsmlInput<'_>) -> ModalResult { + let (name, body) = seq!( + _: literal(INVOKE_START), + _: ws1, + dsml_name_attr, + _: ws0, + _: ">", + take_until(0.., INVOKE_END), + _: literal(INVOKE_END), + ) + .parse_next(input)?; + let raw_params = parse_invoke_params(body)?; + Ok(DsmlEvent::Invoke { + name: name.to_string(), + raw_params, + }) +} + +/// Parse a DSML invoke body. +fn parse_invoke_params(invoke_body: &str) -> ModalResult> { + let mut input = invoke_body; + delimited(ws0, repeat(0.., terminated(parse_parameter, ws0)), eof).parse_next(&mut input) +} + +/// Parse a DSML parameter block. +fn parse_parameter(input: &mut &str) -> ModalResult { + seq! {DsmlParameter { + _: literal(PARAMETER_START), + _: ws1, + name: name_attr.map(|name: &str| name.to_string()), + _: ws1, + is_string: string_attr.map(|value| value == "true"), + _: ws0, + _: ">", + value: take_until(0.., PARAMETER_END).map(xml_unescape).map(|value| value.into_owned()), + _: literal(PARAMETER_END), + }} + .parse_next(input) +} + +/// Parse a name attribute. +fn name_attr<'i>(input: &mut &'i str) -> ModalResult<&'i str> { + delimited("name=\"", take_until(1.., "\""), "\"").parse_next(input) +} + +/// Parse a string attribute. +fn string_attr<'i>(input: &mut &'i str) -> ModalResult<&'i str> { + delimited("string=\"", alt(("true", "false")), "\"").parse_next(input) +} + +/// Parse a DSML name attribute. +fn dsml_name_attr<'i>(input: &mut DsmlInput<'i>) -> ModalResult<&'i str> { + delimited("name=\"", take_until(1.., "\""), "\"").parse_next(input) +} diff --git a/rust/src/tool-parser/src/deepseek_json/deepseek_v3.rs b/rust/src/tool-parser/src/deepseek_json/deepseek_v3.rs new file mode 100644 index 00000000000..3481951b8c2 --- /dev/null +++ b/rust/src/tool-parser/src/deepseek_json/deepseek_v3.rs @@ -0,0 +1,235 @@ +use super::{DeepSeekJsonFormat, DeepSeekJsonToolParser}; +use crate::{Result, Tool, ToolParseResult, ToolParser}; + +/// Tool parser for DeepSeek V3 JSON-fenced tool calls. +/// +/// Example tool call content: +/// +/// ````text +/// <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather +/// ```json +/// {"location":"Tokyo"} +/// ```<|tool▁call▁end|><|tool▁calls▁end|> +/// ```` +/// +/// Arguments are already OpenAI-style JSON text inside the markdown fence, so +/// they are streamed as raw argument deltas without schema conversion or JSON +/// normalization. +pub struct DeepSeekV3ToolParser(DeepSeekJsonToolParser); + +impl DeepSeekV3ToolParser { + /// Create a DeepSeek V3 tool parser. + fn new(_tools: &[Tool]) -> Self { + Self(DeepSeekJsonToolParser::new(DeepSeekJsonFormat::V3)) + } +} + +impl ToolParser for DeepSeekV3ToolParser { + /// Create a boxed DeepSeek V3 tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Push one decoded text chunk through the DeepSeek V3 parser. + fn push(&mut self, chunk: &str) -> Result { + self.0.push(chunk) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + self.0.finish() + } +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use thiserror_ext::AsReport; + + use super::DeepSeekV3ToolParser; + use crate::deepseek_json::{ + TOOL_CALL_SEPARATOR, TOOL_CALL_START, TOOL_CALLS_END, TOOL_CALLS_START, V3_ARGUMENT_END, + V3_JSON_START, + }; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + use crate::{ToolParseResult, ToolParser}; + + fn v3_tool_call(function_name: &str, arguments: &str) -> String { + format!( + "{TOOL_CALL_START}function{TOOL_CALL_SEPARATOR}{function_name}{V3_JSON_START}{arguments}{V3_ARGUMENT_END}" + ) + } + + fn tool_section(tool_calls: &[String]) -> String { + format!("{TOOL_CALLS_START}{}{TOOL_CALLS_END}", tool_calls.join("")) + } + + #[test] + fn deepseek_v3_parse_complete_without_tool_call_keeps_text() { + let mut parser = DeepSeekV3ToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn deepseek_v3_parse_complete_extracts_raw_json_arguments() { + let mut parser = DeepSeekV3ToolParser::new(&test_tools()); + let arguments = r#"{ "location": "Tokyo", "days": "3" }"#; + let result = parser + .parse_complete(&format!( + "Let me check.\n{} trailing text", + tool_section(&[v3_tool_call("get_weather", arguments)]) + )) + .unwrap(); + + assert_eq!(result.normal_text, "Let me check.\n"); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn deepseek_v3_does_not_validate_or_normalize_arguments() { + let mut parser = DeepSeekV3ToolParser::new(&test_tools()); + let arguments = r#"{"location":"Tokyo",}"#; + let result = parser + .parse_complete(&tool_section(&[v3_tool_call("get_weather", arguments)])) + .unwrap(); + + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn deepseek_v3_streaming_emits_argument_deltas() { + let mut parser = DeepSeekV3ToolParser::new(&test_tools()); + let chunks = [ + TOOL_CALLS_START, + TOOL_CALL_START, + "function", + TOOL_CALL_SEPARATOR, + "get_weather", + V3_JSON_START, + "{\"location\":", + "\"Beijing\"", + "}", + V3_ARGUMENT_END, + TOOL_CALLS_END, + ]; + + let mut result = ToolParseResult::default(); + let mut observed_arguments = Vec::new(); + for chunk in chunks { + let next = parser.push(chunk).unwrap(); + observed_arguments.extend( + next.calls + .iter() + .filter(|call| call.name.is_none()) + .map(|call| call.arguments.clone()), + ); + result.append(next); + } + result.append(parser.finish().unwrap()); + + assert_eq!(observed_arguments, ["{\"location\":", "\"Beijing\"", "}"]); + assert_eq!( + result.coalesce_calls().calls[0].arguments, + r#"{"location":"Beijing"}"# + ); + } + + #[test] + fn deepseek_v3_streaming_handles_split_markers() { + let input = format!( + "hello {}", + tool_section(&[v3_tool_call("get_weather", r#"{"location":"Tokyo"}"#)]) + ); + let chunks = split_by_chars(&input, 5); + let mut parser = DeepSeekV3ToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.normal_text, "hello "); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, r#"{"location":"Tokyo"}"#); + } + + #[test] + fn deepseek_v3_keeps_fenced_end_marker_literal_inside_json_string() { + let mut parser = DeepSeekV3ToolParser::new(&test_tools()); + let arguments = format!("{{\"text\":\"literal {V3_ARGUMENT_END} inside\"}}"); + let input = tool_section(&[v3_tool_call("echo", &arguments)]); + + let result = parser.parse_complete(&input).unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn deepseek_v3_streaming_extracts_multiple_tool_calls() { + let input = tool_section(&[ + v3_tool_call("get_weather", r#"{"location":"Shanghai"}"#), + v3_tool_call("add", r#"{"x":1,"y":2}"#), + ]); + let chunks = split_by_chars(&input, 7); + let mut parser = DeepSeekV3ToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + expect![[r#" + ToolParseResult { + normal_text: "", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"location\":\"Shanghai\"}", + }, + ToolCallDelta { + tool_index: 1, + name: Some( + "add", + ), + arguments: "{\"x\":1,\"y\":2}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } + + #[test] + fn deepseek_v3_finish_fails_incomplete_tool_call() { + let mut parser = DeepSeekV3ToolParser::new(&test_tools()); + parser + .push(&format!( + "{TOOL_CALLS_START}{TOOL_CALL_START}function{TOOL_CALL_SEPARATOR}get_weather{V3_JSON_START}{{\"location\"" + )) + .unwrap(); + + let error = parser.finish().unwrap_err(); + + expect!["tool parser parsing failed: incomplete DeepSeek V3 tool call"] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn deepseek_v3_malformed_type_fails_fast() { + let mut parser = DeepSeekV3ToolParser::new(&test_tools()); + let input = format!( + "{TOOL_CALLS_START}{TOOL_CALL_START}tool{TOOL_CALL_SEPARATOR}get_weather{V3_JSON_START}{{}}" + ); + + let error = parser.push(&input).unwrap_err(); + + expect!["tool parser parsing failed: "].assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/tool-parser/src/deepseek_json/deepseek_v31.rs b/rust/src/tool-parser/src/deepseek_json/deepseek_v31.rs new file mode 100644 index 00000000000..bde6bc8fe8a --- /dev/null +++ b/rust/src/tool-parser/src/deepseek_json/deepseek_v31.rs @@ -0,0 +1,239 @@ +use super::{DeepSeekJsonFormat, DeepSeekJsonToolParser}; +use crate::{Result, Tool, ToolParseResult, ToolParser}; + +/// Tool parser for DeepSeek V3.1 raw JSON tool calls. +/// +/// Example tool call content: +/// +/// ```text +/// <|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location":"Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|> +/// ``` +/// +/// Arguments are already OpenAI-style JSON text, so they are streamed as raw +/// argument deltas without schema conversion or JSON normalization. +pub struct DeepSeekV31ToolParser(DeepSeekJsonToolParser); + +impl DeepSeekV31ToolParser { + /// Create a DeepSeek V3.1 tool parser. + fn new(_tools: &[Tool]) -> Self { + Self(DeepSeekJsonToolParser::new(DeepSeekJsonFormat::V31)) + } +} + +impl ToolParser for DeepSeekV31ToolParser { + /// Create a boxed DeepSeek V3.1 tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Push one decoded text chunk through the DeepSeek V3.1 parser. + fn push(&mut self, chunk: &str) -> Result { + self.0.push(chunk) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + self.0.finish() + } +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use thiserror_ext::AsReport; + + use super::DeepSeekV31ToolParser; + use crate::deepseek_json::{ + TOOL_CALL_END, TOOL_CALL_SEPARATOR, TOOL_CALL_START, TOOL_CALLS_END, TOOL_CALLS_START, + }; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + use crate::{ToolParseResult, ToolParser}; + + fn v31_tool_call(function_name: &str, arguments: &str) -> String { + format!("{TOOL_CALL_START}{function_name}{TOOL_CALL_SEPARATOR}{arguments}{TOOL_CALL_END}") + } + + fn tool_section(tool_calls: &[String]) -> String { + format!("{TOOL_CALLS_START}{}{TOOL_CALLS_END}", tool_calls.join("")) + } + + #[test] + fn deepseek_v31_parse_complete_without_tool_call_keeps_text() { + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn deepseek_v31_parse_complete_extracts_raw_json_arguments() { + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + let arguments = r#"{ "location": "Tokyo", "days": "3" }"#; + let result = parser + .parse_complete(&format!( + "Let me check.{} trailing text", + tool_section(&[v31_tool_call("get_weather", arguments)]) + )) + .unwrap(); + + assert_eq!(result.normal_text, "Let me check."); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn deepseek_v31_does_not_validate_or_normalize_arguments() { + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + let arguments = r#"{"location":"Tokyo",}"#; + let result = parser + .parse_complete(&tool_section(&[v31_tool_call("get_weather", arguments)])) + .unwrap(); + + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn deepseek_v31_streaming_emits_argument_deltas() { + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + let chunks = [ + TOOL_CALLS_START, + TOOL_CALL_START, + "get_weather", + TOOL_CALL_SEPARATOR, + "{\"location\":", + "\"Beijing\"", + "}", + TOOL_CALL_END, + TOOL_CALLS_END, + ]; + + let mut result = ToolParseResult::default(); + let mut observed_arguments = Vec::new(); + for chunk in chunks { + let next = parser.push(chunk).unwrap(); + observed_arguments.extend( + next.calls + .iter() + .filter(|call| call.name.is_none()) + .map(|call| call.arguments.clone()), + ); + result.append(next); + } + result.append(parser.finish().unwrap()); + + assert_eq!(observed_arguments, ["{\"location\":", "\"Beijing\"", "}"]); + assert_eq!( + result.coalesce_calls().calls[0].arguments, + r#"{"location":"Beijing"}"# + ); + } + + #[test] + fn deepseek_v31_streaming_handles_split_markers() { + let input = format!( + "hello {}", + tool_section(&[v31_tool_call("get_weather", r#"{"location":"Tokyo"}"#)]) + ); + let chunks = split_by_chars(&input, 5); + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.normal_text, "hello "); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, r#"{"location":"Tokyo"}"#); + } + + #[test] + fn deepseek_v31_keeps_end_marker_literal_inside_json_string() { + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + let arguments = format!(r#"{{"text":"literal {TOOL_CALL_END} inside"}}"#); + let input = tool_section(&[v31_tool_call("echo", &arguments)]); + + let result = parser.parse_complete(&input).unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn deepseek_v31_streaming_extracts_multiple_tool_calls() { + let input = tool_section(&[ + v31_tool_call("get_weather", r#"{"location":"Shanghai"}"#), + v31_tool_call("add", r#"{"x":1,"y":2}"#), + ]); + let chunks = split_by_chars(&input, 7); + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + expect![[r#" + ToolParseResult { + normal_text: "", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"location\":\"Shanghai\"}", + }, + ToolCallDelta { + tool_index: 1, + name: Some( + "add", + ), + arguments: "{\"x\":1,\"y\":2}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } + + #[test] + fn deepseek_v31_streaming_drops_eos_after_complete_tool_calls() { + let input = format!( + "{}<|end▁of▁sentence|>", + tool_section(&[v31_tool_call("get_weather", r#"{"location":"Tokyo"}"#)]) + ); + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &[&input]); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, r#"{"location":"Tokyo"}"#); + } + + #[test] + fn deepseek_v31_finish_fails_incomplete_tool_call() { + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + parser + .push(&format!( + "{TOOL_CALLS_START}{TOOL_CALL_START}get_weather{TOOL_CALL_SEPARATOR}{{\"location\"" + )) + .unwrap(); + + let error = parser.finish().unwrap_err(); + + expect!["tool parser parsing failed: incomplete DeepSeek V3.1 tool call"] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn deepseek_v31_malformed_empty_name_fails_fast() { + let mut parser = DeepSeekV31ToolParser::new(&test_tools()); + let input = format!("{TOOL_CALLS_START}{TOOL_CALL_START}{TOOL_CALL_SEPARATOR}{{}}"); + + let error = parser.push(&input).unwrap_err(); + + expect!["tool parser parsing failed: "].assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/tool-parser/src/deepseek_json/mod.rs b/rust/src/tool-parser/src/deepseek_json/mod.rs new file mode 100644 index 00000000000..a24d47ae468 --- /dev/null +++ b/rust/src/tool-parser/src/deepseek_json/mod.rs @@ -0,0 +1,308 @@ +mod deepseek_v3; +mod deepseek_v31; + +pub use deepseek_v3::DeepSeekV3ToolParser; +pub use deepseek_v31::DeepSeekV31ToolParser; +use winnow::ascii::multispace0 as ws0; +use winnow::combinator::{alt, seq}; +use winnow::prelude::*; +use winnow::stream::Partial; +use winnow::token::{literal, rest, take_until}; + +use super::utils::{JsonObjectScanState, parse_buffered_event, safe_text_len, take_json_object}; +use super::{Result, ToolCallDelta, ToolParseResult}; + +pub(super) const TOOL_CALLS_START: &str = "<|tool▁calls▁begin|>"; +pub(super) const TOOL_CALLS_END: &str = "<|tool▁calls▁end|>"; +pub(super) const TOOL_CALL_START: &str = "<|tool▁call▁begin|>"; +pub(super) const TOOL_CALL_END: &str = "<|tool▁call▁end|>"; +pub(super) const TOOL_CALL_SEPARATOR: &str = "<|tool▁sep|>"; +pub(super) const V3_JSON_START: &str = "\n```json\n"; +pub(super) const V3_ARGUMENT_END: &str = "\n```<|tool▁call▁end|>"; + +type DeepSeekJsonInput<'i> = Partial<&'i str>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum DeepSeekJsonFormat { + V3, + V31, +} + +impl DeepSeekJsonFormat { + /// Return the parser name used in diagnostics. + const fn parser_name(self) -> &'static str { + match self { + Self::V3 => "DeepSeek V3", + Self::V31 => "DeepSeek V3.1", + } + } + + /// Return the marker that closes the raw JSON arguments payload. + const fn argument_end_marker(self) -> &'static str { + match self { + Self::V3 => V3_ARGUMENT_END, + Self::V31 => TOOL_CALL_END, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum DeepSeekJsonMode { + Text, + ToolBlock, + Header, + Arguments { json_scan: JsonObjectScanState }, + Done, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum DeepSeekJsonEvent { + Text { len: usize }, + ToolCallsStart, + ToolCallStart, + ToolCallHeader { function_name: String }, + Arguments { len: usize }, + ToolCallEnd, + ToolCallsEnd, + IgnoredRest, +} + +/// Tool parser core for DeepSeek JSON-argument tool calls. +struct DeepSeekJsonToolParser { + buffer: String, + mode: DeepSeekJsonMode, + active_tool_index: Option, + emitted_tool_count: usize, + format: DeepSeekJsonFormat, +} + +impl DeepSeekJsonToolParser { + /// Create a parser for one DeepSeek JSON-argument format. + fn new(format: DeepSeekJsonFormat) -> Self { + Self { + buffer: String::new(), + mode: DeepSeekJsonMode::Text, + active_tool_index: None, + emitted_tool_count: 0, + format, + } + } + + /// Apply one parsed DeepSeek JSON event to parser state and output. + fn apply_event( + &mut self, + event: DeepSeekJsonEvent, + result: &mut ToolParseResult, + ) -> Result<()> { + match event { + DeepSeekJsonEvent::Text { len: consumed_len } => { + result.normal_text.push_str(&self.buffer[..consumed_len]); + } + DeepSeekJsonEvent::ToolCallsStart => self.mode = DeepSeekJsonMode::ToolBlock, + DeepSeekJsonEvent::ToolCallStart => self.mode = DeepSeekJsonMode::Header, + DeepSeekJsonEvent::ToolCallHeader { function_name } => { + let tool_index = self.emitted_tool_count; + self.emitted_tool_count += 1; + self.active_tool_index = Some(tool_index); + self.mode = DeepSeekJsonMode::Arguments { + json_scan: JsonObjectScanState::default(), + }; + result.calls.push(ToolCallDelta { + tool_index, + name: Some(function_name), + arguments: String::new(), + }); + } + DeepSeekJsonEvent::Arguments { len: consumed_len } => { + let Some(tool_index) = self.active_tool_index else { + return Err(parsing_failed!( + "{} arguments without an active tool call", + self.format.parser_name() + )); + }; + result.calls.push(ToolCallDelta { + tool_index, + name: None, + arguments: self.buffer[..consumed_len].to_string(), + }); + } + DeepSeekJsonEvent::ToolCallEnd => { + self.active_tool_index = None; + self.mode = DeepSeekJsonMode::ToolBlock; + } + DeepSeekJsonEvent::ToolCallsEnd => { + self.active_tool_index = None; + self.mode = DeepSeekJsonMode::Done; + } + DeepSeekJsonEvent::IgnoredRest => {} + } + Ok(()) + } + + /// Push one decoded text chunk through the DeepSeek JSON parser. + fn push(&mut self, chunk: &str) -> Result { + self.buffer.push_str(chunk); + let mut result = ToolParseResult::default(); + + while let Some((event, consumed_len)) = parse_buffered_event(&self.buffer, |input| { + parse_next_deepseek_json_event(input, &mut self.mode, self.format) + })? { + self.apply_event(event, &mut result)?; + self.buffer.drain(..consumed_len); + } + + Ok(result) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + let mut result = ToolParseResult::default(); + match &self.mode { + DeepSeekJsonMode::Text => result.normal_text.push_str(&self.buffer), + DeepSeekJsonMode::ToolBlock | DeepSeekJsonMode::Done => {} + DeepSeekJsonMode::Header | DeepSeekJsonMode::Arguments { .. } => { + return Err(parsing_failed!( + "incomplete {} tool call", + self.format.parser_name() + )); + } + } + self.reset(); + Ok(result) + } + + /// Reset all streaming state. + fn reset(&mut self) { + self.buffer.clear(); + self.mode = DeepSeekJsonMode::Text; + self.active_tool_index = None; + self.emitted_tool_count = 0; + } +} + +/// Parse a DeepSeek JSON event for the current parser mode. +fn parse_next_deepseek_json_event( + input: &mut DeepSeekJsonInput<'_>, + mode: &mut DeepSeekJsonMode, + format: DeepSeekJsonFormat, +) -> ModalResult { + match mode { + DeepSeekJsonMode::Text => parse_text_event(input), + DeepSeekJsonMode::ToolBlock => parse_tool_block_event(input), + DeepSeekJsonMode::Header => tool_call_header_event(input, format), + DeepSeekJsonMode::Arguments { json_scan } => { + parse_arguments_event(input, format, json_scan) + } + DeepSeekJsonMode::Done => ignored_rest_event(input), + } +} + +/// Parse a text-mode DeepSeek JSON event. +fn parse_text_event(input: &mut DeepSeekJsonInput<'_>) -> ModalResult { + alt((tool_calls_start_event, safe_text_event)).parse_next(input) +} + +/// Parse one event inside the DeepSeek tool-calls section. +fn parse_tool_block_event(input: &mut DeepSeekJsonInput<'_>) -> ModalResult { + ws0.void().parse_next(input)?; + alt((tool_calls_end_event, tool_call_start_event)).parse_next(input) +} + +/// Parse one event inside a DeepSeek tool-call arguments payload. +fn parse_arguments_event( + input: &mut DeepSeekJsonInput<'_>, + format: DeepSeekJsonFormat, + json_scan: &mut JsonObjectScanState, +) -> ModalResult { + if json_scan.complete() { + tool_call_end_event(input, format) + } else { + argument_delta_event(input, json_scan) + } +} + +/// Parse a DeepSeek tool-calls start marker. +fn tool_calls_start_event(input: &mut DeepSeekJsonInput<'_>) -> ModalResult { + literal(TOOL_CALLS_START) + .value(DeepSeekJsonEvent::ToolCallsStart) + .parse_next(input) +} + +/// Parse a DeepSeek tool-calls end marker. +fn tool_calls_end_event(input: &mut DeepSeekJsonInput<'_>) -> ModalResult { + literal(TOOL_CALLS_END).value(DeepSeekJsonEvent::ToolCallsEnd).parse_next(input) +} + +/// Parse a DeepSeek tool-call start marker. +fn tool_call_start_event(input: &mut DeepSeekJsonInput<'_>) -> ModalResult { + literal(TOOL_CALL_START) + .value(DeepSeekJsonEvent::ToolCallStart) + .parse_next(input) +} + +/// Parse a DeepSeek tool-call end marker. +fn tool_call_end_event( + input: &mut DeepSeekJsonInput<'_>, + format: DeepSeekJsonFormat, +) -> ModalResult { + literal(format.argument_end_marker()) + .value(DeepSeekJsonEvent::ToolCallEnd) + .parse_next(input) +} + +/// Parse a DeepSeek tool-call header before the JSON arguments payload. +fn tool_call_header_event( + input: &mut DeepSeekJsonInput<'_>, + format: DeepSeekJsonFormat, +) -> ModalResult { + match format { + DeepSeekJsonFormat::V3 => v3_tool_call_header_event(input), + DeepSeekJsonFormat::V31 => v31_tool_call_header_event(input), + } +} + +/// Parse a DeepSeek V3 tool-call header. +fn v3_tool_call_header_event(input: &mut DeepSeekJsonInput<'_>) -> ModalResult { + let name = seq!( + _: literal("function"), + _: literal(TOOL_CALL_SEPARATOR), + take_until(1.., V3_JSON_START), + _: literal(V3_JSON_START), + ) + .parse_next(input)?; + + Ok(DeepSeekJsonEvent::ToolCallHeader { + function_name: name.0.trim().to_string(), + }) +} + +/// Parse a DeepSeek V3.1 tool-call header. +fn v31_tool_call_header_event(input: &mut DeepSeekJsonInput<'_>) -> ModalResult { + let (name, _) = ( + take_until(1.., TOOL_CALL_SEPARATOR), + literal(TOOL_CALL_SEPARATOR), + ) + .parse_next(input)?; + + Ok(DeepSeekJsonEvent::ToolCallHeader { + function_name: name.trim().to_string(), + }) +} + +/// Parse a DeepSeek raw JSON arguments delta. +fn argument_delta_event( + input: &mut DeepSeekJsonInput<'_>, + json_scan: &mut JsonObjectScanState, +) -> ModalResult { + take_json_object(input, json_scan).map(|len| DeepSeekJsonEvent::Arguments { len }) +} + +/// Parse a safe text run before the next DeepSeek tool-calls section. +fn safe_text_event(input: &mut DeepSeekJsonInput<'_>) -> ModalResult { + safe_text_len(input, TOOL_CALLS_START).map(|len| DeepSeekJsonEvent::Text { len }) +} + +/// Parse ignored rest after the DeepSeek tool-calls section ends. +fn ignored_rest_event(input: &mut DeepSeekJsonInput<'_>) -> ModalResult { + rest.value(DeepSeekJsonEvent::IgnoredRest).parse_next(input) +} diff --git a/rust/src/tool-parser/src/error.rs b/rust/src/tool-parser/src/error.rs new file mode 100644 index 00000000000..0ac4a02c658 --- /dev/null +++ b/rust/src/tool-parser/src/error.rs @@ -0,0 +1,13 @@ +use thiserror::Error; +use thiserror_ext::Macro; + +/// Result alias for tool parser operations. +pub type Result = std::result::Result; + +/// Errors produced while creating or running tool parsers. +#[derive(Debug, Error, Macro)] +#[thiserror_ext(macro(path = "crate::error"))] +pub enum ToolParserError { + #[error("tool parser parsing failed: {message}")] + ParsingFailed { message: String }, +} diff --git a/rust/src/tool-parser/src/gemma4.rs b/rust/src/tool-parser/src/gemma4.rs new file mode 100644 index 00000000000..2eb3608a5bb --- /dev/null +++ b/rust/src/tool-parser/src/gemma4.rs @@ -0,0 +1,653 @@ +use serde_json::{Map, Number, Value}; +use winnow::ascii::multispace0 as ws0; +use winnow::combinator::{alt, delimited, opt, separated, seq, terminated}; +use winnow::error::{ContextError, ErrMode, ModalResult}; +use winnow::prelude::*; +use winnow::stream::Partial; +use winnow::token::{literal, take_till, take_until}; + +use super::utils::{parse_buffered_event, safe_text_len}; +use super::{Result, ToolCallDelta, ToolParseResult, ToolParser}; +use crate::Tool; + +const TOOL_CALL_START: &str = "<|tool_call>"; +const TOOL_CALL_END: &str = ""; +const STRING_DELIM: &str = "<|\"|>"; +const CALL_PREFIX: &str = "call:"; + +type Gemma4Input<'i> = Partial<&'i str>; + +#[derive(Debug, Clone, PartialEq)] +enum Gemma4Event { + Text { + len: usize, + }, + ToolCall { + name: String, + args: Map, + }, +} + +/// Tool parser for Google Gemma4 models. +/// +/// Original Python implementation: +/// +/// +/// Handles the Gemma4 function call format: +/// +/// `<|tool_call>call:func_name{key:<|"|>value<|"|>}` +/// +/// Arguments are emitted only after a full Gemma4 tool call is parsed. +pub struct Gemma4ToolParser { + buffer: String, + emitted_tool_count: usize, +} + +impl Gemma4ToolParser { + fn new(_tools: &[Tool]) -> Self { + Self { + buffer: String::new(), + emitted_tool_count: 0, + } + } + + fn apply_event(&mut self, event: Gemma4Event, result: &mut ToolParseResult) -> Result<()> { + match event { + Gemma4Event::Text { len: consumed_len } => { + result.normal_text.push_str(&self.buffer[..consumed_len]); + } + Gemma4Event::ToolCall { name, args } => { + let arguments = serde_json::to_string(&args) + .map_err(|error| parsing_failed!("failed to serialize arguments: {}", error))?; + + result.calls.push(ToolCallDelta { + tool_index: self.emitted_tool_count, + name: Some(name), + arguments, + }); + self.emitted_tool_count += 1; + } + } + Ok(()) + } + + fn reset(&mut self) { + self.buffer.clear(); + self.emitted_tool_count = 0; + } +} + +impl ToolParser for Gemma4ToolParser { + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + fn preserve_special_tokens(&self) -> bool { + true + } + + fn push(&mut self, chunk: &str) -> Result { + self.buffer.push_str(chunk); + let mut result = ToolParseResult::default(); + + while let Some((event, consumed_len)) = + parse_buffered_event(&self.buffer, parse_next_gemma4_event)? + { + self.apply_event(event, &mut result)?; + self.buffer.drain(..consumed_len); + } + + Ok(result) + } + + fn finish(&mut self) -> Result { + let mut result = ToolParseResult::default(); + + if !self.buffer.is_empty() { + if self.buffer.starts_with(TOOL_CALL_START) { + self.reset(); + return Err(parsing_failed!("incomplete Gemma4 tool call")); + } + result.normal_text.push_str(&self.buffer); + } + + self.reset(); + Ok(result) + } +} + +/// Parse one Gemma4 event from buffered streaming input. +fn parse_next_gemma4_event(input: &mut Gemma4Input<'_>) -> ModalResult { + alt((tool_call_event, safe_text_event)).parse_next(input) +} + +/// Parse a complete Gemma4 tool call. +// TODO: incremental parsing arguments to reduce scanning from O(n^2) to O(n). +fn tool_call_event(input: &mut Gemma4Input<'_>) -> ModalResult { + let (name, args) = seq!( + _: literal(TOOL_CALL_START), + _: literal(CALL_PREFIX), + gemma4_tool_name, + _: literal("{"), + gemma4_args, + _: literal("}"), + _: literal(TOOL_CALL_END), + ) + .parse_next(input)?; + + Ok(Gemma4Event::ToolCall { name, args }) +} + +/// Parse a Gemma4 tool name. +fn gemma4_tool_name(input: &mut Gemma4Input<'_>) -> ModalResult { + let name = take_until(1.., "{").parse_next(input)?.trim(); + if name.is_empty() { + return Err(ErrMode::Cut(ContextError::new())); + } + Ok(name.to_string()) +} + +/// Parse a safe text run before the next Gemma4 marker. +fn safe_text_event(input: &mut Gemma4Input<'_>) -> ModalResult { + safe_text_len(input, TOOL_CALL_START).map(|len| Gemma4Event::Text { len }) +} + +/// Parse Gemma4's custom key-value argument object content. +fn gemma4_args(input: &mut Gemma4Input<'_>) -> ModalResult> { + let pairs: Vec<(String, Value)> = delimited( + ws0, + terminated( + separated(0.., gemma4_pair, comma_separator), + opt(comma_separator), + ), + ws0, + ) + .parse_next(input)?; + Ok(pairs.into_iter().collect()) +} + +/// Parse a Gemma4 key-value pair. +fn gemma4_pair(input: &mut Gemma4Input<'_>) -> ModalResult<(String, Value)> { + let (key, value) = seq!( + _: ws0, + gemma4_key, + _: ws0, + _: literal(":"), + _: ws0, + gemma4_value, + ) + .parse_next(input)?; + Ok((key, value)) +} + +/// Parse a Gemma4 bare key. +fn gemma4_key(input: &mut Gemma4Input<'_>) -> ModalResult { + let key = take_till(1.., |char: char| char == ':').parse_next(input)?.trim(); + if key.is_empty() { + return Err(ErrMode::Cut(ContextError::new())); + } + Ok(key.to_string()) +} + +/// Parse a Gemma4 value. +fn gemma4_value(input: &mut Gemma4Input<'_>) -> ModalResult { + alt(( + gemma4_string.map(|value: &str| Value::String(value.to_string())), + gemma4_object.map(Value::Object), + gemma4_array_value.map(Value::Array), + gemma4_bare_value, + )) + .parse_next(input) +} + +/// Parse a Gemma4 string delimited by `<|"|>`. +fn gemma4_string<'i>(input: &mut Gemma4Input<'i>) -> ModalResult<&'i str> { + delimited( + literal(STRING_DELIM), + take_until(0.., STRING_DELIM), + literal(STRING_DELIM), + ) + .parse_next(input) +} + +/// Parse a nested Gemma4 object. +fn gemma4_object(input: &mut Gemma4Input<'_>) -> ModalResult> { + delimited(literal("{"), gemma4_args, literal("}")).parse_next(input) +} + +/// Parse a Gemma4 array value. +fn gemma4_array_value(input: &mut Gemma4Input<'_>) -> ModalResult> { + delimited(literal("["), gemma4_array_content, literal("]")).parse_next(input) +} + +/// Parse Gemma4 array content. +fn gemma4_array_content(input: &mut Gemma4Input<'_>) -> ModalResult> { + delimited( + ws0, + terminated( + separated(0.., gemma4_value, comma_separator), + opt(comma_separator), + ), + ws0, + ) + .parse_next(input) +} + +/// Parse a Gemma4 bare scalar. +fn gemma4_bare_value(input: &mut Gemma4Input<'_>) -> ModalResult { + take_till(1.., |char: char| matches!(char, ',' | '}' | ']')) + .map(parse_gemma4_scalar) + .parse_next(input) +} + +/// Parse a Gemma4 comma separator. +fn comma_separator(input: &mut Gemma4Input<'_>) -> ModalResult<()> { + delimited(ws0, literal(","), ws0).void().parse_next(input) +} + +fn parse_gemma4_scalar(value: &str) -> Value { + let value = value.trim(); + if value.is_empty() { + return Value::String(String::new()); + } + if value == "true" { + return Value::Bool(true); + } + if value == "false" { + return Value::Bool(false); + } + if matches!(value, "null" | "none" | "nil" | "NULL" | "None" | "NIL") { + return Value::Null; + } + if value.contains('.') { + if let Ok(parsed) = value.parse::() + && let Some(number) = Number::from_f64(parsed) + { + return Value::Number(number); + } + } else if let Ok(parsed) = value.parse::() { + return Value::Number(Number::from(parsed)); + } + + Value::String(value.to_string()) +} + +#[cfg(test)] +mod tests { + use serde_json::{Value, json}; + use thiserror_ext::AsReport; + use winnow::combinator::{eof, terminated}; + use winnow::error::ErrMode; + use winnow::prelude::*; + use winnow::stream::Partial; + + use super::{ + Gemma4ToolParser, ToolCallDelta, ToolParseResult, ToolParser, gemma4_args, + gemma4_array_content, + }; + use crate::Tool; + + fn parse_gemma4_args(args: &str) -> super::Result> { + let mut input = Partial::new(args); + let _ = input.complete(); + match terminated(gemma4_args, eof).parse_next(&mut input) { + Ok(value) => Ok(value), + Err(ErrMode::Incomplete(_)) => Err(parsing_failed!("incomplete Gemma4 arguments")), + Err(ErrMode::Backtrack(error) | ErrMode::Cut(error)) => { + Err(parsing_failed!("{}", error)) + } + } + } + + fn parse_gemma4_array(array: &str) -> super::Result> { + let mut input = Partial::new(array); + let _ = input.complete(); + match terminated(gemma4_array_content, eof).parse_next(&mut input) { + Ok(value) => Ok(value), + Err(ErrMode::Incomplete(_)) => Err(parsing_failed!("incomplete Gemma4 array")), + Err(ErrMode::Backtrack(error) | ErrMode::Cut(error)) => { + Err(parsing_failed!("{}", error)) + } + } + } + + fn test_tools() -> Vec { + vec![ + Tool { + name: "get_weather".to_string(), + description: None, + parameters: json!({ "type": "object" }), + strict: None, + }, + Tool { + name: "get_time".to_string(), + description: None, + parameters: json!({ "type": "object" }), + strict: None, + }, + Tool { + name: "write_file".to_string(), + description: None, + parameters: json!({ "type": "object" }), + strict: None, + }, + Tool { + name: "Edit".to_string(), + description: None, + parameters: json!({ "type": "object" }), + strict: None, + }, + Tool { + name: "search".to_string(), + description: None, + parameters: json!({ "type": "object" }), + strict: None, + }, + Tool { + name: "set".to_string(), + description: None, + parameters: json!({ "type": "object" }), + strict: None, + }, + Tool { + name: "get_status".to_string(), + description: None, + parameters: json!({ "type": "object" }), + strict: None, + }, + Tool { + name: "todowrite".to_string(), + description: None, + parameters: json!({ "type": "object" }), + strict: None, + }, + ] + } + + fn collect_stream(chunks: &[&str]) -> ToolParseResult { + let mut parser = Gemma4ToolParser::new(&test_tools()); + let mut result = ToolParseResult::default(); + for chunk in chunks { + result.append(parser.push(chunk).unwrap()); + } + result.append(parser.finish().unwrap()); + result.coalesce_calls() + } + + fn first_call(result: &ToolParseResult) -> &ToolCallDelta { + result.calls.first().expect("expected one tool call") + } + + #[test] + fn gemma4_parse_args_handles_scalars_and_nested_values() { + let parsed = parse_gemma4_args( + "name:<|\"|>test<|\"|>,count:42,active:true,score:114.514,nested:{inner:<|\"|>value<|\"|>},items:[<|\"|>a<|\"|>,<|\"|>b<|\"|>]", + ) + .unwrap(); + + assert_eq!( + Value::Object(parsed), + json!({ + "name": "test", + "count": 42, + "active": true, + "score": 114.514, + "nested": { "inner": "value" }, + "items": ["a", "b"], + }) + ); + } + + #[test] + fn gemma4_parse_args_handles_empty_arguments() { + let parsed = parse_gemma4_args("").unwrap(); + assert_eq!(Value::Object(parsed), json!({})); + } + + #[test] + fn gemma4_parse_array_handles_bare_values() { + let parsed = parse_gemma4_array("42,true,114.514").unwrap(); + assert_eq!(Value::Array(parsed), json!([42, true, 114.514])); + } + + #[test] + fn gemma4_parse_complete_extracts_single_tool_call() { + let mut parser = Gemma4ToolParser::new(&test_tools()); + let result = parser + .parse_complete("<|tool_call>call:get_weather{location:<|\"|>London<|\"|>}") + .unwrap(); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(first_call(&result).name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ "location": "London" }) + ); + } + + #[test] + fn gemma4_parse_complete_rejects_incomplete_tool_call() { + let mut parser = Gemma4ToolParser::new(&test_tools()); + let error = parser + .parse_complete("<|tool_call>call:get_weather{location:<|\"|>London") + .unwrap_err(); + + assert!(error.to_report_string().contains("incomplete Gemma4 tool call")); + } + + #[test] + fn gemma4_streaming_basic_single_tool_call() { + let result = collect_stream(&[ + "<|tool_call>", + "call:get_weather{", + "location:<|\"|>Paris", + ", France", + "<|\"|>}", + "", + ]); + + assert!(result.normal_text.is_empty()); + assert_eq!(first_call(&result).name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ "location": "Paris, France" }) + ); + } + + #[test] + fn gemma4_streaming_text_before_and_after_tool_call() { + let result = collect_stream(&[ + "Let me check ", + "the weather. ", + "<|tool_call>", + "call:get_weather{", + "location:<|\"|>London<|\"|>}", + "<", + "div>", + ]); + + assert_eq!(result.normal_text, "Let me check the weather.
"); + assert_eq!(first_call(&result).name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ "location": "London" }) + ); + } + + #[test] + fn gemma4_streaming_waits_for_complete_tool_call() { + let mut parser = Gemma4ToolParser::new(&test_tools()); + let mut result = ToolParseResult::default(); + + for chunk in [ + "<|tool_call>", + "call:get_weather{", + "location:<|\"|>Paris<|\"|>}", + ] { + result.append(parser.push(chunk).unwrap()); + assert!(result.calls.is_empty()); + } + + result.append(parser.push("").unwrap()); + let result = result.coalesce_calls(); + + assert_eq!(first_call(&result).name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ "location": "Paris" }) + ); + } + + #[test] + fn gemma4_streaming_handles_boolean_split_across_chunks() { + let result = collect_stream(&[ + "<|tool_call>", + "call:search{input:{all:tru", + "e}}", + "", + ]); + + assert_eq!(first_call(&result).name.as_deref(), Some("search")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ "input": { "all": true } }) + ); + } + + #[test] + fn gemma4_streaming_handles_false_split_across_chunks() { + let result = collect_stream(&["<|tool_call>", "call:set{flag:fals", "e}", ""]); + + assert_eq!(first_call(&result).name.as_deref(), Some("set")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ "flag": false }) + ); + } + + #[test] + fn gemma4_streaming_handles_number_split_across_chunks() { + let result = collect_stream(&["<|tool_call>", "call:set{count:4", "2}", ""]); + + assert_eq!(first_call(&result).name.as_deref(), Some("set")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ "count": 42 }) + ); + } + + #[test] + fn gemma4_streaming_handles_split_string_delimiter() { + let result = collect_stream(&[ + "<|tool_call>", + "call:todowrite{", + "content:<|\"|>Buy milk<|", + "\"|>}", + "", + ]); + + assert_eq!(first_call(&result).name.as_deref(), Some("todowrite")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ "content": "Buy milk" }) + ); + assert!(!first_call(&result).arguments.contains("<|")); + } + + #[test] + fn gemma4_streaming_handles_end_marker_literal_inside_string() { + let result = collect_stream(&[ + "<|tool_call>", + "call:todowrite{", + "content:<|\"|>literal } inside", + "<|\"|>}", + "", + ]); + + assert_eq!(first_call(&result).name.as_deref(), Some("todowrite")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ "content": "literal } inside" }) + ); + } + + #[test] + fn gemma4_streaming_handles_html_argument_without_duplication() { + let result = collect_stream(&[ + "<|tool_call>", + "call:write_file{", + "path:<|\"|>index.html<|\"|>,", + "content:<|\"|>\n<", + "html lang=\"zh-CN\">\n<", + "head>\n <", + "meta charset=\"UTF-8\">\n <", + "meta name=\"viewport\" content=\"width=device-width\">\n", + "<|\"|>}", + "", + ]); + + assert_eq!(first_call(&result).name.as_deref(), Some("write_file")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ + "path": "index.html", + "content": "\n\n\n \n \n", + }) + ); + } + + #[test] + fn gemma4_streaming_trailing_bare_bool_is_not_duplicated() { + let result = collect_stream(&[ + "<|tool_call>", + "call:Edit{", + "file_path:<|\"|>src/env.py<|\"|>,", + "old_string:<|\"|>old_val<|\"|>,", + "new_string:<|\"|>new_val<|\"|>,", + "replace_all:", + "false}", + "", + ]); + + assert_eq!(first_call(&result).name.as_deref(), Some("Edit")); + assert_eq!( + serde_json::from_str::(&first_call(&result).arguments).unwrap(), + json!({ + "file_path": "src/env.py", + "old_string": "old_val", + "new_string": "new_val", + "replace_all": false, + }) + ); + assert_eq!( + first_call(&result).arguments.matches("replace_all").count(), + 1 + ); + } + + #[test] + fn gemma4_finish_flushes_partial_start_marker_as_text() { + let mut parser = Gemma4ToolParser::new(&test_tools()); + let mut result = parser.push("<").unwrap(); + result.append(parser.finish().unwrap()); + + assert_eq!(result.normal_text, "<"); + assert!(result.calls.is_empty()); + } + + #[test] + fn gemma4_finish_rejects_complete_args_without_end_marker() { + let mut parser = Gemma4ToolParser::new(&test_tools()); + for chunk in ["<|tool_call>", "call:get_status{}"] { + parser.push(chunk).unwrap(); + } + + let error = parser.finish().unwrap_err(); + + assert!(error.to_report_string().contains("incomplete Gemma4 tool call")); + } +} diff --git a/rust/src/tool-parser/src/glm_xml/glm45_moe.rs b/rust/src/tool-parser/src/glm_xml/glm45_moe.rs new file mode 100644 index 00000000000..b145671429c --- /dev/null +++ b/rust/src/tool-parser/src/glm_xml/glm45_moe.rs @@ -0,0 +1,43 @@ +use super::{GlmXmlToolParser, Separator}; +use crate::{Result, Tool, ToolParseResult, ToolParser}; + +/// Tool parser for GLM-4.5/4.6 MoE XML-style tool calls. +/// +/// Example tool call content: +/// +/// ```text +/// get_weather +/// city +/// Hangzhou +/// +/// ``` +/// +/// Arguments are emitted only after a full `tool_call` block is parsed. +pub struct Glm45MoeToolParser(GlmXmlToolParser); + +impl Glm45MoeToolParser { + /// Create a GLM-4.5/4.6 MoE tool parser. + pub(super) fn new(tools: &[Tool]) -> Self { + Self(GlmXmlToolParser::new(tools, Separator::Newline)) + } +} + +impl ToolParser for Glm45MoeToolParser { + /// Create a boxed GLM-4.5/4.6 MoE tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Push one decoded text chunk through the GLM MoE parser. + fn push(&mut self, chunk: &str) -> Result { + self.0.push(chunk) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + self.0.finish() + } +} diff --git a/rust/src/tool-parser/src/glm_xml/glm47_moe.rs b/rust/src/tool-parser/src/glm_xml/glm47_moe.rs new file mode 100644 index 00000000000..7e7538d7c38 --- /dev/null +++ b/rust/src/tool-parser/src/glm_xml/glm47_moe.rs @@ -0,0 +1,135 @@ +use super::{GlmXmlToolParser, Separator}; +use crate::{Result, Tool, ToolParseResult, ToolParser}; + +/// Tool parser for GLM-4.7 MoE XML-style tool calls. +/// +/// GLM-4.7 reuses the GLM-4.5 parser with a more flexible function-name +/// separator, so the name may be followed by whitespace, a newline, or the +/// first `` tag directly. +pub struct Glm47MoeToolParser(GlmXmlToolParser); + +impl Glm47MoeToolParser { + fn new(tools: &[Tool]) -> Self { + Self(GlmXmlToolParser::new(tools, Separator::Flexible)) + } +} + +impl ToolParser for Glm47MoeToolParser { + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + fn push(&mut self, chunk: &str) -> Result { + self.0.push(chunk) + } + + fn finish(&mut self) -> Result { + self.0.finish() + } +} + +#[cfg(test)] +mod tests { + use serde_json::{Value, json}; + + use super::{Glm47MoeToolParser, ToolParser}; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + + fn glm47_tool_call(function_name: &str, params: &[(&str, &str)]) -> String { + let params = params + .iter() + .map(|(name, value)| format!("{name}{value}")) + .collect::>() + .join(""); + format!("{function_name}{params}") + } + + #[test] + fn glm47_parse_complete_extracts_single_tool_call() { + let mut parser = Glm47MoeToolParser::new(&test_tools()); + let output = format!( + "Let me search for that.\n{}", + glm47_tool_call( + "get_weather", + &[("city", "Beijing"), ("date", "2024-12-25")] + ) + ); + + let result = parser.parse_complete(&output).unwrap(); + + assert_eq!(result.normal_text, "Let me search for that.\n"); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({"city": "Beijing", "date": "2024-12-25"}) + ); + } + + #[test] + fn glm47_streaming_extracts_multiple_tool_calls() { + let mut parser = Glm47MoeToolParser::new(&test_tools()); + let output = format!( + "{}{}", + glm47_tool_call("get_weather", &[("city", "Shanghai")]), + glm47_tool_call("add", &[("x", "1"), ("y", "2")]) + ); + + let chunks = split_by_chars(&output, 7); + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.normal_text, ""); + assert_eq!(result.calls.len(), 2); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[1].name.as_deref(), Some("add")); + assert_eq!( + serde_json::from_str::(&result.calls[1].arguments).unwrap(), + json!({"x": 1, "y": 2}) + ); + } + + #[test] + fn glm47_parse_complete_converts_schema_types() { + let mut parser = Glm47MoeToolParser::new(&test_tools()); + let result = parser + .parse_complete(&glm47_tool_call( + "convert", + &[ + ("whole", "42"), + ("flag", "true"), + ("payload", r#"{"nested":{"key":"value"}}"#), + ("items", "[1, 2, 3]"), + ("empty", ""), + ], + )) + .unwrap(); + + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "whole": 42, + "flag": true, + "payload": {"nested": {"key": "value"}}, + "items": [1, 2, 3], + "empty": "" + }) + ); + } + + #[test] + fn glm47_parse_complete_extracts_zero_argument_call() { + let mut parser = Glm47MoeToolParser::new(&test_tools()); + + let result = parser.parse_complete("add").unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("add")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({}) + ); + } +} diff --git a/rust/src/tool-parser/src/glm_xml/mod.rs b/rust/src/tool-parser/src/glm_xml/mod.rs new file mode 100644 index 00000000000..ccb05e8d85f --- /dev/null +++ b/rust/src/tool-parser/src/glm_xml/mod.rs @@ -0,0 +1,435 @@ +use winnow::ascii::multispace0 as ws0; +use winnow::combinator::{alt, eof, repeat, seq, terminated}; +use winnow::prelude::*; +use winnow::stream::Partial; +use winnow::token::{literal, rest, take_until, take_while}; + +use super::parameters::ToolSchemas; +use super::utils::{parse_buffered_event, safe_text_len, xml_unescape}; +use super::{Result, ToolCallDelta, ToolParseResult}; +use crate::Tool; + +mod glm45_moe; +mod glm47_moe; + +pub use glm45_moe::Glm45MoeToolParser; +pub use glm47_moe::Glm47MoeToolParser; + +const TOOL_CALL_START: &str = ""; +const TOOL_CALL_END: &str = ""; +const ARG_KEY_START: &str = ""; +const ARG_KEY_END: &str = ""; +const ARG_VALUE_START: &str = ""; +const ARG_VALUE_END: &str = ""; + +type GlmInput<'i> = Partial<&'i str>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum GlmMode { + Text, + ToolCall, + AfterToolCall, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Separator { + /// GLM-4.5/4.6 format: function name must end at a newline before + /// arguments. + Newline, + /// GLM-4.7 format: function name may end at whitespace or directly before + /// ``. + Flexible, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum GlmEvent { + Text { + len: usize, + }, + ToolCallStart, + ToolCall { + name: String, + raw_params: Vec<(String, String)>, + }, + IgnoredRest, +} + +/// Tool parser core for GLM XML-style tool calls. +struct GlmXmlToolParser { + buffer: String, + mode: GlmMode, + emitted_tool_count: usize, + tool_parameters: ToolSchemas, + separator: Separator, +} + +impl GlmXmlToolParser { + /// Create a GLM XML tool parser with a function-name separator. + fn new(tools: &[Tool], separator: Separator) -> Self { + Self { + buffer: String::new(), + mode: GlmMode::Text, + emitted_tool_count: 0, + tool_parameters: ToolSchemas::from_tools(tools), + separator, + } + } + + /// Apply one parsed GLM event to parser state and output. + fn apply_event(&mut self, event: GlmEvent, result: &mut ToolParseResult) -> Result<()> { + match event { + GlmEvent::Text { len: consumed_len } => { + result.normal_text.push_str(&self.buffer[..consumed_len]); + } + GlmEvent::ToolCallStart => self.mode = GlmMode::ToolCall, + GlmEvent::ToolCall { name, raw_params } => { + self.mode = GlmMode::AfterToolCall; + let arguments = self.tool_parameters.convert_params_with_schema(&name, raw_params); + let arguments = serde_json::to_string(&arguments) + .map_err(|error| parsing_failed!("failed to serialize arguments: {}", error))?; + + result.calls.push(ToolCallDelta { + tool_index: self.emitted_tool_count, + name: Some(name), + arguments, + }); + self.emitted_tool_count += 1; + } + GlmEvent::IgnoredRest => {} + } + Ok(()) + } + + /// Reset all streaming state. + fn reset(&mut self) { + self.buffer.clear(); + self.mode = GlmMode::Text; + self.emitted_tool_count = 0; + } + + /// Push one decoded text chunk through the GLM MoE parser. + fn push(&mut self, chunk: &str) -> Result { + self.buffer.push_str(chunk); + let mut result = ToolParseResult::default(); + + while let Some((event, consumed_len)) = parse_buffered_event(&self.buffer, |input| { + parse_next_glm_event(input, self.mode, self.separator) + })? { + self.apply_event(event, &mut result)?; + self.buffer.drain(..consumed_len); + } + + Ok(result) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + let mut result = ToolParseResult::default(); + if !self.buffer.is_empty() { + match self.mode { + GlmMode::Text => result.normal_text.push_str(&self.buffer), + GlmMode::ToolCall => return Err(parsing_failed!("incomplete GLM MoE tool call")), + GlmMode::AfterToolCall => {} + } + } + self.reset(); + Ok(result) + } +} + +/// Parse a GLM event for the current parser mode. +fn parse_next_glm_event( + input: &mut GlmInput<'_>, + mode: GlmMode, + separator: Separator, +) -> ModalResult { + match mode { + GlmMode::Text => parse_text_event(input), + GlmMode::ToolCall => tool_call_event(input, separator), + GlmMode::AfterToolCall => after_tool_call_event(input), + } +} + +/// Parse a text-mode GLM event. +fn parse_text_event(input: &mut GlmInput<'_>) -> ModalResult { + alt((tool_call_start_event, safe_text_event)).parse_next(input) +} + +/// Parse a GLM tool-call start marker. +fn tool_call_start_event(input: &mut GlmInput<'_>) -> ModalResult { + literal(TOOL_CALL_START).value(GlmEvent::ToolCallStart).parse_next(input) +} + +/// Parse a safe text run before the next GLM marker. +fn safe_text_event(input: &mut GlmInput<'_>) -> ModalResult { + safe_text_len(input, TOOL_CALL_START).map(|len| GlmEvent::Text { len }) +} + +/// Parse text after a completed GLM tool call. +fn after_tool_call_event(input: &mut GlmInput<'_>) -> ModalResult { + ws0.void().parse_next(input)?; + alt((tool_call_start_event, ignored_rest_event)).parse_next(input) +} + +/// Parse a trailing rest after GLM tool calls. +fn ignored_rest_event(input: &mut GlmInput<'_>) -> ModalResult { + rest.value(GlmEvent::IgnoredRest).parse_next(input) +} + +/// Parse a complete GLM tool call. +fn tool_call_event(input: &mut GlmInput<'_>, separator: Separator) -> ModalResult { + let (body,) = seq!( + take_until(0.., TOOL_CALL_END), + _: literal(TOOL_CALL_END), + ) + .parse_next(input)?; + + parse_tool_call_body(body, separator) +} + +/// Parse a GLM tool-call body. +fn parse_tool_call_body(body: &str, separator: Separator) -> ModalResult { + let mut input = body; + let (name, raw_params) = match separator { + Separator::Newline => seq!( + _: ws0, + parse_newline_separated_function_name, + parse_parameters, + _: ws0, + _: eof, + ) + .parse_next(&mut input)?, + Separator::Flexible => seq!( + _: ws0, + parse_flexible_function_name, + parse_parameters, + _: ws0, + _: eof, + ) + .parse_next(&mut input)?, + }; + + Ok(GlmEvent::ToolCall { + name: name.to_string(), + raw_params, + }) +} + +/// Parse a GLM-4.5 newline-separated function name. +fn parse_newline_separated_function_name<'i>(input: &mut &'i str) -> ModalResult<&'i str> { + terminated(take_until(1.., "\n"), "\n").map(str::trim).parse_next(input) +} + +/// Parse a GLM-4.7 whitespace-or-tag-separated function name. +fn parse_flexible_function_name<'i>(input: &mut &'i str) -> ModalResult<&'i str> { + terminated( + take_while(1.., |ch: char| !ch.is_whitespace() && ch != '<'), + ws0, + ) + .parse_next(input) +} + +/// Parse GLM argument key-value pairs. +fn parse_parameters(input: &mut &str) -> ModalResult> { + repeat(0.., terminated(parse_parameter, ws0)).parse_next(input) +} + +/// Parse a GLM argument key-value pair. +fn parse_parameter(input: &mut &str) -> ModalResult<(String, String)> { + let (key, value) = seq!( + _: literal(ARG_KEY_START), + take_until(1.., ARG_KEY_END), + _: literal(ARG_KEY_END), + _: ws0, + _: literal(ARG_VALUE_START), + take_until(0.., ARG_VALUE_END).map(str::trim).map(xml_unescape), + _: literal(ARG_VALUE_END), + ) + .parse_next(input)?; + + Ok((key.trim().to_string(), value.into_owned())) +} + +#[cfg(test)] +mod tests { + use serde_json::{Value, json}; + use thiserror_ext::AsReport; + + use super::Glm45MoeToolParser; + use crate::ToolParser; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + + fn glm45_tool_call(function_name: &str, params: &[(&str, &str)]) -> String { + let params = params + .iter() + .map(|(name, value)| { + format!("{name}\n{value}") + }) + .collect::>() + .join("\n"); + format!("{function_name}\n{params}\n") + } + + #[test] + fn glm45_parse_complete_without_tool_call_keeps_text() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn glm45_parse_complete_extracts_single_tool_call() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + let output = format!( + "Let me search for that.\n{}", + glm45_tool_call( + "get_weather", + &[("city", "Beijing"), ("date", "2024-12-25")] + ) + ); + + let result = parser.parse_complete(&output).unwrap(); + + assert_eq!(result.normal_text, "Let me search for that.\n"); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({"city": "Beijing", "date": "2024-12-25"}) + ); + } + + #[test] + fn glm45_streaming_extracts_multiple_tool_calls() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + let output = format!( + "{}\n{}", + glm45_tool_call("get_weather", &[("city", "Shanghai")]), + glm45_tool_call("add", &[("x", "1"), ("y", "2")]) + ); + + let chunks = split_by_chars(&output, 11); + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.normal_text, ""); + assert_eq!(result.calls.len(), 2); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[1].name.as_deref(), Some("add")); + assert_eq!( + serde_json::from_str::(&result.calls[1].arguments).unwrap(), + json!({"x": 1, "y": 2}) + ); + } + + #[test] + fn glm45_parse_complete_unescapes_literal_closing_tags_in_arg_value() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + let result = parser + .parse_complete(&glm45_tool_call( + "get_weather", + &[ + ("city", "Paris </arg_value></tool_call>"), + ("date", "2026-05-08"), + ], + )) + .unwrap(); + + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "city": "Paris ", + "date": "2026-05-08", + }) + ); + } + + #[test] + fn glm45_streaming_without_tool_call_emits_text_incrementally() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &["hello ", "world"]); + + assert_eq!(result.normal_text, "hello world"); + assert!(result.calls.is_empty()); + } + + #[test] + fn glm45_streaming_preserves_prefix_text() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + + let result = collect_stream( + &mut parser, + &[ + "Prefix ", + &glm45_tool_call("get_weather", &[("city", "Hangzhou")]), + ], + ); + + assert_eq!(result.normal_text, "Prefix "); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn glm45_streaming_handles_start_token_split_across_chunks() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "hello get_weather\n", + "cityParis", + ], + ); + + assert_eq!(result.normal_text, "hello "); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + } + + #[test] + fn glm45_streaming_does_not_emit_incomplete_tool_call() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + + let result = parser.push("get_weather\ncity").unwrap(); + + assert_eq!(result.normal_text, ""); + assert!(result.calls.is_empty()); + } + + #[test] + fn glm45_finish_fails_incomplete_tool_call() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + + parser.push("get_weather\ncity").unwrap(); + let error = parser.finish().unwrap_err(); + + assert!(error.as_report().to_string().contains("incomplete GLM MoE tool call")); + } + + #[test] + fn glm45_malformed_tool_call_fails_fast() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + + let error = parser.push("get_weathercityParis").unwrap_err(); + + assert!(error.as_report().to_string().contains("tool parser parsing failed")); + } + + #[test] + fn glm45_streaming_ignores_trailing_text_after_tool_calls() { + let mut parser = Glm45MoeToolParser::new(&test_tools()); + + let result = collect_stream( + &mut parser, + &[&format!( + "{}<|endoftext|>", + glm45_tool_call("get_weather", &[("city", "Paris")]) + )], + ); + + assert_eq!(result.normal_text, ""); + assert_eq!(result.calls.len(), 1); + } +} diff --git a/rust/src/tool-parser/src/json/hermes.rs b/rust/src/tool-parser/src/json/hermes.rs new file mode 100644 index 00000000000..d57c7de8ef3 --- /dev/null +++ b/rust/src/tool-parser/src/json/hermes.rs @@ -0,0 +1,221 @@ +use super::{JsonToolCallConfig, JsonToolCallParser, JsonToolCallWhitespace}; +use crate::{Result, Tool, ToolParseResult, ToolParser}; + +const HERMES_CONFIG: JsonToolCallConfig = JsonToolCallConfig { + parser_name: "Hermes", + start_marker: "", + end_marker: "", + marker_whitespace: JsonToolCallWhitespace::Optional, + delimiter: None, + name_key: "name", + arguments_key: "arguments", +}; + +/// Tool parser for Hermes XML-wrapped JSON tool calls. +/// +/// Example tool call content: +/// +/// ```text +/// {"name": "get_weather", "arguments": {"location":"Tokyo"}} +/// ``` +/// +/// Arguments are already OpenAI-style JSON text, so they are streamed as raw +/// argument deltas without schema conversion or JSON normalization. +/// +/// Note: parallel calls are represented as repeated +/// `...` blocks, not as multiple calls inside one tag. +pub struct HermesToolParser { + inner: JsonToolCallParser, +} + +impl HermesToolParser { + /// Create a Hermes tool parser. + fn new(_tools: &[Tool]) -> Self { + Self { + inner: JsonToolCallParser::new(HERMES_CONFIG), + } + } +} + +impl ToolParser for HermesToolParser { + /// Create a boxed Hermes tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Push one decoded text chunk through the Hermes parser. + fn push(&mut self, chunk: &str) -> Result { + self.inner.push(chunk) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + self.inner.finish() + } +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use thiserror_ext::AsReport; + + use super::HermesToolParser; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + use crate::{ToolParseResult, ToolParser}; + + fn build_tool_call(function_name: &str, arguments: &str) -> String { + format!(r#"{{"name":"{function_name}","arguments":{arguments}}}"#) + } + + #[test] + fn hermes_parse_complete_without_tool_call_keeps_text() { + let mut parser = HermesToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn hermes_parse_complete_extracts_raw_json_arguments() { + let mut parser = HermesToolParser::new(&test_tools()); + let arguments = r#"{ "location": "Tokyo", "days": "3" }"#; + let result = parser + .parse_complete(&format!( + "Let me check.\n{}", + build_tool_call("get_weather", arguments) + )) + .unwrap(); + + assert_eq!(result.normal_text, "Let me check.\n"); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn hermes_accepts_newline_after_tool_call_start() { + let mut parser = HermesToolParser::new(&test_tools()); + let result = parser + .parse_complete( + r#" +{"name":"get_weather","arguments":{}}"#, + ) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + } + + #[test] + fn hermes_does_not_validate_or_normalize_arguments() { + let mut parser = HermesToolParser::new(&test_tools()); + let arguments = r#"{"location":"Tokyo",}"#; + let result = parser.parse_complete(&build_tool_call("get_weather", arguments)).unwrap(); + + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn hermes_streaming_emits_argument_deltas() { + let mut parser = HermesToolParser::new(&test_tools()); + let chunks = [ + "preface {\"name\":\"get_weather\",\"arguments\":", + "{\"location\":", + "\"Beijing\"", + "}", + "} suffix", + ]; + + let mut result = ToolParseResult::default(); + let mut observed_arguments = Vec::new(); + for chunk in chunks { + let next = parser.push(chunk).unwrap(); + observed_arguments.extend( + next.calls + .iter() + .filter(|call| call.name.is_none()) + .map(|call| call.arguments.clone()), + ); + result.append(next); + } + result.append(parser.finish().unwrap()); + + assert_eq!(observed_arguments, ["{\"location\":", "\"Beijing\"", "}"]); + assert_eq!(result.normal_text, "preface suffix"); + assert_eq!( + result.coalesce_calls().calls[0].arguments, + r#"{"location":"Beijing"}"# + ); + } + + #[test] + fn hermes_streaming_handles_split_markers() { + let input = format!( + "hello {}", + build_tool_call("get_weather", r#"{"location":"Tokyo"}"#) + ); + let chunks = split_by_chars(&input, 5); + let mut parser = HermesToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.normal_text, "hello "); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, r#"{"location":"Tokyo"}"#); + } + + #[test] + fn hermes_streaming_extracts_multiple_tool_calls() { + let input = format!( + "{}{}", + build_tool_call("get_weather", r#"{"location":"Shanghai"}"#), + build_tool_call("add", r#"{"x":1,"y":2}"#), + ); + let chunks = split_by_chars(&input, 7); + let mut parser = HermesToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + expect![[r#" + ToolParseResult { + normal_text: "", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"location\":\"Shanghai\"}", + }, + ToolCallDelta { + tool_index: 1, + name: Some( + "add", + ), + arguments: "{\"x\":1,\"y\":2}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } + + #[test] + fn hermes_finish_fails_incomplete_tool_call() { + let mut parser = HermesToolParser::new(&test_tools()); + parser + .push(r#"{"name":"get_weather","arguments":{"location""#) + .unwrap(); + + let error = parser.finish().unwrap_err(); + + expect!["tool parser parsing failed: incomplete Hermes tool call"] + .assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/tool-parser/src/json/llama.rs b/rust/src/tool-parser/src/json/llama.rs new file mode 100644 index 00000000000..3e579be72fb --- /dev/null +++ b/rust/src/tool-parser/src/json/llama.rs @@ -0,0 +1,487 @@ +use winnow::ascii::multispace0 as ws0; +use winnow::combinator::seq; +use winnow::error::{ModalResult, StrContext}; +use winnow::prelude::*; +use winnow::token::literal; + +use super::{ + JsonToolCallConfig, JsonToolCallEvent, JsonToolCallWhitespace, JsonToolInput, + argument_delta_event, tool_call_header_event, +}; +use crate::utils::{JsonObjectScanState, parse_buffered_event}; +use crate::{Result, Tool, ToolCallDelta, ToolParseResult, ToolParser}; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum LlamaJsonMode { + Start, + Header, + Arguments { json_scan: JsonObjectScanState }, + AfterCall, + Passthrough, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum LlamaJsonEvent { + ToolCallHeader { function_name: String }, + Arguments { len: usize }, + ToolCallClose, + Separator, +} + +/// Tool parser for strict Llama JSON-template tool calls. +/// +/// Example tool call content: +/// +/// ```text +/// {"name":"get_weather","parameters":{"location":"Tokyo"}}; {"name":"add","parameters":{"x":1,"y":2}} +/// ``` +/// +/// Arguments are already OpenAI-style JSON text, so they are streamed as raw +/// argument deltas without schema conversion or JSON normalization. +/// +/// Natural text at the beginning of the stream permanently disables tool +/// parsing for that assistant output. +pub struct Llama3JsonToolParser { + buffer: String, + mode: LlamaJsonMode, + active_tool_index: Option, + emitted_tool_count: usize, +} + +impl Llama3JsonToolParser { + /// Create a Llama JSON tool parser. + fn new(_tools: &[Tool]) -> Self { + Self { + buffer: String::new(), + mode: LlamaJsonMode::Start, + active_tool_index: None, + emitted_tool_count: 0, + } + } + + /// Commit the stream to JSON parsing or permanent passthrough. + fn commit_start(&mut self) -> bool { + if !matches!(self.mode, LlamaJsonMode::Start) { + return true; + } + + if self.buffer.is_empty() { + return false; + } + + if self.buffer.starts_with('{') { + self.mode = LlamaJsonMode::Header; + } else { + self.mode = LlamaJsonMode::Passthrough; + } + true + } + + /// Apply one parsed Llama JSON event to parser state and output. + fn apply_event(&mut self, event: LlamaJsonEvent, result: &mut ToolParseResult) -> Result<()> { + match event { + LlamaJsonEvent::ToolCallHeader { function_name } => { + let tool_index = self.emitted_tool_count; + self.emitted_tool_count += 1; + self.active_tool_index = Some(tool_index); + self.mode = LlamaJsonMode::Arguments { + json_scan: JsonObjectScanState::default(), + }; + result.calls.push(ToolCallDelta { + tool_index, + name: Some(function_name), + arguments: String::new(), + }); + } + LlamaJsonEvent::Arguments { len: consumed_len } => { + let Some(tool_index) = self.active_tool_index else { + return Err(parsing_failed!( + "Llama JSON arguments without an active tool call" + )); + }; + result.calls.push(ToolCallDelta { + tool_index, + name: None, + arguments: self.buffer[..consumed_len].to_string(), + }); + } + LlamaJsonEvent::ToolCallClose => { + self.active_tool_index = None; + self.mode = LlamaJsonMode::AfterCall; + } + LlamaJsonEvent::Separator => { + self.active_tool_index = None; + self.mode = LlamaJsonMode::Header; + } + } + Ok(()) + } + + /// Reset all streaming state. + fn reset(&mut self) { + self.buffer.clear(); + self.mode = LlamaJsonMode::Start; + self.active_tool_index = None; + self.emitted_tool_count = 0; + } +} + +impl ToolParser for Llama3JsonToolParser { + /// Create a boxed Llama JSON tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Push one decoded text chunk through the Llama JSON parser. + fn push(&mut self, chunk: &str) -> Result { + self.buffer.push_str(chunk); + let mut result = ToolParseResult::default(); + + if !self.commit_start() { + return Ok(result); + } + + if matches!(self.mode, LlamaJsonMode::Passthrough) { + result.normal_text.push_str(&self.buffer); + self.buffer.clear(); + return Ok(result); + } + + while let Some((event, consumed_len)) = parse_buffered_event(&self.buffer, |input| { + parse_next_llama_json_event(input, &mut self.mode) + })? { + self.apply_event(event, &mut result)?; + self.buffer.drain(..consumed_len); + } + + Ok(result) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + let mut result = ToolParseResult::default(); + match &self.mode { + LlamaJsonMode::Start | LlamaJsonMode::Passthrough => { + result.normal_text.push_str(&self.buffer); + } + LlamaJsonMode::AfterCall if self.buffer.trim().is_empty() => {} + LlamaJsonMode::Header | LlamaJsonMode::Arguments { .. } => { + return Err(parsing_failed!("incomplete Llama JSON tool call")); + } + LlamaJsonMode::AfterCall => { + return Err(parsing_failed!("invalid Llama JSON")); + } + } + self.reset(); + Ok(result) + } +} + +/// Parse a Llama JSON event for the current parser mode. +fn parse_next_llama_json_event( + input: &mut JsonToolInput<'_>, + mode: &mut LlamaJsonMode, +) -> ModalResult { + match mode { + LlamaJsonMode::Start | LlamaJsonMode::Passthrough => { + unreachable!("Llama JSON parser driver must commit before parsing events") + } + LlamaJsonMode::Header => llama_tool_call_header_event(input), + LlamaJsonMode::Arguments { json_scan } => parse_llama_arguments_event(input, json_scan), + LlamaJsonMode::AfterCall => after_call_event(input), + } +} + +/// Parse a Llama JSON tool-call header. +fn llama_tool_call_header_event(input: &mut JsonToolInput<'_>) -> ModalResult { + const CONFIG: JsonToolCallConfig = JsonToolCallConfig { + parser_name: "Llama JSON", + start_marker: "", + end_marker: "", + marker_whitespace: JsonToolCallWhitespace::Optional, + delimiter: Some(";"), + name_key: "name", + arguments_key: "parameters", + }; + + match tool_call_header_event(input, CONFIG)? { + JsonToolCallEvent::ToolCallHeader { function_name } => { + Ok(LlamaJsonEvent::ToolCallHeader { function_name }) + } + _ => unreachable!("tool_call_header_event only emits ToolCallHeader"), + } +} + +/// Parse one event inside a Llama JSON arguments payload. +fn parse_llama_arguments_event( + input: &mut JsonToolInput<'_>, + json_scan: &mut JsonObjectScanState, +) -> ModalResult { + if json_scan.complete() { + tool_call_close_event(input) + } else { + match argument_delta_event(input, json_scan)? { + JsonToolCallEvent::Arguments { len } => Ok(LlamaJsonEvent::Arguments { len }), + _ => unreachable!("argument_delta_event only emits Arguments"), + } + } +} + +/// Parse the outer closing brace for one Llama JSON tool call. +fn tool_call_close_event(input: &mut JsonToolInput<'_>) -> ModalResult { + literal("}").value(LlamaJsonEvent::ToolCallClose).parse_next(input) +} + +/// Parse a semicolon separator after one Llama JSON tool call. +fn after_call_event(input: &mut JsonToolInput<'_>) -> ModalResult { + seq!( + _: ws0, + _: literal(";"), + _: ws0, + ) + .value(LlamaJsonEvent::Separator) + .context(StrContext::Label("Llama JSON")) + .parse_next(input) +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use thiserror_ext::AsReport; + + use super::Llama3JsonToolParser; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + use crate::{ToolParseResult, ToolParser}; + + fn build_tool_call(function_name: &str, parameters: &str) -> String { + format!(r#"{{"name":"{function_name}","parameters":{parameters}}}"#) + } + + #[test] + fn llama_json_parse_complete_without_tool_call_keeps_text() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn llama_json_passthrough_never_reenters_tool_parsing() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let mut result = parser.push("plain text first ").unwrap(); + result.append( + parser.push(&build_tool_call("get_weather", r#"{"location":"Tokyo"}"#)).unwrap(), + ); + result.append(parser.finish().unwrap()); + + assert_eq!( + result.normal_text, + r#"plain text first {"name":"get_weather","parameters":{"location":"Tokyo"}}"# + ); + assert!(result.calls.is_empty()); + } + + #[test] + fn llama_json_does_not_support_python_tag_prefix() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let input = format!( + "<|python_tag|>{}", + build_tool_call("get_weather", r#"{"location":"Tokyo"}"#) + ); + let result = parser.parse_complete(&input).unwrap(); + + assert_eq!(result.normal_text, input); + assert!(result.calls.is_empty()); + } + + #[test] + fn llama_json_rejects_leading_whitespace_before_tool_call() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let input = format!( + "\n {}", + build_tool_call("get_weather", r#"{"location":"Tokyo"}"#) + ); + let result = parser.parse_complete(&input).unwrap(); + + assert_eq!(result.normal_text, input); + assert!(result.calls.is_empty()); + } + + #[test] + fn llama_json_extracts_raw_parameters_object() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let arguments = r#"{ "location": "Tokyo", "days": 3 }"#; + let result = parser.parse_complete(&build_tool_call("get_weather", arguments)).unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn llama_json_rejects_arguments_key() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let error = parser + .parse_complete(r#"{"name":"get_weather","arguments":{"location":"Tokyo"}}"#) + .unwrap_err(); + + expect![[r#" + tool parser parsing failed: invalid Llama JSON + expected `parameters`"#]] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn llama_json_extracts_multiple_semicolon_separated_calls() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let input = format!( + "{} \n; {}", + build_tool_call("get_weather", r#"{"location":"Shanghai"}"#), + build_tool_call("add", r#"{"x":1,"y":2}"#), + ); + let result = parser.parse_complete(&input).unwrap(); + + expect![[r#" + ToolParseResult { + normal_text: "", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"location\":\"Shanghai\"}", + }, + ToolCallDelta { + tool_index: 1, + name: Some( + "add", + ), + arguments: "{\"x\":1,\"y\":2}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } + + #[test] + fn llama_json_streaming_emits_argument_deltas() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let chunks = [ + "{\"name\":\"get_weather\",\"parameters\":", + "{\"location\":", + "\"Beijing\"", + "}}", + ]; + + let mut result = ToolParseResult::default(); + let mut observed_arguments = Vec::new(); + for chunk in chunks { + let next = parser.push(chunk).unwrap(); + observed_arguments.extend( + next.calls + .iter() + .filter(|call| call.name.is_none()) + .map(|call| call.arguments.clone()), + ); + result.append(next); + } + result.append(parser.finish().unwrap()); + + assert_eq!(observed_arguments, ["{\"location\":", "\"Beijing\"", "}"]); + assert_eq!( + result.coalesce_calls().calls[0].arguments, + r#"{"location":"Beijing"}"# + ); + } + + #[test] + fn llama_json_streaming_handles_split_objects_and_separator() { + let input = format!( + "{};{}", + build_tool_call("get_weather", r#"{"location":"Dallas","state":"TX"}"#), + build_tool_call("add", r#"{"x":4,"y":5}"#), + ); + let chunks = split_by_chars(&input, 6); + let mut parser = Llama3JsonToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.normal_text, ""); + assert_eq!(result.calls.len(), 2); + assert_eq!( + result.calls[0].arguments, + r#"{"location":"Dallas","state":"TX"}"# + ); + assert_eq!(result.calls[1].name.as_deref(), Some("add")); + assert_eq!(result.calls[1].arguments, r#"{"x":4,"y":5}"#); + } + + #[test] + fn llama_json_handles_nested_multiline_and_escaped_string_parameters() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let arguments = r#"{ + "payload": {"items": [1, {"value": "literal { brace } and \"quote\""}]}, + "flag": true +}"#; + let result = parser.parse_complete(&build_tool_call("convert", arguments)).unwrap(); + + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn llama_json_keeps_trailing_whitespace_after_tool_call() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let result = parser + .parse_complete(&format!( + "{}\n\t ", + build_tool_call("get_weather", r#"{"location":"Tokyo"}"#) + )) + .unwrap(); + + assert_eq!(result.normal_text, ""); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn llama_json_finish_fails_incomplete_tool_call() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + parser.push(r#"{"name":"get_weather","parameters":{"location""#).unwrap(); + + let error = parser.finish().unwrap_err(); + + expect!["tool parser parsing failed: incomplete Llama JSON tool call"] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn llama_json_malformed_field_order_fails_fast() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let error = parser.push(r#"{"parameters":{},"name":"get_weather"}"#).unwrap_err(); + + expect![[r#" + tool parser parsing failed: invalid Llama JSON + expected `name`"#]] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn llama_json_trailing_non_separator_content_errors() { + let mut parser = Llama3JsonToolParser::new(&test_tools()); + let error = parser + .push(&format!( + "{} trailing", + build_tool_call("get_weather", r#"{"location":"Tokyo"}"#) + )) + .unwrap_err(); + + expect!["tool parser parsing failed: invalid Llama JSON"] + .assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/tool-parser/src/json/mistral.rs b/rust/src/tool-parser/src/json/mistral.rs new file mode 100644 index 00000000000..ac5dea804cc --- /dev/null +++ b/rust/src/tool-parser/src/json/mistral.rs @@ -0,0 +1,240 @@ +use super::{JsonToolCallConfig, JsonToolCallParser, JsonToolCallWhitespace}; +use crate::{Result, Tool, ToolParseResult, ToolParser}; + +const MISTRAL_CONFIG: JsonToolCallConfig = JsonToolCallConfig { + parser_name: "Mistral", + start_marker: "[TOOL_CALLS] [", + end_marker: "]", + marker_whitespace: JsonToolCallWhitespace::Optional, + delimiter: Some(","), + name_key: "name", + arguments_key: "arguments", +}; + +/// Tool parser for Mistral JSON-array tool calls. +/// +/// Example tool call content: +/// +/// ```text +/// [TOOL_CALLS] [{"name": "get_weather", "arguments": {"location":"Tokyo"}}] +/// ``` +/// +/// Arguments are already OpenAI-style JSON text, so they are streamed as raw +/// argument deltas without schema conversion or JSON normalization. +pub struct MistralToolParser { + inner: JsonToolCallParser, +} + +impl MistralToolParser { + /// Create a Mistral tool parser. + fn new(_tools: &[Tool]) -> Self { + Self { + inner: JsonToolCallParser::new(MISTRAL_CONFIG), + } + } +} + +impl ToolParser for MistralToolParser { + /// Create a boxed Mistral tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Push one decoded text chunk through the Mistral parser. + fn push(&mut self, chunk: &str) -> Result { + self.inner.push(chunk) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + self.inner.finish() + } +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use thiserror_ext::AsReport; + + use super::MistralToolParser; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + use crate::{ToolParseResult, ToolParser}; + + fn build_tool_call(function_name: &str, arguments: &str) -> String { + format!(r#"{{"name":"{function_name}","arguments":{arguments}}}"#) + } + + fn build_tool_calls(tool_calls: &[String]) -> String { + format!("[TOOL_CALLS] [{}]", tool_calls.join(",")) + } + + #[test] + fn mistral_parse_complete_without_tool_call_keeps_text() { + let mut parser = MistralToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn mistral_parse_complete_extracts_raw_json_arguments() { + let mut parser = MistralToolParser::new(&test_tools()); + let arguments = r#"{ "location": "Tokyo", "days": "3" }"#; + let result = parser + .parse_complete(&format!( + "Let me check.\n{}", + build_tool_calls(&[build_tool_call("get_weather", arguments)]) + )) + .unwrap(); + + assert_eq!(result.normal_text, "Let me check.\n"); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn mistral_parse_complete_extracts_pretty_multiple_tool_calls() { + let mut parser = MistralToolParser::new(&test_tools()); + let result = parser + .parse_complete( + r#"I'll help. +[TOOL_CALLS] [ + {"name": "get_weather", "arguments": {"city": "Tokyo", "units": "celsius"}} + , + {"name": "add", "arguments": {"x": 1, "y": 2}} +]"#, + ) + .unwrap(); + + expect![[r#" + ToolParseResult { + normal_text: "I'll help.\n", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"city\": \"Tokyo\", \"units\": \"celsius\"}", + }, + ToolCallDelta { + tool_index: 1, + name: Some( + "add", + ), + arguments: "{\"x\": 1, \"y\": 2}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } + + #[test] + fn mistral_does_not_validate_or_normalize_arguments() { + let mut parser = MistralToolParser::new(&test_tools()); + let arguments = r#"{"location":"Tokyo",}"#; + let result = parser + .parse_complete(&build_tool_calls(&[build_tool_call( + "get_weather", + arguments, + )])) + .unwrap(); + + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn mistral_streaming_emits_argument_deltas() { + let mut parser = MistralToolParser::new(&test_tools()); + let chunks = [ + "preface [TOOL", + "_CALLS] [{\"name\":\"get_weather\",\"arguments\":", + "{\"location\":", + "\"Beijing\"", + "}", + "}] suffix", + ]; + + let mut result = ToolParseResult::default(); + let mut observed_arguments = Vec::new(); + for chunk in chunks { + let next = parser.push(chunk).unwrap(); + observed_arguments.extend( + next.calls + .iter() + .filter(|call| call.name.is_none()) + .map(|call| call.arguments.clone()), + ); + result.append(next); + } + result.append(parser.finish().unwrap()); + + assert_eq!(observed_arguments, ["{\"location\":", "\"Beijing\"", "}"]); + assert_eq!(result.normal_text, "preface suffix"); + assert_eq!( + result.coalesce_calls().calls[0].arguments, + r#"{"location":"Beijing"}"# + ); + } + + #[test] + fn mistral_streaming_handles_split_markers() { + let input = format!( + "hello {}", + build_tool_calls(&[build_tool_call("get_weather", r#"{"location":"Tokyo"}"#)]) + ); + let chunks = split_by_chars(&input, 5); + let mut parser = MistralToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.normal_text, "hello "); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, r#"{"location":"Tokyo"}"#); + } + + #[test] + fn mistral_keeps_array_bracket_literal_inside_json_string() { + let mut parser = MistralToolParser::new(&test_tools()); + let arguments = r#"{"text":"Array notation: arr[0] = value[1]"}"#; + let result = parser + .parse_complete(&build_tool_calls(&[build_tool_call("echo", arguments)])) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn mistral_finish_fails_incomplete_tool_call() { + let mut parser = MistralToolParser::new(&test_tools()); + parser + .push(r#"[TOOL_CALLS] [{"name":"get_weather","arguments":{"location""#) + .unwrap(); + + let error = parser.finish().unwrap_err(); + + expect!["tool parser parsing failed: incomplete Mistral tool call"] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn mistral_malformed_field_order_fails_fast() { + let mut parser = MistralToolParser::new(&test_tools()); + let error = parser + .push(r#"[TOOL_CALLS] [{"arguments":{},"name":"get_weather"}]"#) + .unwrap_err(); + + expect![[r#" + tool parser parsing failed: invalid Mistral + expected `name`"#]] + .assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/tool-parser/src/json/mod.rs b/rust/src/tool-parser/src/json/mod.rs new file mode 100644 index 00000000000..ce96fbbde70 --- /dev/null +++ b/rust/src/tool-parser/src/json/mod.rs @@ -0,0 +1,465 @@ +//! Shared parser core for JSON tool calls wrapped by text markers. + +pub use hermes::HermesToolParser; +pub use llama::Llama3JsonToolParser; +pub use mistral::MistralToolParser; +pub use qwen::Qwen3XmlToolParser; + +mod hermes; +mod llama; +mod mistral; +mod qwen; + +use winnow::ascii::multispace0 as ws0; +use winnow::combinator::{alt, seq}; +use winnow::error::{ModalResult, StrContext, StrContextValue}; +use winnow::prelude::*; +use winnow::stream::Partial; +use winnow::token::literal; + +use super::utils::{ + JsonObjectScanState, json_str, parse_buffered_event, safe_text_len, take_json_object, +}; +use super::{Result, ToolCallDelta, ToolParseResult}; + +type JsonToolInput<'i> = Partial<&'i str>; + +#[derive(Debug, Clone, Copy)] +struct JsonToolCallConfig { + parser_name: &'static str, + start_marker: &'static str, + end_marker: &'static str, + marker_whitespace: JsonToolCallWhitespace, + delimiter: Option<&'static str>, + name_key: &'static str, + arguments_key: &'static str, +} + +#[derive(Debug, Clone, Copy)] +enum JsonToolCallWhitespace { + Optional, + Exact(&'static str), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum JsonToolCallMode { + Text, + Header, + Arguments { json_scan: JsonObjectScanState }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum JsonToolCallEvent { + Text { len: usize }, + ToolCallStart, + ToolCallHeader { function_name: String }, + Arguments { len: usize }, + ToolCallDelimiter, + ToolCallEnd, +} + +/// Tool parser core for marker-wrapped JSON tool calls. +#[derive(Debug)] +struct JsonToolCallParser { + config: JsonToolCallConfig, + buffer: String, + mode: JsonToolCallMode, + active_tool_index: Option, + emitted_tool_count: usize, +} + +impl JsonToolCallParser { + /// Create a marker-wrapped JSON tool-call parser. + fn new(config: JsonToolCallConfig) -> Self { + Self { + config, + buffer: String::new(), + mode: JsonToolCallMode::Text, + active_tool_index: None, + emitted_tool_count: 0, + } + } + + /// Push one decoded text chunk through the JSON tool-call parser. + fn push(&mut self, chunk: &str) -> Result { + self.buffer.push_str(chunk); + let mut result = ToolParseResult::default(); + let config = self.config; + + while let Some((event, consumed_len)) = parse_buffered_event(&self.buffer, |input| { + parse_next_json_tool_call_event(input, &mut self.mode, config) + })? { + self.apply_event(event, &mut result)?; + self.buffer.drain(..consumed_len); + } + + Ok(result) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + let mut result = ToolParseResult::default(); + match &self.mode { + JsonToolCallMode::Text => result.normal_text.push_str(&self.buffer), + JsonToolCallMode::Header | JsonToolCallMode::Arguments { .. } => { + return Err(parsing_failed!( + "incomplete {} tool call", + self.config.parser_name + )); + } + } + self.reset(); + Ok(result) + } + + /// Apply one parsed JSON tool-call event to parser state and output. + fn apply_event( + &mut self, + event: JsonToolCallEvent, + result: &mut ToolParseResult, + ) -> Result<()> { + match event { + JsonToolCallEvent::Text { len: consumed_len } => { + result.normal_text.push_str(&self.buffer[..consumed_len]); + } + JsonToolCallEvent::ToolCallStart => self.mode = JsonToolCallMode::Header, + JsonToolCallEvent::ToolCallHeader { function_name } => { + let tool_index = self.emitted_tool_count; + self.emitted_tool_count += 1; + self.active_tool_index = Some(tool_index); + self.mode = JsonToolCallMode::Arguments { + json_scan: JsonObjectScanState::default(), + }; + result.calls.push(ToolCallDelta { + tool_index, + name: Some(function_name), + arguments: String::new(), + }); + } + JsonToolCallEvent::Arguments { len: consumed_len } => { + let Some(tool_index) = self.active_tool_index else { + return Err(parsing_failed!( + "{} arguments without an active tool call", + self.config.parser_name + )); + }; + result.calls.push(ToolCallDelta { + tool_index, + name: None, + arguments: self.buffer[..consumed_len].to_string(), + }); + } + JsonToolCallEvent::ToolCallDelimiter => { + self.active_tool_index = None; + self.mode = JsonToolCallMode::Header; + } + JsonToolCallEvent::ToolCallEnd => { + self.active_tool_index = None; + self.mode = JsonToolCallMode::Text; + } + } + Ok(()) + } + + /// Reset all streaming state. + fn reset(&mut self) { + self.buffer.clear(); + self.mode = JsonToolCallMode::Text; + self.active_tool_index = None; + self.emitted_tool_count = 0; + } +} + +/// Parse a JSON tool-call event for the current parser mode. +fn parse_next_json_tool_call_event( + input: &mut JsonToolInput<'_>, + mode: &mut JsonToolCallMode, + config: JsonToolCallConfig, +) -> ModalResult { + match mode { + JsonToolCallMode::Text => parse_text_event(input, config), + JsonToolCallMode::Header => tool_call_header_event(input, config), + JsonToolCallMode::Arguments { json_scan } => { + parse_arguments_event(input, json_scan, config) + } + } +} + +/// Parse a text-mode JSON tool-call event. +fn parse_text_event( + input: &mut JsonToolInput<'_>, + config: JsonToolCallConfig, +) -> ModalResult { + alt(( + |input: &mut JsonToolInput<'_>| tool_call_start_event(input, config), + |input: &mut JsonToolInput<'_>| safe_text_event(input, config), + )) + .parse_next(input) +} + +/// Parse a marker-wrapped JSON tool-call start marker. +fn tool_call_start_event( + input: &mut JsonToolInput<'_>, + config: JsonToolCallConfig, +) -> ModalResult { + seq!( + _: literal(config.start_marker), + _: |input: &mut JsonToolInput<'_>| marker_whitespace(input, config), + ) + .value(JsonToolCallEvent::ToolCallStart) + .parse_next(input) +} + +/// Parse a marker-wrapped JSON tool-call header before the raw arguments +/// payload. +fn tool_call_header_event( + input: &mut JsonToolInput<'_>, + config: JsonToolCallConfig, +) -> ModalResult { + let (function_name,) = seq!( + _: ws0, + _: literal("{"), + _: ws0, + _: |input: &mut JsonToolInput<'_>| json_key(input, config.name_key), + _: ws0, + _: literal(":"), + _: ws0, + json_str, + _: ws0, + _: literal(","), + _: ws0, + _: |input: &mut JsonToolInput<'_>| json_key(input, config.arguments_key), + _: ws0, + _: literal(":"), + _: ws0, + ) + .context(StrContext::Label(config.parser_name)) + .parse_next(input)?; + + Ok(JsonToolCallEvent::ToolCallHeader { function_name }) +} + +/// Parse a configured JSON object key. +fn json_key(input: &mut JsonToolInput<'_>, key: &'static str) -> ModalResult<()> { + seq!( + _: literal("\""), + _: literal(key).context(StrContext::Expected(StrContextValue::StringLiteral(key))), + _: literal("\""), + ) + .void() + .parse_next(input) +} + +/// Parse one event inside a marker-wrapped JSON tool-call arguments payload. +fn parse_arguments_event( + input: &mut JsonToolInput<'_>, + json_scan: &mut JsonObjectScanState, + config: JsonToolCallConfig, +) -> ModalResult { + if json_scan.complete() { + tool_call_close_event(input, config) + } else { + argument_delta_event(input, json_scan) + } +} + +/// Parse a raw JSON arguments delta. +fn argument_delta_event( + input: &mut JsonToolInput<'_>, + json_scan: &mut JsonObjectScanState, +) -> ModalResult { + take_json_object(input, json_scan).map(|len| JsonToolCallEvent::Arguments { len }) +} + +/// Parse a marker-wrapped JSON tool-call close marker. +fn tool_call_close_event( + input: &mut JsonToolInput<'_>, + config: JsonToolCallConfig, +) -> ModalResult { + let _ = literal("}").parse_next(input)?; + + match config.delimiter { + Some(delimiter) => alt(( + |input: &mut JsonToolInput<'_>| tool_call_end_event(input, config), + |input: &mut JsonToolInput<'_>| tool_call_delimiter_event(input, delimiter), + )) + .parse_next(input), + None => tool_call_end_event(input, config), + } +} + +/// Parse a marker-wrapped JSON tool-call end marker. +fn tool_call_end_event( + input: &mut JsonToolInput<'_>, + config: JsonToolCallConfig, +) -> ModalResult { + seq!( + _: |input: &mut JsonToolInput<'_>| marker_whitespace(input, config), + _: literal(config.end_marker), + ) + .value(JsonToolCallEvent::ToolCallEnd) + .parse_next(input) +} + +/// Parse a delimiter between JSON tool calls inside one marker block. +fn tool_call_delimiter_event( + input: &mut JsonToolInput<'_>, + delimiter: &'static str, +) -> ModalResult { + seq!( + _: ws0, + _: literal(delimiter), + _: ws0, + ) + .value(JsonToolCallEvent::ToolCallDelimiter) + .parse_next(input) +} + +/// Parse configured whitespace around a marker-wrapped JSON tool call. +fn marker_whitespace(input: &mut JsonToolInput<'_>, config: JsonToolCallConfig) -> ModalResult<()> { + match config.marker_whitespace { + JsonToolCallWhitespace::Optional => ws0.void().parse_next(input), + JsonToolCallWhitespace::Exact(whitespace) => literal(whitespace).void().parse_next(input), + } +} + +/// Parse a safe text run before the next marker-wrapped JSON tool call. +fn safe_text_event( + input: &mut JsonToolInput<'_>, + config: JsonToolCallConfig, +) -> ModalResult { + safe_text_len(input, config.start_marker).map(|len| JsonToolCallEvent::Text { len }) +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + + use super::{JsonToolCallConfig, JsonToolCallParser, JsonToolCallWhitespace}; + use crate::ToolParseResult; + + const DELIMITED_CONFIG: JsonToolCallConfig = JsonToolCallConfig { + parser_name: "Delimited JSON", + start_marker: "", + end_marker: "", + marker_whitespace: JsonToolCallWhitespace::Optional, + delimiter: Some("<"), + name_key: "function", + arguments_key: "parameters", + }; + + fn build_tool_call(function_name: &str, arguments: &str) -> String { + format!(r#"{{"function":"{function_name}","parameters":{arguments}}}"#) + } + + fn build_tool_calls(tool_calls: &[String]) -> String { + format!("{}", tool_calls.join(" <\n")) + } + + fn collect_chunks(parser: &mut JsonToolCallParser, chunks: &[&str]) -> ToolParseResult { + let mut result = ToolParseResult::default(); + for chunk in chunks { + result.append(parser.push(chunk).unwrap()); + } + result.append(parser.finish().unwrap()); + result.coalesce_calls() + } + + #[test] + fn json_tool_call_delimiter_extracts_multiple_calls_in_one_block() { + let input = build_tool_calls(&[ + build_tool_call("get_weather", r#"{"location":"Shanghai"}"#), + build_tool_call("add", r#"{"x":1,"y":2}"#), + ]); + let mut parser = JsonToolCallParser::new(DELIMITED_CONFIG); + + let result = collect_chunks(&mut parser, &[&input]); + + expect![[r#" + ToolParseResult { + normal_text: "", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"location\":\"Shanghai\"}", + }, + ToolCallDelta { + tool_index: 1, + name: Some( + "add", + ), + arguments: "{\"x\":1,\"y\":2}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } + + #[test] + fn json_tool_call_delimiter_can_arrive_in_later_chunk() { + let mut parser = JsonToolCallParser::new(DELIMITED_CONFIG); + let chunks = [ + r#"{"function":"get_weather","parameters":{"location":"Shanghai"}}"#, + " <\n", + r#"{"function":"add","parameters":{"x":1,"y":2}}"#, + "", + ]; + + let result = collect_chunks(&mut parser, &chunks); + + expect![[r#" + ToolParseResult { + normal_text: "", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"location\":\"Shanghai\"}", + }, + ToolCallDelta { + tool_index: 1, + name: Some( + "add", + ), + arguments: "{\"x\":1,\"y\":2}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } + + #[test] + fn json_tool_call_end_marker_wins_over_delimiter_prefix() { + let mut parser = JsonToolCallParser::new(DELIMITED_CONFIG); + let chunks = [ + r#"{"function":"get_weather","parameters":{"location":"Shanghai"}}"#, + " ", + " trailing text", + ]; + + let result = collect_chunks(&mut parser, &chunks); + + expect![[r#" + ToolParseResult { + normal_text: " trailing text", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"location\":\"Shanghai\"}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } +} diff --git a/rust/src/tool-parser/src/json/qwen.rs b/rust/src/tool-parser/src/json/qwen.rs new file mode 100644 index 00000000000..34bb190ec09 --- /dev/null +++ b/rust/src/tool-parser/src/json/qwen.rs @@ -0,0 +1,279 @@ +use super::{JsonToolCallConfig, JsonToolCallParser, JsonToolCallWhitespace}; +use crate::{Result, Tool, ToolParseResult, ToolParser}; + +const QWEN_XML_CONFIG: JsonToolCallConfig = JsonToolCallConfig { + parser_name: "Qwen XML", + start_marker: "", + end_marker: "", + marker_whitespace: JsonToolCallWhitespace::Exact("\n"), + delimiter: None, + name_key: "name", + arguments_key: "arguments", +}; + +/// Tool parser for Qwen XML-wrapped JSON tool calls. +/// +/// Example tool call content: +/// +/// ```text +/// +/// {"name": "get_weather", "arguments": {"location":"Tokyo"}} +/// +/// ``` +/// +/// Arguments are already OpenAI-style JSON text, so they are streamed as raw +/// argument deltas without schema conversion or JSON normalization. +/// +/// Note: parallel calls are represented as repeated +/// `...` blocks, not as multiple calls inside one tag. +pub struct Qwen3XmlToolParser { + inner: JsonToolCallParser, +} + +impl Qwen3XmlToolParser { + /// Create a Qwen XML tool parser. + fn new(_tools: &[Tool]) -> Self { + Self { + inner: JsonToolCallParser::new(QWEN_XML_CONFIG), + } + } +} + +impl ToolParser for Qwen3XmlToolParser { + /// Create a boxed Qwen XML tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Push one decoded text chunk through the Qwen XML parser. + fn push(&mut self, chunk: &str) -> Result { + self.inner.push(chunk) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + self.inner.finish() + } +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use thiserror_ext::AsReport; + + use super::Qwen3XmlToolParser; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + use crate::{ToolParseResult, ToolParser}; + + fn build_tool_call(function_name: &str, arguments: &str) -> String { + format!( + "\n{{\"name\": \"{function_name}\", \"arguments\": {arguments}}}\n" + ) + } + + #[test] + fn qwen_xml_parse_complete_without_tool_call_keeps_text() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn qwen_xml_parse_complete_extracts_raw_json_arguments() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + let arguments = r#"{ "location": "Tokyo", "days": "3" }"#; + let result = parser + .parse_complete(&format!( + "Let me check.\n{}", + build_tool_call("get_weather", arguments) + )) + .unwrap(); + + assert_eq!(result.normal_text, "Let me check.\n"); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn qwen_xml_does_not_validate_or_normalize_arguments() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + let arguments = r#"{"location":"Tokyo",}"#; + let result = parser.parse_complete(&build_tool_call("get_weather", arguments)).unwrap(); + + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn qwen_xml_streaming_emits_argument_deltas() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + let chunks = [ + "", + "\n{\"name\": \"get_weather\", \"arguments\": ", + "{\"location\":", + "\"Beijing\"", + "}", + "}\n", + ]; + + let mut result = ToolParseResult::default(); + let mut observed_arguments = Vec::new(); + for chunk in chunks { + let next = parser.push(chunk).unwrap(); + observed_arguments.extend( + next.calls + .iter() + .filter(|call| call.name.is_none()) + .map(|call| call.arguments.clone()), + ); + result.append(next); + } + result.append(parser.finish().unwrap()); + + assert_eq!(observed_arguments, ["{\"location\":", "\"Beijing\"", "}"]); + assert_eq!( + result.coalesce_calls().calls[0].arguments, + r#"{"location":"Beijing"}"# + ); + } + + #[test] + fn qwen_xml_streaming_handles_split_markers() { + let input = format!( + "hello {}", + build_tool_call("get_weather", r#"{"location":"Tokyo"}"#) + ); + let chunks = split_by_chars(&input, 5); + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.normal_text, "hello "); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, r#"{"location":"Tokyo"}"#); + } + + #[test] + fn qwen_xml_keeps_end_marker_literal_inside_json_string() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + let arguments = r#"{"text":"literal inside"}"#; + let result = parser.parse_complete(&build_tool_call("echo", arguments)).unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn qwen_xml_decodes_escaped_function_name() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + let result = parser + .parse_complete( + r#" +{"name":"say_\"hi","arguments":{}} +"#, + ) + .unwrap(); + + assert_eq!(result.calls[0].name.as_deref(), Some("say_\"hi")); + } + + #[test] + fn qwen_xml_requires_newline_after_tool_call_start() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + let input = r#"{"name":"get_weather","arguments":{}} +"#; + + let result = parser.parse_complete(input).unwrap(); + + assert_eq!(result.normal_text, input); + assert!(result.calls.is_empty()); + } + + #[test] + fn qwen_xml_requires_newline_before_tool_call_end() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + let error = parser + .parse_complete( + r#" +{"name":"get_weather","arguments":{}}"#, + ) + .unwrap_err(); + + assert!(error.to_report_string().starts_with("tool parser parsing failed:")); + } + + #[test] + fn qwen_xml_streaming_extracts_multiple_tool_calls() { + let input = format!( + "{}{}", + build_tool_call("get_weather", r#"{"location":"Shanghai"}"#), + build_tool_call("add", r#"{"x":1,"y":2}"#), + ); + let chunks = split_by_chars(&input, 7); + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + + let result = collect_stream(&mut parser, &chunks); + + expect![[r#" + ToolParseResult { + normal_text: "", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"location\":\"Shanghai\"}", + }, + ToolCallDelta { + tool_index: 1, + name: Some( + "add", + ), + arguments: "{\"x\":1,\"y\":2}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } + + #[test] + fn qwen_xml_finish_fails_incomplete_tool_call() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + parser + .push( + r#" +{"name":"get_weather","arguments":{"location""#, + ) + .unwrap(); + + let error = parser.finish().unwrap_err(); + + expect!["tool parser parsing failed: incomplete Qwen XML tool call"] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn qwen_xml_malformed_field_order_fails_fast() { + let mut parser = Qwen3XmlToolParser::new(&test_tools()); + let error = parser + .push( + r#" +{"arguments":{},"name":"get_weather"} +"#, + ) + .unwrap_err(); + + expect![[r#" + tool parser parsing failed: invalid Qwen XML + expected `name`"#]] + .assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/tool-parser/src/kimi_k2.rs b/rust/src/tool-parser/src/kimi_k2.rs new file mode 100644 index 00000000000..921b6e0c011 --- /dev/null +++ b/rust/src/tool-parser/src/kimi_k2.rs @@ -0,0 +1,560 @@ +use winnow::ascii::{digit1, multispace0 as ws0}; +use winnow::combinator::{alt, eof, repeat, seq}; +use winnow::prelude::*; +use winnow::stream::Partial; +use winnow::token::{literal, rest, take_until, take_while}; + +use super::utils::{JsonObjectScanState, parse_buffered_event, safe_text_len, take_json_object}; +use super::{Result, ToolCallDelta, ToolParseResult, ToolParser}; +use crate::Tool; + +const TOOL_CALLS_START: &str = "<|tool_calls_section_begin|>"; +const TOOL_CALLS_END: &str = "<|tool_calls_section_end|>"; +const TOOL_CALL_START: &str = "<|tool_call_begin|>"; +const TOOL_CALL_END: &str = "<|tool_call_end|>"; +const TOOL_CALL_ARGUMENT_START: &str = "<|tool_call_argument_begin|>"; + +type KimiK2Input<'i> = Partial<&'i str>; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum KimiK2Mode { + Text, + ToolBlock, + Header, + Arguments { json_scan: JsonObjectScanState }, + Done, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum KimiK2Event { + Text { + len: usize, + }, + ToolCallsStart, + ToolCallStart, + ToolCallHeader { + function_name: String, + function_index: usize, + }, + Arguments { + len: usize, + }, + ToolCallEnd, + ToolCallsEnd, + IgnoredRest, +} + +/// Tool parser for Kimi K2 token-delimited tool calls. +/// +/// Example tool call content: +/// +/// ```text +/// <|tool_calls_section_begin|> +/// <|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location":"NYC"}<|tool_call_end|> +/// <|tool_calls_section_end|> +/// ``` +/// +/// Arguments are already OpenAI-style JSON text, so they are streamed as raw +/// argument deltas without schema conversion or JSON normalization. +pub struct KimiK2ToolParser { + buffer: String, + mode: KimiK2Mode, + active_tool_index: Option, +} + +impl KimiK2ToolParser { + /// Create a Kimi K2 tool parser. + fn new(_tools: &[Tool]) -> Self { + Self { + buffer: String::new(), + mode: KimiK2Mode::Text, + active_tool_index: None, + } + } + + /// Apply one parsed Kimi K2 event to parser state and output. + fn apply_event(&mut self, event: KimiK2Event, result: &mut ToolParseResult) -> Result<()> { + match event { + KimiK2Event::Text { len: consumed_len } => { + result.normal_text.push_str(&self.buffer[..consumed_len]); + } + KimiK2Event::ToolCallsStart => self.mode = KimiK2Mode::ToolBlock, + KimiK2Event::ToolCallStart => self.mode = KimiK2Mode::Header, + KimiK2Event::ToolCallHeader { + function_name, + function_index, + } => { + let tool_index = function_index; + self.active_tool_index = Some(tool_index); + self.mode = KimiK2Mode::Arguments { + json_scan: JsonObjectScanState::default(), + }; + result.calls.push(ToolCallDelta { + tool_index, + name: Some(function_name), + arguments: String::new(), + }); + } + KimiK2Event::Arguments { len: consumed_len } => { + let Some(tool_index) = self.active_tool_index else { + return Err(parsing_failed!( + "Kimi K2 arguments without an active tool call" + )); + }; + result.calls.push(ToolCallDelta { + tool_index, + name: None, + arguments: self.buffer[..consumed_len].to_string(), + }); + } + KimiK2Event::ToolCallEnd => { + self.active_tool_index = None; + self.mode = KimiK2Mode::ToolBlock; + } + KimiK2Event::ToolCallsEnd => { + self.active_tool_index = None; + self.mode = KimiK2Mode::Done; + } + KimiK2Event::IgnoredRest => {} + } + Ok(()) + } + + /// Reset all streaming state. + fn reset(&mut self) { + self.buffer.clear(); + self.mode = KimiK2Mode::Text; + self.active_tool_index = None; + } +} + +impl ToolParser for KimiK2ToolParser { + /// Create a boxed Kimi K2 tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Preserve Kimi K2 special-token markers while decoding. + fn preserve_special_tokens(&self) -> bool { + true + } + + /// Push one decoded text chunk through the Kimi K2 parser. + fn push(&mut self, chunk: &str) -> Result { + self.buffer.push_str(chunk); + let mut result = ToolParseResult::default(); + + while let Some((event, consumed_len)) = parse_buffered_event(&self.buffer, |input| { + parse_next_kimi_k2_event(input, &mut self.mode) + })? { + self.apply_event(event, &mut result)?; + self.buffer.drain(..consumed_len); + } + + Ok(result) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + let mut result = ToolParseResult::default(); + match &self.mode { + KimiK2Mode::Text => result.normal_text.push_str(&self.buffer), + KimiK2Mode::ToolBlock | KimiK2Mode::Done => {} + KimiK2Mode::Header | KimiK2Mode::Arguments { .. } => { + return Err(parsing_failed!("incomplete Kimi K2 tool call")); + } + } + self.reset(); + Ok(result) + } +} + +/// Parse a Kimi K2 event for the current parser mode. +fn parse_next_kimi_k2_event( + input: &mut KimiK2Input<'_>, + mode: &mut KimiK2Mode, +) -> ModalResult { + match mode { + KimiK2Mode::Text => parse_text_event(input), + KimiK2Mode::ToolBlock => parse_tool_block_event(input), + KimiK2Mode::Header => tool_call_header_event(input), + KimiK2Mode::Arguments { json_scan } => parse_arguments_event(input, json_scan), + KimiK2Mode::Done => ignored_rest_event(input), + } +} + +/// Parse a text-mode Kimi K2 event. +fn parse_text_event(input: &mut KimiK2Input<'_>) -> ModalResult { + alt((tool_calls_start_event, safe_text_event)).parse_next(input) +} + +/// Parse one event inside the Kimi K2 tool-calls section. +fn parse_tool_block_event(input: &mut KimiK2Input<'_>) -> ModalResult { + alt((tool_calls_end_event, tool_call_start_event)).parse_next(input) +} + +/// Parse one event inside a Kimi K2 tool-call arguments payload. +fn parse_arguments_event( + input: &mut KimiK2Input<'_>, + json_scan: &mut JsonObjectScanState, +) -> ModalResult { + if json_scan.complete() { + tool_call_end_event(input) + } else { + argument_delta_event(input, json_scan) + } +} + +/// Parse a Kimi K2 tool-calls section start marker. +fn tool_calls_start_event(input: &mut KimiK2Input<'_>) -> ModalResult { + literal(TOOL_CALLS_START).value(KimiK2Event::ToolCallsStart).parse_next(input) +} + +/// Parse a Kimi K2 tool-calls section end marker. +fn tool_calls_end_event(input: &mut KimiK2Input<'_>) -> ModalResult { + (ws0, literal(TOOL_CALLS_END)) + .value(KimiK2Event::ToolCallsEnd) + .parse_next(input) +} + +/// Parse a Kimi K2 tool-call start marker. +fn tool_call_start_event(input: &mut KimiK2Input<'_>) -> ModalResult { + (ws0, literal(TOOL_CALL_START)) + .value(KimiK2Event::ToolCallStart) + .parse_next(input) +} + +/// Parse a Kimi K2 tool-call end marker. +fn tool_call_end_event(input: &mut KimiK2Input<'_>) -> ModalResult { + literal(TOOL_CALL_END).value(KimiK2Event::ToolCallEnd).parse_next(input) +} + +/// Parse a Kimi K2 tool-call header before the argument marker. +fn tool_call_header_event(input: &mut KimiK2Input<'_>) -> ModalResult { + let (header, _) = ( + take_until(1.., TOOL_CALL_ARGUMENT_START), + literal(TOOL_CALL_ARGUMENT_START), + ) + .parse_next(input)?; + + let mut header_input = header; + let (header, _, _) = (tool_header, ws0, eof).parse_next(&mut header_input)?; + + Ok(KimiK2Event::ToolCallHeader { + function_name: header.function_name, + function_index: header.function_index, + }) +} + +/// Parse a Kimi K2 raw JSON arguments delta. +fn argument_delta_event( + input: &mut KimiK2Input<'_>, + json_scan: &mut JsonObjectScanState, +) -> ModalResult { + take_json_object(input, json_scan).map(|len| KimiK2Event::Arguments { len }) +} + +/// Parse a safe text run before the next Kimi K2 tool-calls section. +fn safe_text_event(input: &mut KimiK2Input<'_>) -> ModalResult { + safe_text_len(input, TOOL_CALLS_START).map(|len| KimiK2Event::Text { len }) +} + +/// Parse ignored rest after the Kimi K2 tool-calls section ends. +fn ignored_rest_event(input: &mut KimiK2Input<'_>) -> ModalResult { + rest.value(KimiK2Event::IgnoredRest).parse_next(input) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct KimiK2ToolHeader { + function_name: String, + function_index: usize, +} + +/// Parse a Kimi K2 tool-call header. +fn tool_header(input: &mut &str) -> ModalResult { + let (function_name, function_index) = seq!( + _: ws0, + _: namespace_prefix, + tool_name_segment, + _: literal(":"), + tool_call_index, + ) + .parse_next(input)?; + + Ok(KimiK2ToolHeader { + function_name: function_name.to_string(), + function_index, + }) +} + +/// Parse Kimi K2 namespace segments before the final tool name. +fn namespace_prefix(input: &mut &str) -> ModalResult<()> { + repeat(0.., namespace_segment).parse_next(input) +} + +/// Parse a Kimi K2 namespace segment. +fn namespace_segment<'i>(input: &mut &'i str) -> ModalResult<&'i str> { + let (segment, _) = (tool_name_segment, literal(".")).parse_next(input)?; + Ok(segment) +} + +/// Parse a Kimi K2 tool name segment. +fn tool_name_segment<'i>(input: &mut &'i str) -> ModalResult<&'i str> { + take_while(1.., |ch: char| { + !ch.is_whitespace() && ch != '<' && ch != ':' && ch != '.' + }) + .parse_next(input) +} + +/// Parse a Kimi K2 tool-call index. +fn tool_call_index(input: &mut &str) -> ModalResult { + digit1.parse_to().parse_next(input) +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use thiserror_ext::AsReport; + + use super::{ + KimiK2ToolParser, TOOL_CALL_ARGUMENT_START, TOOL_CALL_END, TOOL_CALL_START, TOOL_CALLS_END, + TOOL_CALLS_START, ToolParser, tool_header, + }; + use crate::ToolParseResult; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + + fn build_tool_call(function_name: &str, index: usize, arguments: &str) -> String { + format!( + "{TOOL_CALL_START}functions.{function_name}:{index}{TOOL_CALL_ARGUMENT_START}{arguments}{TOOL_CALL_END}" + ) + } + + fn build_tool_section(tool_calls: &[String]) -> String { + format!("{TOOL_CALLS_START}{}{TOOL_CALLS_END}", tool_calls.join("")) + } + + #[test] + fn kimi_k2_parse_complete_without_tool_call_keeps_text() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn kimi_k2_parse_complete_extracts_raw_json_arguments() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let arguments = r#"{ "location": "NYC", "days": "3" }"#; + let result = parser + .parse_complete(&format!( + "Checking. {} trailing text", + build_tool_section(&[build_tool_call("get_weather", 0, arguments)]) + )) + .unwrap(); + + assert_eq!(result.normal_text, "Checking. "); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn kimi_k2_does_not_validate_or_normalize_arguments() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let arguments = r#"{"location":"NYC",}"#; + let result = parser + .parse_complete(&build_tool_section(&[build_tool_call( + "get_weather", + 0, + arguments, + )])) + .unwrap(); + + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn kimi_k2_streaming_emits_argument_deltas() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let chunks = [ + TOOL_CALLS_START, + TOOL_CALL_START, + "functions.get_weather:0", + TOOL_CALL_ARGUMENT_START, + "{\"location\":", + "\"Paris\"", + "}", + TOOL_CALL_END, + TOOL_CALLS_END, + ]; + + let mut result = ToolParseResult::default(); + let mut observed_arguments = Vec::new(); + for chunk in chunks { + let next = parser.push(chunk).unwrap(); + observed_arguments.extend( + next.calls + .iter() + .filter(|call| call.name.is_none()) + .map(|call| call.arguments.clone()), + ); + result.append(next); + } + result.append(parser.finish().unwrap()); + + assert_eq!(observed_arguments, ["{\"location\":", "\"Paris\"", "}"]); + let result = result.coalesce_calls(); + assert_eq!(result.calls[0].arguments, r#"{"location":"Paris"}"#); + } + + #[test] + fn kimi_k2_streaming_holds_back_split_markers() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let chunks = [ + "hello <|tool_calls", + "_section_begin|>", + TOOL_CALL_START, + "functions.get_weather:0", + TOOL_CALL_ARGUMENT_START, + r#"{"location":"NYC"}"#, + "<|tool_call", + "_end|>", + TOOL_CALLS_END, + ]; + + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.normal_text, "hello "); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, r#"{"location":"NYC"}"#); + } + + #[test] + fn kimi_k2_keeps_end_marker_literal_inside_json_string() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let arguments = format!(r#"{{"text":"literal {TOOL_CALL_END} inside"}}"#); + let input = build_tool_section(&[build_tool_call("echo", 0, &arguments)]); + + let result = parser.parse_complete(&input).unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].arguments, arguments); + } + + #[test] + fn kimi_k2_streaming_keeps_split_end_marker_literal_inside_json_string() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let chunks = [ + TOOL_CALLS_START, + TOOL_CALL_START, + "functions.echo:0", + TOOL_CALL_ARGUMENT_START, + r#"{"text":"literal <|tool"#, + r#"_call_end|> inside"}"#, + TOOL_CALL_END, + TOOL_CALLS_END, + ]; + + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + result.calls[0].arguments, + r#"{"text":"literal <|tool_call_end|> inside"}"# + ); + } + + #[test] + fn kimi_k2_streaming_extracts_multiple_tool_calls() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let input = build_tool_section(&[ + build_tool_call("get_weather", 0, r#"{"location":"Shanghai"}"#), + build_tool_call("add", 1, r#"{"x":1,"y":2}"#), + ]); + + let chunks = split_by_chars(&input, 7); + let result = collect_stream(&mut parser, &chunks); + + expect![[r#" + ToolParseResult { + normal_text: "", + calls: [ + ToolCallDelta { + tool_index: 0, + name: Some( + "get_weather", + ), + arguments: "{\"location\":\"Shanghai\"}", + }, + ToolCallDelta { + tool_index: 1, + name: Some( + "add", + ), + arguments: "{\"x\":1,\"y\":2}", + }, + ], + } + "#]] + .assert_debug_eq(&result); + } + + #[test] + fn kimi_k2_accepts_non_functions_header_prefix() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let input = format!( + "{TOOL_CALLS_START}{TOOL_CALL_START}api.tools.search:42{TOOL_CALL_ARGUMENT_START}{{}}{TOOL_CALL_END}{TOOL_CALLS_END}" + ); + + let result = parser.parse_complete(&input).unwrap(); + + assert_eq!(result.calls[0].tool_index, 42); + assert_eq!(result.calls[0].name.as_deref(), Some("search")); + assert_eq!(result.calls[0].arguments, "{}"); + } + + #[test] + fn kimi_k2_tool_header_parses_namespace_function_and_index() { + let mut input = "api.tools.search:42"; + let header = tool_header(&mut input).unwrap(); + + expect![[r#" + KimiK2ToolHeader { + function_name: "search", + function_index: 42, + } + "#]] + .assert_debug_eq(&header); + } + + #[test] + fn kimi_k2_finish_fails_incomplete_tool_call() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + parser + .push(&format!( + "{TOOL_CALLS_START}{TOOL_CALL_START}functions.get_weather:0{TOOL_CALL_ARGUMENT_START}{{\"location\"" + )) + .unwrap(); + + let error = parser.finish().unwrap_err(); + + expect!["tool parser parsing failed: incomplete Kimi K2 tool call"] + .assert_eq(&error.to_report_string()); + } + + #[test] + fn kimi_k2_malformed_header_fails_fast() { + let mut parser = KimiK2ToolParser::new(&test_tools()); + let input = + format!("{TOOL_CALLS_START}{TOOL_CALL_START}get_weather{TOOL_CALL_ARGUMENT_START}{{}}"); + + let error = parser.push(&input).unwrap_err(); + + expect!["tool parser parsing failed: "].assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/tool-parser/src/lib.rs b/rust/src/tool-parser/src/lib.rs new file mode 100644 index 00000000000..f7d3411f3b7 --- /dev/null +++ b/rust/src/tool-parser/src/lib.rs @@ -0,0 +1,141 @@ +//! Streaming tool parsers for chat completions. + +#[macro_use] +mod error; +mod deepseek_dsml; +mod deepseek_json; +mod gemma4; +mod glm_xml; +mod json; +mod kimi_k2; +mod minimax_m2; +mod parameters; +mod qwen_coder; +#[cfg(any(test, feature = "test-util"))] +pub mod test_utils; +mod utils; + +use std::collections::{BTreeMap, btree_map}; + +pub use deepseek_dsml::{DeepSeekV4ToolParser, DeepSeekV32ToolParser}; +pub use deepseek_json::{DeepSeekV3ToolParser, DeepSeekV31ToolParser}; +pub use error::{Result, ToolParserError}; +pub use gemma4::Gemma4ToolParser; +pub use glm_xml::{Glm45MoeToolParser, Glm47MoeToolParser}; +pub use json::{HermesToolParser, Llama3JsonToolParser, MistralToolParser, Qwen3XmlToolParser}; +pub use kimi_k2::KimiK2ToolParser; +pub use minimax_m2::MinimaxM2ToolParser; +pub use qwen_coder::Qwen3CoderToolParser; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// One function-style tool made available to the model. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Tool { + pub name: String, + pub description: Option, + pub parameters: Value, + pub strict: Option, +} + +/// One tool-call update emitted while parsing assistant text. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ToolCallDelta { + /// Stable parser-local tool index for this call within one assistant turn. + pub tool_index: usize, + /// Function name, present on the first update for one tool call. + pub name: Option, + /// Arguments text contributed by this update. + pub arguments: String, +} + +/// Result of advancing tool parsing with one assistant-text input. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct ToolParseResult { + /// Plain assistant text that is not part of any tool call. + pub normal_text: String, + /// Tool-call updates extracted from this input. + pub calls: Vec, +} + +impl ToolParseResult { + /// Append another parser result onto this one. + /// + /// Note that this does not attempt to merge multiple deltas for the same + /// tool call into one complete item. Call `coalesce_calls()` after if + /// that behavior is desired. + pub(crate) fn append(&mut self, mut other: Self) { + self.normal_text.push_str(&other.normal_text); + self.calls.append(&mut other.calls); + } + + /// Merge multiple deltas for the same tool call into one complete item. + /// + /// This is primarily used by the default `parse_complete()` implementation, + /// which delegates through the incremental parser lifecycle and then + /// needs to collapse streaming-style argument fragments into one final + /// tool call. + pub(crate) fn coalesce_calls(mut self) -> Self { + let mut merged = BTreeMap::::new(); + let mut order = Vec::new(); + + for call in self.calls { + match merged.entry(call.tool_index) { + btree_map::Entry::Vacant(entry) => { + order.push(call.tool_index); + entry.insert(call); + } + btree_map::Entry::Occupied(mut entry) => { + let existing = entry.get_mut(); + if existing.name.is_none() { + existing.name = call.name; + } + existing.arguments.push_str(&call.arguments); + } + } + } + + self.calls = + order.into_iter().filter_map(|tool_index| merged.remove(&tool_index)).collect(); + self + } +} + +/// Incremental parser that extracts tool calls from assistant output. +pub trait ToolParser: Send { + /// Construct a boxed parser instance for one request stream. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static; + + /// Return whether decoded output must preserve tokenizer special tokens. + /// + /// Some model families emit tool-call sentinels as special tokens. Those + /// parsers need `skip_special_tokens = false` while parsing is enabled. + fn preserve_special_tokens(&self) -> bool { + false + } + + /// Feed one decoded text delta into the parser. + fn push(&mut self, chunk: &str) -> Result; + + /// Flush any buffered partial state at end of stream. + fn finish(&mut self) -> Result { + Ok(ToolParseResult::default()) + } + + /// Parse complete tool calls from final output. + /// + /// The default implementation reuses the incremental parser lifecycle by + /// feeding the full output through `push()` and then calling `finish()`. + /// This keeps one source of truth for robust parsers whose incremental + /// state machine is equivalent across arbitrary chunking. + fn parse_complete(&mut self, output: &str) -> Result { + let mut result = self.push(output)?; + result.append(self.finish()?); + Ok(result.coalesce_calls()) + } +} + +#[cfg(test)] +mod tests; diff --git a/rust/src/tool-parser/src/minimax_m2.rs b/rust/src/tool-parser/src/minimax_m2.rs new file mode 100644 index 00000000000..34f2bc89b11 --- /dev/null +++ b/rust/src/tool-parser/src/minimax_m2.rs @@ -0,0 +1,519 @@ +use winnow::ascii::{multispace0 as ws0, multispace1 as ws1}; +use winnow::combinator::{alt, delimited, repeat, seq, terminated}; +use winnow::prelude::*; +use winnow::stream::Partial; +use winnow::token::{literal, rest, take_until}; + +use super::parameters::ToolSchemas; +use super::utils::{parse_buffered_event, safe_text_len, xml_unescape}; +use super::{Result, ToolCallDelta, ToolParseResult, ToolParser}; +use crate::Tool; + +const TOOL_CALL_START: &str = ""; +const TOOL_CALL_END: &str = ""; +const INVOKE_START: &str = " = Partial<&'i str>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MinimaxM2Mode { + Text, + ToolBlock, + Done, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum MinimaxM2Event { + Text { + len: usize, + }, + ToolBlockStart, + Invoke { + name: String, + raw_params: Vec<(String, String)>, + }, + ToolBlockEnd, + IgnoredRest, +} + +/// Tool parser for MiniMax M2 XML-style tool calls. +/// +/// Example tool call content: +/// +/// ```text +/// +/// Seattle +/// +/// ``` +/// +/// Arguments are emitted only after a full `` block is parsed. +pub struct MinimaxM2ToolParser { + buffer: String, + mode: MinimaxM2Mode, + emitted_tool_count: usize, + tool_parameters: ToolSchemas, +} + +impl MinimaxM2ToolParser { + /// Create a MiniMax M2 tool parser. + fn new(tools: &[Tool]) -> Self { + Self { + buffer: String::new(), + mode: MinimaxM2Mode::Text, + emitted_tool_count: 0, + tool_parameters: ToolSchemas::from_tools(tools), + } + } + + /// Apply one parsed MiniMax M2 event to parser state and output. + fn apply_event(&mut self, event: MinimaxM2Event, result: &mut ToolParseResult) -> Result<()> { + match event { + MinimaxM2Event::Text { len: consumed_len } => { + result.normal_text.push_str(&self.buffer[..consumed_len]); + } + MinimaxM2Event::ToolBlockStart => self.mode = MinimaxM2Mode::ToolBlock, + MinimaxM2Event::Invoke { name, raw_params } => { + let arguments = self.tool_parameters.convert_params_with_schema(&name, raw_params); + let arguments = serde_json::to_string(&arguments) + .map_err(|error| parsing_failed!("failed to serialize arguments: {}", error))?; + + result.calls.push(ToolCallDelta { + tool_index: self.emitted_tool_count, + name: Some(name), + arguments, + }); + self.emitted_tool_count += 1; + } + MinimaxM2Event::ToolBlockEnd => self.mode = MinimaxM2Mode::Done, + MinimaxM2Event::IgnoredRest => {} + } + Ok(()) + } + + /// Reset all streaming state. + fn reset(&mut self) { + self.buffer.clear(); + self.mode = MinimaxM2Mode::Text; + self.emitted_tool_count = 0; + } +} + +impl ToolParser for MinimaxM2ToolParser { + /// Create a boxed MiniMax M2 tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Push one decoded text chunk through the MiniMax M2 parser. + fn push(&mut self, chunk: &str) -> Result { + self.buffer.push_str(chunk); + let mut result = ToolParseResult::default(); + + while let Some((event, consumed_len)) = parse_buffered_event(&self.buffer, |input| { + parse_next_minimax_m2_event(input, self.mode) + })? { + self.apply_event(event, &mut result)?; + self.buffer.drain(..consumed_len); + } + + Ok(result) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + let mut result = ToolParseResult::default(); + match self.mode { + MinimaxM2Mode::Text => { + result.normal_text.push_str(&self.buffer); + } + MinimaxM2Mode::ToolBlock => { + return Err(parsing_failed!("incomplete MiniMax M2 tool call")); + } + MinimaxM2Mode::Done => {} + } + self.reset(); + Ok(result) + } +} + +/// Parse a MiniMax M2 event for the current parser mode. +fn parse_next_minimax_m2_event( + input: &mut MinimaxM2Input<'_>, + mode: MinimaxM2Mode, +) -> ModalResult { + match mode { + MinimaxM2Mode::Text => parse_text_event(input), + MinimaxM2Mode::ToolBlock => parse_tool_block_event(input), + MinimaxM2Mode::Done => ignored_rest_event(input), + } +} + +/// Parse a text-mode MiniMax M2 event. +fn parse_text_event(input: &mut MinimaxM2Input<'_>) -> ModalResult { + alt((tool_block_start_event, safe_text_event)).parse_next(input) +} + +/// Parse a MiniMax M2 tool-block start marker. +fn tool_block_start_event(input: &mut MinimaxM2Input<'_>) -> ModalResult { + literal(TOOL_CALL_START).value(MinimaxM2Event::ToolBlockStart).parse_next(input) +} + +/// Parse a safe text run before the next MiniMax M2 marker. +fn safe_text_event(input: &mut MinimaxM2Input<'_>) -> ModalResult { + safe_text_len(input, TOOL_CALL_START).map(|len| MinimaxM2Event::Text { len }) +} + +/// Parse one event inside a MiniMax M2 tool block. +fn parse_tool_block_event(input: &mut MinimaxM2Input<'_>) -> ModalResult { + alt((tool_block_end_event, invoke_event)).parse_next(input) +} + +/// Parse a MiniMax M2 tool-block end marker. +fn tool_block_end_event(input: &mut MinimaxM2Input<'_>) -> ModalResult { + (ws0, literal(TOOL_CALL_END)) + .value(MinimaxM2Event::ToolBlockEnd) + .parse_next(input) +} + +/// Parse a complete MiniMax M2 invoke block. +fn invoke_event(input: &mut MinimaxM2Input<'_>) -> ModalResult { + let (name, raw_params) = seq!( + _: ws0, + _: literal(INVOKE_START), + _: (ws1, literal("name=")), + attr_value, + _: literal(">"), + repeat(0.., terminated(parameter, ws0)), + _: literal(INVOKE_END), + ) + .parse_next(input)?; + + Ok(MinimaxM2Event::Invoke { + name: name.trim().to_string(), + raw_params, + }) +} + +/// Parse a MiniMax M2 parameter block. +fn parameter(input: &mut MinimaxM2Input<'_>) -> ModalResult<(String, String)> { + let (name, value) = seq!( + _: literal(PARAMETER_START), + _: (ws1, literal("name=")), + attr_value, + _: literal(">"), + take_until(0.., PARAMETER_END).map(xml_unescape), + _: literal(PARAMETER_END), + ) + .parse_next(input)?; + + Ok((name.trim().to_string(), value.into_owned())) +} + +/// Parse a quoted or unquoted XML attribute value. +fn attr_value<'i>(input: &mut MinimaxM2Input<'i>) -> ModalResult<&'i str> { + alt(( + delimited(literal("\""), take_until(1.., "\""), literal("\"")), + delimited(literal("'"), take_until(1.., "'"), literal("'")), + take_until(1.., ">"), + )) + .parse_next(input) +} + +/// Parse ignored rest after the MiniMax M2 tool block ends. +fn ignored_rest_event(input: &mut MinimaxM2Input<'_>) -> ModalResult { + rest.value(MinimaxM2Event::IgnoredRest).parse_next(input) +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use serde_json::{Value, json}; + use thiserror_ext::AsReport; + + use super::{MinimaxM2ToolParser, TOOL_CALL_END, TOOL_CALL_START, ToolParser}; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + + fn build_tool_block(invokes: &[(&str, Vec<(&str, &str)>)]) -> String { + let invokes = invokes + .iter() + .map(|(function_name, params)| { + let params = params + .iter() + .map(|(name, value)| format!(r#"{value}"#)) + .collect::>() + .join(""); + format!(r#"{params}"#) + }) + .collect::(); + format!("{TOOL_CALL_START}{invokes}{TOOL_CALL_END}") + } + + #[test] + fn minimax_m2_parse_complete_without_tool_call_keeps_text() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn minimax_m2_parse_complete_extracts_single_tool_call() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_block(&[( + "get_weather", + vec![("city", "Seattle"), ("days", "5")], + )])) + .unwrap(); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "city": "Seattle", "days": 5 }) + ); + } + + #[test] + fn minimax_m2_parse_complete_preserves_prefix_and_ignores_trailing_text() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let output = format!( + "Let me check. {} This trailing text is ignored.", + build_tool_block(&[("get_weather", vec![("city", "Seattle")])]) + ); + let result = parser.parse_complete(&output).unwrap(); + + assert_eq!(result.normal_text, "Let me check. "); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn minimax_m2_parse_complete_extracts_multiple_invokes() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_block(&[ + ("get_weather", vec![("city", "Seattle")]), + ("get_weather", vec![("city", "NYC")]), + ])) + .unwrap(); + + assert_eq!(result.calls.len(), 2); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[1].tool_index, 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "city": "Seattle" }) + ); + assert_eq!( + serde_json::from_str::(&result.calls[1].arguments).unwrap(), + json!({ "city": "NYC" }) + ); + } + + #[test] + fn minimax_m2_parse_complete_converts_schema_types() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_block(&[( + "convert", + vec![ + ("whole", "5.0"), + ("flag", "true"), + ("payload", r#"{"nested":true}"#), + ("items", "[1,2]"), + ("empty", "42"), + ], + )])) + .unwrap(); + + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "whole": 5.0, + "flag": true, + "payload": { "nested": true }, + "items": [1, 2], + "empty": "42", + }) + ); + } + + #[test] + fn minimax_m2_parse_complete_unescapes_literal_closing_tags_in_parameter_value() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_block(&[( + "get_weather", + vec![ + ( + "city", + "Seattle </parameter></invoke></minimax:tool_call>", + ), + ("days", "5"), + ], + )])) + .unwrap(); + + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "city": "Seattle ", + "days": 5, + }) + ); + } + + #[test] + fn minimax_m2_parse_complete_handles_multiline_parameters() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = parser + .parse_complete( + "\ + \ + \nrectangle\n\ + {\"width\":10,\n\"height\":20}\ + 2\ + \ + ", + ) + .unwrap(); + + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "shape": "\nrectangle\n", + "dimensions": { "width": 10, "height": 20 }, + "precision": 2, + }) + ); + } + + #[test] + fn minimax_m2_streaming_extracts_single_tool_call() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "", + r#""#, + r#"Seattle"#, + "", + ], + ); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "city": "Seattle" }) + ); + } + + #[test] + fn minimax_m2_streaming_preserves_prefix_text() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "Let me check. ", + "", + r#"Seattle"#, + "", + ], + ); + + assert_eq!(result.normal_text, "Let me check. "); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn minimax_m2_streaming_without_tool_call_emits_text_incrementally() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &["Hello, ", "world!"]); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn minimax_m2_streaming_handles_marker_split_across_chunks() { + let text = build_tool_block(&[("get_weather", vec![("city", "Seattle")])]); + let chunks = split_by_chars(&text, 3); + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.calls.len(), 1); + assert!(result.normal_text.is_empty()); + } + + #[test] + fn minimax_m2_streaming_extracts_multiple_invokes_in_order() { + let text = build_tool_block(&[ + ("get_weather", vec![("city", "Seattle")]), + ("get_weather", vec![("city", "NYC")]), + ]); + let chunks = split_by_chars(&text, 7); + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.calls.len(), 2); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[1].tool_index, 1); + } + + #[test] + fn minimax_m2_streaming_ignores_text_after_tool_block() { + let text = format!( + "{} ignored", + build_tool_block(&[("get_weather", vec![("city", "Seattle")])]) + ); + let chunks = split_by_chars(&text, 5); + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &chunks); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn minimax_m2_streaming_does_not_emit_incomplete_tool_call() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let result = parser.push(r#""#).unwrap(); + + assert!(result.normal_text.is_empty()); + assert!(result.calls.is_empty()); + } + + #[test] + fn minimax_m2_finish_fails_incomplete_tool_call() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + parser.push(r#""#).unwrap(); + + assert!(parser.finish().is_err()); + } + + #[test] + fn minimax_m2_finish_fails_after_bare_tool_block_start() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + parser.push("").unwrap(); + + assert!(parser.finish().is_err()); + } + + #[test] + fn minimax_m2_malformed_tool_call_fails_fast() { + let mut parser = MinimaxM2ToolParser::new(&test_tools()); + let error = parser.push("").unwrap_err(); + + expect!["tool parser parsing failed: "].assert_eq(&error.to_report_string()); + } +} diff --git a/rust/src/tool-parser/src/parameters.rs b/rust/src/tool-parser/src/parameters.rs new file mode 100644 index 00000000000..cf21bad16cd --- /dev/null +++ b/rust/src/tool-parser/src/parameters.rs @@ -0,0 +1,508 @@ +use std::collections::BTreeMap; + +use serde_json::{Number, Value}; + +use crate::Tool; + +/// Normalized parameter schemas for all tools in one request. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(super) struct ToolSchemas { + tools: BTreeMap, +} + +/// Normalized parameter schema for one tool. +/// +/// This is a minimal subset of JSON Schema with some normalization heuristics +/// to support common schema patterns and upstream schema variations, focused on +/// coercing raw string parameter values into more specific JSON types for +/// downstream tool call execution. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(super) struct ToolSchema { + params: BTreeMap, +} + +/// Normalized JSON parameter type used for raw string coercion. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) enum JsonParamType { + String, + Integer, + Number, + Boolean, + Object, + Array, + Null, + OneOf(Vec), +} + +impl ToolSchemas { + /// Normalize OpenAI-style tool parameter JSON schemas for one request. + pub(super) fn from_tools(tools: &[Tool]) -> Self { + let tools = tools + .iter() + .map(|tool| (tool.name.clone(), ToolSchema::from_schema(&tool.parameters))) + .collect(); + + Self { tools } + } + + /// Convert raw string parameter values for one named tool. + /// + /// Unknown tool names use an empty schema, so all parameters fall back to + /// strings. + pub(super) fn convert_params_with_schema( + &self, + function_name: &str, + params: Vec<(String, String)>, + ) -> serde_json::Map { + let tool_schema = self.tools.get(function_name).unwrap_or(ToolSchema::empty()); + let mut converted = serde_json::Map::with_capacity(params.len()); + for (name, value) in params { + let value = tool_schema.convert(&name, &value); + converted.insert(name, value); + } + converted + } + + /// Convert one raw string parameter value for one named tool. + pub(super) fn convert_param_with_schema( + &self, + function_name: &str, + name: &str, + value: &str, + ) -> Value { + let tool_schema = self.tools.get(function_name).unwrap_or(ToolSchema::empty()); + tool_schema.convert(name, value) + } +} + +impl ToolSchema { + /// Return an empty schema with no parameter information, which causes all + /// parameters to be treated as strings. + const fn empty() -> &'static Self { + static EMPTY: ToolSchema = ToolSchema { + params: BTreeMap::new(), + }; + &EMPTY + } + + /// Normalize an OpenAI-style tool parameters JSON schema. + fn from_schema(parameters: &Value) -> Self { + let Some(properties) = parameters.get("properties").and_then(Value::as_object) else { + return Self::default(); + }; + + let params = properties + .iter() + .filter_map(|(name, schema)| { + JsonParamType::from_schema(schema).map(|param_type| (name.clone(), param_type)) + }) + .collect(); + + Self { params } + } + + /// Convert one raw parameter value using its normalized schema type. + /// + /// If the parameter name is unknown, or we don't have a schema for it, or + /// the value fails to convert, this falls back to returning the raw + /// string as a JSON string value. + fn convert(&self, name: &str, value: &str) -> Value { + if value.eq_ignore_ascii_case("null") { + return Value::Null; + } + + let Some(param_type) = self.params.get(name) else { + return Value::String(value.to_string()); + }; + + convert_value(param_type, value).unwrap_or_else(|| Value::String(value.to_string())) + } +} + +impl JsonParamType { + /// Normalize one parameter property schema. + fn from_schema(schema: &Value) -> Option { + let schema = schema.as_object()?; + + if let Some(type_value) = schema.get("type") { + return Self::from_type_value(type_value); + } + + if let Some(composite) = schema.get("anyOf").or_else(|| schema.get("oneOf")) { + let param_type = composite + .as_array() + .map(|schemas| schemas.iter().filter_map(Self::from_schema).collect::>()) + .filter(|types| !types.is_empty()) + .map(Self::one_of) + .unwrap_or(Self::Object); + return Some(param_type); + } + + if schema.contains_key("enum") { + return Some(Self::String); + } + if schema.contains_key("items") { + return Some(Self::Array); + } + if schema.contains_key("properties") { + return Some(Self::Object); + } + + None + } + + /// Normalize a JSON schema `type` value. + fn from_type_value(type_value: &Value) -> Option { + match type_value { + Value::String(kind) => Self::from_type_name(kind), + Value::Array(kinds) => { + let types = kinds + .iter() + .filter_map(Value::as_str) + .filter_map(Self::from_type_name) + .collect::>(); + if types.is_empty() { + None + } else { + Some(Self::one_of(types)) + } + } + _ => None, + } + } + + /// Normalize one JSON schema type name. + fn from_type_name(kind: &str) -> Option { + let kind = kind.trim().to_ascii_lowercase(); + match kind.as_str() { + "string" | "str" | "text" | "varchar" | "char" | "enum" => Some(Self::String), + "integer" | "int" => Some(Self::Integer), + "number" | "float" => Some(Self::Number), + "boolean" | "bool" | "binary" => Some(Self::Boolean), + "object" => Some(Self::Object), + "array" | "arr" | "sequence" => Some(Self::Array), + "null" => Some(Self::Null), + _ if kind.starts_with("int") + || kind.starts_with("uint") + || kind.starts_with("long") + || kind.starts_with("short") + || kind.starts_with("unsigned") => + { + Some(Self::Integer) + } + _ if kind.starts_with("num") || kind.starts_with("float") => Some(Self::Number), + _ if kind.starts_with("dict") => Some(Self::Object), + _ if kind.starts_with("list") => Some(Self::Array), + _ => None, + } + } + + /// Collapse a candidate type list into one normalized type. + fn one_of(mut types: Vec) -> Self { + if types.len() == 1 { + types.remove(0) + } else { + Self::OneOf(types) + } + } +} + +/// Convert one raw string value to a normalized JSON type. +fn convert_value(param_type: &JsonParamType, value: &str) -> Option { + match param_type { + JsonParamType::String => Some(Value::String(value.to_string())), + JsonParamType::Integer => value.parse::().ok().map(Number::from).map(Value::Number), + JsonParamType::Number => convert_number(value), + JsonParamType::Boolean => convert_boolean(value), + JsonParamType::Object | JsonParamType::Array => serde_json::from_str(value).ok(), + JsonParamType::Null => value.eq_ignore_ascii_case("null").then_some(Value::Null), + JsonParamType::OneOf(types) => { + types.iter().find_map(|param_type| convert_value(param_type, value)) + } + } +} + +/// Convert one raw string value to a JSON number. +fn convert_number(value: &str) -> Option { + if let Ok(parsed) = value.parse::() { + return Some(Value::Number(Number::from(parsed))); + } + Number::from_f64(value.parse::().ok()?).map(Value::Number) +} + +/// Convert one raw string value to a boolean. +fn convert_boolean(value: &str) -> Option { + match value.trim().to_ascii_lowercase().as_str() { + "true" | "1" => Some(Value::Bool(true)), + "false" | "0" => Some(Value::Bool(false)), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::{ToolSchema, ToolSchemas}; + use crate::Tool; + + fn test_tool(name: &str, parameters: serde_json::Value) -> Tool { + Tool { + name: name.to_string(), + description: None, + parameters, + strict: None, + } + } + + #[test] + fn invalid_schema_converts_everything_as_string() { + let params = ToolSchema::from_schema(&json!({ "type": "object" })); + + assert_eq!(params.convert("count", "42"), json!("42")); + assert_eq!(params.convert("count", "null"), json!(null)); + } + + #[test] + fn skips_unknown_property_schema_and_unknown_type() { + let params = ToolSchema::from_schema(&json!({ + "type": "object", + "properties": { + "unknown_schema": true, + "unknown_type": { "type": "mystery" }, + "known": { "type": "integer" } + } + })); + + assert_eq!(params.convert("unknown_schema", "42"), json!("42")); + assert_eq!(params.convert("unknown_type", "42"), json!("42")); + assert_eq!(params.convert("known", "42"), json!(42)); + } + + #[test] + fn converts_supported_types() { + let params = ToolSchema::from_schema(&json!({ + "type": "object", + "properties": { + "text": { "type": "string" }, + "count": { "type": "integer" }, + "size": { "type": "number" }, + "enabled": { "type": "boolean" }, + "payload": { "type": "object" }, + "items": { "type": "array" }, + "nothing": { "type": "null" } + } + })); + + assert_eq!(params.convert("text", "42"), json!("42")); + assert_eq!(params.convert("count", "42"), json!(42)); + assert_eq!(params.convert("size", "5.0"), json!(5.0)); + assert_eq!(params.convert("enabled", "1"), json!(true)); + assert_eq!(params.convert("payload", r#"{"k":1}"#), json!({ "k": 1 })); + assert_eq!(params.convert("items", "[1,2]"), json!([1, 2])); + assert_eq!(params.convert("nothing", "null"), json!(null)); + } + + #[test] + fn number_conversion_parses_int_then_float() { + let params = ToolSchema::from_schema(&json!({ + "type": "object", + "properties": { + "value": { "type": "number" } + } + })); + + assert_eq!(params.convert("value", "5"), json!(5)); + assert_eq!(params.convert("value", "5.0"), json!(5.0)); + assert_eq!(params.convert("value", "5."), json!(5.0)); + assert_eq!(params.convert("value", "+1"), json!(1)); + assert_eq!(params.convert("value", "+1.0"), json!(1.0)); + assert_eq!( + params.convert("value", "9223372036854775807.5"), + json!(9223372036854775808.0) + ); + } + + #[test] + fn converts_upstream_aliases() { + let params = ToolSchema::from_schema(&json!({ + "type": "object", + "properties": { + "s": { "type": "varchar" }, + "i": { "type": "unsigned_int" }, + "n": { "type": "float64" }, + "b": { "type": "binary" }, + "a": { "type": "sequence" }, + "o": { "type": "dict" } + } + })); + + assert_eq!(params.convert("s", "x"), json!("x")); + assert_eq!(params.convert("i", "7"), json!(7)); + assert_eq!(params.convert("n", "7.5"), json!(7.5)); + assert_eq!(params.convert("b", "true"), json!(true)); + assert_eq!(params.convert("a", "[1]"), json!([1])); + assert_eq!(params.convert("o", r#"{"x":1}"#), json!({ "x": 1 })); + } + + #[test] + fn preserves_union_type_order() { + let integer_first = ToolSchema::from_schema(&json!({ + "type": "object", + "properties": { + "value": { "type": ["integer", "string"] } + } + })); + let string_first = ToolSchema::from_schema(&json!({ + "type": "object", + "properties": { + "value": { "type": ["string", "integer"] } + } + })); + + assert_eq!(integer_first.convert("value", "42"), json!(42)); + assert_eq!(string_first.convert("value", "42"), json!("42")); + } + + #[test] + fn converts_composite_schemas() { + let params = ToolSchema::from_schema(&json!({ + "type": "object", + "properties": { + "choice": { + "anyOf": [ + { "type": "integer" }, + { "type": "string" } + ] + }, + "fallback_object": { + "oneOf": [ + { "type": "mystery" } + ] + } + } + })); + + assert_eq!(params.convert("choice", "42"), json!(42)); + assert_eq!( + params.convert("fallback_object", r#"{"x":1}"#), + json!({ "x": 1 }) + ); + } + + #[test] + fn infers_type_from_schema_shape_without_type() { + let params = ToolSchema::from_schema(&json!({ + "type": "object", + "properties": { + "choice": { "enum": ["a", "b"] }, + "items": { "items": { "type": "integer" } }, + "payload": { "properties": { "x": { "type": "integer" } } } + } + })); + + assert_eq!(params.convert("choice", "a"), json!("a")); + assert_eq!(params.convert("items", "[1,2]"), json!([1, 2])); + assert_eq!(params.convert("payload", r#"{"x":1}"#), json!({ "x": 1 })); + } + + #[test] + fn converts_params_for_known_tool() { + let schemas = ToolSchemas::from_tools(&[test_tool( + "search", + json!({ + "type": "object", + "properties": { + "query": { "type": "string" }, + "topn": { "type": "integer" } + } + }), + )]); + + let converted = schemas.convert_params_with_schema( + "search", + vec![ + ("query".to_string(), "rust".to_string()), + ("topn".to_string(), "5".to_string()), + ], + ); + + assert_eq!(converted.get("query"), Some(&json!("rust"))); + assert_eq!(converted.get("topn"), Some(&json!(5))); + } + + #[test] + fn convert_params_falls_back_to_string_for_failed_coercion() { + let schemas = ToolSchemas::from_tools(&[test_tool( + "convert", + json!({ + "type": "object", + "properties": { + "whole": { "type": "number" }, + "flag": { "type": "boolean" }, + "payload": { "type": "object" }, + "items": { "type": "array" }, + "missing_type": {} + } + }), + )]); + + let converted = schemas.convert_params_with_schema( + "convert", + vec![ + ("whole".to_string(), "not-a-number".to_string()), + ("flag".to_string(), "maybe".to_string()), + ("payload".to_string(), "not-json".to_string()), + ("items".to_string(), "not-json".to_string()), + ("missing_type".to_string(), "42".to_string()), + ("unknown_param".to_string(), "42".to_string()), + ], + ); + + assert_eq!(converted.get("whole"), Some(&json!("not-a-number"))); + assert_eq!(converted.get("flag"), Some(&json!("maybe"))); + assert_eq!(converted.get("payload"), Some(&json!("not-json"))); + assert_eq!(converted.get("items"), Some(&json!("not-json"))); + assert_eq!(converted.get("missing_type"), Some(&json!("42"))); + assert_eq!(converted.get("unknown_param"), Some(&json!("42"))); + } + + #[test] + fn convert_params_preserves_null_for_known_param() { + let schemas = ToolSchemas::from_tools(&[test_tool( + "convert", + json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + } + }), + )]); + + let converted = schemas + .convert_params_with_schema("convert", vec![("value".to_string(), "NULL".to_string())]); + + assert_eq!(converted.get("value"), Some(&json!(null))); + } + + #[test] + fn unknown_tool_converts_values_without_schema() { + let schemas = ToolSchemas::from_tools(&[test_tool( + "search", + json!({ "type": "object", "properties": {} }), + )]); + + let converted = schemas.convert_params_with_schema( + "missing", + vec![ + ("query".to_string(), "rust".to_string()), + ("topn".to_string(), "5".to_string()), + ("nullish".to_string(), "null".to_string()), + ], + ); + + assert_eq!(converted.get("query"), Some(&json!("rust"))); + assert_eq!(converted.get("topn"), Some(&json!("5"))); + assert_eq!(converted.get("nullish"), Some(&json!(null))); + } +} diff --git a/rust/src/tool-parser/src/qwen_coder.rs b/rust/src/tool-parser/src/qwen_coder.rs new file mode 100644 index 00000000000..227315f9bd2 --- /dev/null +++ b/rust/src/tool-parser/src/qwen_coder.rs @@ -0,0 +1,632 @@ +use winnow::ascii::multispace0 as ws0; +use winnow::combinator::{alt, delimited, eof, repeat, seq, terminated}; +use winnow::prelude::*; +use winnow::stream::Partial; +use winnow::token::{literal, take_until}; + +use super::parameters::ToolSchemas; +use super::utils::{parse_buffered_event, safe_text_len, xml_unescape}; +use super::{Result, ToolCallDelta, ToolParseResult, ToolParser}; +use crate::Tool; + +const TOOL_CALL_START: &str = ""; +const TOOL_CALL_END: &str = ""; +const FUNCTION_START: &str = " = Partial<&'i str>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum QwenCoderMode { + Text, + ToolCall, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum QwenCoderEvent { + Text { + len: usize, + }, + ToolCallStart, + ToolCall { + name: String, + raw_params: Vec<(String, String)>, + }, +} + +/// Tool parser for Qwen Coder XML-style tool calls. +/// +/// Example tool call content: +/// +/// ```text +/// +/// +/// 杭州 +/// +/// +/// ``` +/// +/// Arguments are emitted only after a full `tool_call` block is parsed. +/// +/// Note: parallel calls are represented as repeated +/// `...` blocks, not as multiple calls inside one tag. +pub struct Qwen3CoderToolParser { + buffer: String, + mode: QwenCoderMode, + emitted_tool_count: usize, + tool_parameters: ToolSchemas, +} + +impl Qwen3CoderToolParser { + /// Create a Qwen Coder tool parser. + fn new(tools: &[Tool]) -> Self { + Self { + buffer: String::new(), + mode: QwenCoderMode::Text, + emitted_tool_count: 0, + tool_parameters: ToolSchemas::from_tools(tools), + } + } + + /// Apply one parsed Qwen Coder event to parser state and output. + fn apply_event(&mut self, event: QwenCoderEvent, result: &mut ToolParseResult) -> Result<()> { + match event { + QwenCoderEvent::Text { len: consumed_len } => { + result.normal_text.push_str(&self.buffer[..consumed_len]); + } + QwenCoderEvent::ToolCallStart => self.mode = QwenCoderMode::ToolCall, + QwenCoderEvent::ToolCall { name, raw_params } => { + self.mode = QwenCoderMode::Text; + let arguments = self.tool_parameters.convert_params_with_schema(&name, raw_params); + let arguments = serde_json::to_string(&arguments) + .map_err(|error| parsing_failed!("failed to serialize arguments: {}", error))?; + + result.calls.push(ToolCallDelta { + tool_index: self.emitted_tool_count, + name: Some(name), + arguments, + }); + self.emitted_tool_count += 1; + } + } + Ok(()) + } + + /// Reset all streaming state. + fn reset(&mut self) { + self.buffer.clear(); + self.mode = QwenCoderMode::Text; + self.emitted_tool_count = 0; + } +} + +impl ToolParser for Qwen3CoderToolParser { + /// Create a boxed Qwen Coder tool parser. + fn create(tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self::new(tools))) + } + + /// Push one decoded text chunk through the Qwen Coder parser. + fn push(&mut self, chunk: &str) -> Result { + self.buffer.push_str(chunk); + let mut result = ToolParseResult::default(); + + while let Some((event, consumed_len)) = parse_buffered_event(&self.buffer, |input| { + parse_next_qwen_coder_event(input, self.mode) + })? { + self.apply_event(event, &mut result)?; + self.buffer.drain(..consumed_len); + } + + Ok(result) + } + + /// Flush buffered text and reset parser state. + fn finish(&mut self) -> Result { + let mut result = ToolParseResult::default(); + if !self.buffer.is_empty() { + if self.mode == QwenCoderMode::ToolCall || self.buffer.starts_with(TOOL_CALL_START) { + return Err(parsing_failed!("incomplete Qwen Coder tool call")); + } + result.normal_text.push_str(&self.buffer); + } + self.reset(); + Ok(result) + } +} + +/// Parse a Qwen Coder event for the current parser mode. +fn parse_next_qwen_coder_event( + input: &mut QwenCoderInput<'_>, + mode: QwenCoderMode, +) -> ModalResult { + match mode { + QwenCoderMode::Text => parse_text_event(input), + QwenCoderMode::ToolCall => tool_call_event(input), + } +} + +/// Parse a text-mode Qwen Coder event. +fn parse_text_event(input: &mut QwenCoderInput<'_>) -> ModalResult { + alt((tool_call_start_event, safe_text_event)).parse_next(input) +} + +/// Parse a Qwen Coder tool-call start marker. +fn tool_call_start_event(input: &mut QwenCoderInput<'_>) -> ModalResult { + literal(TOOL_CALL_START).value(QwenCoderEvent::ToolCallStart).parse_next(input) +} + +/// Parse a safe text run before the next Qwen Coder marker. +fn safe_text_event(input: &mut QwenCoderInput<'_>) -> ModalResult { + safe_text_len(input, TOOL_CALL_START).map(|len| QwenCoderEvent::Text { len }) +} + +/// Parse a complete Qwen Coder tool call. +fn tool_call_event(input: &mut QwenCoderInput<'_>) -> ModalResult { + let (body,) = seq!( + _: ws0, + take_until(0.., TOOL_CALL_END), + _: literal(TOOL_CALL_END), + ) + .parse_next(input)?; + + parse_tool_call_body(body) +} + +/// Parse a Qwen Coder function block. +fn function_event(input: &mut &str) -> ModalResult { + let (name, raw_params) = seq!( + _: literal(FUNCTION_START), + take_until(1.., ">"), + _: ">", + _: ws0, + repeat(0.., terminated(parameter, ws0)), + _: literal(FUNCTION_END), + ) + .parse_next(input)?; + + Ok(QwenCoderEvent::ToolCall { + name: name.to_string(), + raw_params, + }) +} + +/// Parse a Qwen Coder parameter block. +fn parameter(input: &mut &str) -> ModalResult<(String, String)> { + let (name, value) = seq!( + _: literal(PARAMETER_START), + take_until(1.., ">"), + _: ">", + take_until(0.., PARAMETER_END).map(trim_one_wrapping_newline).map(xml_unescape), + _: literal(PARAMETER_END), + ) + .parse_next(input)?; + + Ok((name.to_string(), value.into_owned())) +} + +/// Parse a Qwen Coder tool-call body. +fn parse_tool_call_body(body: &str) -> ModalResult { + let mut input = body; + delimited(ws0, function_event, (ws0, eof)).parse_next(&mut input) +} + +/// Trim a single leading and trailing newline from a parameter value. +fn trim_one_wrapping_newline(value: &str) -> &str { + let value = value.strip_prefix('\n').unwrap_or(value); + value.strip_suffix('\n').unwrap_or(value) +} + +#[cfg(test)] +mod tests { + use expect_test::expect; + use serde_json::{Value, json}; + use thiserror_ext::AsReport; + + use super::{Qwen3CoderToolParser, ToolParser}; + use crate::test_utils::{collect_stream, split_by_chars, test_tools}; + + fn build_tool_call(function_name: &str, params: &[(&str, &str)]) -> String { + let params = params + .iter() + .map(|(name, value)| format!("{value}")) + .collect::>() + .join("\n"); + format!("\n\n{params}\n\n") + } + + #[test] + fn qwen_coder_parse_complete_without_tool_call_keeps_text() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser.parse_complete("Hello, world!").unwrap(); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn qwen_coder_parse_complete_extracts_single_tool_call() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_call( + "get_weather", + &[("location", "SF"), ("date", "2026-04-29")], + )) + .unwrap(); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "location": "SF", + "date": "2026-04-29" + }) + ); + } + + #[test] + fn qwen_coder_parse_complete_preserves_prefix_text() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let output = format!( + "Thinking... {}", + build_tool_call("get_weather", &[("location", "NYC")]) + ); + let result = parser.parse_complete(&output).unwrap(); + + assert_eq!(result.normal_text, "Thinking... "); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn qwen_coder_parse_complete_converts_schema_types() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_call( + "convert", + &[ + ("whole", "5.0"), + ("flag", "true"), + ("payload", r#"{"nested":true}"#), + ("items", "[1,2]"), + ("empty", "42"), + ], + )) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "whole": 5.0, + "flag": true, + "payload": { "nested": true }, + "items": [1, 2], + "empty": "42", + }) + ); + } + + #[test] + fn qwen_coder_parse_complete_extracts_empty_arguments() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser.parse_complete(&build_tool_call("get_weather", &[])).unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({}) + ); + } + + #[test] + fn qwen_coder_parse_complete_handles_upstream_multiline_typed_params() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser + .parse_complete( + "\n\ + \n\ + \n\ + rectangle\n\ + \n\ + \n\ + {\"width\": 10,\n\ + \"height\": 20}\n\ + \n\ + \n\ + 2\n\ + \n\ + \n\ + ", + ) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("calculate_area")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "shape": "rectangle", + "dimensions": { "width": 10, "height": 20 }, + "precision": 2, + }) + ); + } + + #[test] + fn qwen_coder_parse_complete_handles_nested_json_parameter() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_call( + "convert", + &[( + "payload", + r#"{"nested":{"value":[1,2,3],"child":{"enabled":true}}}"#, + )], + )) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "payload": { + "nested": { + "value": [1, 2, 3], + "child": { "enabled": true }, + }, + }, + }) + ); + } + + #[test] + fn qwen_coder_parse_complete_preserves_xml_like_parameter_values() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_call( + "process", + &[ + ( + "html_content", + r#"
Hello
"#, + ), + ("xml_snippet", r#""#), + ], + )) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "html_content": r#"
Hello
"#, + "xml_snippet": r#""#, + }) + ); + } + + #[test] + fn qwen_coder_parse_complete_unescapes_literal_closing_tags_in_parameter_value() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_call( + "get_weather", + &[ + ( + "location", + "杭州 </parameter></function></tool_call>", + ), + ("date", "2026-05-08"), + ], + )) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "location": "杭州
", + "date": "2026-05-08", + }) + ); + } + + #[test] + fn qwen_coder_parse_complete_does_not_double_encode_anyof_object() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser + .parse_complete(&build_tool_call( + "update_record", + &[("data", r#"{"key":"value","count":42}"#)], + )) + .unwrap(); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ + "data": { "key": "value", "count": 42 }, + }) + ); + } + + #[test] + fn qwen_coder_streaming_extracts_single_tool_call() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "\n", + "\n", + "SF\n", + "\n", + "", + ], + ); + + assert!(result.normal_text.is_empty()); + assert_eq!(result.calls.len(), 1); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "location": "SF" }) + ); + } + + #[test] + fn qwen_coder_streaming_preserves_prefix_text() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = collect_stream( + &mut parser, + &[ + "Thinking... ", + "\n", + "\n", + "SF\n", + "\n", + "", + ], + ); + + assert_eq!(result.normal_text, "Thinking... "); + assert_eq!(result.calls.len(), 1); + } + + #[test] + fn qwen_coder_streaming_without_tool_call_emits_text_incrementally() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &["Hello, ", "world!"]); + + assert_eq!(result.normal_text, "Hello, world!"); + assert!(result.calls.is_empty()); + } + + #[test] + fn qwen_coder_streaming_extracts_multiple_tool_calls_in_order() { + let text = format!( + "{}\n{}", + build_tool_call("get_weather", &[("location", "SF")]), + build_tool_call("get_weather", &[("location", "NYC")]) + ); + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &[&text]); + + assert_eq!(result.calls.len(), 2); + assert_eq!(result.calls[0].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[1].name.as_deref(), Some("get_weather")); + assert_eq!(result.calls[0].tool_index, 0); + assert_eq!(result.calls[1].tool_index, 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "location": "SF" }) + ); + assert_eq!( + serde_json::from_str::(&result.calls[1].arguments).unwrap(), + json!({ "location": "NYC" }) + ); + } + + #[test] + fn qwen_coder_streaming_preserves_text_between_tool_calls() { + let text = format!( + "I'll check two cities.{}Between calls.{}Done.", + build_tool_call("get_weather", &[("city", "Dallas"), ("state", "TX")]), + build_tool_call("get_weather", &[("city", "Orlando"), ("state", "FL")]) + ); + let chunks = split_by_chars(&text, 5); + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &chunks); + + assert_eq!( + result.normal_text, + "I'll check two cities.Between calls.Done." + ); + assert_eq!(result.calls.len(), 2); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "city": "Dallas", "state": "TX" }) + ); + assert_eq!( + serde_json::from_str::(&result.calls[1].arguments).unwrap(), + json!({ "city": "Orlando", "state": "FL" }) + ); + } + + #[test] + fn qwen_coder_streaming_handles_start_token_split_across_chunks() { + let text = build_tool_call("get_weather", &[("location", "SF")]); + let chunks = split_by_chars(&text, 3); + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = collect_stream(&mut parser, &chunks); + + assert_eq!(result.calls.len(), 1); + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "location": "SF" }) + ); + } + + #[test] + fn qwen_coder_streaming_does_not_emit_incomplete_tool_call() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser + .push("\n\nSF") + .unwrap(); + + assert!(result.normal_text.is_empty()); + assert!(result.calls.is_empty()); + } + + #[test] + fn qwen_coder_finish_fails_incomplete_tool_call() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + parser + .push("\n\nSF") + .unwrap(); + + assert!(parser.finish().is_err()); + } + + #[test] + fn qwen_coder_malformed_tool_call_fails_fast() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let error = parser.push("\n\n").unwrap_err(); + + expect!["tool parser parsing failed: "].assert_eq(&error.to_report_string()); + } + + #[test] + fn qwen_coder_missing_parameter_end_fails_fast_after_function_end() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let error = parser + .push( + "\n\nSF\n", + ) + .unwrap_err(); + + expect!["tool parser parsing failed: "].assert_eq(&error.to_report_string()); + } + + #[test] + fn qwen_coder_parse_function_body_trims_one_wrapping_newline() { + let mut parser = Qwen3CoderToolParser::new(&test_tools()); + let result = parser + .parse_complete( + "\n\n\nHangzhou\n\n\n", + ) + .unwrap(); + + assert_eq!( + serde_json::from_str::(&result.calls[0].arguments).unwrap(), + json!({ "location": "Hangzhou" }) + ); + } +} diff --git a/rust/src/tool-parser/src/test_utils.rs b/rust/src/tool-parser/src/test_utils.rs new file mode 100644 index 00000000000..45dda6ddfdf --- /dev/null +++ b/rust/src/tool-parser/src/test_utils.rs @@ -0,0 +1,115 @@ +use serde_json::json; + +use super::{ToolParseResult, ToolParser}; +use crate::Tool; + +/// Build a reusable set of function tools for parser unit tests. +pub fn test_tools() -> Vec { + vec![ + Tool { + name: "get_weather".to_string(), + description: None, + parameters: json!({ + "type": "object", + "properties": { + "location": { "type": "string" }, + "city": { "type": "string" }, + "state": { "type": "string" }, + "unit": { "type": "string" }, + "date": { "type": "string" }, + "days": { "type": "integer" } + } + }), + strict: None, + }, + Tool { + name: "add".to_string(), + description: None, + parameters: json!({ + "type": "object", + "properties": { + "x": { "type": "integer" }, + "y": { "type": "integer" } + } + }), + strict: None, + }, + Tool { + name: "convert".to_string(), + description: None, + parameters: json!({ + "type": "object", + "properties": { + "whole": { "type": "number" }, + "flag": { "type": "boolean" }, + "payload": { "type": "object" }, + "items": { "type": "array" }, + "empty": { "type": "string" } + } + }), + strict: None, + }, + Tool { + name: "calculate_area".to_string(), + description: None, + parameters: json!({ + "type": "object", + "properties": { + "shape": { "type": "string" }, + "dimensions": { "type": "object" }, + "precision": { "type": "integer" } + } + }), + strict: None, + }, + Tool { + name: "update_record".to_string(), + description: None, + parameters: json!({ + "type": "object", + "properties": { + "data": { + "anyOf": [ + { "type": "object" }, + { "type": "null" } + ] + } + } + }), + strict: None, + }, + ] +} + +/// Push chunks through a streaming parser and coalesce its tool-call deltas. +pub fn collect_stream(parser: &mut T, chunks: &[&str]) -> ToolParseResult { + let mut result = ToolParseResult::default(); + for chunk in chunks { + result.append(parser.push(chunk).unwrap()); + } + result.append(parser.finish().unwrap()); + result.coalesce_calls() +} + +/// Split text into chunks containing at most `chunk_chars` Unicode scalar +/// values. +pub fn split_by_chars(text: &str, chunk_chars: usize) -> Vec<&str> { + let mut chunks = Vec::new(); + let mut start = 0; + let mut count = 0; + + for (index, _) in text.char_indices() { + if count == chunk_chars { + chunks.push(&text[start..index]); + start = index; + count = 0; + } + count += 1; + } + + if start < text.len() { + chunks.push(&text[start..]); + } + + chunks +} diff --git a/rust/src/tool-parser/src/tests.rs b/rust/src/tool-parser/src/tests.rs new file mode 100644 index 00000000000..73e3b7bbf35 --- /dev/null +++ b/rust/src/tool-parser/src/tests.rs @@ -0,0 +1,97 @@ +use super::{Result, Tool, ToolCallDelta, ToolParseResult, ToolParser}; + +struct DefaultParser; + +impl ToolParser for DefaultParser { + fn create(_tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self)) + } + + fn push(&mut self, _chunk: &str) -> Result { + Ok(ToolParseResult::default()) + } +} + +#[test] +fn tool_parser_does_not_preserve_special_tokens_by_default() { + let parser = DefaultParser; + + assert!(!parser.preserve_special_tokens()); +} + +#[test] +fn default_parse_complete_delegates_through_push_and_finish() { + struct StreamingParser; + + impl ToolParser for StreamingParser { + fn create(_tools: &[Tool]) -> Result> + where + Self: Sized + 'static, + { + Ok(Box::new(Self)) + } + + fn push(&mut self, _chunk: &str) -> Result { + Ok(ToolParseResult { + normal_text: "prefix ".to_string(), + calls: vec![ + ToolCallDelta { + tool_index: 0, + name: Some("weather".to_string()), + arguments: "{\"location\":".to_string(), + }, + ToolCallDelta { + tool_index: 0, + name: None, + arguments: "\"Paris\"".to_string(), + }, + ToolCallDelta { + tool_index: 1, + name: Some("time".to_string()), + arguments: "{\"timezone\":".to_string(), + }, + ], + }) + } + + fn finish(&mut self) -> Result { + Ok(ToolParseResult { + normal_text: "suffix".to_string(), + calls: vec![ + ToolCallDelta { + tool_index: 0, + name: None, + arguments: "}".to_string(), + }, + ToolCallDelta { + tool_index: 1, + name: None, + arguments: "\"UTC\"}".to_string(), + }, + ], + }) + } + } + + let mut parser = StreamingParser; + let result = parser.parse_complete("ignored").unwrap(); + assert_eq!(result.normal_text, "prefix suffix"); + assert_eq!( + result.calls, + vec![ + ToolCallDelta { + tool_index: 0, + name: Some("weather".to_string()), + arguments: "{\"location\":\"Paris\"}".to_string(), + }, + ToolCallDelta { + tool_index: 1, + name: Some("time".to_string()), + arguments: "{\"timezone\":\"UTC\"}".to_string(), + }, + ] + ); +} diff --git a/rust/src/tool-parser/src/utils.rs b/rust/src/tool-parser/src/utils.rs new file mode 100644 index 00000000000..171c1af0eec --- /dev/null +++ b/rust/src/tool-parser/src/utils.rs @@ -0,0 +1,581 @@ +//! Shared helpers for tool parsers. + +use std::borrow::Cow; + +use winnow::error::{ContextError, ErrMode, ModalResult, Needed, StrContext, StrContextValue}; +use winnow::stream::{Offset, Partial, Stream}; + +use super::Result; + +/// Return the byte length of the longest proper prefix of `token` that is also +/// a suffix of `buffer`. +/// +/// Streaming parsers use this to keep only the trailing fragment that might +/// still grow into a full marker after the next decoded chunk arrives. +/// +/// The returned length is always a valid UTF-8 boundary in `token`, so callers +/// can safely slice `&token[..len]` even when markers contain non-ASCII +/// characters such as DeepSeek's DSML delimiters. +pub(super) fn partial_prefix_len(buffer: &str, token: &str) -> usize { + let Some(first_byte) = token.as_bytes().first().copied() else { + return 0; + }; + + let max_len = buffer.len().min(token.len().saturating_sub(1)); + let tail_start = buffer.len() - max_len; + let buffer_bytes = buffer.as_bytes(); + let token_bytes = token.as_bytes(); + + // Scan from the longest possible suffix to preserve overlapping prefixes. + for index in tail_start..buffer.len() { + if buffer_bytes[index] != first_byte { + continue; + } + + let len = buffer.len() - index; + if buffer.is_char_boundary(index) + && token.is_char_boundary(len) + && token_bytes[..len] == buffer_bytes[index..] + { + return len; + } + } + + 0 +} + +/// Parse a safe text run before the next marker. +/// +/// Returns the text length in bytes, and advances the input. +pub(super) fn safe_text_len(input: &mut Partial<&str>, marker: &str) -> ModalResult { + let text = **input; + if text.is_empty() { + return incomplete(); + } + + if let Some(start_idx) = text.find(marker) { + input.next_slice(start_idx); + return Ok(start_idx); + } + + let keep_len = partial_prefix_len(text, marker); + let emit_len = text.len().saturating_sub(keep_len); + if emit_len == 0 { + return incomplete(); + } + + input.next_slice(emit_len); + Ok(emit_len) +} + +/// Decode XML/HTML entities in XML-style parameter values. +pub(super) fn xml_unescape(value: &str) -> Cow<'_, str> { + if !value.as_bytes().contains(&b'&') { + return Cow::Borrowed(value); + } + + let mut output: Option = None; + let mut copied_len = 0; + let mut rest = value; + + while let Some(ampersand) = rest.find('&') { + let before_ampersand = &rest[..ampersand]; + let after_ampersand = &rest[ampersand + '&'.len_utf8()..]; + if let Some(semicolon) = after_ampersand.find(';') { + let entity = &after_ampersand[..semicolon]; + if let Some(decoded) = decode_xml_entity(entity) { + match &mut output { + Some(output) => output.push_str(before_ampersand), + None => { + let mut new_output = String::with_capacity(value.len()); + new_output.push_str(&value[..copied_len + ampersand]); + output = Some(new_output); + } + } + let output = output.as_mut().expect("output is initialized above"); + output.push(decoded); + let consumed_len = ampersand + '&'.len_utf8() + semicolon + ';'.len_utf8(); + copied_len += consumed_len; + rest = &rest[consumed_len..]; + continue; + } + } + + if let Some(output) = &mut output { + output.push_str(before_ampersand); + output.push('&'); + } + let consumed_len = ampersand + '&'.len_utf8(); + copied_len += consumed_len; + rest = after_ampersand; + } + + if let Some(mut output) = output { + output.push_str(rest); + Cow::Owned(output) + } else { + Cow::Borrowed(value) + } +} + +fn decode_xml_entity(entity: &str) -> Option { + match entity { + "amp" => Some('&'), + "lt" => Some('<'), + "gt" => Some('>'), + "quot" => Some('"'), + "apos" => Some('\''), + entity if entity.starts_with("#x") || entity.starts_with("#X") => { + u32::from_str_radix(&entity[2..], 16).ok().and_then(char::from_u32) + } + entity if entity.starts_with('#') => { + entity[1..].parse::().ok().and_then(char::from_u32) + } + _ => None, + } +} + +/// Streaming lexical state for a top-level JSON object. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(super) struct JsonObjectScanState { + object_depth: usize, + array_depth: usize, + in_string: bool, + escape: bool, + phase: JsonObjectScanPhase, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +enum JsonObjectScanPhase { + #[default] + Initial, + Scanning, + Complete, +} + +impl JsonObjectScanState { + /// Returns whether the top-level JSON object has closed. + pub(super) const fn complete(&self) -> bool { + matches!(self.phase, JsonObjectScanPhase::Complete) + } +} + +/// Parse a raw top-level JSON object argument prefix. +/// +/// The returned length is safe to emit as raw argument text. This scans only +/// lexical boundaries from `{` through the matching `}`, preserving +/// malformed-but-balanced JSON without deserializing or normalizing it. +pub(super) fn take_json_object( + input: &mut Partial<&str>, + state: &mut JsonObjectScanState, +) -> ModalResult { + let text = **input; + if text.is_empty() { + return incomplete(); + } + if state.complete() { + return Err(json_scan_error( + "JSON object argument", + StrContextValue::Description("active JSON object scan"), + )); + } + + let bytes = text.as_bytes(); + let just_started = matches!(state.phase, JsonObjectScanPhase::Initial); + if just_started { + if bytes[0] != b'{' { + return Err(json_scan_error( + "JSON object argument", + StrContextValue::CharLiteral('{'), + )); + } + state.phase = JsonObjectScanPhase::Scanning; + state.object_depth = 1; + } + + let mut index = usize::from(just_started); + + while index < bytes.len() { + let byte = bytes[index]; + index += 1; + + if state.in_string { + if state.escape { + state.escape = false; + } else if byte == b'\\' { + state.escape = true; + } else if byte == b'"' { + state.in_string = false; + } + continue; + } + + match byte { + b'"' => state.in_string = true, + b'{' => state.object_depth += 1, + b'}' => { + state.object_depth = state.object_depth.checked_sub(1).ok_or_else(|| { + json_scan_error( + "JSON object argument", + StrContextValue::Description("balanced object braces"), + ) + })?; + if state.object_depth == 0 && state.array_depth == 0 { + state.phase = JsonObjectScanPhase::Complete; + input.next_slice(index); + return Ok(index); + } + if state.object_depth == 0 { + return Err(json_scan_error( + "JSON object argument", + StrContextValue::Description( + "nested arrays to close before the top-level object", + ), + )); + } + } + b'[' => state.array_depth += 1, + b']' => { + state.array_depth = state.array_depth.checked_sub(1).ok_or_else(|| { + json_scan_error( + "JSON object argument", + StrContextValue::Description("balanced array brackets"), + ) + })?; + } + _ => {} + } + } + + input.next_slice(text.len()); + Ok(text.len()) +} + +/// Parse a JSON string literal. +pub(super) fn json_str(input: &mut Partial<&str>) -> ModalResult { + let text = **input; + if text.is_empty() { + return incomplete(); + } + + let bytes = text.as_bytes(); + if bytes[0] != b'"' { + return Err(json_scan_error( + "JSON string", + StrContextValue::CharLiteral('"'), + )); + } + + let mut escape = false; + let mut index = 1; + while index < bytes.len() { + let byte = bytes[index]; + index += 1; + + if escape { + escape = false; + continue; + } + + match byte { + b'\\' => escape = true, + b'"' => { + let raw = &text[..index]; + let value = serde_json::from_str::(raw).map_err(|_| { + json_scan_error( + "JSON string", + StrContextValue::Description("valid JSON string"), + ) + })?; + input.next_slice(index); + return Ok(value); + } + _ => {} + } + } + + incomplete() +} + +fn json_scan_error(label: &'static str, expected: StrContextValue) -> ErrMode { + let mut error = ContextError::new(); + error.push(StrContext::Label(label)); + error.push(StrContext::Expected(expected)); + ErrMode::Cut(error) +} + +/// Parse one event from a buffered streaming input. +/// +/// Returns: +/// - `Ok(Some((event, consumed_len)))` if an event was successfully parsed, along with the number +/// of bytes consumed from the buffer. +/// - `Ok(None)` if the buffer does not contain a full event yet, and more data is needed. +/// - `Err` if a parsing error occurred. +pub(super) fn parse_buffered_event( + buffer: &str, + parse: impl FnOnce(&mut Partial<&str>) -> ModalResult, +) -> Result> { + let mut input = Partial::new(buffer); + let checkpoint = input.checkpoint(); + let event = match parse(&mut input) { + Ok(event) => event, + Err(ErrMode::Incomplete(_)) => return Ok(None), + Err(ErrMode::Backtrack(e) | ErrMode::Cut(e)) => { + // TODO: enrich context for error reporting + return Err(parsing_failed!("{}", e)); + } + }; + let consumed_len = input.offset_from(&checkpoint); + if consumed_len == 0 { + return Ok(None); + } + + Ok(Some((event, consumed_len))) +} + +/// Returns an error indicating that we need more data to continue parsing. +pub(super) fn incomplete() -> ModalResult { + Err(ErrMode::Incomplete(Needed::Unknown)) +} + +#[cfg(test)] +mod tests { + use std::borrow::Cow; + + use expect_test::expect; + use winnow::error::ErrMode; + use winnow::stream::{Offset, Partial, Stream}; + + use super::{ + JsonObjectScanState, json_str, partial_prefix_len, safe_text_len, take_json_object, + xml_unescape, + }; + + #[test] + fn partial_prefix_len_handles_ascii_markers() { + assert_eq!( + partial_prefix_len("hello<|tool", "<|tool_call>"), + "<|tool".len() + ); + assert_eq!(partial_prefix_len("hello world", "<|tool_call>"), 0); + } + + #[test] + fn partial_prefix_len_prefers_longest_overlapping_prefix() { + assert_eq!(partial_prefix_len("chunk ending in aba", "ababa"), 3); + } + + #[test] + fn partial_prefix_len_handles_unicode_markers() { + let token = "<|DSML|function_calls>"; + assert_eq!( + partial_prefix_len("prefix <|DSML|fun", token), + "<|DSML|fun".len() + ); + assert_eq!(partial_prefix_len("prefix <|DSML", token), "<|DSML".len()); + } + + #[test] + fn safe_text_len_stops_before_marker() { + let mut input = Partial::new("hello"); + let checkpoint = input.checkpoint(); + + let len = safe_text_len(&mut input, "").unwrap(); + + assert_eq!(len, "hello".len()); + assert_eq!(input.offset_from(&checkpoint), "hello".len()); + } + + #[test] + fn safe_text_len_holds_back_partial_marker() { + let mut input = Partial::new("hello").unwrap(); + + assert_eq!(len, "hello".len()); + assert_eq!(input.offset_from(&checkpoint), "hello".len()); + } + + #[test] + fn safe_text_len_reports_incomplete_for_only_partial_marker() { + let mut input = Partial::new("").unwrap_err(); + + assert!(matches!(error, ErrMode::Incomplete(_))); + } + + #[test] + fn xml_unescape_decodes_common_entities() { + assert_eq!( + xml_unescape("<tag attr="value">Tom & Jerry's</tag>"), + r#"Tom & Jerry's"# + ); + } + + #[test] + fn xml_unescape_decodes_numeric_entities() { + assert_eq!(xml_unescape("<tag>😀"), "😀"); + } + + #[test] + fn xml_unescape_preserves_unknown_and_incomplete_entities() { + let output = xml_unescape("Tom & Jerry &unknown; &"); + + assert!(matches!(output, Cow::Borrowed(_))); + assert_eq!(output, "Tom & Jerry &unknown; &"); + } + + #[test] + fn xml_unescape_borrows_when_no_entity_is_present() { + let input = "plain text"; + let output = xml_unescape(input); + + assert!(matches!(output, Cow::Borrowed(_))); + assert_eq!(output, input); + } + + #[test] + fn take_json_object_consumes_simple_object() { + let mut state = JsonObjectScanState::default(); + let buffer = r#"{"location":"Paris"}"#; + let mut input = Partial::new(buffer); + let checkpoint = input.checkpoint(); + + let len = take_json_object(&mut input, &mut state).unwrap(); + + assert_eq!(len, r#"{"location":"Paris"}"#.len()); + assert_eq!(input.offset_from(&checkpoint), len); + assert!(state.complete()); + } + + #[test] + fn take_json_object_tracks_nested_values_and_strings() { + let mut state = JsonObjectScanState::default(); + let arguments = r#"{"nested":{"items":[{"text":"} <|tool_call_end|> \" \\"}]}}"#; + let buffer = format!("{arguments}"); + let mut input = Partial::new(buffer.as_str()); + + let len = take_json_object(&mut input, &mut state).unwrap(); + + assert_eq!(len, arguments.len()); + assert!(state.complete()); + } + + #[test] + fn take_json_object_rejects_leading_whitespace() { + let mut state = JsonObjectScanState::default(); + let mut input = Partial::new(" {\"x\":1}"); + + let error = take_json_object(&mut input, &mut state).unwrap_err(); + + let ErrMode::Cut(error) = error else { + panic!("expected cut error"); + }; + expect![[r#" + invalid JSON object argument + expected `{`"#]] + .assert_eq(&error.to_string()); + } + + #[test] + fn take_json_object_leaves_trailing_whitespace_to_caller() { + let mut state = JsonObjectScanState::default(); + let mut input = Partial::new("{\"x\":1}\n"); + let checkpoint = input.checkpoint(); + + let len = take_json_object(&mut input, &mut state).unwrap(); + + assert_eq!(len, "{\"x\":1}".len()); + assert_eq!(input.offset_from(&checkpoint), len); + assert!(state.complete()); + } + + #[test] + fn take_json_object_continues_across_chunks() { + let mut state = JsonObjectScanState::default(); + let chunks = [ + r#"{"text":"literal "#, + r#"<|tool_call_end|>"#, + r#" inside"}"#, + ]; + let mut collected = String::new(); + + for chunk in chunks { + let mut input = Partial::new(chunk); + let len = take_json_object(&mut input, &mut state).unwrap(); + collected.push_str(&chunk[..len]); + } + + assert_eq!(collected, r#"{"text":"literal <|tool_call_end|> inside"}"#); + assert!(state.complete()); + } + + #[test] + fn take_json_object_rejects_non_object_top_level() { + let mut state = JsonObjectScanState::default(); + let mut input = Partial::new(r#"[{"x":1}]"#); + + let error = take_json_object(&mut input, &mut state).unwrap_err(); + + let ErrMode::Cut(error) = error else { + panic!("expected cut error"); + }; + expect![[r#" + invalid JSON object argument + expected `{`"#]] + .assert_eq(&error.to_string()); + } + + #[test] + fn take_json_object_reports_unbalanced_array() { + let mut state = JsonObjectScanState::default(); + let mut input = Partial::new(r#"{"x":]}"#); + + let error = take_json_object(&mut input, &mut state).unwrap_err(); + + let ErrMode::Cut(error) = error else { + panic!("expected cut error"); + }; + expect![[r#" + invalid JSON object argument + expected balanced array brackets"#]] + .assert_eq(&error.to_string()); + } + + #[test] + fn take_json_object_reports_top_level_close_before_nested_array() { + let mut state = JsonObjectScanState::default(); + let mut input = Partial::new(r#"{"x":[}"#); + + let error = take_json_object(&mut input, &mut state).unwrap_err(); + + let ErrMode::Cut(error) = error else { + panic!("expected cut error"); + }; + expect![[r#" + invalid JSON object argument + expected nested arrays to close before the top-level object"#]] + .assert_eq(&error.to_string()); + } + + #[test] + fn json_str_decodes_escaped_content() { + let mut input = Partial::new(r#""say_\"hi\u0021" rest"#); + + let value = json_str(&mut input).unwrap(); + + assert_eq!(value, "say_\"hi!"); + assert_eq!(*input, " rest"); + } + + #[test] + fn json_str_reports_incomplete_escaped_string() { + let mut input = Partial::new(r#""say_\"#); + + let error = json_str(&mut input).unwrap_err(); + + assert!(matches!(error, ErrMode::Incomplete(_))); + } +} diff --git a/tools/pre_commit/rust-check.sh b/tools/pre_commit/rust-check.sh new file mode 100755 index 00000000000..2bff1bfacd2 --- /dev/null +++ b/tools/pre_commit/rust-check.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +# Wrapper for the rust-* pre-commit hooks. +# +# Skips (with a warning) when `cargo` or the requested cargo subcommand is +# not installed, so contributors who don't touch the Rust code aren't forced +# to install the Rust toolchain (or niche cargo extensions like cargo-sort +# / cargo-autoinherit). Buildkite CI covers the rust hooks regardless. +# +# Usage: tools/pre_commit/rust-check.sh [extra cargo args...] + +set -euo pipefail + +# Pre-commit captures stdout/stderr and only replays on failure. Try to write +# to /dev/tty so the warning is visible during a normal `git commit` even +# though we exit 0; fall back to stderr where there's no controlling tty +# (e.g. CI). +# +# The leading newline pushes the warning off pre-commit's dot-leader line so +# the message doesn't mash into "Rust - ... ........WARNING:". The hook's +# "Passed" still lands on its own line just below the warning. +warn() { + { printf '\n%s\n' "$*" >/dev/tty; } 2>/dev/null || printf '\n%s\n' "$*" >&2 +} + +subcommand="$1" +shift + +if ! command -v cargo >/dev/null 2>&1; then + warn "WARNING: 'cargo' not found in PATH; skipping rust pre-commit hook (cargo ${subcommand}). + Install the Rust toolchain via https://rustup.rs/ if you need to run rust hooks locally." + exit 0 +fi + +# Cargo subcommands resolve to a `cargo-` binary on PATH. Check up-front +# so a missing helper produces a friendly skip instead of a cargo error. +if ! command -v "cargo-${subcommand}" >/dev/null 2>&1; then + case "${subcommand}" in + fmt) install_hint="rustup component add rustfmt" ;; + *) install_hint="cargo install cargo-${subcommand}" ;; + esac + warn "WARNING: 'cargo ${subcommand}' is not installed; skipping rust pre-commit hook. + Install it with: ${install_hint}" + exit 0 +fi + +cd "$(git rev-parse --show-toplevel)/rust" +exec cargo "${subcommand}" "$@"