[TRTLLM-10866][feat] implement disaggregated harmony chat (#11336)

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2026-02-10 01:09:03 +08:00 committed by GitHub
parent 100bfdc516
commit e719721a60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 181 additions and 76 deletions

View File

@ -54,12 +54,18 @@ async def send_request(session, server_host, server_port, model, prompt,
if line.startswith("data: "):
line = line[len("data: "):]
response_json = json.loads(line)
text += response_json["choices"][0]["text"]
choices = response_json.get("choices", [])
if not choices:
continue
text += choices[0].get("text", "")
logging.info(text)
return text
else:
response_json = await response.json()
text = response_json["choices"][0]["text"]
choices = response_json.get("choices", [])
if not choices:
raise ValueError("Missing choices in completion response")
text = choices[0].get("text", "")
logging.info(text)
return text
@ -100,14 +106,21 @@ async def send_chat_request(session, server_host, server_port, model, prompt,
if line.startswith("data: "):
line = line[len("data: "):]
response_json = json.loads(line)
if "content" in response_json["choices"][0]["delta"]:
text += response_json["choices"][0]["delta"][
"content"]
choices = response_json.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
content = delta.get("content")
if content is not None:
text += content
logging.info(text)
return text
else:
response_json = await response.json()
text = response_json["choices"][0]["message"]["content"]
choices = response_json.get("choices", [])
if not choices:
raise ValueError("Missing choices in chat completion response")
text = choices[0].get("message", {}).get("content", "")
logging.info(text)
return text

View File

@ -25,7 +25,7 @@ from .openai_protocol import (ChatCompletionMessageParam,
ChatCompletionStreamResponse,
ChatCompletionToolsParam, ChatMessage,
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
UsageInfo)
UsageInfo, to_disaggregated_params)
# yapf: enable
@ -1644,18 +1644,35 @@ def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
tools_for_parser = tools_dict
output = outputs[0]
parsed_output = harmony_adapter.harmony_output_to_openai(
output.token_ids, tools_for_parser, tool_choice)
disaggregated_params = output.disaggregated_params
# CONVERTED OUTPUT (after harmony to openai conversion)
logger.debug(f"✅ CONVERTED OUTPUT: {json.dumps(parsed_output, indent=2)}")
response_message = {}
finish_reason = output.finish_reason
usage_info = None
# skip harmony parsing for context only requests
if disaggregated_params is None or disaggregated_params.request_type != "context_only":
parsed_output = harmony_adapter.harmony_output_to_openai(
output.token_ids, tools_for_parser, tool_choice)
# Create response message
response_message = _create_response_message(parsed_output)
# CONVERTED OUTPUT (after harmony to openai conversion)
logger.debug(
f"✅ CONVERTED OUTPUT: {json.dumps(parsed_output, indent=2)}")
# Determine finish reason
finish_reason = _determine_finish_reason(parsed_output,
output.finish_reason)
# Create response message
response_message = _create_response_message(parsed_output)
# Determine finish reason
finish_reason = _determine_finish_reason(parsed_output,
output.finish_reason)
# Optional: Log if harmony parsing failed (for debugging)
if parsed_output.get('_harmony_parsing_failed'):
logger.warning(
f"⚠️ Harmony parsing fell back to raw text decoding, {parsed_output}"
)
else:
# Context only requests don't need a full response message,
# the real response will be responded by generation server
response_message = {"role": "assistant", "content": ""}
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
usage_info = _create_usage_info(num_prompt_tokens, outputs)
@ -1667,14 +1684,12 @@ def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(**response_message),
finish_reason=finish_reason)
finish_reason=finish_reason,
disaggregated_params=to_disaggregated_params(
output.disaggregated_params))
],
usage=usage_info,
)
# Optional: Log if harmony parsing failed (for debugging)
if parsed_output.get('_harmony_parsing_failed'):
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
logger.debug(f"response\n\n{response}\n")
return response

View File

@ -198,7 +198,7 @@ class OpenAIServer:
@self.app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
return self.create_error_response(message=str(exc))
return JSONResponse(status_code=400, content={"error": str(exc)})
if self.server_role is not ServerRole.MM_ENCODER:
self.register_routes()
@ -493,6 +493,21 @@ class OpenAIServer:
async with self.perf_metrics_lock:
self.perf_metrics.append(item)
async def _create_chat_response(self,
promise: RequestOutput, postproc_params: PostprocParams, raw_request: Request, disaggregated_params: Optional[LlmDisaggregatedParams] = None) -> ChatCompletionResponse:
await promise.aresult()
if self.postproc_worker_enabled:
chat_response = promise.outputs[0]._postprocess_result
else:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
chat_response = post_processor(promise, args)
if disaggregated_params is not None and chat_response.choices[0].disaggregated_params is None:
raise ValueError(f"disaggregated_params is not set in the response for request"
f" {disaggregated_params.disagg_request_id}")
return chat_response
async def openai_chat(self, request: ChatCompletionRequest, raw_request: Request) -> Response:
def get_role() -> str:
@ -525,22 +540,6 @@ class OpenAIServer:
logger.error(traceback.format_exc())
raise
async def create_chat_response(
promise: RequestOutput, postproc_params: PostprocParams, disaggregated_params: Optional[LlmDisaggregatedParams] = None) -> ChatCompletionResponse:
await promise.aresult()
if self.postproc_worker_enabled:
chat_response =promise.outputs[0]._postprocess_result
else:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
chat_response = post_processor(promise, args)
# Add prompt_tokens_ids to the response
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
chat_response.prompt_token_ids = promise.prompt_token_ids
raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds()
await self._extract_metrics(promise, raw_request)
return chat_response
try:
conversation: List[ConversationMessage] = []
tool_dicts = None if request.tools is None else [
@ -617,7 +616,7 @@ class OpenAIServer:
return StreamingResponse(content=response_generator,
media_type="text/event-stream")
else:
response = await create_chat_response(promise, postproc_params, disaggregated_params)
response = await self._create_chat_response(promise, postproc_params, disaggregated_params)
return JSONResponse(content=response.model_dump())
except CppExecutorError:
logger.error(traceback.format_exc())
@ -872,17 +871,6 @@ class OpenAIServer:
Supports both streaming and non-streaming modes.
"""
async def create_harmony_response(
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
await promise.aresult()
if self.postproc_worker_enabled:
chat_response =promise.outputs[0]._postprocess_result
else:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
chat_response = post_processor(promise, args)
return chat_response
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
async for res in promise:
if not self.postproc_worker_enabled:
@ -934,6 +922,8 @@ class OpenAIServer:
vocab_size=self.tokenizer.tokenizer.vocab_size,
reasoning_parser="gpt_oss")
sampling_params.detokenize = False # Harmony adapter handles detokenization
disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params)
trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers))
postproc_args = ChatCompletionPostprocArgs.from_request(request)
postproc_params = PostprocParams(
@ -949,6 +939,8 @@ class OpenAIServer:
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
streaming=bool(request.stream),
lora_request=request.lora_request,
disaggregated_params=disaggregated_params,
trace_headers=trace_headers,
)
postproc_args.request_id = promise.request_id
@ -965,7 +957,7 @@ class OpenAIServer:
media_type="text/event-stream"
)
else:
response = await create_harmony_response(promise, postproc_params)
response = await self._create_chat_response(promise, postproc_params, raw_request, disaggregated_params)
return JSONResponse(response.model_dump())
except Exception as e:

View File

@ -19,6 +19,7 @@ import re
import subprocess
import tempfile
import time
from collections import namedtuple
from dataclasses import dataclass
from typing import Callable
@ -199,6 +200,9 @@ def get_test_config(test_desc, example_dir, test_root):
f"{test_configs_root}/disagg_config_ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml"
),
"gpt_oss_120b_stress":
(4,
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
"gpt_oss_120b_harmony":
(4,
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
"cancel_stress_test":
@ -248,6 +252,52 @@ def generate_worker_commands(model_path, config, server_config,
return worker_commands
ClientTestSet = namedtuple('ClientTestSet', [
'completion', 'completion_streaming', 'chat', 'chat_streaming',
'verify_completion', 'verify_streaming_completion', 'verify_chat',
'verify_streaming_chat'
])
def get_client_test_set(test_desc):
"""Get the set of client tests to run for a given test description."""
if test_desc == "tool_calls":
return ClientTestSet(completion=False,
completion_streaming=False,
chat=True,
chat_streaming=False,
verify_completion=False,
verify_streaming_completion=False,
verify_chat=False,
verify_streaming_chat=False)
if test_desc == "gpt_oss_120b_harmony":
return ClientTestSet(completion=True,
completion_streaming=True,
chat=True,
chat_streaming=True,
verify_completion=True,
verify_streaming_completion=True,
verify_chat=False,
verify_streaming_chat=False)
if test_desc in ("overlap", "trtllm_sampler"):
return ClientTestSet(completion=True,
completion_streaming=True,
chat=True,
chat_streaming=True,
verify_completion=True,
verify_streaming_completion=True,
verify_chat=True,
verify_streaming_chat=False)
return ClientTestSet(completion=True,
completion_streaming=True,
chat=False,
chat_streaming=False,
verify_completion=True,
verify_streaming_completion=True,
verify_chat=False,
verify_streaming_chat=False)
def run_client_tests(example_dir,
config_file,
test_desc,
@ -259,8 +309,12 @@ def run_client_tests(example_dir,
server_url,
workers_proc,
server_proc,
use_ray=False):
use_ray=False,
client_test_set=None):
"""Run client tests against the disaggregated server."""
if client_test_set is None:
client_test_set = get_client_test_set(test_desc)
client_dir = f"{example_dir}/clients"
for _ in range(num_iters):
client_cmd = [
@ -272,8 +326,6 @@ def run_client_tests(example_dir,
if prompt_file == "long_prompts.json":
# Use max_tokens 4 for long prompts to reduce test time
client_cmd.extend(['--max-tokens', '4'])
if test_desc == "tool_calls":
client_cmd.extend(['-e', 'chat', '-o', 'output_tool_calls.json'])
# Prepare poll processes
worker_processes = []
@ -284,33 +336,34 @@ def run_client_tests(example_dir,
worker_processes = [workers_proc]
poll_procs = worker_processes + [server_proc]
check_call(client_cmd, env=env, poll_procs=poll_procs)
# tool calls test does not need to run streaming and completion
if test_desc == "tool_calls":
continue
# Run completion test (non-streaming)
if client_test_set.completion:
check_call(client_cmd, env=env, poll_procs=poll_procs)
# Streaming client run
streaming_client_cmd = client_cmd + [
'--streaming', '-o', 'output_streaming.json'
]
check_call(streaming_client_cmd, env=env, poll_procs=poll_procs)
# Run the chat completion endpoint test only for TinyLlama
if test_desc == "overlap" or test_desc == "trtllm_sampler":
chat_client_cmd = client_cmd + [
'-e', 'chat', '-o', 'output_chat.json'
# Run streaming completion test
if client_test_set.completion_streaming:
streaming_client_cmd = client_cmd + [
'--streaming', '-o', 'output_streaming.json'
]
check_call(streaming_client_cmd, env=env, poll_procs=poll_procs)
# Run chat completion test
if client_test_set.chat:
chat_output = 'output_tool_calls.json' if test_desc == "tool_calls" else 'output_chat.json'
chat_client_cmd = client_cmd + ['-e', 'chat', '-o', chat_output]
check_call(chat_client_cmd, env=env, poll_procs=poll_procs)
streaming_chat_client_cmd = chat_client_cmd + [
'--streaming', '-o', 'output_streaming_chat.json'
# Run streaming chat completion test
if client_test_set.chat_streaming:
streaming_chat_client_cmd = client_cmd + [
'-e', 'chat', '--streaming', '-o', 'output_streaming_chat.json'
]
check_call(streaming_chat_client_cmd,
env=env,
poll_procs=poll_procs)
# Skip output verification for long prompts test
# Skip output verification for long prompts or tool call tests
if prompt_file == "long_prompts.json" or prompt_file == "tool_call_prompts.json":
continue
@ -320,11 +373,16 @@ def run_client_tests(example_dir,
# Verify outputs
not_expected_strings = ["Berlin Berlin"]
output_files = ['output.json', 'output_streaming.json']
if test_desc == "overlap" or test_desc == "trtllm_sampler":
# Disable streaming chat completion for overlap test
# due to bug
output_files.extend(['output_chat.json'])
output_files = []
if client_test_set.completion and client_test_set.verify_completion:
output_files.append('output.json')
if client_test_set.completion_streaming and client_test_set.verify_streaming_completion:
output_files.append('output_streaming.json')
if client_test_set.chat and client_test_set.verify_chat:
# Streaming chat completion output not verified due to known bug
output_files.append('output_chat.json')
if client_test_set.chat_streaming and client_test_set.verify_streaming_chat:
output_files.append('output_streaming_chat.json')
if test_desc.startswith("gen_only"):
continue
@ -336,6 +394,11 @@ def run_client_tests(example_dir,
expected_strings = [
"Berlin", ["Asyncio is a", "Asyncio module in"]
]
elif "gpt_oss_120b" in test_desc:
expected_strings = [
"The capital of Germany is Berlin",
"Using `asyncio` in Python"
]
else:
expected_strings = [
"The capital of Germany is Berlin",
@ -2086,6 +2149,27 @@ def test_disaggregated_deepseek_v3_lite_bf16_tllm_gen_helix(
prompt_file="long_prompts.json")
@skip_pre_blackwell
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("model_path", ['gpt_oss/gpt-oss-120b'])
def test_disaggregated_gpt_oss_120b_harmony(disaggregated_test_root,
disaggregated_example_root,
llm_venv, model_path):
model_dir = f"{llm_models_root()}/{model_path}"
src_dst_dict = {
model_dir: f"{llm_venv.get_working_directory()}/{model_path}",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_test(disaggregated_example_root,
"gpt_oss_120b_harmony",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
@pytest.mark.timeout(12600)
@pytest.mark.parametrize("test_config", [
pytest.param(TestConfig(model_path='DeepSeek-R1/DeepSeek-R1-0528-FP4-v2',

View File

@ -26,6 +26,7 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
- accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] TIMEOUT (60)
- condition: