diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 53732ec852..912fb84ac5 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -348,6 +348,13 @@ class OpenAIServer: async def health_generate(self, raw_request: Request) -> Response: """Health check that performs a minimal generation.""" + extra_args = {} + if self.llm.args.max_beam_width > 1: + extra_args = dict( + use_beam_search=True, + best_of=self.llm.args.max_beam_width, + n=1, + ) try: # Create a minimal chat request health_request = ChatCompletionRequest( @@ -355,7 +362,8 @@ class OpenAIServer: model=self.model, max_completion_tokens=1, # Request only 1 token out stream=False, - temperature=0.0 # Deterministic output + temperature=0.0, # Deterministic output + **extra_args, ) # Call the chat completion logic diff --git a/tests/unittest/llmapi/apps/_test_openai_misc.py b/tests/unittest/llmapi/apps/_test_openai_misc.py index 9e1b1a8dbe..d4a8ae5e3a 100644 --- a/tests/unittest/llmapi/apps/_test_openai_misc.py +++ b/tests/unittest/llmapi/apps/_test_openai_misc.py @@ -10,11 +10,23 @@ from ..test_llm import get_model_path from .openai_server import RemoteOpenAIServer -@pytest.fixture(scope="module", params=["trt", "pytorch"]) -def backend(request): +@pytest.fixture(scope="module", + params=[("trt", True), ("pytorch", False), ("pytorch", True)], + ids=lambda p: f"{p[0]}-{'with' if p[1] else 'no'}_beam_search") +def backend_and_beam_search(request): return request.param +@pytest.fixture(scope="module") +def backend(backend_and_beam_search): + return backend_and_beam_search[0] + + +@pytest.fixture(scope="module") +def enable_beam_search(backend_and_beam_search): + return backend_and_beam_search[1] + + @pytest.fixture(scope="module") def model_name(backend): # Note: TRT backend does not support Qwen3-0.6B-Base, @@ -31,6 +43,11 @@ def max_batch_size(request): return request.param +@pytest.fixture(scope="module") +def max_beam_width(enable_beam_search): + return 4 if enable_beam_search else 1 + + # Note: In the model Qwen3-0.6B-Base, "max_position_embeddings" is 32768, # so the inferred max_seq_len is 32768. @pytest.fixture(scope="module", params=["32768"]) @@ -39,12 +56,12 @@ def max_seq_len(request): @pytest.fixture(scope="module") -def server(model_name: str, backend: str, max_batch_size: str, - max_seq_len: str): +def server(model_name: str, backend: str, max_batch_size: str, max_seq_len: str, + enable_beam_search: bool, max_beam_width: int): model_path = get_model_path(model_name) args = ["--backend", f"{backend}"] - if backend != "pytorch": - args.extend(["--max_beam_width", "4"]) + if enable_beam_search: + args.extend(["--max_beam_width", str(max_beam_width)]) if max_batch_size is not None: args.extend(["--max_batch_size", max_batch_size]) if max_seq_len is not None: @@ -82,10 +99,24 @@ def test_model(client: openai.OpenAI, model_name: str): assert model.id == model_name.split('/')[-1] +@pytest.mark.parametrize("use_beam_search", [False, True]) # reference: https://github.com/vllm-project/vllm/blob/44f990515b124272f87954fc763d90697d8aa1db/tests/entrypoints/openai/test_basic.py#L123 @pytest.mark.asyncio -async def test_request_cancellation(server: RemoteOpenAIServer, - model_name: str): +async def test_request_cancellation(server: RemoteOpenAIServer, model_name: str, + use_beam_search: bool, + enable_beam_search: bool, backend: str, + max_beam_width: int): + if backend == "pytorch" and use_beam_search != enable_beam_search: + pytest.skip("PyTorch backend fixes beam width on startup") + + beam_search_args = {} + if use_beam_search: + beam_search_args = dict( + use_beam_search=True, + n=1, + best_of=max_beam_width, + ) + # clunky test: send an ungodly amount of load in with short timeouts # then ensure that it still responds quickly afterwards chat_input = [{"role": "user", "content": "Write a long story"}] @@ -94,17 +125,23 @@ async def test_request_cancellation(server: RemoteOpenAIServer, client = server.get_async_client() response = await client.chat.completions.create(messages=chat_input, model=model_name, - max_tokens=10000) + max_tokens=10000, + extra_body=beam_search_args) client = server.get_async_client(timeout=0.5, max_retries=3) tasks = [] # Request about 2 million tokens for _ in range(200): task = asyncio.create_task( - client.chat.completions.create(messages=chat_input, - model=model_name, - max_tokens=10000, - extra_body={"min_tokens": 10000})) + client.chat.completions.create( + messages=chat_input, + model=model_name, + max_tokens=10000, + extra_body={ + "min_tokens": 10000, + **beam_search_args, + }, + )) tasks.append(task) done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) @@ -122,6 +159,7 @@ async def test_request_cancellation(server: RemoteOpenAIServer, client = server.get_async_client(timeout=5) response = await client.chat.completions.create(messages=chat_input, model=model_name, - max_tokens=10) + max_tokens=10, + extra_body=beam_search_args) assert len(response.choices) == 1