mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Frontend] Consolidate beam search by BeamSearchMixin. (#42946)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -2003,7 +2003,7 @@ steps:
|
||||
- vllm/model_executor/layers
|
||||
- vllm/sampling_metadata.py
|
||||
- vllm/v1/sample/
|
||||
- vllm/beam_search.py
|
||||
- vllm/entrypoints/generate/beam_search/
|
||||
- tests/samplers
|
||||
- tests/conftest.py
|
||||
- vllm/_aiter_ops.py
|
||||
|
||||
@@ -11,6 +11,7 @@ steps:
|
||||
- vllm/sampling_metadata.py
|
||||
- tests/samplers
|
||||
- tests/conftest.py
|
||||
- vllm/entrypoints/generate/beam_search
|
||||
commands:
|
||||
# VLLM_USE_FLASHINFER_SAMPLER defaults to 1 now, so we need to pin both
|
||||
# values explicitly to still cover the PyTorch-native (Triton) path.
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import PromptType, RequestOutput, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
|
||||
from .utils import (
|
||||
BeamSearchInstance,
|
||||
BeamSearchOutput,
|
||||
BeamSearchSequence,
|
||||
create_sort_beams_key_function,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BeamSearchOfflineMixin(ABC):
|
||||
"""Offline inference for beam search"""
|
||||
|
||||
renderer: BaseRenderer
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
prompts: list[TokensPrompt | TextPrompt],
|
||||
params: BeamSearchParams,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
use_tqdm: bool = False,
|
||||
concurrency_limit: int | None = None,
|
||||
) -> list[BeamSearchOutput]:
|
||||
"""
|
||||
Generate sequences using beam search.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts. Each prompt can be a string or a list
|
||||
of token IDs.
|
||||
params: The beam search parameters.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
concurrency_limit: The maximum number of concurrent requests.
|
||||
If None, the number of concurrent requests is unlimited.
|
||||
"""
|
||||
# TODO: how does beam search work together with length penalty,
|
||||
# frequency, penalty, and stopping criteria, etc.?
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
temperature = params.temperature
|
||||
ignore_eos = params.ignore_eos
|
||||
length_penalty = params.length_penalty
|
||||
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
engine_inputs = self._preprocess_cmpl(prompts)
|
||||
lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs))
|
||||
|
||||
if use_tqdm and concurrency_limit is not None:
|
||||
logger.warning(
|
||||
"Progress bar is not supported when using concurrency_limit. "
|
||||
"Disabling progress bar."
|
||||
)
|
||||
use_tqdm = False
|
||||
|
||||
if concurrency_limit is None:
|
||||
concurrency_limit = len(engine_inputs)
|
||||
|
||||
# generate 2 * beam_width candidates at each step
|
||||
# following the huggingface transformers implementation
|
||||
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
|
||||
sampling_params = SamplingParams(
|
||||
logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
skip_clone=True, # Internal beam search, safe to skip clone
|
||||
)
|
||||
instances: list[BeamSearchInstance] = []
|
||||
|
||||
for lora_req, prompt in zip(lora_requests, engine_inputs):
|
||||
if prompt["type"] == "embeds":
|
||||
raise NotImplementedError(
|
||||
"Embedding prompt not supported for beam search"
|
||||
)
|
||||
|
||||
instances.append(
|
||||
BeamSearchInstance(
|
||||
prompt,
|
||||
lora_request=lora_req,
|
||||
logprobs=None,
|
||||
),
|
||||
)
|
||||
|
||||
for prompt_start in range(0, len(instances), concurrency_limit):
|
||||
instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
|
||||
|
||||
token_iter = range(max_tokens)
|
||||
if use_tqdm:
|
||||
token_iter = tqdm(
|
||||
token_iter, desc="Beam search", unit="token", unit_scale=False
|
||||
)
|
||||
logger.warning(
|
||||
"The progress bar shows the upper bound on token steps and "
|
||||
"may finish early due to stopping conditions. It does not "
|
||||
"reflect instance-level progress."
|
||||
)
|
||||
for _ in token_iter:
|
||||
all_beams: list[BeamSearchSequence] = list(
|
||||
sum((instance.beams for instance in instances_batch), [])
|
||||
)
|
||||
pos = [0] + list(
|
||||
itertools.accumulate(
|
||||
len(instance.beams) for instance in instances_batch
|
||||
)
|
||||
)
|
||||
instance_start_and_end: list[tuple[int, int]] = list(
|
||||
zip(pos[:-1], pos[1:])
|
||||
)
|
||||
|
||||
if len(all_beams) == 0:
|
||||
break
|
||||
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
output = self._render_and_run_requests(
|
||||
prompts=(beam.get_prompt() for beam in all_beams),
|
||||
params=self._params_to_seq(sampling_params, len(all_beams)),
|
||||
output_type=RequestOutput,
|
||||
lora_requests=[beam.lora_request for beam in all_beams],
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
for (start, end), instance in zip(
|
||||
instance_start_and_end, instances_batch
|
||||
):
|
||||
instance_new_beams = []
|
||||
for i in range(start, end):
|
||||
current_beam = all_beams[i]
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
# if `result.outputs[0].logprobs` is None, it means
|
||||
# the sequence is completed because of the
|
||||
# max-model-len or abortion. we don't need to add
|
||||
# it to the new beams.
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
current_beam.orig_prompt,
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
)
|
||||
|
||||
if token_id == eos_token_id and not ignore_eos:
|
||||
instance.completed.append(new_beam)
|
||||
else:
|
||||
instance_new_beams.append(new_beam)
|
||||
sorted_beams = sorted(
|
||||
instance_new_beams, key=sort_beams_key, reverse=True
|
||||
)
|
||||
instance.beams = sorted_beams[:beam_width]
|
||||
|
||||
outputs = []
|
||||
for instance in instances:
|
||||
instance.completed.extend(instance.beams)
|
||||
sorted_completed = sorted(
|
||||
instance.completed, key=sort_beams_key, reverse=True
|
||||
)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
beam.text = tokenizer.decode(beam.tokens)
|
||||
|
||||
outputs.append(BeamSearchOutput(sequences=best_beams))
|
||||
|
||||
return outputs
|
||||
|
||||
@abstractmethod
|
||||
def _preprocess_cmpl(
|
||||
self,
|
||||
prompts: Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[EngineInput]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _lora_request_to_seq(
|
||||
self,
|
||||
lora_request: LoRARequest | None | Sequence[LoRARequest | None],
|
||||
num_requests: int,
|
||||
) -> Sequence[LoRARequest | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _params_to_seq(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
num_requests: int,
|
||||
) -> Sequence[SamplingParams]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _render_and_run_requests(
|
||||
self,
|
||||
prompts: Iterable[EngineInput],
|
||||
params: Sequence[SamplingParams],
|
||||
output_type: type[RequestOutput],
|
||||
*,
|
||||
lora_requests: Sequence[LoRARequest | None] | None = None,
|
||||
priorities: Sequence[int] | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
):
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,225 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from abc import ABC
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm import CompletionOutput, RequestOutput
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import collect_from_async_generator
|
||||
|
||||
from .utils import BeamSearchSequence, create_sort_beams_key_function
|
||||
|
||||
|
||||
class BeamSearchOnlineMixin(ABC):
|
||||
"""online serving for beam search"""
|
||||
|
||||
renderer: BaseRenderer
|
||||
engine_client: EngineClient
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: EngineInput,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
ignore_eos = params.ignore_eos
|
||||
temperature = params.temperature
|
||||
length_penalty = params.length_penalty
|
||||
include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
if prompt["type"] == "embeds":
|
||||
raise NotImplementedError("Embedding prompt not supported for beam search")
|
||||
|
||||
# Extract prompt tokens and text based on model type
|
||||
decoder_prompt = (
|
||||
prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
|
||||
)
|
||||
prompt_text = decoder_prompt.get("prompt")
|
||||
prompt_token_ids = decoder_prompt["prompt_token_ids"]
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
logprobs_num = 2 * beam_width
|
||||
sampling_params = SamplingParams(
|
||||
logprobs=logprobs_num,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
all_beams = [
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=prompt_token_ids,
|
||||
cum_logprob=0,
|
||||
logprobs=[],
|
||||
lora_request=lora_request,
|
||||
)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
tasks = []
|
||||
request_id_batch = f"{request_id}-{random_uuid()}"
|
||||
|
||||
for i, beam in enumerate(all_beams):
|
||||
prompt_item = beam.get_prompt()
|
||||
lora_request_item = beam.lora_request
|
||||
request_id_item = f"{request_id_batch}-beam-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.engine_client.generate(
|
||||
prompt_item,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request_item,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
output = [x[0] for x in await asyncio.gather(*tasks)]
|
||||
|
||||
new_beams = []
|
||||
# Store all new tokens generated by beam
|
||||
all_beams_token_id = []
|
||||
# Store the cumulative probability of all tokens
|
||||
# generated by beam search
|
||||
all_beams_logprob = []
|
||||
# Iterate through all beam inference results
|
||||
for i, result in enumerate(output):
|
||||
current_beam = all_beams[i]
|
||||
|
||||
# check for error finish reason and abort beam search
|
||||
if result.outputs[0].finish_reason == "error":
|
||||
# yield error output and terminate beam search
|
||||
yield RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
return
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
all_beams_token_id.extend(list(logprobs.keys()))
|
||||
all_beams_logprob.extend(
|
||||
[
|
||||
current_beam.cum_logprob + obj.logprob
|
||||
for obj in logprobs.values()
|
||||
]
|
||||
)
|
||||
|
||||
# Handle the token for the end of sentence (EOS)
|
||||
all_beams_token_id = np.array(all_beams_token_id)
|
||||
all_beams_logprob = np.array(all_beams_logprob)
|
||||
|
||||
if not ignore_eos:
|
||||
# Get the index position of eos token in all generated results
|
||||
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
|
||||
for idx in eos_idx:
|
||||
current_beam = all_beams[idx // logprobs_num]
|
||||
result = output[idx // logprobs_num]
|
||||
assert result.outputs[0].logprobs is not None
|
||||
logprobs_entry = result.outputs[0].logprobs[0]
|
||||
completed.append(
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=current_beam.tokens + [eos_token_id]
|
||||
if include_stop_str_in_output
|
||||
else current_beam.tokens,
|
||||
logprobs=current_beam.logprobs + [logprobs_entry],
|
||||
cum_logprob=float(all_beams_logprob[idx]),
|
||||
finish_reason="stop",
|
||||
stop_reason=eos_token_id,
|
||||
)
|
||||
)
|
||||
# After processing, set the log probability of the eos condition
|
||||
# to negative infinity.
|
||||
all_beams_logprob[eos_idx] = -np.inf
|
||||
|
||||
# Processing non-EOS tokens
|
||||
# Get indices of the top beam_width probabilities
|
||||
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
|
||||
:beam_width
|
||||
]
|
||||
|
||||
for idx in topn_idx:
|
||||
current_beam = all_beams[idx // logprobs_num]
|
||||
result = output[idx // logprobs_num]
|
||||
token_id = int(all_beams_token_id[idx])
|
||||
assert result.outputs[0].logprobs is not None
|
||||
logprobs_entry = result.outputs[0].logprobs[0]
|
||||
new_beams.append(
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs_entry],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=float(all_beams_logprob[idx]),
|
||||
)
|
||||
)
|
||||
|
||||
all_beams = new_beams
|
||||
|
||||
completed.extend(all_beams)
|
||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
if beam.tokens[-1] == eos_token_id and not ignore_eos:
|
||||
# Skip the eos token in the text.
|
||||
tokens = beam.tokens[tokenized_length:-1]
|
||||
else:
|
||||
tokens = beam.tokens[tokenized_length:]
|
||||
beam.text = tokenizer.decode(tokens)
|
||||
|
||||
yield RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
text=beam.text, # type: ignore
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens[tokenized_length:],
|
||||
index=i,
|
||||
logprobs=beam.logprobs,
|
||||
finish_reason=beam.finish_reason
|
||||
if beam.finish_reason is not None
|
||||
else "length",
|
||||
stop_reason=beam.stop_reason,
|
||||
)
|
||||
for (i, beam) in enumerate(best_beams)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
+3
-168
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -12,12 +11,6 @@ from pydantic import ValidationError
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import TypeVar, overload
|
||||
|
||||
from vllm.beam_search import (
|
||||
BeamSearchInstance,
|
||||
BeamSearchOutput,
|
||||
BeamSearchSequence,
|
||||
create_sort_beams_key_function,
|
||||
)
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CompilationConfig,
|
||||
@@ -45,13 +38,12 @@ from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
load_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.generate.beam_search.offline import BeamSearchOfflineMixin
|
||||
from vllm.entrypoints.pooling.offline import PoolingOfflineMixin
|
||||
from vllm.entrypoints.utils import log_non_default_args
|
||||
from vllm.inputs import (
|
||||
EngineInput,
|
||||
PromptType,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -65,7 +57,7 @@ from vllm.renderers.inputs.preprocess import (
|
||||
parse_model_prompt,
|
||||
prompt_to_seq,
|
||||
)
|
||||
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.counter import Counter
|
||||
@@ -89,7 +81,7 @@ _P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
|
||||
class LLM(PoolingOfflineMixin):
|
||||
class LLM(BeamSearchOfflineMixin, PoolingOfflineMixin):
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
|
||||
This class includes a tokenizer, a language model (possibly distributed
|
||||
@@ -675,163 +667,6 @@ class LLM(PoolingOfflineMixin):
|
||||
"""
|
||||
return self.llm_engine.apply_model(func)
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
prompts: list[TokensPrompt | TextPrompt],
|
||||
params: BeamSearchParams,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
use_tqdm: bool = False,
|
||||
concurrency_limit: int | None = None,
|
||||
) -> list[BeamSearchOutput]:
|
||||
"""
|
||||
Generate sequences using beam search.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts. Each prompt can be a string or a list
|
||||
of token IDs.
|
||||
params: The beam search parameters.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
concurrency_limit: The maximum number of concurrent requests.
|
||||
If None, the number of concurrent requests is unlimited.
|
||||
"""
|
||||
# TODO: how does beam search work together with length penalty,
|
||||
# frequency, penalty, and stopping criteria, etc.?
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
temperature = params.temperature
|
||||
ignore_eos = params.ignore_eos
|
||||
length_penalty = params.length_penalty
|
||||
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
engine_inputs = self._preprocess_cmpl(prompts)
|
||||
lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs))
|
||||
|
||||
if use_tqdm and concurrency_limit is not None:
|
||||
logger.warning(
|
||||
"Progress bar is not supported when using concurrency_limit. "
|
||||
"Disabling progress bar."
|
||||
)
|
||||
use_tqdm = False
|
||||
|
||||
if concurrency_limit is None:
|
||||
concurrency_limit = len(engine_inputs)
|
||||
|
||||
# generate 2 * beam_width candidates at each step
|
||||
# following the huggingface transformers implementation
|
||||
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
|
||||
sampling_params = SamplingParams(
|
||||
logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
skip_clone=True, # Internal beam search, safe to skip clone
|
||||
)
|
||||
instances: list[BeamSearchInstance] = []
|
||||
|
||||
for lora_req, prompt in zip(lora_requests, engine_inputs):
|
||||
if prompt["type"] == "embeds":
|
||||
raise NotImplementedError(
|
||||
"Embedding prompt not supported for beam search"
|
||||
)
|
||||
|
||||
instances.append(
|
||||
BeamSearchInstance(
|
||||
prompt,
|
||||
lora_request=lora_req,
|
||||
logprobs=None,
|
||||
),
|
||||
)
|
||||
|
||||
for prompt_start in range(0, len(instances), concurrency_limit):
|
||||
instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
|
||||
|
||||
token_iter = range(max_tokens)
|
||||
if use_tqdm:
|
||||
token_iter = tqdm(
|
||||
token_iter, desc="Beam search", unit="token", unit_scale=False
|
||||
)
|
||||
logger.warning(
|
||||
"The progress bar shows the upper bound on token steps and "
|
||||
"may finish early due to stopping conditions. It does not "
|
||||
"reflect instance-level progress."
|
||||
)
|
||||
for _ in token_iter:
|
||||
all_beams: list[BeamSearchSequence] = list(
|
||||
sum((instance.beams for instance in instances_batch), [])
|
||||
)
|
||||
pos = [0] + list(
|
||||
itertools.accumulate(
|
||||
len(instance.beams) for instance in instances_batch
|
||||
)
|
||||
)
|
||||
instance_start_and_end: list[tuple[int, int]] = list(
|
||||
zip(pos[:-1], pos[1:])
|
||||
)
|
||||
|
||||
if len(all_beams) == 0:
|
||||
break
|
||||
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
output = self._render_and_run_requests(
|
||||
prompts=(beam.get_prompt() for beam in all_beams),
|
||||
params=self._params_to_seq(sampling_params, len(all_beams)),
|
||||
output_type=RequestOutput,
|
||||
lora_requests=[beam.lora_request for beam in all_beams],
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
for (start, end), instance in zip(
|
||||
instance_start_and_end, instances_batch
|
||||
):
|
||||
instance_new_beams = []
|
||||
for i in range(start, end):
|
||||
current_beam = all_beams[i]
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
# if `result.outputs[0].logprobs` is None, it means
|
||||
# the sequence is completed because of the
|
||||
# max-model-len or abortion. we don't need to add
|
||||
# it to the new beams.
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
current_beam.orig_prompt,
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
)
|
||||
|
||||
if token_id == eos_token_id and not ignore_eos:
|
||||
instance.completed.append(new_beam)
|
||||
else:
|
||||
instance_new_beams.append(new_beam)
|
||||
sorted_beams = sorted(
|
||||
instance_new_beams, key=sort_beams_key, reverse=True
|
||||
)
|
||||
instance.beams = sorted_beams[:beam_width]
|
||||
|
||||
outputs = []
|
||||
for instance in instances:
|
||||
instance.completed.extend(instance.beams)
|
||||
sorted_completed = sorted(
|
||||
instance.completed, key=sort_beams_key, reverse=True
|
||||
)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
beam.text = tokenizer.decode(beam.tokens)
|
||||
|
||||
outputs.append(BeamSearchOutput(sequences=best_beams))
|
||||
|
||||
return outputs
|
||||
|
||||
def _preprocess_cmpl(
|
||||
self,
|
||||
prompts: Sequence[PromptType],
|
||||
|
||||
@@ -1,25 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Awaitable, Mapping
|
||||
from collections.abc import Awaitable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from openai.types.responses import ToolChoiceFunction
|
||||
from pydantic import ConfigDict, TypeAdapter, ValidationError
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.generate.beam_search.online import BeamSearchOnlineMixin
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
BatchChatCompletionRequest,
|
||||
@@ -56,7 +54,6 @@ from vllm.inputs import EngineInput, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.renderers import ChatParams, TokenizeParams
|
||||
from vllm.renderers.inputs.preprocess import (
|
||||
extract_prompt_components,
|
||||
@@ -71,7 +68,6 @@ from vllm.tracing import (
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import collect_from_async_generator
|
||||
from vllm.utils.mistral import is_mistral_tool_parser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -133,7 +129,7 @@ class ServeContext(Generic[RequestT]):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
class OpenAIServing(BeamSearchOnlineMixin):
|
||||
request_id_prefix: ClassVar[str] = """
|
||||
A short string prepended to every request’s ID.
|
||||
"""
|
||||
@@ -174,205 +170,6 @@ class OpenAIServing:
|
||||
# Never fail server startup over the fingerprint.
|
||||
self.system_fingerprint = None
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: EngineInput,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
ignore_eos = params.ignore_eos
|
||||
temperature = params.temperature
|
||||
length_penalty = params.length_penalty
|
||||
include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
if prompt["type"] == "embeds":
|
||||
raise NotImplementedError("Embedding prompt not supported for beam search")
|
||||
|
||||
# Extract prompt tokens and text based on model type
|
||||
decoder_prompt = (
|
||||
prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
|
||||
)
|
||||
prompt_text = decoder_prompt.get("prompt")
|
||||
prompt_token_ids = decoder_prompt["prompt_token_ids"]
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
logprobs_num = 2 * beam_width
|
||||
sampling_params = SamplingParams(
|
||||
logprobs=logprobs_num,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
all_beams = [
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=prompt_token_ids,
|
||||
cum_logprob=0,
|
||||
logprobs=[],
|
||||
lora_request=lora_request,
|
||||
)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
tasks = []
|
||||
request_id_batch = f"{request_id}-{random_uuid()}"
|
||||
|
||||
for i, beam in enumerate(all_beams):
|
||||
prompt_item = beam.get_prompt()
|
||||
lora_request_item = beam.lora_request
|
||||
request_id_item = f"{request_id_batch}-beam-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.engine_client.generate(
|
||||
prompt_item,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request_item,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
output = [x[0] for x in await asyncio.gather(*tasks)]
|
||||
|
||||
new_beams = []
|
||||
# Store all new tokens generated by beam
|
||||
all_beams_token_id = []
|
||||
# Store the cumulative probability of all tokens
|
||||
# generated by beam search
|
||||
all_beams_logprob = []
|
||||
# Iterate through all beam inference results
|
||||
for i, result in enumerate(output):
|
||||
current_beam = all_beams[i]
|
||||
|
||||
# check for error finish reason and abort beam search
|
||||
if result.outputs[0].finish_reason == "error":
|
||||
# yield error output and terminate beam search
|
||||
yield RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
return
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
all_beams_token_id.extend(list(logprobs.keys()))
|
||||
all_beams_logprob.extend(
|
||||
[
|
||||
current_beam.cum_logprob + obj.logprob
|
||||
for obj in logprobs.values()
|
||||
]
|
||||
)
|
||||
|
||||
# Handle the token for the end of sentence (EOS)
|
||||
all_beams_token_id = np.array(all_beams_token_id)
|
||||
all_beams_logprob = np.array(all_beams_logprob)
|
||||
|
||||
if not ignore_eos:
|
||||
# Get the index position of eos token in all generated results
|
||||
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
|
||||
for idx in eos_idx:
|
||||
current_beam = all_beams[idx // logprobs_num]
|
||||
result = output[idx // logprobs_num]
|
||||
assert result.outputs[0].logprobs is not None
|
||||
logprobs_entry = result.outputs[0].logprobs[0]
|
||||
completed.append(
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=current_beam.tokens + [eos_token_id]
|
||||
if include_stop_str_in_output
|
||||
else current_beam.tokens,
|
||||
logprobs=current_beam.logprobs + [logprobs_entry],
|
||||
cum_logprob=float(all_beams_logprob[idx]),
|
||||
finish_reason="stop",
|
||||
stop_reason=eos_token_id,
|
||||
)
|
||||
)
|
||||
# After processing, set the log probability of the eos condition
|
||||
# to negative infinity.
|
||||
all_beams_logprob[eos_idx] = -np.inf
|
||||
|
||||
# Processing non-EOS tokens
|
||||
# Get indices of the top beam_width probabilities
|
||||
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
|
||||
:beam_width
|
||||
]
|
||||
|
||||
for idx in topn_idx:
|
||||
current_beam = all_beams[idx // logprobs_num]
|
||||
result = output[idx // logprobs_num]
|
||||
token_id = int(all_beams_token_id[idx])
|
||||
assert result.outputs[0].logprobs is not None
|
||||
logprobs_entry = result.outputs[0].logprobs[0]
|
||||
new_beams.append(
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs_entry],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=float(all_beams_logprob[idx]),
|
||||
)
|
||||
)
|
||||
|
||||
all_beams = new_beams
|
||||
|
||||
completed.extend(all_beams)
|
||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
if beam.tokens[-1] == eos_token_id and not ignore_eos:
|
||||
# Skip the eos token in the text.
|
||||
tokens = beam.tokens[tokenized_length:-1]
|
||||
else:
|
||||
tokens = beam.tokens[tokenized_length:]
|
||||
beam.text = tokenizer.decode(tokens)
|
||||
|
||||
yield RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt_text,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
text=beam.text, # type: ignore
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens[tokenized_length:],
|
||||
index=i,
|
||||
logprobs=beam.logprobs,
|
||||
finish_reason=beam.finish_reason
|
||||
if beam.finish_reason is not None
|
||||
else "length",
|
||||
stop_reason=beam.stop_reason,
|
||||
)
|
||||
for (i, beam) in enumerate(best_beams)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_error_response(
|
||||
message: str | Exception,
|
||||
|
||||
Reference in New Issue
Block a user