mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
069ad30bdb
commit
93ae8a14ab
24
examples/disaggregated/clients/tool_call_prompts.json
Normal file
24
examples/disaggregated/clients/tool_call_prompts.json
Normal file
@ -0,0 +1,24 @@
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in San Francisco?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"location\":\"San Francisco\",\"unit\":\"fahrenheit\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
@ -155,7 +155,7 @@ class OpenAIHttpClient(OpenAIClient):
|
||||
request: UCompletionRequest,
|
||||
hooks: Optional[ResponseHooks] = None,
|
||||
) -> AsyncGenerator[Any, None]:
|
||||
json_data = request.model_dump(exclude_unset=True)
|
||||
json_data = request.model_dump(exclude_unset=True, mode="json")
|
||||
is_stream = request.stream
|
||||
for attempt in range(self._max_retries + 1):
|
||||
try:
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
@ -145,12 +144,15 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
def _get_ctx_request(
|
||||
self, request: UCompletionRequest, disagg_request_id: Optional[int]
|
||||
) -> UCompletionRequest:
|
||||
ctx_request = copy.deepcopy(request)
|
||||
ctx_request.disaggregated_params = DisaggregatedParams(
|
||||
request_type="context_only", disagg_request_id=disagg_request_id
|
||||
ctx_request = request.model_copy(
|
||||
update={
|
||||
"disaggregated_params": DisaggregatedParams(
|
||||
request_type="context_only", disagg_request_id=disagg_request_id
|
||||
),
|
||||
"stream": False,
|
||||
"stream_options": None,
|
||||
}
|
||||
)
|
||||
ctx_request.stream = False
|
||||
ctx_request.stream_options = None
|
||||
return ctx_request
|
||||
|
||||
def _get_gen_request(
|
||||
|
||||
@ -95,6 +95,7 @@ def get_test_config(test_desc, example_dir, test_root):
|
||||
(2, f"{test_configs_root}/disagg_config_cuda_graph_padding.yaml"),
|
||||
"mixed": (2, f"{test_configs_root}/disagg_config_mixed.yaml"),
|
||||
"overlap": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
|
||||
"tool_calls": (2, f"{test_configs_root}/disagg_config_overlap.yaml"),
|
||||
"perf_metrics": (2, f"{test_configs_root}/disagg_config_metrics.yaml"),
|
||||
"trtllm_sampler":
|
||||
(2, f"{test_configs_root}/disagg_config_trtllm_sampler.yaml"),
|
||||
@ -271,6 +272,8 @@ def run_client_tests(example_dir,
|
||||
if prompt_file == "long_prompts.json":
|
||||
# Use max_tokens 4 for long prompts to reduce test time
|
||||
client_cmd.extend(['--max-tokens', '4'])
|
||||
if test_desc == "tool_calls":
|
||||
client_cmd.extend(['-e', 'chat', '-o', 'output_tool_calls.json'])
|
||||
|
||||
# Prepare poll processes
|
||||
worker_processes = []
|
||||
@ -283,6 +286,10 @@ def run_client_tests(example_dir,
|
||||
poll_procs = worker_processes + [server_proc]
|
||||
check_call(client_cmd, env=env, poll_procs=poll_procs)
|
||||
|
||||
# tool calls test does not need to run streaming and completion
|
||||
if test_desc == "tool_calls":
|
||||
continue
|
||||
|
||||
# Streaming client run
|
||||
streaming_client_cmd = client_cmd + [
|
||||
'--streaming', '-o', 'output_streaming.json'
|
||||
@ -304,7 +311,7 @@ def run_client_tests(example_dir,
|
||||
poll_procs=poll_procs)
|
||||
|
||||
# Skip output verification for long prompts test
|
||||
if prompt_file == "long_prompts.json":
|
||||
if prompt_file == "long_prompts.json" or prompt_file == "tool_call_prompts.json":
|
||||
continue
|
||||
|
||||
if extra_endpoints_test is not None:
|
||||
@ -786,6 +793,29 @@ def test_disaggregated_perf_metrics(disaggregated_test_root, llm_venv,
|
||||
extra_endpoints_test=extra_endpoints_test)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_disaggregated_chat_completion_tool_calls(disaggregated_test_root,
|
||||
llm_venv,
|
||||
disaggregated_example_root,
|
||||
llama_model_root):
|
||||
src_dst_dict = {
|
||||
llama_model_root:
|
||||
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"tool_calls",
|
||||
num_iters=1,
|
||||
prompt_file="tool_call_prompts.json",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_disaggregated_kv_cache_time_output(disaggregated_test_root, llm_venv,
|
||||
|
||||
@ -48,6 +48,7 @@ l0_a10:
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_conditional[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_ngram[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_chat_completion_tool_calls[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_workers.py::test_workers_conditional_disaggregation[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_workers.py::test_workers_kv_cache_events[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user