[None][fix] make health_generate work with beam search (#11097)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-02-04 09:46:19 +01:00 committed by GitHub
parent 02b80bfd58
commit f0ca62b175
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 15 deletions

View File

@ -348,6 +348,13 @@ class OpenAIServer:
async def health_generate(self, raw_request: Request) -> Response:
"""Health check that performs a minimal generation."""
extra_args = {}
if self.llm.args.max_beam_width > 1:
extra_args = dict(
use_beam_search=True,
best_of=self.llm.args.max_beam_width,
n=1,
)
try:
# Create a minimal chat request
health_request = ChatCompletionRequest(
@ -355,7 +362,8 @@ class OpenAIServer:
model=self.model,
max_completion_tokens=1, # Request only 1 token out
stream=False,
temperature=0.0 # Deterministic output
temperature=0.0, # Deterministic output
**extra_args,
)
# Call the chat completion logic

View File

@ -10,11 +10,23 @@ from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
@pytest.fixture(scope="module", params=["trt", "pytorch"])
def backend(request):
@pytest.fixture(scope="module",
params=[("trt", True), ("pytorch", False), ("pytorch", True)],
ids=lambda p: f"{p[0]}-{'with' if p[1] else 'no'}_beam_search")
def backend_and_beam_search(request):
return request.param
@pytest.fixture(scope="module")
def backend(backend_and_beam_search):
return backend_and_beam_search[0]
@pytest.fixture(scope="module")
def enable_beam_search(backend_and_beam_search):
return backend_and_beam_search[1]
@pytest.fixture(scope="module")
def model_name(backend):
# Note: TRT backend does not support Qwen3-0.6B-Base,
@ -31,6 +43,11 @@ def max_batch_size(request):
return request.param
@pytest.fixture(scope="module")
def max_beam_width(enable_beam_search):
return 4 if enable_beam_search else 1
# Note: In the model Qwen3-0.6B-Base, "max_position_embeddings" is 32768,
# so the inferred max_seq_len is 32768.
@pytest.fixture(scope="module", params=["32768"])
@ -39,12 +56,12 @@ def max_seq_len(request):
@pytest.fixture(scope="module")
def server(model_name: str, backend: str, max_batch_size: str,
max_seq_len: str):
def server(model_name: str, backend: str, max_batch_size: str, max_seq_len: str,
enable_beam_search: bool, max_beam_width: int):
model_path = get_model_path(model_name)
args = ["--backend", f"{backend}"]
if backend != "pytorch":
args.extend(["--max_beam_width", "4"])
if enable_beam_search:
args.extend(["--max_beam_width", str(max_beam_width)])
if max_batch_size is not None:
args.extend(["--max_batch_size", max_batch_size])
if max_seq_len is not None:
@ -82,10 +99,24 @@ def test_model(client: openai.OpenAI, model_name: str):
assert model.id == model_name.split('/')[-1]
@pytest.mark.parametrize("use_beam_search", [False, True])
# reference: https://github.com/vllm-project/vllm/blob/44f990515b124272f87954fc763d90697d8aa1db/tests/entrypoints/openai/test_basic.py#L123
@pytest.mark.asyncio
async def test_request_cancellation(server: RemoteOpenAIServer,
model_name: str):
async def test_request_cancellation(server: RemoteOpenAIServer, model_name: str,
use_beam_search: bool,
enable_beam_search: bool, backend: str,
max_beam_width: int):
if backend == "pytorch" and use_beam_search != enable_beam_search:
pytest.skip("PyTorch backend fixes beam width on startup")
beam_search_args = {}
if use_beam_search:
beam_search_args = dict(
use_beam_search=True,
n=1,
best_of=max_beam_width,
)
# clunky test: send an ungodly amount of load in with short timeouts
# then ensure that it still responds quickly afterwards
chat_input = [{"role": "user", "content": "Write a long story"}]
@ -94,17 +125,23 @@ async def test_request_cancellation(server: RemoteOpenAIServer,
client = server.get_async_client()
response = await client.chat.completions.create(messages=chat_input,
model=model_name,
max_tokens=10000)
max_tokens=10000,
extra_body=beam_search_args)
client = server.get_async_client(timeout=0.5, max_retries=3)
tasks = []
# Request about 2 million tokens
for _ in range(200):
task = asyncio.create_task(
client.chat.completions.create(messages=chat_input,
model=model_name,
max_tokens=10000,
extra_body={"min_tokens": 10000}))
client.chat.completions.create(
messages=chat_input,
model=model_name,
max_tokens=10000,
extra_body={
"min_tokens": 10000,
**beam_search_args,
},
))
tasks.append(task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
@ -122,6 +159,7 @@ async def test_request_cancellation(server: RemoteOpenAIServer,
client = server.get_async_client(timeout=5)
response = await client.chat.completions.create(messages=chat_input,
model=model_name,
max_tokens=10)
max_tokens=10,
extra_body=beam_search_args)
assert len(response.choices) == 1