Fix lost requests for disaggregated serving (#5815)

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
Iman Tabrizian 2025-07-08 16:42:45 -07:00 committed by Zhihan Jiang
parent 155b19e6b0
commit 435eeea8b6
3 changed files with 279 additions and 241 deletions

View File

@ -68,7 +68,7 @@ class OpenAIDisaggServer:
async def lifespan(app: FastAPI):
# Create a persistent aiohttp ClientSession
self.session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, keepalive_timeout=300),
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
timeout=aiohttp.ClientTimeout(total=req_timeout_secs))
logger.info("Waiting for context and generation servers to be ready")

View File

@ -50,183 +50,196 @@ async def async_request_trt_llm(
request_func_input: RequestFuncInput,
streaming: bool = True,
pbar: Optional[tqdm] = None,
session: Optional[aiohttp.ClientSession] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": streaming,
}
if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
request_session = aiohttp.ClientSession(
trust_env=True,
timeout=AIOHTTP_TIMEOUT,
connector=aiohttp.TCPConnector(
limit=0, limit_per_host=0)) if session is None else session
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
output.success = True
if streaming:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": streaming,
}
if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:")
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with request_session.post(url=api_url, json=payload) as response:
if response.status == 200:
output.success = True
if streaming:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data:")
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
most_recent_timestamp = timestamp
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
output.latency = most_recent_timestamp - st
else:
content = await response.content.read()
data = json.loads(content.decode())
output.ttft = -1
output.itl = []
output.generated_text = data["text_output"]
output.latency = time.perf_counter() - st
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
content = await response.content.read()
data = json.loads(content.decode())
output.ttft = -1
output.itl = []
output.generated_text = data["text_output"]
output.latency = time.perf_counter() - st
if pbar:
pbar.update(1)
return output
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
finally:
if session is None:
await request_session.close()
if pbar:
pbar.update(1)
return output
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
streaming: bool = True,
pbar: Optional[tqdm] = None,
session: Optional[aiohttp.ClientSession] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": streaming,
}
if streaming:
payload["stream_options"] = {"include_usage": True}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
request_session = aiohttp.ClientSession(
trust_env=True,
timeout=AIOHTTP_TIMEOUT,
connector=aiohttp.TCPConnector(
limit=0, limit_per_host=0)) if session is None else session
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": streaming,
}
if streaming:
payload["stream_options"] = {"include_usage": True}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
generated_text = ""
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
if streaming:
first_chunk_received = False
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
generated_text = ""
st = time.perf_counter()
most_recent_timestamp = st
try:
async with request_session.post(url=api_url,
json=payload,
headers=headers) as response:
if response.status == 200:
if streaming:
first_chunk_received = False
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
content = await response.content.read()
data = json.loads(content.decode())
generated_text = data["choices"][0]["text"]
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received:
output.success = True
output.generated_text = generated_text
output.latency = time.perf_counter() - st
output.ttft = -1
output.itl = []
output.output_tokens = data["usage"][
"completion_tokens"]
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
content = await response.content.read()
data = json.loads(content.decode())
generated_text = data["choices"][0]["text"]
output.success = True
output.generated_text = generated_text
output.latency = time.perf_counter() - st
output.ttft = -1
output.itl = []
output.output_tokens = data["usage"]["completion_tokens"]
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
finally:
if session is None:
await request_session.close()
if pbar:
pbar.update(1)
@ -237,114 +250,121 @@ async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
streaming: bool = True,
pbar: Optional[tqdm] = None,
session: Optional[aiohttp.ClientSession] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(
("chat/completions", "profile"
)), "OpenAI Chat Completions API URL must end with 'chat/completions'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"messages": [
],
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": streaming,
}
request_session = aiohttp.ClientSession(
trust_env=True,
timeout=AIOHTTP_TIMEOUT,
connector=aiohttp.TCPConnector(
limit=0, limit_per_host=0)) if session is None else session
if isinstance(request_func_input.prompt, list) and all(
[isinstance(i, int) for i in request_func_input.prompt]):
payload["prompt_token_ids"] = request_func_input.prompt
else:
assert isinstance(
request_func_input.prompt,
str), "Prompt must be a string or a list of integers"
payload["messages"].append({
"role":
"user",
"content": [{
"type": "text",
"text": request_func_input.prompt
}]
})
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
"messages": [
],
"temperature": 0.0,
"max_completion_tokens": request_func_input.output_len,
"stream": streaming,
}
if streaming:
payload["stream_options"] = {"include_usage": True}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if isinstance(request_func_input.prompt, list) and all(
[isinstance(i, int) for i in request_func_input.prompt]):
payload["prompt_token_ids"] = request_func_input.prompt
else:
assert isinstance(request_func_input.prompt,
str), "Prompt must be a string or a list of integers"
payload["messages"].append({
"role":
"user",
"content": [{
"type": "text",
"text": request_func_input.prompt
}]
})
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
if streaming:
payload["stream_options"] = {"include_usage": True}
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
output.success = True
if streaming:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with request_session.post(url=api_url,
json=payload,
headers=headers) as response:
if response.status == 200:
output.success = True
if streaming:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
content = await response.content.read()
data = json.loads(content.decode())
output.generated_text = data["choices"][0]["message"][
"content"]
output.output_tokens = data["usage"][
"completion_tokens"]
output.itl = []
output.latency = time.perf_counter() - st
output.ttft = -1
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
content = await response.content.read()
data = json.loads(content.decode())
output.generated_text = data["choices"][0]["message"][
"content"]
output.output_tokens = data["usage"]["completion_tokens"]
output.itl = []
output.latency = time.perf_counter() - st
output.ttft = -1
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
finally:
if session is None:
await request_session.close()
if pbar:
pbar.update(1)

View File

@ -30,11 +30,12 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional
import aiohttp
import numpy as np
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from .backend_request_func import (ASYNC_REQUEST_FUNCS,
from .backend_request_func import (AIOHTTP_TIMEOUT, ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
RequestFuncOutput, get_tokenizer)
from .benchmark_dataset import (AIMODataset, BurstGPTDataset,
@ -322,18 +323,29 @@ async def benchmark(
semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None)
async def limited_request_func(request_func_input, streaming, pbar):
async def limited_request_func(request_func_input, streaming, pbar,
session):
if semaphore is None:
return await request_func(request_func_input=request_func_input,
streaming=streaming,
pbar=pbar)
pbar=pbar,
session=session)
async with semaphore:
return await request_func(request_func_input=request_func_input,
streaming=streaming,
pbar=pbar)
pbar=pbar,
session=session)
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
session = aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT,
connector=aiohttp.TCPConnector(
limit=0,
limit_per_host=0,
force_close=True))
i = 0
async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len = request.prompt, \
request.prompt_len, request.expected_output_len
@ -356,7 +368,9 @@ async def benchmark(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
streaming=streaming,
pbar=pbar)))
pbar=pbar,
session=session)))
i += 1
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
@ -370,7 +384,8 @@ async def benchmark(
logprobs=logprobs,
)
profile_output = await request_func(request_func_input=profile_input,
streaming=streaming)
streaming=streaming,
session=session)
if profile_output.success:
print("Profiler stopped")
@ -379,6 +394,9 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time
# Close the session
await session.close()
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,