mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-10866][feat] implement disaggregated harmony chat (#11336)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
100bfdc516
commit
e719721a60
@ -54,12 +54,18 @@ async def send_request(session, server_host, server_port, model, prompt,
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: "):]
|
||||
response_json = json.loads(line)
|
||||
text += response_json["choices"][0]["text"]
|
||||
choices = response_json.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
text += choices[0].get("text", "")
|
||||
logging.info(text)
|
||||
return text
|
||||
else:
|
||||
response_json = await response.json()
|
||||
text = response_json["choices"][0]["text"]
|
||||
choices = response_json.get("choices", [])
|
||||
if not choices:
|
||||
raise ValueError("Missing choices in completion response")
|
||||
text = choices[0].get("text", "")
|
||||
logging.info(text)
|
||||
return text
|
||||
|
||||
@ -100,14 +106,21 @@ async def send_chat_request(session, server_host, server_port, model, prompt,
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: "):]
|
||||
response_json = json.loads(line)
|
||||
if "content" in response_json["choices"][0]["delta"]:
|
||||
text += response_json["choices"][0]["delta"][
|
||||
"content"]
|
||||
choices = response_json.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content")
|
||||
if content is not None:
|
||||
text += content
|
||||
logging.info(text)
|
||||
return text
|
||||
else:
|
||||
response_json = await response.json()
|
||||
text = response_json["choices"][0]["message"]["content"]
|
||||
choices = response_json.get("choices", [])
|
||||
if not choices:
|
||||
raise ValueError("Missing choices in chat completion response")
|
||||
text = choices[0].get("message", {}).get("content", "")
|
||||
logging.info(text)
|
||||
return text
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ from .openai_protocol import (ChatCompletionMessageParam,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionToolsParam, ChatMessage,
|
||||
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
|
||||
UsageInfo)
|
||||
UsageInfo, to_disaggregated_params)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -1644,18 +1644,35 @@ def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
tools_for_parser = tools_dict
|
||||
|
||||
output = outputs[0]
|
||||
parsed_output = harmony_adapter.harmony_output_to_openai(
|
||||
output.token_ids, tools_for_parser, tool_choice)
|
||||
disaggregated_params = output.disaggregated_params
|
||||
|
||||
# CONVERTED OUTPUT (after harmony to openai conversion)
|
||||
logger.debug(f"✅ CONVERTED OUTPUT: {json.dumps(parsed_output, indent=2)}")
|
||||
response_message = {}
|
||||
finish_reason = output.finish_reason
|
||||
usage_info = None
|
||||
# skip harmony parsing for context only requests
|
||||
if disaggregated_params is None or disaggregated_params.request_type != "context_only":
|
||||
parsed_output = harmony_adapter.harmony_output_to_openai(
|
||||
output.token_ids, tools_for_parser, tool_choice)
|
||||
|
||||
# Create response message
|
||||
response_message = _create_response_message(parsed_output)
|
||||
# CONVERTED OUTPUT (after harmony to openai conversion)
|
||||
logger.debug(
|
||||
f"✅ CONVERTED OUTPUT: {json.dumps(parsed_output, indent=2)}")
|
||||
|
||||
# Determine finish reason
|
||||
finish_reason = _determine_finish_reason(parsed_output,
|
||||
output.finish_reason)
|
||||
# Create response message
|
||||
response_message = _create_response_message(parsed_output)
|
||||
|
||||
# Determine finish reason
|
||||
finish_reason = _determine_finish_reason(parsed_output,
|
||||
output.finish_reason)
|
||||
# Optional: Log if harmony parsing failed (for debugging)
|
||||
if parsed_output.get('_harmony_parsing_failed'):
|
||||
logger.warning(
|
||||
f"⚠️ Harmony parsing fell back to raw text decoding, {parsed_output}"
|
||||
)
|
||||
else:
|
||||
# Context only requests don't need a full response message,
|
||||
# the real response will be responded by generation server
|
||||
response_message = {"role": "assistant", "content": ""}
|
||||
|
||||
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
|
||||
usage_info = _create_usage_info(num_prompt_tokens, outputs)
|
||||
@ -1667,14 +1684,12 @@ def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(**response_message),
|
||||
finish_reason=finish_reason)
|
||||
finish_reason=finish_reason,
|
||||
disaggregated_params=to_disaggregated_params(
|
||||
output.disaggregated_params))
|
||||
],
|
||||
usage=usage_info,
|
||||
)
|
||||
# Optional: Log if harmony parsing failed (for debugging)
|
||||
if parsed_output.get('_harmony_parsing_failed'):
|
||||
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
|
||||
logger.debug(f"response\n\n{response}\n")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -198,7 +198,7 @@ class OpenAIServer:
|
||||
|
||||
@self.app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
return self.create_error_response(message=str(exc))
|
||||
return JSONResponse(status_code=400, content={"error": str(exc)})
|
||||
|
||||
if self.server_role is not ServerRole.MM_ENCODER:
|
||||
self.register_routes()
|
||||
@ -493,6 +493,21 @@ class OpenAIServer:
|
||||
async with self.perf_metrics_lock:
|
||||
self.perf_metrics.append(item)
|
||||
|
||||
async def _create_chat_response(self,
|
||||
promise: RequestOutput, postproc_params: PostprocParams, raw_request: Request, disaggregated_params: Optional[LlmDisaggregatedParams] = None) -> ChatCompletionResponse:
|
||||
await promise.aresult()
|
||||
if self.postproc_worker_enabled:
|
||||
chat_response = promise.outputs[0]._postprocess_result
|
||||
else:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
chat_response = post_processor(promise, args)
|
||||
|
||||
if disaggregated_params is not None and chat_response.choices[0].disaggregated_params is None:
|
||||
raise ValueError(f"disaggregated_params is not set in the response for request"
|
||||
f" {disaggregated_params.disagg_request_id}")
|
||||
|
||||
return chat_response
|
||||
|
||||
async def openai_chat(self, request: ChatCompletionRequest, raw_request: Request) -> Response:
|
||||
|
||||
def get_role() -> str:
|
||||
@ -525,22 +540,6 @@ class OpenAIServer:
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
async def create_chat_response(
|
||||
promise: RequestOutput, postproc_params: PostprocParams, disaggregated_params: Optional[LlmDisaggregatedParams] = None) -> ChatCompletionResponse:
|
||||
await promise.aresult()
|
||||
if self.postproc_worker_enabled:
|
||||
chat_response =promise.outputs[0]._postprocess_result
|
||||
else:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
chat_response = post_processor(promise, args)
|
||||
|
||||
# Add prompt_tokens_ids to the response
|
||||
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
|
||||
chat_response.prompt_token_ids = promise.prompt_token_ids
|
||||
raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds()
|
||||
await self._extract_metrics(promise, raw_request)
|
||||
return chat_response
|
||||
|
||||
try:
|
||||
conversation: List[ConversationMessage] = []
|
||||
tool_dicts = None if request.tools is None else [
|
||||
@ -617,7 +616,7 @@ class OpenAIServer:
|
||||
return StreamingResponse(content=response_generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
response = await create_chat_response(promise, postproc_params, disaggregated_params)
|
||||
response = await self._create_chat_response(promise, postproc_params, disaggregated_params)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
except CppExecutorError:
|
||||
logger.error(traceback.format_exc())
|
||||
@ -872,17 +871,6 @@ class OpenAIServer:
|
||||
Supports both streaming and non-streaming modes.
|
||||
"""
|
||||
|
||||
async def create_harmony_response(
|
||||
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
|
||||
await promise.aresult()
|
||||
if self.postproc_worker_enabled:
|
||||
chat_response =promise.outputs[0]._postprocess_result
|
||||
else:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
chat_response = post_processor(promise, args)
|
||||
|
||||
return chat_response
|
||||
|
||||
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
|
||||
async for res in promise:
|
||||
if not self.postproc_worker_enabled:
|
||||
@ -934,6 +922,8 @@ class OpenAIServer:
|
||||
vocab_size=self.tokenizer.tokenizer.vocab_size,
|
||||
reasoning_parser="gpt_oss")
|
||||
sampling_params.detokenize = False # Harmony adapter handles detokenization
|
||||
disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params)
|
||||
trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers))
|
||||
|
||||
postproc_args = ChatCompletionPostprocArgs.from_request(request)
|
||||
postproc_params = PostprocParams(
|
||||
@ -949,6 +939,8 @@ class OpenAIServer:
|
||||
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
|
||||
streaming=bool(request.stream),
|
||||
lora_request=request.lora_request,
|
||||
disaggregated_params=disaggregated_params,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
postproc_args.request_id = promise.request_id
|
||||
|
||||
@ -965,7 +957,7 @@ class OpenAIServer:
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
response = await create_harmony_response(promise, postproc_params)
|
||||
response = await self._create_chat_response(promise, postproc_params, raw_request, disaggregated_params)
|
||||
return JSONResponse(response.model_dump())
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -19,6 +19,7 @@ import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
@ -199,6 +200,9 @@ def get_test_config(test_desc, example_dir, test_root):
|
||||
f"{test_configs_root}/disagg_config_ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml"
|
||||
),
|
||||
"gpt_oss_120b_stress":
|
||||
(4,
|
||||
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
|
||||
"gpt_oss_120b_harmony":
|
||||
(4,
|
||||
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_gptoss_tllm.yaml"),
|
||||
"cancel_stress_test":
|
||||
@ -248,6 +252,52 @@ def generate_worker_commands(model_path, config, server_config,
|
||||
return worker_commands
|
||||
|
||||
|
||||
ClientTestSet = namedtuple('ClientTestSet', [
|
||||
'completion', 'completion_streaming', 'chat', 'chat_streaming',
|
||||
'verify_completion', 'verify_streaming_completion', 'verify_chat',
|
||||
'verify_streaming_chat'
|
||||
])
|
||||
|
||||
|
||||
def get_client_test_set(test_desc):
|
||||
"""Get the set of client tests to run for a given test description."""
|
||||
if test_desc == "tool_calls":
|
||||
return ClientTestSet(completion=False,
|
||||
completion_streaming=False,
|
||||
chat=True,
|
||||
chat_streaming=False,
|
||||
verify_completion=False,
|
||||
verify_streaming_completion=False,
|
||||
verify_chat=False,
|
||||
verify_streaming_chat=False)
|
||||
if test_desc == "gpt_oss_120b_harmony":
|
||||
return ClientTestSet(completion=True,
|
||||
completion_streaming=True,
|
||||
chat=True,
|
||||
chat_streaming=True,
|
||||
verify_completion=True,
|
||||
verify_streaming_completion=True,
|
||||
verify_chat=False,
|
||||
verify_streaming_chat=False)
|
||||
if test_desc in ("overlap", "trtllm_sampler"):
|
||||
return ClientTestSet(completion=True,
|
||||
completion_streaming=True,
|
||||
chat=True,
|
||||
chat_streaming=True,
|
||||
verify_completion=True,
|
||||
verify_streaming_completion=True,
|
||||
verify_chat=True,
|
||||
verify_streaming_chat=False)
|
||||
return ClientTestSet(completion=True,
|
||||
completion_streaming=True,
|
||||
chat=False,
|
||||
chat_streaming=False,
|
||||
verify_completion=True,
|
||||
verify_streaming_completion=True,
|
||||
verify_chat=False,
|
||||
verify_streaming_chat=False)
|
||||
|
||||
|
||||
def run_client_tests(example_dir,
|
||||
config_file,
|
||||
test_desc,
|
||||
@ -259,8 +309,12 @@ def run_client_tests(example_dir,
|
||||
server_url,
|
||||
workers_proc,
|
||||
server_proc,
|
||||
use_ray=False):
|
||||
use_ray=False,
|
||||
client_test_set=None):
|
||||
"""Run client tests against the disaggregated server."""
|
||||
if client_test_set is None:
|
||||
client_test_set = get_client_test_set(test_desc)
|
||||
|
||||
client_dir = f"{example_dir}/clients"
|
||||
for _ in range(num_iters):
|
||||
client_cmd = [
|
||||
@ -272,8 +326,6 @@ def run_client_tests(example_dir,
|
||||
if prompt_file == "long_prompts.json":
|
||||
# Use max_tokens 4 for long prompts to reduce test time
|
||||
client_cmd.extend(['--max-tokens', '4'])
|
||||
if test_desc == "tool_calls":
|
||||
client_cmd.extend(['-e', 'chat', '-o', 'output_tool_calls.json'])
|
||||
|
||||
# Prepare poll processes
|
||||
worker_processes = []
|
||||
@ -284,33 +336,34 @@ def run_client_tests(example_dir,
|
||||
worker_processes = [workers_proc]
|
||||
|
||||
poll_procs = worker_processes + [server_proc]
|
||||
check_call(client_cmd, env=env, poll_procs=poll_procs)
|
||||
|
||||
# tool calls test does not need to run streaming and completion
|
||||
if test_desc == "tool_calls":
|
||||
continue
|
||||
# Run completion test (non-streaming)
|
||||
if client_test_set.completion:
|
||||
check_call(client_cmd, env=env, poll_procs=poll_procs)
|
||||
|
||||
# Streaming client run
|
||||
streaming_client_cmd = client_cmd + [
|
||||
'--streaming', '-o', 'output_streaming.json'
|
||||
]
|
||||
check_call(streaming_client_cmd, env=env, poll_procs=poll_procs)
|
||||
|
||||
# Run the chat completion endpoint test only for TinyLlama
|
||||
if test_desc == "overlap" or test_desc == "trtllm_sampler":
|
||||
chat_client_cmd = client_cmd + [
|
||||
'-e', 'chat', '-o', 'output_chat.json'
|
||||
# Run streaming completion test
|
||||
if client_test_set.completion_streaming:
|
||||
streaming_client_cmd = client_cmd + [
|
||||
'--streaming', '-o', 'output_streaming.json'
|
||||
]
|
||||
check_call(streaming_client_cmd, env=env, poll_procs=poll_procs)
|
||||
|
||||
# Run chat completion test
|
||||
if client_test_set.chat:
|
||||
chat_output = 'output_tool_calls.json' if test_desc == "tool_calls" else 'output_chat.json'
|
||||
chat_client_cmd = client_cmd + ['-e', 'chat', '-o', chat_output]
|
||||
check_call(chat_client_cmd, env=env, poll_procs=poll_procs)
|
||||
|
||||
streaming_chat_client_cmd = chat_client_cmd + [
|
||||
'--streaming', '-o', 'output_streaming_chat.json'
|
||||
# Run streaming chat completion test
|
||||
if client_test_set.chat_streaming:
|
||||
streaming_chat_client_cmd = client_cmd + [
|
||||
'-e', 'chat', '--streaming', '-o', 'output_streaming_chat.json'
|
||||
]
|
||||
check_call(streaming_chat_client_cmd,
|
||||
env=env,
|
||||
poll_procs=poll_procs)
|
||||
|
||||
# Skip output verification for long prompts test
|
||||
# Skip output verification for long prompts or tool call tests
|
||||
if prompt_file == "long_prompts.json" or prompt_file == "tool_call_prompts.json":
|
||||
continue
|
||||
|
||||
@ -320,11 +373,16 @@ def run_client_tests(example_dir,
|
||||
# Verify outputs
|
||||
not_expected_strings = ["Berlin Berlin"]
|
||||
|
||||
output_files = ['output.json', 'output_streaming.json']
|
||||
if test_desc == "overlap" or test_desc == "trtllm_sampler":
|
||||
# Disable streaming chat completion for overlap test
|
||||
# due to bug
|
||||
output_files.extend(['output_chat.json'])
|
||||
output_files = []
|
||||
if client_test_set.completion and client_test_set.verify_completion:
|
||||
output_files.append('output.json')
|
||||
if client_test_set.completion_streaming and client_test_set.verify_streaming_completion:
|
||||
output_files.append('output_streaming.json')
|
||||
if client_test_set.chat and client_test_set.verify_chat:
|
||||
# Streaming chat completion output not verified due to known bug
|
||||
output_files.append('output_chat.json')
|
||||
if client_test_set.chat_streaming and client_test_set.verify_streaming_chat:
|
||||
output_files.append('output_streaming_chat.json')
|
||||
|
||||
if test_desc.startswith("gen_only"):
|
||||
continue
|
||||
@ -336,6 +394,11 @@ def run_client_tests(example_dir,
|
||||
expected_strings = [
|
||||
"Berlin", ["Asyncio is a", "Asyncio module in"]
|
||||
]
|
||||
elif "gpt_oss_120b" in test_desc:
|
||||
expected_strings = [
|
||||
"The capital of Germany is Berlin",
|
||||
"Using `asyncio` in Python"
|
||||
]
|
||||
else:
|
||||
expected_strings = [
|
||||
"The capital of Germany is Berlin",
|
||||
@ -2086,6 +2149,27 @@ def test_disaggregated_deepseek_v3_lite_bf16_tllm_gen_helix(
|
||||
prompt_file="long_prompts.json")
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("model_path", ['gpt_oss/gpt-oss-120b'])
|
||||
def test_disaggregated_gpt_oss_120b_harmony(disaggregated_test_root,
|
||||
disaggregated_example_root,
|
||||
llm_venv, model_path):
|
||||
model_dir = f"{llm_models_root()}/{model_path}"
|
||||
src_dst_dict = {
|
||||
model_dir: f"{llm_venv.get_working_directory()}/{model_path}",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
|
||||
run_disaggregated_test(disaggregated_example_root,
|
||||
"gpt_oss_120b_harmony",
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
|
||||
|
||||
@pytest.mark.timeout(12600)
|
||||
@pytest.mark.parametrize("test_config", [
|
||||
pytest.param(TestConfig(model_path='DeepSeek-R1/DeepSeek-R1-0528-FP4-v2',
|
||||
|
||||
@ -26,6 +26,7 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
|
||||
- accuracy/test_llm_api_pytorch.py::TestMiniMaxM2::test_4gpus[attention_dp=False-cuda_graph=True-overlap_scheduler=True-tp_size=4-ep_size=4] TIMEOUT (60)
|
||||
- condition:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user