[Frontend] Consolidate beam search by BeamSearchMixin. (#42946)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2026-05-19 15:37:40 +08:00
committed by GitHub
parent 257af77bc2
commit 301d986473
9 changed files with 459 additions and 375 deletions
+1 -1
View File
@@ -2003,7 +2003,7 @@ steps:
- vllm/model_executor/layers
- vllm/sampling_metadata.py
- vllm/v1/sample/
- vllm/beam_search.py
- vllm/entrypoints/generate/beam_search/
- tests/samplers
- tests/conftest.py
- vllm/_aiter_ops.py
+1
View File
@@ -11,6 +11,7 @@ steps:
- vllm/sampling_metadata.py
- tests/samplers
- tests/conftest.py
- vllm/entrypoints/generate/beam_search
commands:
# VLLM_USE_FLASHINFER_SAMPLER defaults to 1 now, so we need to pin both
# values explicitly to still cover the PyTorch-native (Triton) path.
@@ -0,0 +1,226 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from typing import Any
from tqdm import tqdm
from vllm import PromptType, RequestOutput, TextPrompt, TokensPrompt
from vllm.inputs import EngineInput
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.sampling_params import BeamSearchParams, SamplingParams
from .utils import (
BeamSearchInstance,
BeamSearchOutput,
BeamSearchSequence,
create_sort_beams_key_function,
)
logger = init_logger(__name__)
class BeamSearchOfflineMixin(ABC):
"""Offline inference for beam search"""
renderer: BaseRenderer
def beam_search(
self,
prompts: list[TokensPrompt | TextPrompt],
params: BeamSearchParams,
lora_request: list[LoRARequest] | LoRARequest | None = None,
use_tqdm: bool = False,
concurrency_limit: int | None = None,
) -> list[BeamSearchOutput]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
params: The beam search parameters.
lora_request: LoRA request to use for generation, if any.
use_tqdm: Whether to use tqdm to display the progress bar.
concurrency_limit: The maximum number of concurrent requests.
If None, the number of concurrent requests is unlimited.
"""
# TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.?
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
engine_inputs = self._preprocess_cmpl(prompts)
lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs))
if use_tqdm and concurrency_limit is not None:
logger.warning(
"Progress bar is not supported when using concurrency_limit. "
"Disabling progress bar."
)
use_tqdm = False
if concurrency_limit is None:
concurrency_limit = len(engine_inputs)
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
sampling_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
skip_clone=True, # Internal beam search, safe to skip clone
)
instances: list[BeamSearchInstance] = []
for lora_req, prompt in zip(lora_requests, engine_inputs):
if prompt["type"] == "embeds":
raise NotImplementedError(
"Embedding prompt not supported for beam search"
)
instances.append(
BeamSearchInstance(
prompt,
lora_request=lora_req,
logprobs=None,
),
)
for prompt_start in range(0, len(instances), concurrency_limit):
instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
token_iter = range(max_tokens)
if use_tqdm:
token_iter = tqdm(
token_iter, desc="Beam search", unit="token", unit_scale=False
)
logger.warning(
"The progress bar shows the upper bound on token steps and "
"may finish early due to stopping conditions. It does not "
"reflect instance-level progress."
)
for _ in token_iter:
all_beams: list[BeamSearchSequence] = list(
sum((instance.beams for instance in instances_batch), [])
)
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances_batch
)
)
instance_start_and_end: list[tuple[int, int]] = list(
zip(pos[:-1], pos[1:])
)
if len(all_beams) == 0:
break
# only runs for one step
# we don't need to use tqdm here
output = self._render_and_run_requests(
prompts=(beam.get_prompt() for beam in all_beams),
params=self._params_to_seq(sampling_params, len(all_beams)),
output_type=RequestOutput,
lora_requests=[beam.lora_request for beam in all_beams],
use_tqdm=False,
)
for (start, end), instance in zip(
instance_start_and_end, instances_batch
):
instance_new_beams = []
for i in range(start, end):
current_beam = all_beams[i]
result = output[i]
if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the
# max-model-len or abortion. we don't need to add
# it to the new beams.
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
current_beam.orig_prompt,
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
)
if token_id == eos_token_id and not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(
instance_new_beams, key=sort_beams_key, reverse=True
)
instance.beams = sorted_beams[:beam_width]
outputs = []
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(
instance.completed, key=sort_beams_key, reverse=True
)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams))
return outputs
@abstractmethod
def _preprocess_cmpl(
self,
prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> Sequence[EngineInput]:
raise NotImplementedError
@abstractmethod
def _lora_request_to_seq(
self,
lora_request: LoRARequest | None | Sequence[LoRARequest | None],
num_requests: int,
) -> Sequence[LoRARequest | None]:
raise NotImplementedError
@abstractmethod
def _params_to_seq(
self,
params: SamplingParams,
num_requests: int,
) -> Sequence[SamplingParams]:
raise NotImplementedError
@abstractmethod
def _render_and_run_requests(
self,
prompts: Iterable[EngineInput],
params: Sequence[SamplingParams],
output_type: type[RequestOutput],
*,
lora_requests: Sequence[LoRARequest | None] | None = None,
priorities: Sequence[int] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
):
raise NotImplementedError
@@ -0,0 +1,225 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from abc import ABC
from collections.abc import AsyncGenerator, Mapping
import numpy as np
from vllm import CompletionOutput, RequestOutput
from vllm.engine.protocol import EngineClient
from vllm.inputs import EngineInput
from vllm.lora.request import LoRARequest
from vllm.renderers import BaseRenderer
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.utils import random_uuid
from vllm.utils.async_utils import collect_from_async_generator
from .utils import BeamSearchSequence, create_sort_beams_key_function
class BeamSearchOnlineMixin(ABC):
"""online serving for beam search"""
renderer: BaseRenderer
engine_client: EngineClient
async def beam_search(
self,
prompt: EngineInput,
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
if prompt["type"] == "embeds":
raise NotImplementedError("Embedding prompt not supported for beam search")
# Extract prompt tokens and text based on model type
decoder_prompt = (
prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
)
prompt_text = decoder_prompt.get("prompt")
prompt_token_ids = decoder_prompt["prompt_token_ids"]
tokenized_length = len(prompt_token_ids)
logprobs_num = 2 * beam_width
sampling_params = SamplingParams(
logprobs=logprobs_num,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(
orig_prompt=prompt,
tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
lora_request=lora_request,
)
]
completed = []
for _ in range(max_tokens):
tasks = []
request_id_batch = f"{request_id}-{random_uuid()}"
for i, beam in enumerate(all_beams):
prompt_item = beam.get_prompt()
lora_request_item = beam.lora_request
request_id_item = f"{request_id_batch}-beam-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.engine_client.generate(
prompt_item,
sampling_params,
request_id_item,
lora_request=lora_request_item,
trace_headers=trace_headers,
)
)
)
tasks.append(task)
output = [x[0] for x in await asyncio.gather(*tasks)]
new_beams = []
# Store all new tokens generated by beam
all_beams_token_id = []
# Store the cumulative probability of all tokens
# generated by beam search
all_beams_logprob = []
# Iterate through all beam inference results
for i, result in enumerate(output):
current_beam = all_beams[i]
# check for error finish reason and abort beam search
if result.outputs[0].finish_reason == "error":
# yield error output and terminate beam search
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
return
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
all_beams_token_id.extend(list(logprobs.keys()))
all_beams_logprob.extend(
[
current_beam.cum_logprob + obj.logprob
for obj in logprobs.values()
]
)
# Handle the token for the end of sentence (EOS)
all_beams_token_id = np.array(all_beams_token_id)
all_beams_logprob = np.array(all_beams_logprob)
if not ignore_eos:
# Get the index position of eos token in all generated results
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
for idx in eos_idx:
current_beam = all_beams[idx // logprobs_num]
result = output[idx // logprobs_num]
assert result.outputs[0].logprobs is not None
logprobs_entry = result.outputs[0].logprobs[0]
completed.append(
BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [eos_token_id]
if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs + [logprobs_entry],
cum_logprob=float(all_beams_logprob[idx]),
finish_reason="stop",
stop_reason=eos_token_id,
)
)
# After processing, set the log probability of the eos condition
# to negative infinity.
all_beams_logprob[eos_idx] = -np.inf
# Processing non-EOS tokens
# Get indices of the top beam_width probabilities
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
:beam_width
]
for idx in topn_idx:
current_beam = all_beams[idx // logprobs_num]
result = output[idx // logprobs_num]
token_id = int(all_beams_token_id[idx])
assert result.outputs[0].logprobs is not None
logprobs_entry = result.outputs[0].logprobs[0]
new_beams.append(
BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs_entry],
lora_request=current_beam.lora_request,
cum_logprob=float(all_beams_logprob[idx]),
)
)
all_beams = new_beams
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
if beam.tokens[-1] == eos_token_id and not ignore_eos:
# Skip the eos token in the text.
tokens = beam.tokens[tokenized_length:-1]
else:
tokens = beam.tokens[tokenized_length:]
beam.text = tokenizer.decode(tokens)
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text, # type: ignore
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason
if beam.finish_reason is not None
else "length",
stop_reason=beam.stop_reason,
)
for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
+3 -168
View File
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -12,12 +11,6 @@ from pydantic import ValidationError
from tqdm.auto import tqdm
from typing_extensions import TypeVar, overload
from vllm.beam_search import (
BeamSearchInstance,
BeamSearchOutput,
BeamSearchSequence,
create_sort_beams_key_function,
)
from vllm.config import (
AttentionConfig,
CompilationConfig,
@@ -45,13 +38,12 @@ from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
load_chat_template,
)
from vllm.entrypoints.generate.beam_search.offline import BeamSearchOfflineMixin
from vllm.entrypoints.pooling.offline import PoolingOfflineMixin
from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs import (
EngineInput,
PromptType,
TextPrompt,
TokensPrompt,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -65,7 +57,7 @@ from vllm.renderers.inputs.preprocess import (
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter
@@ -89,7 +81,7 @@ _P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
_R = TypeVar("_R", default=Any)
class LLM(PoolingOfflineMixin):
class LLM(BeamSearchOfflineMixin, PoolingOfflineMixin):
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
@@ -675,163 +667,6 @@ class LLM(PoolingOfflineMixin):
"""
return self.llm_engine.apply_model(func)
def beam_search(
self,
prompts: list[TokensPrompt | TextPrompt],
params: BeamSearchParams,
lora_request: list[LoRARequest] | LoRARequest | None = None,
use_tqdm: bool = False,
concurrency_limit: int | None = None,
) -> list[BeamSearchOutput]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
params: The beam search parameters.
lora_request: LoRA request to use for generation, if any.
use_tqdm: Whether to use tqdm to display the progress bar.
concurrency_limit: The maximum number of concurrent requests.
If None, the number of concurrent requests is unlimited.
"""
# TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.?
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
engine_inputs = self._preprocess_cmpl(prompts)
lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs))
if use_tqdm and concurrency_limit is not None:
logger.warning(
"Progress bar is not supported when using concurrency_limit. "
"Disabling progress bar."
)
use_tqdm = False
if concurrency_limit is None:
concurrency_limit = len(engine_inputs)
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
sampling_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
skip_clone=True, # Internal beam search, safe to skip clone
)
instances: list[BeamSearchInstance] = []
for lora_req, prompt in zip(lora_requests, engine_inputs):
if prompt["type"] == "embeds":
raise NotImplementedError(
"Embedding prompt not supported for beam search"
)
instances.append(
BeamSearchInstance(
prompt,
lora_request=lora_req,
logprobs=None,
),
)
for prompt_start in range(0, len(instances), concurrency_limit):
instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
token_iter = range(max_tokens)
if use_tqdm:
token_iter = tqdm(
token_iter, desc="Beam search", unit="token", unit_scale=False
)
logger.warning(
"The progress bar shows the upper bound on token steps and "
"may finish early due to stopping conditions. It does not "
"reflect instance-level progress."
)
for _ in token_iter:
all_beams: list[BeamSearchSequence] = list(
sum((instance.beams for instance in instances_batch), [])
)
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances_batch
)
)
instance_start_and_end: list[tuple[int, int]] = list(
zip(pos[:-1], pos[1:])
)
if len(all_beams) == 0:
break
# only runs for one step
# we don't need to use tqdm here
output = self._render_and_run_requests(
prompts=(beam.get_prompt() for beam in all_beams),
params=self._params_to_seq(sampling_params, len(all_beams)),
output_type=RequestOutput,
lora_requests=[beam.lora_request for beam in all_beams],
use_tqdm=False,
)
for (start, end), instance in zip(
instance_start_and_end, instances_batch
):
instance_new_beams = []
for i in range(start, end):
current_beam = all_beams[i]
result = output[i]
if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the
# max-model-len or abortion. we don't need to add
# it to the new beams.
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
current_beam.orig_prompt,
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob
+ logprob_obj.logprob,
)
if token_id == eos_token_id and not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(
instance_new_beams, key=sort_beams_key, reverse=True
)
instance.beams = sorted_beams[:beam_width]
outputs = []
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(
instance.completed, key=sort_beams_key, reverse=True
)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams))
return outputs
def _preprocess_cmpl(
self,
prompts: Sequence[PromptType],
+3 -206
View File
@@ -1,25 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import json
import time
from collections.abc import AsyncGenerator, Awaitable, Mapping
from collections.abc import Awaitable, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
import numpy as np
from fastapi import Request
from openai.types.responses import ToolChoiceFunction
from pydantic import ConfigDict, TypeAdapter, ValidationError
from starlette.datastructures import Headers
import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.generate.beam_search.online import BeamSearchOnlineMixin
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
BatchChatCompletionRequest,
@@ -56,7 +54,6 @@ from vllm.inputs import EngineInput, PromptType
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers import ChatParams, TokenizeParams
from vllm.renderers.inputs.preprocess import (
extract_prompt_components,
@@ -71,7 +68,6 @@ from vllm.tracing import (
log_tracing_disabled_warning,
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import collect_from_async_generator
from vllm.utils.mistral import is_mistral_tool_parser
logger = init_logger(__name__)
@@ -133,7 +129,7 @@ class ServeContext(Generic[RequestT]):
model_config = ConfigDict(arbitrary_types_allowed=True)
class OpenAIServing:
class OpenAIServing(BeamSearchOnlineMixin):
request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID.
"""
@@ -174,205 +170,6 @@ class OpenAIServing:
# Never fail server startup over the fingerprint.
self.system_fingerprint = None
async def beam_search(
self,
prompt: EngineInput,
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
tokenizer = self.renderer.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
if prompt["type"] == "embeds":
raise NotImplementedError("Embedding prompt not supported for beam search")
# Extract prompt tokens and text based on model type
decoder_prompt = (
prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
)
prompt_text = decoder_prompt.get("prompt")
prompt_token_ids = decoder_prompt["prompt_token_ids"]
tokenized_length = len(prompt_token_ids)
logprobs_num = 2 * beam_width
sampling_params = SamplingParams(
logprobs=logprobs_num,
max_tokens=1,
temperature=temperature,
)
all_beams = [
BeamSearchSequence(
orig_prompt=prompt,
tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[],
lora_request=lora_request,
)
]
completed = []
for _ in range(max_tokens):
tasks = []
request_id_batch = f"{request_id}-{random_uuid()}"
for i, beam in enumerate(all_beams):
prompt_item = beam.get_prompt()
lora_request_item = beam.lora_request
request_id_item = f"{request_id_batch}-beam-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.engine_client.generate(
prompt_item,
sampling_params,
request_id_item,
lora_request=lora_request_item,
trace_headers=trace_headers,
)
)
)
tasks.append(task)
output = [x[0] for x in await asyncio.gather(*tasks)]
new_beams = []
# Store all new tokens generated by beam
all_beams_token_id = []
# Store the cumulative probability of all tokens
# generated by beam search
all_beams_logprob = []
# Iterate through all beam inference results
for i, result in enumerate(output):
current_beam = all_beams[i]
# check for error finish reason and abort beam search
if result.outputs[0].finish_reason == "error":
# yield error output and terminate beam search
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
return
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
all_beams_token_id.extend(list(logprobs.keys()))
all_beams_logprob.extend(
[
current_beam.cum_logprob + obj.logprob
for obj in logprobs.values()
]
)
# Handle the token for the end of sentence (EOS)
all_beams_token_id = np.array(all_beams_token_id)
all_beams_logprob = np.array(all_beams_logprob)
if not ignore_eos:
# Get the index position of eos token in all generated results
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
for idx in eos_idx:
current_beam = all_beams[idx // logprobs_num]
result = output[idx // logprobs_num]
assert result.outputs[0].logprobs is not None
logprobs_entry = result.outputs[0].logprobs[0]
completed.append(
BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [eos_token_id]
if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs + [logprobs_entry],
cum_logprob=float(all_beams_logprob[idx]),
finish_reason="stop",
stop_reason=eos_token_id,
)
)
# After processing, set the log probability of the eos condition
# to negative infinity.
all_beams_logprob[eos_idx] = -np.inf
# Processing non-EOS tokens
# Get indices of the top beam_width probabilities
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
:beam_width
]
for idx in topn_idx:
current_beam = all_beams[idx // logprobs_num]
result = output[idx // logprobs_num]
token_id = int(all_beams_token_id[idx])
assert result.outputs[0].logprobs is not None
logprobs_entry = result.outputs[0].logprobs[0]
new_beams.append(
BeamSearchSequence(
orig_prompt=prompt,
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs_entry],
lora_request=current_beam.lora_request,
cum_logprob=float(all_beams_logprob[idx]),
)
)
all_beams = new_beams
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
if beam.tokens[-1] == eos_token_id and not ignore_eos:
# Skip the eos token in the text.
tokens = beam.tokens[tokenized_length:-1]
else:
tokens = beam.tokens[tokenized_length:]
beam.text = tokenizer.decode(tokens)
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
text=beam.text, # type: ignore
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:],
index=i,
logprobs=beam.logprobs,
finish_reason=beam.finish_reason
if beam.finish_reason is not None
else "length",
stop_reason=beam.stop_reason,
)
for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
@staticmethod
def create_error_response(
message: str | Exception,