mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Fix lost requests for disaggregated serving (#5815)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
parent
155b19e6b0
commit
435eeea8b6
@ -68,7 +68,7 @@ class OpenAIDisaggServer:
|
||||
async def lifespan(app: FastAPI):
|
||||
# Create a persistent aiohttp ClientSession
|
||||
self.session = aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, keepalive_timeout=300),
|
||||
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
|
||||
timeout=aiohttp.ClientTimeout(total=req_timeout_secs))
|
||||
|
||||
logger.info("Waiting for context and generation servers to be ready")
|
||||
|
||||
@ -50,183 +50,196 @@ async def async_request_trt_llm(
|
||||
request_func_input: RequestFuncInput,
|
||||
streaming: bool = True,
|
||||
pbar: Optional[tqdm] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": streaming,
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["min_length"] = request_func_input.output_len
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
request_session = aiohttp.ClientSession(
|
||||
trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT,
|
||||
connector=aiohttp.TCPConnector(
|
||||
limit=0, limit_per_host=0)) if session is None else session
|
||||
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
output.success = True
|
||||
if streaming:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": streaming,
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["min_length"] = request_func_input.output_len
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data:")
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with request_session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
output.success = True
|
||||
if streaming:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
data = json.loads(chunk)
|
||||
output.generated_text += data["text_output"]
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data:")
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
data = json.loads(chunk)
|
||||
output.generated_text += data["text_output"]
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
content = await response.content.read()
|
||||
data = json.loads(content.decode())
|
||||
output.ttft = -1
|
||||
output.itl = []
|
||||
output.generated_text = data["text_output"]
|
||||
output.latency = time.perf_counter() - st
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
content = await response.content.read()
|
||||
data = json.loads(content.decode())
|
||||
output.ttft = -1
|
||||
output.itl = []
|
||||
output.generated_text = data["text_output"]
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
finally:
|
||||
if session is None:
|
||||
await request_session.close()
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
streaming: bool = True,
|
||||
pbar: Optional[tqdm] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": streaming,
|
||||
}
|
||||
if streaming:
|
||||
payload["stream_options"] = {"include_usage": True}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
request_session = aiohttp.ClientSession(
|
||||
trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT,
|
||||
connector=aiohttp.TCPConnector(
|
||||
limit=0, limit_per_host=0)) if session is None else session
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"repetition_penalty": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": streaming,
|
||||
}
|
||||
if streaming:
|
||||
payload["stream_options"] = {"include_usage": True}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
|
||||
generated_text = ""
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
if streaming:
|
||||
first_chunk_received = False
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
generated_text = ""
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with request_session.post(url=api_url,
|
||||
json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
if streaming:
|
||||
first_chunk_received = False
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
content = await response.content.read()
|
||||
data = json.loads(content.decode())
|
||||
generated_text = data["choices"][0]["text"]
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
output.generated_text = generated_text
|
||||
output.latency = time.perf_counter() - st
|
||||
output.ttft = -1
|
||||
output.itl = []
|
||||
output.output_tokens = data["usage"][
|
||||
"completion_tokens"]
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
content = await response.content.read()
|
||||
data = json.loads(content.decode())
|
||||
generated_text = data["choices"][0]["text"]
|
||||
output.success = True
|
||||
output.generated_text = generated_text
|
||||
output.latency = time.perf_counter() - st
|
||||
output.ttft = -1
|
||||
output.itl = []
|
||||
output.output_tokens = data["usage"]["completion_tokens"]
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
finally:
|
||||
if session is None:
|
||||
await request_session.close()
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
@ -237,114 +250,121 @@ async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
streaming: bool = True,
|
||||
pbar: Optional[tqdm] = None,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
("chat/completions", "profile"
|
||||
)), "OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"messages": [
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": streaming,
|
||||
}
|
||||
request_session = aiohttp.ClientSession(
|
||||
trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT,
|
||||
connector=aiohttp.TCPConnector(
|
||||
limit=0, limit_per_host=0)) if session is None else session
|
||||
|
||||
if isinstance(request_func_input.prompt, list) and all(
|
||||
[isinstance(i, int) for i in request_func_input.prompt]):
|
||||
payload["prompt_token_ids"] = request_func_input.prompt
|
||||
else:
|
||||
assert isinstance(
|
||||
request_func_input.prompt,
|
||||
str), "Prompt must be a string or a list of integers"
|
||||
payload["messages"].append({
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": request_func_input.prompt
|
||||
}]
|
||||
})
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"messages": [
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": streaming,
|
||||
}
|
||||
|
||||
if streaming:
|
||||
payload["stream_options"] = {"include_usage": True}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if isinstance(request_func_input.prompt, list) and all(
|
||||
[isinstance(i, int) for i in request_func_input.prompt]):
|
||||
payload["prompt_token_ids"] = request_func_input.prompt
|
||||
else:
|
||||
assert isinstance(request_func_input.prompt,
|
||||
str), "Prompt must be a string or a list of integers"
|
||||
payload["messages"].append({
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": request_func_input.prompt
|
||||
}]
|
||||
})
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
if streaming:
|
||||
payload["stream_options"] = {"include_usage": True}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
output.success = True
|
||||
if streaming:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with request_session.post(url=api_url,
|
||||
json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
output.success = True
|
||||
if streaming:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
content = await response.content.read()
|
||||
data = json.loads(content.decode())
|
||||
output.generated_text = data["choices"][0]["message"][
|
||||
"content"]
|
||||
output.output_tokens = data["usage"][
|
||||
"completion_tokens"]
|
||||
output.itl = []
|
||||
output.latency = time.perf_counter() - st
|
||||
output.ttft = -1
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
content = await response.content.read()
|
||||
data = json.loads(content.decode())
|
||||
output.generated_text = data["choices"][0]["message"][
|
||||
"content"]
|
||||
output.output_tokens = data["usage"]["completion_tokens"]
|
||||
output.itl = []
|
||||
output.latency = time.perf_counter() - st
|
||||
output.ttft = -1
|
||||
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
finally:
|
||||
if session is None:
|
||||
await request_session.close()
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
|
||||
@ -30,11 +30,12 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from .backend_request_func import (ASYNC_REQUEST_FUNCS,
|
||||
from .backend_request_func import (AIOHTTP_TIMEOUT, ASYNC_REQUEST_FUNCS,
|
||||
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
|
||||
RequestFuncOutput, get_tokenizer)
|
||||
from .benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||
@ -322,18 +323,29 @@ async def benchmark(
|
||||
semaphore = (asyncio.Semaphore(max_concurrency)
|
||||
if max_concurrency else None)
|
||||
|
||||
async def limited_request_func(request_func_input, streaming, pbar):
|
||||
async def limited_request_func(request_func_input, streaming, pbar,
|
||||
session):
|
||||
if semaphore is None:
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
streaming=streaming,
|
||||
pbar=pbar)
|
||||
pbar=pbar,
|
||||
session=session)
|
||||
async with semaphore:
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
streaming=streaming,
|
||||
pbar=pbar)
|
||||
pbar=pbar,
|
||||
session=session)
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
session = aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT,
|
||||
connector=aiohttp.TCPConnector(
|
||||
limit=0,
|
||||
limit_per_host=0,
|
||||
force_close=True))
|
||||
|
||||
i = 0
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, prompt_len, output_len = request.prompt, \
|
||||
request.prompt_len, request.expected_output_len
|
||||
@ -356,7 +368,9 @@ async def benchmark(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
streaming=streaming,
|
||||
pbar=pbar)))
|
||||
pbar=pbar,
|
||||
session=session)))
|
||||
i += 1
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
@ -370,7 +384,8 @@ async def benchmark(
|
||||
logprobs=logprobs,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input,
|
||||
streaming=streaming)
|
||||
streaming=streaming,
|
||||
session=session)
|
||||
if profile_output.success:
|
||||
print("Profiler stopped")
|
||||
|
||||
@ -379,6 +394,9 @@ async def benchmark(
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
# Close the session
|
||||
await session.close()
|
||||
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user