From 301d986473a0ffc1df563422e01eac4a1efd59e0 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Tue, 19 May 2026 15:37:40 +0800 Subject: [PATCH] [Frontend] Consolidate beam search by BeamSearchMixin. (#42946) Signed-off-by: wang.yuqi --- .buildkite/test-amd.yaml | 2 +- .buildkite/test_areas/samplers.yaml | 1 + vllm/entrypoints/generate/__init__.py | 0 .../generate/beam_search/__init__.py | 0 .../generate/beam_search/offline.py | 226 ++++++++++++++++++ .../generate/beam_search/online.py | 225 +++++++++++++++++ .../generate/beam_search/utils.py} | 0 vllm/entrypoints/llm.py | 171 +------------ vllm/entrypoints/openai/engine/serving.py | 209 +--------------- 9 files changed, 459 insertions(+), 375 deletions(-) create mode 100644 vllm/entrypoints/generate/__init__.py create mode 100644 vllm/entrypoints/generate/beam_search/__init__.py create mode 100644 vllm/entrypoints/generate/beam_search/offline.py create mode 100644 vllm/entrypoints/generate/beam_search/online.py rename vllm/{beam_search.py => entrypoints/generate/beam_search/utils.py} (100%) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index b47ccd3f5a2..e53ca5023dc 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -2003,7 +2003,7 @@ steps: - vllm/model_executor/layers - vllm/sampling_metadata.py - vllm/v1/sample/ - - vllm/beam_search.py + - vllm/entrypoints/generate/beam_search/ - tests/samplers - tests/conftest.py - vllm/_aiter_ops.py diff --git a/.buildkite/test_areas/samplers.yaml b/.buildkite/test_areas/samplers.yaml index 52e79ec854d..6ec6f8efd35 100644 --- a/.buildkite/test_areas/samplers.yaml +++ b/.buildkite/test_areas/samplers.yaml @@ -11,6 +11,7 @@ steps: - vllm/sampling_metadata.py - tests/samplers - tests/conftest.py + - vllm/entrypoints/generate/beam_search commands: # VLLM_USE_FLASHINFER_SAMPLER defaults to 1 now, so we need to pin both # values explicitly to still cover the PyTorch-native (Triton) path. diff --git a/vllm/entrypoints/generate/__init__.py b/vllm/entrypoints/generate/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm/entrypoints/generate/beam_search/__init__.py b/vllm/entrypoints/generate/beam_search/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm/entrypoints/generate/beam_search/offline.py b/vllm/entrypoints/generate/beam_search/offline.py new file mode 100644 index 00000000000..29a0402fa13 --- /dev/null +++ b/vllm/entrypoints/generate/beam_search/offline.py @@ -0,0 +1,226 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Sequence +from typing import Any + +from tqdm import tqdm + +from vllm import PromptType, RequestOutput, TextPrompt, TokensPrompt +from vllm.inputs import EngineInput +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.renderers import BaseRenderer +from vllm.sampling_params import BeamSearchParams, SamplingParams + +from .utils import ( + BeamSearchInstance, + BeamSearchOutput, + BeamSearchSequence, + create_sort_beams_key_function, +) + +logger = init_logger(__name__) + + +class BeamSearchOfflineMixin(ABC): + """Offline inference for beam search""" + + renderer: BaseRenderer + + def beam_search( + self, + prompts: list[TokensPrompt | TextPrompt], + params: BeamSearchParams, + lora_request: list[LoRARequest] | LoRARequest | None = None, + use_tqdm: bool = False, + concurrency_limit: int | None = None, + ) -> list[BeamSearchOutput]: + """ + Generate sequences using beam search. + + Args: + prompts: A list of prompts. Each prompt can be a string or a list + of token IDs. + params: The beam search parameters. + lora_request: LoRA request to use for generation, if any. + use_tqdm: Whether to use tqdm to display the progress bar. + concurrency_limit: The maximum number of concurrent requests. + If None, the number of concurrent requests is unlimited. + """ + # TODO: how does beam search work together with length penalty, + # frequency, penalty, and stopping criteria, etc.? + beam_width = params.beam_width + max_tokens = params.max_tokens + temperature = params.temperature + ignore_eos = params.ignore_eos + length_penalty = params.length_penalty + + tokenizer = self.renderer.get_tokenizer() + eos_token_id = tokenizer.eos_token_id + sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) + + engine_inputs = self._preprocess_cmpl(prompts) + lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs)) + + if use_tqdm and concurrency_limit is not None: + logger.warning( + "Progress bar is not supported when using concurrency_limit. " + "Disabling progress bar." + ) + use_tqdm = False + + if concurrency_limit is None: + concurrency_limit = len(engine_inputs) + + # generate 2 * beam_width candidates at each step + # following the huggingface transformers implementation + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa + sampling_params = SamplingParams( + logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature, + skip_clone=True, # Internal beam search, safe to skip clone + ) + instances: list[BeamSearchInstance] = [] + + for lora_req, prompt in zip(lora_requests, engine_inputs): + if prompt["type"] == "embeds": + raise NotImplementedError( + "Embedding prompt not supported for beam search" + ) + + instances.append( + BeamSearchInstance( + prompt, + lora_request=lora_req, + logprobs=None, + ), + ) + + for prompt_start in range(0, len(instances), concurrency_limit): + instances_batch = instances[prompt_start : prompt_start + concurrency_limit] + + token_iter = range(max_tokens) + if use_tqdm: + token_iter = tqdm( + token_iter, desc="Beam search", unit="token", unit_scale=False + ) + logger.warning( + "The progress bar shows the upper bound on token steps and " + "may finish early due to stopping conditions. It does not " + "reflect instance-level progress." + ) + for _ in token_iter: + all_beams: list[BeamSearchSequence] = list( + sum((instance.beams for instance in instances_batch), []) + ) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances_batch + ) + ) + instance_start_and_end: list[tuple[int, int]] = list( + zip(pos[:-1], pos[1:]) + ) + + if len(all_beams) == 0: + break + + # only runs for one step + # we don't need to use tqdm here + output = self._render_and_run_requests( + prompts=(beam.get_prompt() for beam in all_beams), + params=self._params_to_seq(sampling_params, len(all_beams)), + output_type=RequestOutput, + lora_requests=[beam.lora_request for beam in all_beams], + use_tqdm=False, + ) + + for (start, end), instance in zip( + instance_start_and_end, instances_batch + ): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the + # max-model-len or abortion. we don't need to add + # it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + current_beam.orig_prompt, + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + ) + + if token_id == eos_token_id and not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted( + instance_new_beams, key=sort_beams_key, reverse=True + ) + instance.beams = sorted_beams[:beam_width] + + outputs = [] + for instance in instances: + instance.completed.extend(instance.beams) + sorted_completed = sorted( + instance.completed, key=sort_beams_key, reverse=True + ) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens) + + outputs.append(BeamSearchOutput(sequences=best_beams)) + + return outputs + + @abstractmethod + def _preprocess_cmpl( + self, + prompts: Sequence[PromptType], + tokenization_kwargs: dict[str, Any] | None = None, + mm_processor_kwargs: dict[str, Any] | None = None, + ) -> Sequence[EngineInput]: + raise NotImplementedError + + @abstractmethod + def _lora_request_to_seq( + self, + lora_request: LoRARequest | None | Sequence[LoRARequest | None], + num_requests: int, + ) -> Sequence[LoRARequest | None]: + raise NotImplementedError + + @abstractmethod + def _params_to_seq( + self, + params: SamplingParams, + num_requests: int, + ) -> Sequence[SamplingParams]: + raise NotImplementedError + + @abstractmethod + def _render_and_run_requests( + self, + prompts: Iterable[EngineInput], + params: Sequence[SamplingParams], + output_type: type[RequestOutput], + *, + lora_requests: Sequence[LoRARequest | None] | None = None, + priorities: Sequence[int] | None = None, + use_tqdm: bool | Callable[..., tqdm] = True, + ): + raise NotImplementedError diff --git a/vllm/entrypoints/generate/beam_search/online.py b/vllm/entrypoints/generate/beam_search/online.py new file mode 100644 index 00000000000..1daef9529be --- /dev/null +++ b/vllm/entrypoints/generate/beam_search/online.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from abc import ABC +from collections.abc import AsyncGenerator, Mapping + +import numpy as np + +from vllm import CompletionOutput, RequestOutput +from vllm.engine.protocol import EngineClient +from vllm.inputs import EngineInput +from vllm.lora.request import LoRARequest +from vllm.renderers import BaseRenderer +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.utils import random_uuid +from vllm.utils.async_utils import collect_from_async_generator + +from .utils import BeamSearchSequence, create_sort_beams_key_function + + +class BeamSearchOnlineMixin(ABC): + """online serving for beam search""" + + renderer: BaseRenderer + engine_client: EngineClient + + async def beam_search( + self, + prompt: EngineInput, + request_id: str, + params: BeamSearchParams, + lora_request: LoRARequest | None = None, + trace_headers: Mapping[str, str] | None = None, + ) -> AsyncGenerator[RequestOutput, None]: + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output + + tokenizer = self.renderer.get_tokenizer() + eos_token_id = tokenizer.eos_token_id + sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) + + if prompt["type"] == "embeds": + raise NotImplementedError("Embedding prompt not supported for beam search") + + # Extract prompt tokens and text based on model type + decoder_prompt = ( + prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"] + ) + prompt_text = decoder_prompt.get("prompt") + prompt_token_ids = decoder_prompt["prompt_token_ids"] + + tokenized_length = len(prompt_token_ids) + + logprobs_num = 2 * beam_width + sampling_params = SamplingParams( + logprobs=logprobs_num, + max_tokens=1, + temperature=temperature, + ) + all_beams = [ + BeamSearchSequence( + orig_prompt=prompt, + tokens=prompt_token_ids, + cum_logprob=0, + logprobs=[], + lora_request=lora_request, + ) + ] + completed = [] + + for _ in range(max_tokens): + tasks = [] + request_id_batch = f"{request_id}-{random_uuid()}" + + for i, beam in enumerate(all_beams): + prompt_item = beam.get_prompt() + lora_request_item = beam.lora_request + request_id_item = f"{request_id_batch}-beam-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.engine_client.generate( + prompt_item, + sampling_params, + request_id_item, + lora_request=lora_request_item, + trace_headers=trace_headers, + ) + ) + ) + tasks.append(task) + + output = [x[0] for x in await asyncio.gather(*tasks)] + + new_beams = [] + # Store all new tokens generated by beam + all_beams_token_id = [] + # Store the cumulative probability of all tokens + # generated by beam search + all_beams_logprob = [] + # Iterate through all beam inference results + for i, result in enumerate(output): + current_beam = all_beams[i] + + # check for error finish reason and abort beam search + if result.outputs[0].finish_reason == "error": + # yield error output and terminate beam search + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason="error", + ) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + ) + return + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + all_beams_token_id.extend(list(logprobs.keys())) + all_beams_logprob.extend( + [ + current_beam.cum_logprob + obj.logprob + for obj in logprobs.values() + ] + ) + + # Handle the token for the end of sentence (EOS) + all_beams_token_id = np.array(all_beams_token_id) + all_beams_logprob = np.array(all_beams_logprob) + + if not ignore_eos: + # Get the index position of eos token in all generated results + eos_idx = np.where(all_beams_token_id == eos_token_id)[0] + for idx in eos_idx: + current_beam = all_beams[idx // logprobs_num] + result = output[idx // logprobs_num] + assert result.outputs[0].logprobs is not None + logprobs_entry = result.outputs[0].logprobs[0] + completed.append( + BeamSearchSequence( + orig_prompt=prompt, + tokens=current_beam.tokens + [eos_token_id] + if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + [logprobs_entry], + cum_logprob=float(all_beams_logprob[idx]), + finish_reason="stop", + stop_reason=eos_token_id, + ) + ) + # After processing, set the log probability of the eos condition + # to negative infinity. + all_beams_logprob[eos_idx] = -np.inf + + # Processing non-EOS tokens + # Get indices of the top beam_width probabilities + topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[ + :beam_width + ] + + for idx in topn_idx: + current_beam = all_beams[idx // logprobs_num] + result = output[idx // logprobs_num] + token_id = int(all_beams_token_id[idx]) + assert result.outputs[0].logprobs is not None + logprobs_entry = result.outputs[0].logprobs[0] + new_beams.append( + BeamSearchSequence( + orig_prompt=prompt, + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs_entry], + lora_request=current_beam.lora_request, + cum_logprob=float(all_beams_logprob[idx]), + ) + ) + + all_beams = new_beams + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + if beam.tokens[-1] == eos_token_id and not ignore_eos: + # Skip the eos token in the text. + tokens = beam.tokens[tokenized_length:-1] + else: + tokens = beam.tokens[tokenized_length:] + beam.text = tokenizer.decode(tokens) + + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + text=beam.text, # type: ignore + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason + if beam.finish_reason is not None + else "length", + stop_reason=beam.stop_reason, + ) + for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + ) diff --git a/vllm/beam_search.py b/vllm/entrypoints/generate/beam_search/utils.py similarity index 100% rename from vllm/beam_search.py rename to vllm/entrypoints/generate/beam_search/utils.py diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9c342fc4808..0e29f38c0bf 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools from collections.abc import Callable, Iterable, Sequence from pathlib import Path from typing import TYPE_CHECKING, Any @@ -12,12 +11,6 @@ from pydantic import ValidationError from tqdm.auto import tqdm from typing_extensions import TypeVar, overload -from vllm.beam_search import ( - BeamSearchInstance, - BeamSearchOutput, - BeamSearchSequence, - create_sort_beams_key_function, -) from vllm.config import ( AttentionConfig, CompilationConfig, @@ -45,13 +38,12 @@ from vllm.entrypoints.chat_utils import ( ChatTemplateContentFormatOption, load_chat_template, ) +from vllm.entrypoints.generate.beam_search.offline import BeamSearchOfflineMixin from vllm.entrypoints.pooling.offline import PoolingOfflineMixin from vllm.entrypoints.utils import log_non_default_args from vllm.inputs import ( EngineInput, PromptType, - TextPrompt, - TokensPrompt, ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -65,7 +57,7 @@ from vllm.renderers.inputs.preprocess import ( parse_model_prompt, prompt_to_seq, ) -from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.usage.usage_lib import UsageContext from vllm.utils.counter import Counter @@ -89,7 +81,7 @@ _P = TypeVar("_P", bound=SamplingParams | PoolingParams | None) _R = TypeVar("_R", default=Any) -class LLM(PoolingOfflineMixin): +class LLM(BeamSearchOfflineMixin, PoolingOfflineMixin): """An LLM for generating texts from given prompts and sampling parameters. This class includes a tokenizer, a language model (possibly distributed @@ -675,163 +667,6 @@ class LLM(PoolingOfflineMixin): """ return self.llm_engine.apply_model(func) - def beam_search( - self, - prompts: list[TokensPrompt | TextPrompt], - params: BeamSearchParams, - lora_request: list[LoRARequest] | LoRARequest | None = None, - use_tqdm: bool = False, - concurrency_limit: int | None = None, - ) -> list[BeamSearchOutput]: - """ - Generate sequences using beam search. - - Args: - prompts: A list of prompts. Each prompt can be a string or a list - of token IDs. - params: The beam search parameters. - lora_request: LoRA request to use for generation, if any. - use_tqdm: Whether to use tqdm to display the progress bar. - concurrency_limit: The maximum number of concurrent requests. - If None, the number of concurrent requests is unlimited. - """ - # TODO: how does beam search work together with length penalty, - # frequency, penalty, and stopping criteria, etc.? - beam_width = params.beam_width - max_tokens = params.max_tokens - temperature = params.temperature - ignore_eos = params.ignore_eos - length_penalty = params.length_penalty - - tokenizer = self.renderer.get_tokenizer() - eos_token_id = tokenizer.eos_token_id - sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) - - engine_inputs = self._preprocess_cmpl(prompts) - lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs)) - - if use_tqdm and concurrency_limit is not None: - logger.warning( - "Progress bar is not supported when using concurrency_limit. " - "Disabling progress bar." - ) - use_tqdm = False - - if concurrency_limit is None: - concurrency_limit = len(engine_inputs) - - # generate 2 * beam_width candidates at each step - # following the huggingface transformers implementation - # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa - sampling_params = SamplingParams( - logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature, - skip_clone=True, # Internal beam search, safe to skip clone - ) - instances: list[BeamSearchInstance] = [] - - for lora_req, prompt in zip(lora_requests, engine_inputs): - if prompt["type"] == "embeds": - raise NotImplementedError( - "Embedding prompt not supported for beam search" - ) - - instances.append( - BeamSearchInstance( - prompt, - lora_request=lora_req, - logprobs=None, - ), - ) - - for prompt_start in range(0, len(instances), concurrency_limit): - instances_batch = instances[prompt_start : prompt_start + concurrency_limit] - - token_iter = range(max_tokens) - if use_tqdm: - token_iter = tqdm( - token_iter, desc="Beam search", unit="token", unit_scale=False - ) - logger.warning( - "The progress bar shows the upper bound on token steps and " - "may finish early due to stopping conditions. It does not " - "reflect instance-level progress." - ) - for _ in token_iter: - all_beams: list[BeamSearchSequence] = list( - sum((instance.beams for instance in instances_batch), []) - ) - pos = [0] + list( - itertools.accumulate( - len(instance.beams) for instance in instances_batch - ) - ) - instance_start_and_end: list[tuple[int, int]] = list( - zip(pos[:-1], pos[1:]) - ) - - if len(all_beams) == 0: - break - - # only runs for one step - # we don't need to use tqdm here - output = self._render_and_run_requests( - prompts=(beam.get_prompt() for beam in all_beams), - params=self._params_to_seq(sampling_params, len(all_beams)), - output_type=RequestOutput, - lora_requests=[beam.lora_request for beam in all_beams], - use_tqdm=False, - ) - - for (start, end), instance in zip( - instance_start_and_end, instances_batch - ): - instance_new_beams = [] - for i in range(start, end): - current_beam = all_beams[i] - result = output[i] - - if result.outputs[0].logprobs is not None: - # if `result.outputs[0].logprobs` is None, it means - # the sequence is completed because of the - # max-model-len or abortion. we don't need to add - # it to the new beams. - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - current_beam.orig_prompt, - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], - lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob - + logprob_obj.logprob, - ) - - if token_id == eos_token_id and not ignore_eos: - instance.completed.append(new_beam) - else: - instance_new_beams.append(new_beam) - sorted_beams = sorted( - instance_new_beams, key=sort_beams_key, reverse=True - ) - instance.beams = sorted_beams[:beam_width] - - outputs = [] - for instance in instances: - instance.completed.extend(instance.beams) - sorted_completed = sorted( - instance.completed, key=sort_beams_key, reverse=True - ) - best_beams = sorted_completed[:beam_width] - - for beam in best_beams: - beam.text = tokenizer.decode(beam.tokens) - - outputs.append(BeamSearchOutput(sequences=best_beams)) - - return outputs - def _preprocess_cmpl( self, prompts: Sequence[PromptType], diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 6152f915cc3..ff67575fcc6 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -1,25 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio import contextlib import json import time -from collections.abc import AsyncGenerator, Awaitable, Mapping +from collections.abc import Awaitable, Mapping from dataclasses import dataclass, field from http import HTTPStatus from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar -import numpy as np from fastapi import Request from openai.types.responses import ToolChoiceFunction from pydantic import ConfigDict, TypeAdapter, ValidationError from starlette.datastructures import Headers import vllm.envs as envs -from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.generate.beam_search.online import BeamSearchOnlineMixin from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.chat_completion.protocol import ( BatchChatCompletionRequest, @@ -56,7 +54,6 @@ from vllm.inputs import EngineInput, PromptType from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest -from vllm.outputs import CompletionOutput, RequestOutput from vllm.renderers import ChatParams, TokenizeParams from vllm.renderers.inputs.preprocess import ( extract_prompt_components, @@ -71,7 +68,6 @@ from vllm.tracing import ( log_tracing_disabled_warning, ) from vllm.utils import random_uuid -from vllm.utils.async_utils import collect_from_async_generator from vllm.utils.mistral import is_mistral_tool_parser logger = init_logger(__name__) @@ -133,7 +129,7 @@ class ServeContext(Generic[RequestT]): model_config = ConfigDict(arbitrary_types_allowed=True) -class OpenAIServing: +class OpenAIServing(BeamSearchOnlineMixin): request_id_prefix: ClassVar[str] = """ A short string prepended to every request’s ID. """ @@ -174,205 +170,6 @@ class OpenAIServing: # Never fail server startup over the fingerprint. self.system_fingerprint = None - async def beam_search( - self, - prompt: EngineInput, - request_id: str, - params: BeamSearchParams, - lora_request: LoRARequest | None = None, - trace_headers: Mapping[str, str] | None = None, - ) -> AsyncGenerator[RequestOutput, None]: - beam_width = params.beam_width - max_tokens = params.max_tokens - ignore_eos = params.ignore_eos - temperature = params.temperature - length_penalty = params.length_penalty - include_stop_str_in_output = params.include_stop_str_in_output - - tokenizer = self.renderer.get_tokenizer() - eos_token_id = tokenizer.eos_token_id - sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) - - if prompt["type"] == "embeds": - raise NotImplementedError("Embedding prompt not supported for beam search") - - # Extract prompt tokens and text based on model type - decoder_prompt = ( - prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"] - ) - prompt_text = decoder_prompt.get("prompt") - prompt_token_ids = decoder_prompt["prompt_token_ids"] - - tokenized_length = len(prompt_token_ids) - - logprobs_num = 2 * beam_width - sampling_params = SamplingParams( - logprobs=logprobs_num, - max_tokens=1, - temperature=temperature, - ) - all_beams = [ - BeamSearchSequence( - orig_prompt=prompt, - tokens=prompt_token_ids, - cum_logprob=0, - logprobs=[], - lora_request=lora_request, - ) - ] - completed = [] - - for _ in range(max_tokens): - tasks = [] - request_id_batch = f"{request_id}-{random_uuid()}" - - for i, beam in enumerate(all_beams): - prompt_item = beam.get_prompt() - lora_request_item = beam.lora_request - request_id_item = f"{request_id_batch}-beam-{i}" - task = asyncio.create_task( - collect_from_async_generator( - self.engine_client.generate( - prompt_item, - sampling_params, - request_id_item, - lora_request=lora_request_item, - trace_headers=trace_headers, - ) - ) - ) - tasks.append(task) - - output = [x[0] for x in await asyncio.gather(*tasks)] - - new_beams = [] - # Store all new tokens generated by beam - all_beams_token_id = [] - # Store the cumulative probability of all tokens - # generated by beam search - all_beams_logprob = [] - # Iterate through all beam inference results - for i, result in enumerate(output): - current_beam = all_beams[i] - - # check for error finish reason and abort beam search - if result.outputs[0].finish_reason == "error": - # yield error output and terminate beam search - yield RequestOutput( - request_id=request_id, - prompt=prompt_text, - outputs=[ - CompletionOutput( - index=0, - text="", - token_ids=[], - cumulative_logprob=None, - logprobs=None, - finish_reason="error", - ) - ], - finished=True, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, - ) - return - - if result.outputs[0].logprobs is not None: - logprobs = result.outputs[0].logprobs[0] - all_beams_token_id.extend(list(logprobs.keys())) - all_beams_logprob.extend( - [ - current_beam.cum_logprob + obj.logprob - for obj in logprobs.values() - ] - ) - - # Handle the token for the end of sentence (EOS) - all_beams_token_id = np.array(all_beams_token_id) - all_beams_logprob = np.array(all_beams_logprob) - - if not ignore_eos: - # Get the index position of eos token in all generated results - eos_idx = np.where(all_beams_token_id == eos_token_id)[0] - for idx in eos_idx: - current_beam = all_beams[idx // logprobs_num] - result = output[idx // logprobs_num] - assert result.outputs[0].logprobs is not None - logprobs_entry = result.outputs[0].logprobs[0] - completed.append( - BeamSearchSequence( - orig_prompt=prompt, - tokens=current_beam.tokens + [eos_token_id] - if include_stop_str_in_output - else current_beam.tokens, - logprobs=current_beam.logprobs + [logprobs_entry], - cum_logprob=float(all_beams_logprob[idx]), - finish_reason="stop", - stop_reason=eos_token_id, - ) - ) - # After processing, set the log probability of the eos condition - # to negative infinity. - all_beams_logprob[eos_idx] = -np.inf - - # Processing non-EOS tokens - # Get indices of the top beam_width probabilities - topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[ - :beam_width - ] - - for idx in topn_idx: - current_beam = all_beams[idx // logprobs_num] - result = output[idx // logprobs_num] - token_id = int(all_beams_token_id[idx]) - assert result.outputs[0].logprobs is not None - logprobs_entry = result.outputs[0].logprobs[0] - new_beams.append( - BeamSearchSequence( - orig_prompt=prompt, - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs_entry], - lora_request=current_beam.lora_request, - cum_logprob=float(all_beams_logprob[idx]), - ) - ) - - all_beams = new_beams - - completed.extend(all_beams) - sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) - best_beams = sorted_completed[:beam_width] - - for beam in best_beams: - if beam.tokens[-1] == eos_token_id and not ignore_eos: - # Skip the eos token in the text. - tokens = beam.tokens[tokenized_length:-1] - else: - tokens = beam.tokens[tokenized_length:] - beam.text = tokenizer.decode(tokens) - - yield RequestOutput( - request_id=request_id, - prompt=prompt_text, - outputs=[ - CompletionOutput( - text=beam.text, # type: ignore - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens[tokenized_length:], - index=i, - logprobs=beam.logprobs, - finish_reason=beam.finish_reason - if beam.finish_reason is not None - else "length", - stop_reason=beam.stop_reason, - ) - for (i, beam) in enumerate(best_beams) - ], - finished=True, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, - ) - @staticmethod def create_error_response( message: str | Exception,