mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][fix] make health_generate work with beam search (#11097)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
02b80bfd58
commit
f0ca62b175
@ -348,6 +348,13 @@ class OpenAIServer:
|
||||
|
||||
async def health_generate(self, raw_request: Request) -> Response:
|
||||
"""Health check that performs a minimal generation."""
|
||||
extra_args = {}
|
||||
if self.llm.args.max_beam_width > 1:
|
||||
extra_args = dict(
|
||||
use_beam_search=True,
|
||||
best_of=self.llm.args.max_beam_width,
|
||||
n=1,
|
||||
)
|
||||
try:
|
||||
# Create a minimal chat request
|
||||
health_request = ChatCompletionRequest(
|
||||
@ -355,7 +362,8 @@ class OpenAIServer:
|
||||
model=self.model,
|
||||
max_completion_tokens=1, # Request only 1 token out
|
||||
stream=False,
|
||||
temperature=0.0 # Deterministic output
|
||||
temperature=0.0, # Deterministic output
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
# Call the chat completion logic
|
||||
|
||||
@ -10,11 +10,23 @@ from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["trt", "pytorch"])
|
||||
def backend(request):
|
||||
@pytest.fixture(scope="module",
|
||||
params=[("trt", True), ("pytorch", False), ("pytorch", True)],
|
||||
ids=lambda p: f"{p[0]}-{'with' if p[1] else 'no'}_beam_search")
|
||||
def backend_and_beam_search(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def backend(backend_and_beam_search):
|
||||
return backend_and_beam_search[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def enable_beam_search(backend_and_beam_search):
|
||||
return backend_and_beam_search[1]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_name(backend):
|
||||
# Note: TRT backend does not support Qwen3-0.6B-Base,
|
||||
@ -31,6 +43,11 @@ def max_batch_size(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def max_beam_width(enable_beam_search):
|
||||
return 4 if enable_beam_search else 1
|
||||
|
||||
|
||||
# Note: In the model Qwen3-0.6B-Base, "max_position_embeddings" is 32768,
|
||||
# so the inferred max_seq_len is 32768.
|
||||
@pytest.fixture(scope="module", params=["32768"])
|
||||
@ -39,12 +56,12 @@ def max_seq_len(request):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_name: str, backend: str, max_batch_size: str,
|
||||
max_seq_len: str):
|
||||
def server(model_name: str, backend: str, max_batch_size: str, max_seq_len: str,
|
||||
enable_beam_search: bool, max_beam_width: int):
|
||||
model_path = get_model_path(model_name)
|
||||
args = ["--backend", f"{backend}"]
|
||||
if backend != "pytorch":
|
||||
args.extend(["--max_beam_width", "4"])
|
||||
if enable_beam_search:
|
||||
args.extend(["--max_beam_width", str(max_beam_width)])
|
||||
if max_batch_size is not None:
|
||||
args.extend(["--max_batch_size", max_batch_size])
|
||||
if max_seq_len is not None:
|
||||
@ -82,10 +99,24 @@ def test_model(client: openai.OpenAI, model_name: str):
|
||||
assert model.id == model_name.split('/')[-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_beam_search", [False, True])
|
||||
# reference: https://github.com/vllm-project/vllm/blob/44f990515b124272f87954fc763d90697d8aa1db/tests/entrypoints/openai/test_basic.py#L123
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_cancellation(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
async def test_request_cancellation(server: RemoteOpenAIServer, model_name: str,
|
||||
use_beam_search: bool,
|
||||
enable_beam_search: bool, backend: str,
|
||||
max_beam_width: int):
|
||||
if backend == "pytorch" and use_beam_search != enable_beam_search:
|
||||
pytest.skip("PyTorch backend fixes beam width on startup")
|
||||
|
||||
beam_search_args = {}
|
||||
if use_beam_search:
|
||||
beam_search_args = dict(
|
||||
use_beam_search=True,
|
||||
n=1,
|
||||
best_of=max_beam_width,
|
||||
)
|
||||
|
||||
# clunky test: send an ungodly amount of load in with short timeouts
|
||||
# then ensure that it still responds quickly afterwards
|
||||
chat_input = [{"role": "user", "content": "Write a long story"}]
|
||||
@ -94,17 +125,23 @@ async def test_request_cancellation(server: RemoteOpenAIServer,
|
||||
client = server.get_async_client()
|
||||
response = await client.chat.completions.create(messages=chat_input,
|
||||
model=model_name,
|
||||
max_tokens=10000)
|
||||
max_tokens=10000,
|
||||
extra_body=beam_search_args)
|
||||
|
||||
client = server.get_async_client(timeout=0.5, max_retries=3)
|
||||
tasks = []
|
||||
# Request about 2 million tokens
|
||||
for _ in range(200):
|
||||
task = asyncio.create_task(
|
||||
client.chat.completions.create(messages=chat_input,
|
||||
model=model_name,
|
||||
max_tokens=10000,
|
||||
extra_body={"min_tokens": 10000}))
|
||||
client.chat.completions.create(
|
||||
messages=chat_input,
|
||||
model=model_name,
|
||||
max_tokens=10000,
|
||||
extra_body={
|
||||
"min_tokens": 10000,
|
||||
**beam_search_args,
|
||||
},
|
||||
))
|
||||
tasks.append(task)
|
||||
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||
@ -122,6 +159,7 @@ async def test_request_cancellation(server: RemoteOpenAIServer,
|
||||
client = server.get_async_client(timeout=5)
|
||||
response = await client.chat.completions.create(messages=chat_input,
|
||||
model=model_name,
|
||||
max_tokens=10)
|
||||
max_tokens=10,
|
||||
extra_body=beam_search_args)
|
||||
|
||||
assert len(response.choices) == 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user