From 93ae8a14ab01e98b35dd8e1c7263c17707f4eefe Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Tue, 27 Jan 2026 15:40:13 +0800 Subject: [PATCH] [#10889][fix] fix pydantic deepcopy bug (#11004) Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- .../clients/tool_call_prompts.json | 24 ++++++++++++++ tensorrt_llm/serve/openai_client.py | 2 +- tensorrt_llm/serve/openai_disagg_service.py | 14 ++++---- .../defs/disaggregated/test_disaggregated.py | 32 ++++++++++++++++++- .../integration/test_lists/test-db/l0_a10.yml | 1 + 5 files changed, 65 insertions(+), 8 deletions(-) create mode 100644 examples/disaggregated/clients/tool_call_prompts.json diff --git a/examples/disaggregated/clients/tool_call_prompts.json b/examples/disaggregated/clients/tool_call_prompts.json new file mode 100644 index 0000000000..d1690d5f47 --- /dev/null +++ b/examples/disaggregated/clients/tool_call_prompts.json @@ -0,0 +1,24 @@ +[ + { + "messages": [ + { + "role": "user", + "content": "What is the weather in San Francisco?" + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\":\"San Francisco\",\"unit\":\"fahrenheit\"}" + } + } + ] + } + ] + } +] diff --git a/tensorrt_llm/serve/openai_client.py b/tensorrt_llm/serve/openai_client.py index 951fba5a7d..25189fa73a 100644 --- a/tensorrt_llm/serve/openai_client.py +++ b/tensorrt_llm/serve/openai_client.py @@ -155,7 +155,7 @@ class OpenAIHttpClient(OpenAIClient): request: UCompletionRequest, hooks: Optional[ResponseHooks] = None, ) -> AsyncGenerator[Any, None]: - json_data = request.model_dump(exclude_unset=True) + json_data = request.model_dump(exclude_unset=True, mode="json") is_stream = request.stream for attempt in range(self._max_retries + 1): try: diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index c1fb8f2af5..0f43a6147d 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -13,7 +13,6 @@ # limitations under the License. import asyncio -import copy import os from typing import Any, Callable, Dict, Optional @@ -145,12 +144,15 @@ class OpenAIDisaggregatedService(OpenAIService): def _get_ctx_request( self, request: UCompletionRequest, disagg_request_id: Optional[int] ) -> UCompletionRequest: - ctx_request = copy.deepcopy(request) - ctx_request.disaggregated_params = DisaggregatedParams( - request_type="context_only", disagg_request_id=disagg_request_id + ctx_request = request.model_copy( + update={ + "disaggregated_params": DisaggregatedParams( + request_type="context_only", disagg_request_id=disagg_request_id + ), + "stream": False, + "stream_options": None, + } ) - ctx_request.stream = False - ctx_request.stream_options = None return ctx_request def _get_gen_request( diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 023681c0ff..0eca927293 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -95,6 +95,7 @@ def get_test_config(test_desc, example_dir, test_root): (2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"), "mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"), "overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"), + "tool_calls": (2, f"{test_configs_root}/disagg_config_overlap.yaml"), "perf_metrics": (2, f"{test_configs_root}/disagg_config_metrics.yaml"), "trtllm_sampler": (2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"), @@ -271,6 +272,8 @@ def run_client_tests(example_dir, if prompt_file == "long_prompts.json": # Use max_tokens 4 for long prompts to reduce test time client_cmd.extend(['--max-tokens', '4']) + if test_desc == "tool_calls": + client_cmd.extend(['-e', 'chat', '-o', 'output_tool_calls.json']) # Prepare poll processes worker_processes = [] @@ -283,6 +286,10 @@ def run_client_tests(example_dir, poll_procs = worker_processes + [server_proc] check_call(client_cmd, env=env, poll_procs=poll_procs) + # tool calls test does not need to run streaming and completion + if test_desc == "tool_calls": + continue + # Streaming client run streaming_client_cmd = client_cmd + [ '--streaming', '-o', 'output_streaming.json' @@ -304,7 +311,7 @@ def run_client_tests(example_dir, poll_procs=poll_procs) # Skip output verification for long prompts test - if prompt_file == "long_prompts.json": + if prompt_file == "long_prompts.json" or prompt_file == "tool_call_prompts.json": continue if extra_endpoints_test is not None: @@ -786,6 +793,29 @@ def test_disaggregated_perf_metrics(disaggregated_test_root, llm_venv, extra_endpoints_test=extra_endpoints_test) +@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], + indirect=True) +def test_disaggregated_chat_completion_tool_calls(disaggregated_test_root, + llm_venv, + disaggregated_example_root, + llama_model_root): + src_dst_dict = { + llama_model_root: + f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + run_disaggregated_test(disaggregated_example_root, + "tool_calls", + num_iters=1, + prompt_file="tool_call_prompts.json", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory()) + + @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], indirect=True) def test_disaggregated_kv_cache_time_output(disaggregated_test_root, llm_venv, diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 3080a37f38..4288f8ef18 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -48,6 +48,7 @@ l0_a10: - disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_conditional[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_ngram[TinyLlama-1.1B-Chat-v1.0] + - disaggregated/test_disaggregated.py::test_disaggregated_chat_completion_tool_calls[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_workers.py::test_workers_conditional_disaggregation[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_workers.py::test_workers_kv_cache_events[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0]