TensorRT-LLMs/examples/disaggregated/clients/disagg_client.py
Shunkangz c835f06371
Refactor the first token response in PD (#4692)
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
2025-06-04 09:11:23 +08:00

190 lines
6.5 KiB
Python

import argparse
import asyncio
import json
import logging
import time
import aiohttp
import yaml
logging.basicConfig(level=logging.INFO)
async def wait_for_server(session, server_host, server_port, timeout):
url = f"http://{server_host}:{server_port}/health"
start_time = time.time()
logging.info("Waiting for server to start")
while time.time() - start_time < timeout:
try:
async with session.get(url) as response:
if response.status == 200:
logging.info("Server is ready.")
return
except aiohttp.ClientError:
pass
await asyncio.sleep(1)
raise Exception("Server did not become ready in time.")
async def send_request(session, server_host, server_port, model, prompt,
max_tokens, temperature, streaming, ignore_eos):
url = f"http://{server_host}:{server_port}/v1/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": model,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"ignore_eos": ignore_eos
}
if streaming:
data["stream"] = True
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
raise Exception(f"Error: {await response.text()}")
if streaming:
text = ""
async for line in response.content:
if line:
line = line.decode('utf-8').strip()
if line == "data: [DONE]":
break
if line.startswith("data: "):
line = line[len("data: "):]
response_json = json.loads(line)
text += response_json["choices"][0]["text"]
logging.info(text)
return text
else:
response_json = await response.json()
text = response_json["choices"][0]["text"]
logging.info(text)
return text
async def send_chat_request(session, server_host, server_port, model, prompt,
max_tokens, temperature, streaming):
url = f"http://{server_host}:{server_port}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model":
model,
"messages": [{
"role": "system",
"content": "You are a helpfule assistant."
}, {
"role": "user",
"content": prompt
}],
"max_tokens":
max_tokens,
"temperature":
temperature
}
if streaming:
data["stream"] = True
async with session.post(url, headers=headers, json=data) as response:
if response.status != 200:
raise Exception(f"Error: {await response.text()}")
if streaming:
text = ""
async for line in response.content:
if line:
line = line.decode('utf-8').strip()
if line == "data: [DONE]":
break
if line.startswith("data: "):
line = line[len("data: "):]
response_json = json.loads(line)
if "content" in response_json["choices"][0]["delta"]:
text += response_json["choices"][0]["delta"][
"content"]
logging.info(text)
return text
else:
response_json = await response.json()
text = response_json["choices"][0]["message"]["content"]
logging.info(text)
return text
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("-c",
"--disagg_config-file",
help="Path to YAML config file",
required=True)
parser.add_argument("-p",
"--prompts-file",
help="Path to JSON file containing prompts",
required=True)
parser.add_argument("--max-tokens",
type=int,
help="Max tokens",
default=100)
parser.add_argument("--temperature",
type=float,
help="Temperature",
default=0.)
parser.add_argument("--server-start-timeout",
type=int,
help="Time to wait for server to start",
default=None)
parser.add_argument("-e",
"--endpoint",
type=str,
help="Endpoint to use",
default="completions")
parser.add_argument("-o",
"--output-file",
type=str,
help="Output filename",
default="output.json")
parser.add_argument("--streaming",
action="store_true",
help="Enable streaming responses")
parser.add_argument("--ignore-eos", action="store_true", help="Ignore eos")
args = parser.parse_args()
with open(args.disagg_config_file, "r") as file:
config = yaml.safe_load(file)
server_host = config.get('hostname', 'localhost')
server_port = config.get('port', 8000)
model = config.get('model', 'TinyLlama/TinyLlama-1.1B-Chat-v1.0')
with open(args.prompts_file, "r") as file:
prompts = json.load(file)
async with aiohttp.ClientSession() as session:
if args.server_start_timeout is not None:
await wait_for_server(session, server_host, server_port,
args.server_start_timeout)
if args.endpoint == "completions":
tasks = [
send_request(session, server_host, server_port, model, prompt,
args.max_tokens, args.temperature, args.streaming,
args.ignore_eos) for prompt in prompts
]
elif args.endpoint == "chat":
tasks = [
send_chat_request(session, server_host, server_port, model,
prompt, args.max_tokens, args.temperature,
args.streaming) for prompt in prompts
]
responses = await asyncio.gather(*tasks)
with open(args.output_file, "w") as file:
json.dump(responses, file, indent=2)
if __name__ == "__main__":
asyncio.run(main())