[#10889][fix] fix pydantic deepcopy bug (#11004)

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2026-01-27 15:40:13 +08:00 committed by GitHub
parent 069ad30bdb
commit 93ae8a14ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 65 additions and 8 deletions

View File

@ -0,0 +1,24 @@
[
{
"messages": [
{
"role": "user",
"content": "What is the weather in San Francisco?"
},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\":\"San Francisco\",\"unit\":\"fahrenheit\"}"
}
}
]
}
]
}
]

View File

@ -155,7 +155,7 @@ class OpenAIHttpClient(OpenAIClient):
request: UCompletionRequest,
hooks: Optional[ResponseHooks] = None,
) -> AsyncGenerator[Any, None]:
json_data = request.model_dump(exclude_unset=True)
json_data = request.model_dump(exclude_unset=True, mode="json")
is_stream = request.stream
for attempt in range(self._max_retries + 1):
try:

View File

@ -13,7 +13,6 @@
# limitations under the License.
import asyncio
import copy
import os
from typing import Any, Callable, Dict, Optional
@ -145,12 +144,15 @@ class OpenAIDisaggregatedService(OpenAIService):
def _get_ctx_request(
self, request: UCompletionRequest, disagg_request_id: Optional[int]
) -> UCompletionRequest:
ctx_request = copy.deepcopy(request)
ctx_request.disaggregated_params = DisaggregatedParams(
request_type="context_only", disagg_request_id=disagg_request_id
ctx_request = request.model_copy(
update={
"disaggregated_params": DisaggregatedParams(
request_type="context_only", disagg_request_id=disagg_request_id
),
"stream": False,
"stream_options": None,
}
)
ctx_request.stream = False
ctx_request.stream_options = None
return ctx_request
def _get_gen_request(

View File

@ -95,6 +95,7 @@ def get_test_config(test_desc, example_dir, test_root):
(2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"),
"mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"),
"overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
"tool_calls": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
"perf_metrics": (2, f"{test_configs_root}/disagg_config_metrics.yaml"),
"trtllm_sampler":
(2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"),
@ -271,6 +272,8 @@ def run_client_tests(example_dir,
if prompt_file == "long_prompts.json":
# Use max_tokens 4 for long prompts to reduce test time
client_cmd.extend(['--max-tokens', '4'])
if test_desc == "tool_calls":
client_cmd.extend(['-e', 'chat', '-o', 'output_tool_calls.json'])
# Prepare poll processes
worker_processes = []
@ -283,6 +286,10 @@ def run_client_tests(example_dir,
poll_procs = worker_processes + [server_proc]
check_call(client_cmd, env=env, poll_procs=poll_procs)
# tool calls test does not need to run streaming and completion
if test_desc == "tool_calls":
continue
# Streaming client run
streaming_client_cmd = client_cmd + [
'--streaming', '-o', 'output_streaming.json'
@ -304,7 +311,7 @@ def run_client_tests(example_dir,
poll_procs=poll_procs)
# Skip output verification for long prompts test
if prompt_file == "long_prompts.json":
if prompt_file == "long_prompts.json" or prompt_file == "tool_call_prompts.json":
continue
if extra_endpoints_test is not None:
@ -786,6 +793,29 @@ def test_disaggregated_perf_metrics(disaggregated_test_root, llm_venv,
extra_endpoints_test=extra_endpoints_test)
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_chat_completion_tool_calls(disaggregated_test_root,
llm_venv,
disaggregated_example_root,
llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_test(disaggregated_example_root,
"tool_calls",
num_iters=1,
prompt_file="tool_call_prompts.json",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_kv_cache_time_output(disaggregated_test_root, llm_venv,

View File

@ -48,6 +48,7 @@ l0_a10:
- disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_conditional[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_ngram[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_chat_completion_tool_calls[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_workers.py::test_workers_conditional_disaggregation[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_workers.py::test_workers_kv_cache_events[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0]