diff --git a/examples/disaggregated/clients/disagg_client.py b/examples/disaggregated/clients/disagg_client.py index 3f74755ddd..e6c85d3223 100644 --- a/examples/disaggregated/clients/disagg_client.py +++ b/examples/disaggregated/clients/disagg_client.py @@ -54,12 +54,18 @@ async def send_request(session, server_host, server_port, model, prompt, if line.startswith("data: "): line = line[len("data: "):] response_json = json.loads(line) - text += response_json["choices"][0]["text"] + choices = response_json.get("choices", []) + if not choices: + continue + text += choices[0].get("text", "") logging.info(text) return text else: response_json = await response.json() - text = response_json["choices"][0]["text"] + choices = response_json.get("choices", []) + if not choices: + raise ValueError("Missing choices in completion response") + text = choices[0].get("text", "") logging.info(text) return text @@ -100,14 +106,21 @@ async def send_chat_request(session, server_host, server_port, model, prompt, if line.startswith("data: "): line = line[len("data: "):] response_json = json.loads(line) - if "content" in response_json["choices"][0]["delta"]: - text += response_json["choices"][0]["delta"][ - "content"] + choices = response_json.get("choices", []) + if not choices: + continue + delta = choices[0].get("delta", {}) + content = delta.get("content") + if content is not None: + text += content logging.info(text) return text else: response_json = await response.json() - text = response_json["choices"][0]["message"]["content"] + choices = response_json.get("choices", []) + if not choices: + raise ValueError("Missing choices in chat completion response") + text = choices[0].get("message", {}).get("content", "") logging.info(text) return text diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index f9e168a394..883483897c 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -25,7 +25,7 @@ from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionStreamResponse, ChatCompletionToolsParam, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, - UsageInfo) + UsageInfo, to_disaggregated_params) # yapf: enable @@ -1644,18 +1644,35 @@ def handle_non_streaming_response(tools: List[ChatCompletionToolsParam], tools_for_parser = tools_dict output = outputs[0] - parsed_output = harmony_adapter.harmony_output_to_openai( - output.token_ids, tools_for_parser, tool_choice) + disaggregated_params = output.disaggregated_params - # CONVERTED OUTPUT (after harmony to openai conversion) - logger.debug(f"✅ CONVERTED OUTPUT: {json.dumps(parsed_output, indent=2)}") + response_message = {} + finish_reason = output.finish_reason + usage_info = None + # skip harmony parsing for context only requests + if disaggregated_params is None or disaggregated_params.request_type != "context_only": + parsed_output = harmony_adapter.harmony_output_to_openai( + output.token_ids, tools_for_parser, tool_choice) - # Create response message - response_message = _create_response_message(parsed_output) + # CONVERTED OUTPUT (after harmony to openai conversion) + logger.debug( + f"✅ CONVERTED OUTPUT: {json.dumps(parsed_output, indent=2)}") - # Determine finish reason - finish_reason = _determine_finish_reason(parsed_output, - output.finish_reason) + # Create response message + response_message = _create_response_message(parsed_output) + + # Determine finish reason + finish_reason = _determine_finish_reason(parsed_output, + output.finish_reason) + # Optional: Log if harmony parsing failed (for debugging) + if parsed_output.get('_harmony_parsing_failed'): + logger.warning( + f"⚠️ Harmony parsing fell back to raw text decoding, {parsed_output}" + ) + else: + # Context only requests don't need a full response message, + # the real response will be responded by generation server + response_message = {"role": "assistant", "content": ""} # Create usage info from metrics (RequestOutput doesn't have usage in v1) usage_info = _create_usage_info(num_prompt_tokens, outputs) @@ -1667,14 +1684,12 @@ def handle_non_streaming_response(tools: List[ChatCompletionToolsParam], ChatCompletionResponseChoice( index=0, message=ChatMessage(**response_message), - finish_reason=finish_reason) + finish_reason=finish_reason, + disaggregated_params=to_disaggregated_params( + output.disaggregated_params)) ], usage=usage_info, ) - # Optional: Log if harmony parsing failed (for debugging) - if parsed_output.get('_harmony_parsing_failed'): - logger.warning("⚠️ Harmony parsing fell back to raw text decoding") - logger.debug(f"response\n\n{response}\n") return response diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 912fb84ac5..eee3b17de6 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -198,7 +198,7 @@ class OpenAIServer: @self.app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - return self.create_error_response(message=str(exc)) + return JSONResponse(status_code=400, content={"error": str(exc)}) if self.server_role is not ServerRole.MM_ENCODER: self.register_routes() @@ -493,6 +493,21 @@ class OpenAIServer: async with self.perf_metrics_lock: self.perf_metrics.append(item) + async def _create_chat_response(self, + promise: RequestOutput, postproc_params: PostprocParams, raw_request: Request, disaggregated_params: Optional[LlmDisaggregatedParams] = None) -> ChatCompletionResponse: + await promise.aresult() + if self.postproc_worker_enabled: + chat_response = promise.outputs[0]._postprocess_result + else: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + chat_response = post_processor(promise, args) + + if disaggregated_params is not None and chat_response.choices[0].disaggregated_params is None: + raise ValueError(f"disaggregated_params is not set in the response for request" + f" {disaggregated_params.disagg_request_id}") + + return chat_response + async def openai_chat(self, request: ChatCompletionRequest, raw_request: Request) -> Response: def get_role() -> str: @@ -525,22 +540,6 @@ class OpenAIServer: logger.error(traceback.format_exc()) raise - async def create_chat_response( - promise: RequestOutput, postproc_params: PostprocParams, disaggregated_params: Optional[LlmDisaggregatedParams] = None) -> ChatCompletionResponse: - await promise.aresult() - if self.postproc_worker_enabled: - chat_response =promise.outputs[0]._postprocess_result - else: - post_processor, args = postproc_params.post_processor, postproc_params.postproc_args - chat_response = post_processor(promise, args) - - # Add prompt_tokens_ids to the response - if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": - chat_response.prompt_token_ids = promise.prompt_token_ids - raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() - await self._extract_metrics(promise, raw_request) - return chat_response - try: conversation: List[ConversationMessage] = [] tool_dicts = None if request.tools is None else [ @@ -617,7 +616,7 @@ class OpenAIServer: return StreamingResponse(content=response_generator, media_type="text/event-stream") else: - response = await create_chat_response(promise, postproc_params, disaggregated_params) + response = await self._create_chat_response(promise, postproc_params, disaggregated_params) return JSONResponse(content=response.model_dump()) except CppExecutorError: logger.error(traceback.format_exc()) @@ -872,17 +871,6 @@ class OpenAIServer: Supports both streaming and non-streaming modes. """ - async def create_harmony_response( - promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse: - await promise.aresult() - if self.postproc_worker_enabled: - chat_response =promise.outputs[0]._postprocess_result - else: - post_processor, args = postproc_params.post_processor, postproc_params.postproc_args - chat_response = post_processor(promise, args) - - return chat_response - async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams): async for res in promise: if not self.postproc_worker_enabled: @@ -934,6 +922,8 @@ class OpenAIServer: vocab_size=self.tokenizer.tokenizer.vocab_size, reasoning_parser="gpt_oss") sampling_params.detokenize = False # Harmony adapter handles detokenization + disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params) + trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers)) postproc_args = ChatCompletionPostprocArgs.from_request(request) postproc_params = PostprocParams( @@ -949,6 +939,8 @@ class OpenAIServer: _postproc_params=postproc_params if self.postproc_worker_enabled else None, streaming=bool(request.stream), lora_request=request.lora_request, + disaggregated_params=disaggregated_params, + trace_headers=trace_headers, ) postproc_args.request_id = promise.request_id @@ -965,7 +957,7 @@ class OpenAIServer: media_type="text/event-stream" ) else: - response = await create_harmony_response(promise, postproc_params) + response = await self._create_chat_response(promise, postproc_params, raw_request, disaggregated_params) return JSONResponse(response.model_dump()) except Exception as e: diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 0eca927293..785866bb4e 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -19,6 +19,7 @@ import re import subprocess import tempfile import time +from collections import namedtuple from dataclasses import dataclass from typing import Callable @@ -199,6 +200,9 @@ def get_test_config(test_desc, example_dir, test_root): f"{test_configs_root}/disagg_config_ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml" ), "gpt_oss_120b_stress": + (4, + f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"), + "gpt_oss_120b_harmony": (4, f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"), "cancel_stress_test": @@ -248,6 +252,52 @@ def generate_worker_commands(model_path, config, server_config, return worker_commands +ClientTestSet = namedtuple('ClientTestSet', [ + 'completion', 'completion_streaming', 'chat', 'chat_streaming', + 'verify_completion', 'verify_streaming_completion', 'verify_chat', + 'verify_streaming_chat' +]) + + +def get_client_test_set(test_desc): + """Get the set of client tests to run for a given test description.""" + if test_desc == "tool_calls": + return ClientTestSet(completion=False, + completion_streaming=False, + chat=True, + chat_streaming=False, + verify_completion=False, + verify_streaming_completion=False, + verify_chat=False, + verify_streaming_chat=False) + if test_desc == "gpt_oss_120b_harmony": + return ClientTestSet(completion=True, + completion_streaming=True, + chat=True, + chat_streaming=True, + verify_completion=True, + verify_streaming_completion=True, + verify_chat=False, + verify_streaming_chat=False) + if test_desc in ("overlap", "trtllm_sampler"): + return ClientTestSet(completion=True, + completion_streaming=True, + chat=True, + chat_streaming=True, + verify_completion=True, + verify_streaming_completion=True, + verify_chat=True, + verify_streaming_chat=False) + return ClientTestSet(completion=True, + completion_streaming=True, + chat=False, + chat_streaming=False, + verify_completion=True, + verify_streaming_completion=True, + verify_chat=False, + verify_streaming_chat=False) + + def run_client_tests(example_dir, config_file, test_desc, @@ -259,8 +309,12 @@ def run_client_tests(example_dir, server_url, workers_proc, server_proc, - use_ray=False): + use_ray=False, + client_test_set=None): """Run client tests against the disaggregated server.""" + if client_test_set is None: + client_test_set = get_client_test_set(test_desc) + client_dir = f"{example_dir}/clients" for _ in range(num_iters): client_cmd = [ @@ -272,8 +326,6 @@ def run_client_tests(example_dir, if prompt_file == "long_prompts.json": # Use max_tokens 4 for long prompts to reduce test time client_cmd.extend(['--max-tokens', '4']) - if test_desc == "tool_calls": - client_cmd.extend(['-e', 'chat', '-o', 'output_tool_calls.json']) # Prepare poll processes worker_processes = [] @@ -284,33 +336,34 @@ def run_client_tests(example_dir, worker_processes = [workers_proc] poll_procs = worker_processes + [server_proc] - check_call(client_cmd, env=env, poll_procs=poll_procs) - # tool calls test does not need to run streaming and completion - if test_desc == "tool_calls": - continue + # Run completion test (non-streaming) + if client_test_set.completion: + check_call(client_cmd, env=env, poll_procs=poll_procs) - # Streaming client run - streaming_client_cmd = client_cmd + [ - '--streaming', '-o', 'output_streaming.json' - ] - check_call(streaming_client_cmd, env=env, poll_procs=poll_procs) - - # Run the chat completion endpoint test only for TinyLlama - if test_desc == "overlap" or test_desc == "trtllm_sampler": - chat_client_cmd = client_cmd + [ - '-e', 'chat', '-o', 'output_chat.json' + # Run streaming completion test + if client_test_set.completion_streaming: + streaming_client_cmd = client_cmd + [ + '--streaming', '-o', 'output_streaming.json' ] + check_call(streaming_client_cmd, env=env, poll_procs=poll_procs) + + # Run chat completion test + if client_test_set.chat: + chat_output = 'output_tool_calls.json' if test_desc == "tool_calls" else 'output_chat.json' + chat_client_cmd = client_cmd + ['-e', 'chat', '-o', chat_output] check_call(chat_client_cmd, env=env, poll_procs=poll_procs) - streaming_chat_client_cmd = chat_client_cmd + [ - '--streaming', '-o', 'output_streaming_chat.json' + # Run streaming chat completion test + if client_test_set.chat_streaming: + streaming_chat_client_cmd = client_cmd + [ + '-e', 'chat', '--streaming', '-o', 'output_streaming_chat.json' ] check_call(streaming_chat_client_cmd, env=env, poll_procs=poll_procs) - # Skip output verification for long prompts test + # Skip output verification for long prompts or tool call tests if prompt_file == "long_prompts.json" or prompt_file == "tool_call_prompts.json": continue @@ -320,11 +373,16 @@ def run_client_tests(example_dir, # Verify outputs not_expected_strings = ["Berlin Berlin"] - output_files = ['output.json', 'output_streaming.json'] - if test_desc == "overlap" or test_desc == "trtllm_sampler": - # Disable streaming chat completion for overlap test - # due to bug - output_files.extend(['output_chat.json']) + output_files = [] + if client_test_set.completion and client_test_set.verify_completion: + output_files.append('output.json') + if client_test_set.completion_streaming and client_test_set.verify_streaming_completion: + output_files.append('output_streaming.json') + if client_test_set.chat and client_test_set.verify_chat: + # Streaming chat completion output not verified due to known bug + output_files.append('output_chat.json') + if client_test_set.chat_streaming and client_test_set.verify_streaming_chat: + output_files.append('output_streaming_chat.json') if test_desc.startswith("gen_only"): continue @@ -336,6 +394,11 @@ def run_client_tests(example_dir, expected_strings = [ "Berlin", ["Asyncio is a", "Asyncio module in"] ] + elif "gpt_oss_120b" in test_desc: + expected_strings = [ + "The capital of Germany is Berlin", + "Using `asyncio` in Python" + ] else: expected_strings = [ "The capital of Germany is Berlin", @@ -2086,6 +2149,27 @@ def test_disaggregated_deepseek_v3_lite_bf16_tllm_gen_helix( prompt_file="long_prompts.json") +@skip_pre_blackwell +@pytest.mark.skip_less_device(4) +@pytest.mark.parametrize("model_path", ['gpt_oss/gpt-oss-120b']) +def test_disaggregated_gpt_oss_120b_harmony(disaggregated_test_root, + disaggregated_example_root, + llm_venv, model_path): + model_dir = f"{llm_models_root()}/{model_path}" + src_dst_dict = { + model_dir: f"{llm_venv.get_working_directory()}/{model_path}", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + run_disaggregated_test(disaggregated_example_root, + "gpt_oss_120b_harmony", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory()) + + @pytest.mark.timeout(12600) @pytest.mark.parametrize("test_config", [ pytest.param(TestConfig(model_path='DeepSeek-R1/DeepSeek-R1-0528-FP4-v2', diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 1bb60496a3..2640b2eaa3 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -26,6 +26,7 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] + - disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b] - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4] - accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] TIMEOUT (60) - condition: