mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
This commit is contained in:
parent
5d7a5e6800
commit
ce37e27066
@ -1521,10 +1521,10 @@ class HarmonyAdapter:
|
||||
return True
|
||||
|
||||
|
||||
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None
|
||||
_SERVE_HARMONY_ADAPTER: HarmonyAdapter | None = None
|
||||
|
||||
|
||||
def get_harmony_adapter():
|
||||
def get_harmony_adapter() -> HarmonyAdapter:
|
||||
global _SERVE_HARMONY_ADAPTER
|
||||
if _SERVE_HARMONY_ADAPTER is None:
|
||||
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()
|
||||
@ -1535,8 +1535,8 @@ def get_harmony_adapter():
|
||||
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
|
||||
tool_choice: str, result: GenerationResult,
|
||||
model: str, request_id: str, done: bool,
|
||||
num_prompt_tokens: int) -> List[str]:
|
||||
first_iteration = True
|
||||
num_prompt_tokens: int,
|
||||
first_iteration: bool) -> List[str]:
|
||||
output = result.outputs[0]
|
||||
|
||||
# Convert tools to dictionary format for harmony adapter (standard pattern)
|
||||
|
||||
@ -873,11 +873,12 @@ class OpenAIServer:
|
||||
return chat_response
|
||||
|
||||
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
|
||||
if not self.postproc_worker_enabled:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
|
||||
async for res in promise:
|
||||
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
|
||||
if not self.postproc_worker_enabled:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
pp_results = post_processor(res, args)
|
||||
else:
|
||||
pp_results = res.outputs[0]._postprocess_result
|
||||
for pp_res in pp_results:
|
||||
yield pp_res
|
||||
|
||||
|
||||
@ -593,7 +593,9 @@ def chat_harmony_streaming_post_processor(
|
||||
request_id=args.request_id,
|
||||
done=rsp._done,
|
||||
num_prompt_tokens=args.num_prompt_tokens,
|
||||
first_iteration=args.first_iteration,
|
||||
)
|
||||
args.first_iteration = False
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@ -161,15 +161,29 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str):
|
||||
}],
|
||||
stream=True,
|
||||
)
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
first_iteration = True
|
||||
async for chunk in response:
|
||||
# Last streaming response will only contains usage info
|
||||
if len(chunk.choices) <= 0:
|
||||
continue
|
||||
|
||||
collected_chunks.append(chunk)
|
||||
collected_messages.append(chunk.choices[0].delta)
|
||||
if chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
if first_iteration:
|
||||
assert delta.role == "assistant", ValueError(
|
||||
"Expected role 'assistant' for first iteration")
|
||||
collected_messages.append(delta)
|
||||
first_iteration = False
|
||||
else:
|
||||
assert delta.role is None, ValueError(
|
||||
"Expected no role except for first iteration")
|
||||
collected_messages.append(delta)
|
||||
else:
|
||||
assert hasattr(chunk, "usage"), ValueError(
|
||||
"Expected usage info in last streaming response")
|
||||
assert chunk.usage.prompt_tokens is not None
|
||||
assert chunk.usage.completion_tokens is not None
|
||||
assert chunk.usage.total_tokens is not None
|
||||
assert chunk.usage.prompt_tokens > 0
|
||||
assert chunk.usage.completion_tokens > 0
|
||||
assert chunk.usage.total_tokens > 0
|
||||
|
||||
full_response = "".join([
|
||||
m.content for m in collected_messages
|
||||
|
||||
@ -362,6 +362,7 @@ class TestHandleStreamingResponse:
|
||||
request_id=request_id,
|
||||
done=False, # Not done yet, still streaming
|
||||
num_prompt_tokens=10,
|
||||
first_iteration=True,
|
||||
)
|
||||
|
||||
# CRITICAL ASSERTION: result.abort() should be called
|
||||
@ -424,6 +425,7 @@ class TestHandleStreamingResponse:
|
||||
request_id=request_id,
|
||||
done=False,
|
||||
num_prompt_tokens=10,
|
||||
first_iteration=True,
|
||||
)
|
||||
|
||||
# CRITICAL ASSERTION: result.abort() should NOT be called
|
||||
|
||||
Loading…
Reference in New Issue
Block a user