[TRTLLM-10030][perf] beam search (remove GPU sync + fix batching + refactor) (#11276)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-02-05 15:33:51 +01:00 committed by GitHub
parent e483c7263d
commit 719e82c429
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 253 additions and 248 deletions

View File

@ -21,7 +21,7 @@ from concurrent import futures
from dataclasses import dataclass
from functools import cached_property
from itertools import repeat
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeAlias, TypeVar, cast
import numpy as np
import torch
@ -264,42 +264,11 @@ class RequestGroupValue:
need_processed_logprobs: torch.Tensor
need_raw_logprobs: torch.Tensor
def __iter__(self):
return iter(
(
self.indices,
self.strategies,
self.speculation_needs_probs_indices,
self.need_processed_logprobs,
self.need_raw_logprobs,
)
)
def __len__(self):
return 5
@dataclass(kw_only=True, frozen=True)
class RequestGroupValueWithMetadata(RequestGroupValue):
metadata: StrategyMetadata | None
@override
def __iter__(self):
return iter(
(
self.indices,
self.strategies,
self.speculation_needs_probs_indices,
self.need_processed_logprobs,
self.need_raw_logprobs,
self.metadata,
)
)
@override
def __len__(self):
return 6
class EarlyStopWithMMResult(Sampler):
"""
@ -307,7 +276,7 @@ class EarlyStopWithMMResult(Sampler):
"""
@override
def sample_async(
def sample_async( # type: ignore
self,
scheduled_requests: ScheduledRequests,
model_outputs,
@ -322,7 +291,7 @@ class EarlyStopWithMMResult(Sampler):
return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data)
@override
def update_requests(
def update_requests( # type: ignore
self,
state: SampleStateWithMMResult,
resource_manager: Optional[ResourceManager] = None,
@ -341,9 +310,9 @@ class EarlyStopWithMMResult(Sampler):
request.state = LlmRequestState.GENERATION_COMPLETE
# NOTE: This is a hack: set finish reason manually and set the beam 0
request.set_finished_reason(FinishReason.LENGTH, 0)
if len(mm_embedding) != sum(request.multimodal_lengths):
if len(mm_embedding) != sum(request.multimodal_lengths): # type: ignore
raise ValueError(
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}"
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}" # type: ignore
)
request.py_result.append_mm_embeddings(mm_embedding)
@ -385,7 +354,7 @@ def _get_max_beam_width(request: LlmRequest) -> int:
sampling_config = request.sampling_config
max_beam_width = sampling_config.beam_width
if sampling_config.beam_width_array is not None:
max_beam_width = max(max_beam_width, sampling_config.beam_width_array.max())
max_beam_width = max(max_beam_width, sampling_config.beam_width_array.max()) # type: ignore
return max_beam_width
@ -416,23 +385,24 @@ def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
# We try to cache the resolved strategy on the request object, as it's not cheap enough to
# resolve it on every iteration.
if hasattr(request, "py_sampling_strategy"):
return request.py_sampling_strategy
return request.py_sampling_strategy # type: ignore
params = _request_get_sampling_params(request)
sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
if _request_sampling_params_cachable(params):
request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) # type: ignore
return sampling_strategy
def _group_requests_by_strategy_key(
requests: Iterable[LlmRequest],
*,
strategy_to_key: Callable[[Strategy, bool], GenericStrategyKeyType],
strategy_to_key: Callable[[Strategy], GenericStrategyKeyType],
pin_memory: bool = False,
vocab_size: int,
) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue]:
# NB: Client code relies on request indices in returned torch.Tensor being sorted.
# NB: Client code relies on request indices in returned torch.Tensor being ordered
# by ascending "requests" index.
RequestGroupValueBuilder = namedtuple(
"RequestGroupValueBuilder",
[
@ -444,8 +414,8 @@ def _group_requests_by_strategy_key(
],
)
group_dict: dict[RequestGroupKey, RequestGroupValueBuilder] = defaultdict(
lambda: RequestGroupValueBuilder([], [], [], [], [])
group_dict: dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValueBuilder] = (
defaultdict(lambda: RequestGroupValueBuilder([], [], [], [], []))
)
for req_index, req in enumerate(requests):
@ -460,7 +430,7 @@ def _group_requests_by_strategy_key(
)
need_raw_logprobs = req.py_logprobs_mode == LogprobMode.RAW and req.return_log_probs
needs_probs = speculation_needs_probs or need_processed_logprobs
strategy_key = strategy_to_key(strategy, needs_probs)
strategy_key = strategy_to_key(strategy)
group_dict_entry = group_dict[
RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs)
]
@ -768,7 +738,7 @@ DEFAULT_BEAM_IDX = 0
# Step index to use when no speculative decoding is used but a step index is required
DEFAULT_STEP_IDX = 0
FinishReasonsList = list[list[int]]
FinishReasonsList: TypeAlias = list[list[list[int]]]
@dataclass(kw_only=True)
@ -797,7 +767,7 @@ class SamplingRequestsMetadata:
@dataclass(kw_only=True)
class SampleStateTensorsHostTorch(SampleStateTensors):
finish_reasons: torch.Tensor
first_finish_reasons: torch.Tensor
first_finish_reasons: torch.Tensor | None
logprobs_state: LogProbsState | None = None
def finish_reasons_list(self) -> FinishReasonsList:
@ -808,7 +778,7 @@ class SampleStateTensorsHostTorch(SampleStateTensors):
@dataclass(kw_only=True)
class SampleStateTorch(SampleState):
host: SampleStateTensorsHostTorch
host: SampleStateTensorsHostTorch # type: ignore
beam_histories: list[BeamHistory | None] | None = None
@ -827,7 +797,7 @@ class AsyncWorkerMixin:
def _async_worker_init(self, enable_async_worker: bool):
self._enable_async_worker = enable_async_worker
self._async_worker = None
self._async_worker_futures: list[futures.Future[any]] = []
self._async_worker_futures: list[futures.Future[Any]] = []
def async_worker_enabled(self):
return getattr(self, "_enable_async_worker", False)
@ -853,6 +823,7 @@ class AsyncWorkerMixin:
def async_worker_stop(self):
assert self.async_worker_enabled()
if self._async_worker_active():
assert self._async_worker is not None
self._async_worker.shutdown(wait=True)
self._async_worker = None
@ -888,6 +859,7 @@ class AsyncWorkerMixin:
copy_ready.record()
# Submit the copy to the async worker thread
assert self._async_worker is not None
result = self._async_worker.submit(
self._async_copy_to_host, copy_ready, dest, src_snapshot
)
@ -980,8 +952,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
# Tensors necessary for all sampling methods
new_tokens = int_tensor(self.NEW_TOKENS_SHAPE)
finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE)
max_lengths_tensor = int_tensor(self.max_num_sequences)
end_ids = int_tensor(self.max_num_sequences)
max_lengths_tensor = int_tensor((self.max_num_sequences,))
end_ids = int_tensor((self.max_num_sequences,))
# Only used for logprobs processing or beam search
sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32)
@ -1236,6 +1208,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
!= FinishReason.NOT_FINISHED.value
).sum() == request.sampling_config.beam_width:
request.state = LlmRequestState.GENERATION_COMPLETE
assert request.py_seq_slot is not None
for beam_idx in range(request.sampling_config.beam_width):
request.set_finished_reason(
FinishReason(
@ -1255,6 +1228,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
When using streaming a requests tokens may be altered, leading to wrong results when called multiple times.
Store the original tokens in a separate buffer to use them as a consistent basis
when updating the tokens in a request."""
assert self.store.original_tokens is not None
assert new_tokens.device == self.store.original_tokens.device, (
"new_tokens and original_tokens must be on the same device"
)
@ -1311,6 +1285,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
assert beam_idx == DEFAULT_BEAM_IDX, (
"beam search does not need to explicitly handle sampled log probs"
)
assert sampled_log_probs_indices_list is not None
assert sampled_log_probs_vals_list is not None
assert sampled_log_probs_rank_list is not None
if sampled_log_probs_indices_list[step_idx] not in logprobs:
logprobs[sampled_log_probs_indices_list[step_idx]] = Logprob(
logprob=sampled_log_probs_vals_list[step_idx],
@ -1388,6 +1365,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
beam_width = request.sampling_config.beam_width
assert request.py_num_logprobs is not None, "request.py_num_logprobs must be provided"
assert logprobs_state_list is not None, "logprobs_state_list must be provided"
assert request.py_seq_slot is not None
token_log_probs = self._store_logprobs_list_to_request(
logprobs_state_list, request.py_seq_slot, beam_width, count, request.py_num_logprobs
)
@ -1396,6 +1374,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
def finish_if_reason(
self, request: LlmRequest, finish_reasons: FinishReasonsList, *, step: int, beam_idx: int
) -> bool:
assert request.py_seq_slot is not None
reason = FinishReason(finish_reasons[request.py_seq_slot][step][beam_idx])
valid_reasons = {FinishReason.END_ID, FinishReason.LENGTH, FinishReason.STOP_WORDS}
if reason in valid_reasons:
@ -1548,6 +1527,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
)
)
@override
@override
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
"""Setup the sampler step for the requests
@ -1563,6 +1543,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
end_ids: list[int] = []
for request in scheduled_requests.context_requests:
if self._is_new_request(request):
assert request.py_seq_slot is not None
seq_slots.append(request.py_seq_slot)
max_lens.append(
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
@ -1593,6 +1574,13 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
if self._is_new_request(request):
if request.py_return_log_probs and request.py_num_logprobs > 1:
raise ValueError("Beam search does not support multiple logprobs")
assert self.store.cache_indirection is not None
assert self.store.cum_log_probs is not None
assert self.store.sampled_log_probs is not None
assert self.store.sampled_log_prob_ranks is not None
assert self.store.predecessor_beams is not None
assert self.store.first_finish_reasons is not None
assert self.store.original_tokens is not None
self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0)
self.store.cum_log_probs[request.py_seq_slot].fill_(0)
self.store.sampled_log_probs[request.py_seq_slot].fill_(0)
@ -1765,9 +1753,11 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
)
if hasattr(request.py_result._log_probs, "log_probs"):
logprobs_list = request.py_result.log_probs
assert logprobs_list is not None
for beam_idx, beam_logprobs in enumerate(logprobs_list):
for token_idx, token_logprobs in enumerate(beam_logprobs):
for key, value in token_logprobs.items():
assert value.rank is not None
logprobs_tensor[beam_idx, token_idx, value.rank - 1] = value.logprob
logprobs_indices_tensor[beam_idx, token_idx, value.rank - 1] = key
return logprobs_tensor, logprobs_indices_tensor
@ -1795,6 +1785,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
if num_generated_tokens == 0 or request.state == LlmRequestState.GENERATION_COMPLETE:
# early return if no tokens have been generated yet or the request is already finished
return None
assert self.store.cache_indirection is not None
assert self.store.original_tokens is not None
assert self.store.sampled_log_probs is not None
cache_indirection = self.store.cache_indirection[
request.py_seq_slot, :num_beams, prompt_length:num_tokens
]
@ -1802,7 +1795,13 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
request.py_seq_slot, :num_beams, prompt_length:num_tokens
]
new_path = torch.zeros_like(current_path)
# initialize each beam with its own index
# Gather the correct tokens and logprobs for each beam
torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path)
if request.py_return_log_probs:
assert self.store.cum_log_probs is not None
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(request)
# concatenate the newly generated logprobs and newly
# generated tokens to the current logprobs and logprobs indices
@ -1823,11 +1822,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
# Initialize the buffers to store the results
new_logprobs = torch.zeros_like(current_logprobs)
new_logprobs_indices = torch.zeros_like(current_logprobs_indices)
# initialize each beam with its own index
# Gather the correct tokens and logprobs for each beam
torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path)
if request.py_return_log_probs:
cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand(
-1, -1, current_logprobs.shape[2]
)
@ -1877,6 +1872,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
{beam_history.tokens.shape[0]} != {beam_width}"
)
if request.py_return_log_probs:
assert beam_history.logprobs is not None
assert beam_history.logprobs_indices is not None
assert beam_history.cum_logprobs is not None
assert beam_history.logprobs.shape[0] == beam_width, (
f"Beam_history.logprobs.shape[0] should equal beam width: \
{beam_history.logprobs.shape[0]} != {beam_width}"
@ -1895,6 +1893,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
for beam_idx in range(beam_width):
gen_token_list.append(beam_history.tokens[beam_idx, : valid_tokens[beam_idx]].tolist())
if request.py_return_log_probs:
assert beam_history.logprobs_indices is not None
assert beam_history.logprobs is not None
gen_log_probs_list.append(
self._convert_logprobs_tensor_to_list(
beam_history.logprobs_indices[
@ -1910,6 +1910,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
if request.py_return_log_probs:
# cum_log_probs will not change when padding with end tokens.
# Therefore, we do not need to correct it
assert beam_history.cum_logprobs is not None
request.py_result.set_log_probs(
gen_log_probs_list, cum_log_probs=beam_history.cum_logprobs.tolist()
)
@ -1920,7 +1921,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
grouped_requests: dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue],
seq_slots: torch.Tensor,
seq_lens: torch.Tensor | None,
get_metadata_type_for_group_fn: Callable[[GenericStrategyKeyType], Type[StrategyMetadata]],
get_metadata_type_for_group_fn: Callable[
[GenericStrategyKeyType], Type[StrategyMetadata] | None
],
) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValueWithMetadata]:
grouped_requests_with_metadata: dict[
RequestGroupKey[GenericStrategyKeyType], RequestGroupValueWithMetadata
@ -1929,6 +1932,12 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
metadata_type = get_metadata_type_for_group_fn(key.strategy_key)
if metadata_type is BeamSearchMetadata:
assert seq_lens is not None, "seq_lens is required for beam search"
assert self.store.cache_indirection is not None
assert self.store.cache_indirection_buffer is not None
assert self.store.cum_log_probs is not None
assert self.store.sampled_log_probs is not None
assert self.store.first_finish_reasons is not None
assert self.store.predecessor_beams is not None
metadata = BeamSearchMetadata(
cache_indirection=self.store.cache_indirection,
cache_indirection_buffer=self.store.cache_indirection_buffer,
@ -1982,7 +1991,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
_, cumsum = request.py_stop_words_list
if -1 in cumsum:
cumsum = cumsum[: cumsum.index(-1)]
longest_stop_word_len = np.max(np.diff(cumsum, prepend=0), initial=0)
longest_stop_word_len = np.max(np.diff(cumsum, prepend=0), initial=0).item()
return longest_stop_word_len > 1
return False
@ -2010,9 +2019,10 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
@torch.inference_mode()
def update_requests(
self,
state: SampleStateTorch,
state: Sampler.SampleState,
resource_manager: Optional[ResourceManager] = None,
) -> None:
state = cast(SampleStateTorch, state)
assert isinstance(state, SampleStateTorch)
if state.sampler_event:
state.sampler_event.synchronize()
@ -2036,7 +2046,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
if beam_histories is not None and beam_histories[req_idx] is not None:
self._finalize_beam(
req,
beam_histories[req_idx],
cast(BeamHistory, beam_histories[req_idx]),
)
else:
for beam_idx in range(req.sampling_config.beam_width):
@ -2055,7 +2065,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
if beam_histories is not None and beam_histories[req_idx] is not None:
self._finalize_beam(
req,
beam_histories[req_idx],
cast(BeamHistory, beam_histories[req_idx]),
)
else:
for beam_idx in range(req.sampling_config.beam_width):
@ -2100,6 +2110,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
self.max_tokens,
self.max_topk_logprobs,
)
assert self.store.topk_vals is not None
assert self.store.topk_indices is not None
self.store.topk_vals.resize_(self.TOPK_LOGPROBS_SHAPE)
self.store.topk_indices.resize_(self.TOPK_LOGPROBS_SHAPE)
@ -2155,8 +2167,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
)
finish_reasons_host = self._copy_to_host(finish_reasons)
beam_histories = [None] * len(requests)
beam_histories: list[BeamHistory | None] = [None] * len(requests)
if self._use_beam_search:
assert first_finish_reasons is not None
assert seq_lens_host is not None, "seq_lens is required for beam search"
assert self.store.first_finish_reasons is not None, (
"first_finish_reasons must be provided"
@ -2167,6 +2180,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
self._maybe_create_beam_histories(
requests, finish_reasons=first_finish_reasons, beam_histories=beam_histories
)
else:
first_finish_reasons_host = None
# copy logprobs to host
logprobs_state: LogProbsState | None = None
@ -2204,9 +2219,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
host=SampleStateTensorsHostTorch(
new_tokens=new_tokens_host,
finish_reasons=finish_reasons_host,
first_finish_reasons=None
if not self._use_beam_search
else first_finish_reasons_host,
first_finish_reasons=first_finish_reasons_host,
logprobs_state=logprobs_state,
),
sampler_event=sampler_event,
@ -2393,14 +2406,19 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
)
batch_req_idx_offset_start = 0
batch_next_tokens_offset_start = 0
for (strategy_key, needs_probs), (
group_req_indices,
group_strategies,
group_speculation_needs_probs_indices,
group_need_processed_logprobs,
group_need_raw_logprobs,
group_metadata,
) in grouped_requests_with_metadata.items():
for (
strategy_key,
needs_probs,
), group_val_with_metadata in grouped_requests_with_metadata.items():
group_req_indices = group_val_with_metadata.indices
group_strategies = group_val_with_metadata.strategies
group_speculation_needs_probs_indices = (
group_val_with_metadata.speculation_needs_probs_indices
)
group_need_processed_logprobs = group_val_with_metadata.need_processed_logprobs
group_need_raw_logprobs = group_val_with_metadata.need_raw_logprobs
group_metadata = group_val_with_metadata.metadata
# group_req_indices: Indices of 'requests' entries having the same sampling
# strategy, ordered ascending.
batch_req_idx_offset_end = batch_req_idx_offset_start + group_req_indices.size(0)
@ -2420,22 +2438,19 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
# indices for accessing logits within the current group
group_logit_indexer = _PackedStepIndexer(
num_steps=req_num_generated_tokens[group_req_indices],
max_steps=req_num_generated_tokens.max() * self.max_beam_width,
max_steps=cast(
int, req_num_generated_tokens.max().item() * self.max_beam_width
),
)
logit_indices_for_processed_logprobs_cuda = (
None
if not any_request_needs_processed_logprobs
else group_logit_indexer[need_processed_logprobs_indices].to(
logits_cuda.device, non_blocking=True
)
)
logit_indices_for_raw_logprobs_cuda = (
None
if not any_request_needs_raw_logprobs
else group_logit_indexer[need_raw_logprobs_indices].to(
logits_cuda.device, non_blocking=True
)
)
logit_indices_for_processed_logprobs_cuda = group_logit_indexer[
need_processed_logprobs_indices
].to(logits_cuda.device, non_blocking=True)
logit_indices_for_raw_logprobs_cuda = group_logit_indexer[
need_raw_logprobs_indices
].to(logits_cuda.device, non_blocking=True)
else:
logit_indices_for_processed_logprobs_cuda = None
logit_indices_for_raw_logprobs_cuda = None
group_logits_cuda_indices = logits_cuda_indexer[group_req_indices]
# NB: Assuming that group_req_indices are sorted
@ -2518,13 +2533,13 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
processed_logits_cuda = torch.where(
current_softmax_cuda > 0, current_logits_cuda, float("-inf")
)
if group_temperature_cuda is not None:
if isinstance(group_temperature_cuda, torch.Tensor):
processed_logits_cuda /= group_temperature_cuda[
logit_indices_for_processed_logprobs_cuda
]
else:
processed_logits_cuda /= group_temperature_cuda
temperature_for_processed_logprobs = group_temperature_cuda
if isinstance(temperature_for_processed_logprobs, torch.Tensor):
temperature_for_processed_logprobs = cast(torch.Tensor, group_temperature_cuda)[
logit_indices_for_processed_logprobs_cuda
].unsqueeze(-1)
if temperature_for_processed_logprobs is not None:
processed_logits_cuda /= temperature_for_processed_logprobs
logit_indices_for_processed_logprobs_cuda += batch_next_tokens_offset_start
batch_logits_for_logprobs_cuda[logit_indices_for_processed_logprobs_cuda] = (
processed_logits_cuda
@ -2768,6 +2783,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
max_stop_word_len = 0
for req in requests:
assert req.py_stop_words_list is not None
_, cumsum = req.py_stop_words_list
if -1 in cumsum:
cumsum = cumsum[: cumsum.index(-1)]
@ -2971,13 +2987,14 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
padded_tokens = self._padded_old_tokens(requests, tokens, predecessor_beams)
for request_idx, request in enumerate(requests):
assert request.py_stop_words_list is not None
swl, ends = request.py_stop_words_list
if -1 in ends:
ends = ends[: ends.index(-1)]
lens = np.diff(ends, prepend=0)
max_len = np.max(lens)
words = torch.zeros(len(lens), max_len, dtype=torch.int32, pin_memory=True)
words = torch.zeros(len(lens), max_len.item(), dtype=torch.int32, pin_memory=True)
for step, (start, length) in enumerate(zip([0] + ends, lens)):
words[step, :length] = torch.tensor(swl[start : start + length], dtype=torch.int32)
words_device = words.to("cuda", non_blocking=True)
@ -3119,6 +3136,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
).scatter_(dim=0, index=padded_indices_cuda, src=sampled_rank_cuda)
if self._use_beam_search:
assert self.store.sampled_log_prob_indices is not None
local_group_req_indices_with_beam_search = torch.tensor(
[
req_id
@ -3186,8 +3204,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
logits_cuda = self._apply_min_length_penalty(
logits_cuda,
requests,
sampling_requests_metadata.req_num_steps,
sampling_requests_metadata.req_num_beams,
sampling_requests_metadata.req_num_steps.tolist(),
sampling_requests_metadata.req_num_beams.tolist(),
)
# Fast path for greedy sampling
@ -3303,7 +3321,7 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
class SampleStateTRTLLM(SampleState):
finalize_events: dict[str, CudaEvent] | None = None
"""`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`"""
host: Optional[SampleStateTensorsHostTRTLLM] = None
host: Optional[SampleStateTensorsHostTRTLLM] = None # type: ignore
class TRTLLMSampler(Sampler, AsyncWorkerMixin):
@ -3337,7 +3355,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
self.logits_datatype = DataType.FLOAT
self.decoding_mode = decoding_mode
self.decoding_config = decoding_config if decoding_config else DecodingConfig(decoding_mode)
max_attn_window = kv_cache_config.max_attention_window
max_attn_window = kv_cache_config.max_attention_window # type: ignore
self.max_seq_len = max_seq_len
self.max_attention_window = (
max(max_attn_window) if max_attn_window is not None else max_seq_len
@ -3416,8 +3434,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
def _instantiate_algorithms(self):
self.algs = Algorithms()
self.algs.decoder = GptDecoderBatched(stream=self.store["torch_stream"])
self.algs.decoder.setup(
self.algs.decoder = GptDecoderBatched(stream=self.store["torch_stream"]) # type: ignore
self.algs.decoder.setup( # type: ignore
mode=self.decoding_mode,
max_num_sequences=self.max_num_sequences,
max_beam_width=self.max_beam_width,
@ -3425,18 +3443,18 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
model_config=self.model_config,
world_config=self.world_config,
)
self.algs.create_new_decoder_requests = CreateNewDecoderRequests(
self.algs.create_new_decoder_requests = CreateNewDecoderRequests( # type: ignore
speculative_decoding_fast_logits=False,
is_leader_in_orch_mode=False,
is_normalize_log_probs=False,
)
self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput()
self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput() # type: ignore
@torch.inference_mode()
@nvtx_range("setup_sampler_step")
def setup_sampler_step(self, requests):
def setup_sampler_step(self, requests): # type: ignore
batch_slots, sampling_configs, lookahead_prompt, lookahead_algo_configs = (
self.algs.create_new_decoder_requests(
self.algs.create_new_decoder_requests( # type: ignore
self.model_config,
self.world_config,
self.decoding_config,
@ -3445,7 +3463,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
self.store["decoder_input_buffers"][self.micro_batch_idx],
self.store["decoder_state"],
self.store["cuda_stream"],
self.algs.decoder.decoder_stream,
self.algs.decoder.decoder_stream, # type: ignore
self.max_seq_len,
self.beam_width(requests.context_requests),
)
@ -3454,7 +3472,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
local_batch_size = len(batch_slots)
if local_batch_size > 0:
sampling_config = make_sampling_config(sampling_configs)
self.algs.decoder.underlying_decoder().setup(
self.algs.decoder.underlying_decoder().setup( # type: ignore
sampling_config,
local_batch_size,
batch_slots,
@ -3468,9 +3486,9 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
batch_size = len(adp)
if batch_size == 0:
return
config = make_sampling_config([r.sampling_config for r in adp])
config = make_sampling_config([r.sampling_config for r in adp]) # type: ignore
slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32)
self.algs.decoder.underlying_decoder().setup(config, batch_size, slots)
self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) # type: ignore
def get_cache_indirection(self) -> torch.Tensor | None:
return self.store["decoder_state"].cache_indirection_output
@ -3519,7 +3537,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
self.store["buffer_manager"],
)
self.algs.decoder.forward_async(
self.algs.decoder.forward_async( # type: ignore
self.store["decoder_state"],
self.store["decoder_input_buffers"][self.micro_batch_idx],
)
@ -3574,7 +3592,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
@torch.inference_mode()
@override
def update_requests(
def update_requests( # type: ignore
self,
state: SampleStateTRTLLM,
resource_manager: Optional[ResourceManager] = None,
@ -3598,8 +3616,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
@nvtx_range("update_requests_single_beam_single_step")
def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
"""Specialization of update_requests for single beam and single step"""
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
finish_reasons = state.host.finish_reasons.flatten().tolist()
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore
finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore
reqs = [
r for r in state.scheduled_requests.context_requests if not r.is_context_init_state
@ -3618,7 +3636,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
seq_slots = []
seq_slots_need_log_probs = []
for request in reqs:
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0):
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0): # type: ignore
continue
reqs_with_new_tokens.append(request)
@ -3628,19 +3646,19 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
seq_slots_need_log_probs.append(request.py_seq_slot)
# [maxTokensPerStep, batchSize, maxBeamWidth]
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist()
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist() # type: ignore
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)
# Log probs
if state.host.log_probs is not None:
if state.host.log_probs is not None: # type: ignore
# [batchSize, maxBeamWidth]
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1 # type: ignore
# [batchSize, maxBeamWidth, maxSequenceLength]
log_probs_host = state.host.log_probs[
log_probs_host = state.host.log_probs[ # type: ignore
seq_slots_need_log_probs, 0, seq_last_idx
].tolist()
# [batchSize, maxBeamWidth]
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist()
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist() # type: ignore
log_probs_idx = 0
for request, new_token in zip(reqs_with_new_tokens, new_tokens):
@ -3659,7 +3677,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
for request in reqs:
request.py_decoding_iter += 1
finished_state = FinishedState(finish_reasons[request.py_seq_slot])
finished_state = FinishedState(finish_reasons[request.py_seq_slot]) # type: ignore
if finished_state.is_finished:
request.state = LlmRequestState.GENERATION_COMPLETE
finish_reason = finished_state.to_finish_reason()
@ -3672,14 +3690,14 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
state: SampleStateTRTLLM,
beam_width: int,
):
new_tokens_host = state.host.new_tokens.tolist()
finished_sum_host = state.host.finished_sum.tolist()
finish_reasons = state.host.finish_reasons.flatten().tolist()
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
new_tokens_host = state.host.new_tokens.tolist() # type: ignore
finished_sum_host = state.host.finished_sum.tolist() # type: ignore
finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore
cum_log_probs_host = (
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None # type: ignore
)
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None # type: ignore
finalize_events = state.finalize_events
reqs = [
@ -3700,7 +3718,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
cum_log_probs = []
for beam_idx in range(beam_width):
seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx]
seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx] # type: ignore
num_new_tokens[beam_idx] = min(
num_generated_tokens, seq_len - request.get_num_tokens(beam_idx)
)
@ -3709,7 +3727,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
new_token = add_token(request, new_tokens_host, beam_idx=beam_idx, step=step)
if request.py_return_log_probs:
assert state.host.log_probs is not None
assert state.host.log_probs is not None # type: ignore
# NOTE: Log probs with drafting has not been tested yet.
begin_log_probs_offset = (
request.prompt_len if request.sampling_config.beam_width == 1 else 0
@ -3720,7 +3738,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
log_probs[beam_idx].append(
{
new_token: Logprob(
logprob=log_probs_host[seq_slot][beam_idx][
logprob=log_probs_host[seq_slot][beam_idx][ # type: ignore
begin_log_probs_offset + current_token
],
rank=1,
@ -3729,9 +3747,9 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
)
if request.py_return_log_probs:
cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx])
cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx]) # type: ignore
finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx])
finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx]) # type: ignore
if finished_state.is_finished:
finish_reason = finished_state.to_finish_reason()
request.set_finished_reason(finish_reason, beam_idx)
@ -3748,7 +3766,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
if request.state != LlmRequestState.GENERATION_COMPLETE:
request.py_decoding_iter += 1
if finished_sum_host[seq_slot] == beam_width:
if finished_sum_host[seq_slot] == beam_width: # type: ignore
request.state = LlmRequestState.GENERATION_COMPLETE
for request in reqs:
if finalize_events is not None and request.request_id in finalize_events:
@ -3761,7 +3779,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
):
"""Finalizes the request. This is necessary for beam search."""
seq_slot = request.py_seq_slot
event = self.algs.decoder.finalize(
event = self.algs.decoder.finalize( # type: ignore
self.store["decoder_state"], seq_slot, request.sampling_config, streaming
)
return event
@ -3775,15 +3793,15 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
beam_width = request.sampling_config.beam_width
# synchronize on the finalize event before continuing the post processing.
# should be unnecessary, as already wait for the sampler event in update_requests
state.finalize_events[request.request_id].synchronize()
state.finalize_events[request.request_id].synchronize() # type: ignore
# Get these values again, as they might have changed during the finalize step
output_ids_host = state.host.gathered_ids
sequence_lengths_host = state.host.sequence_lengths
output_ids_host = state.host.gathered_ids # type: ignore
sequence_lengths_host = state.host.sequence_lengths # type: ignore
if request.py_return_log_probs:
log_probs_host = state.host.log_probs
cum_log_probs_host = state.host.cum_log_probs
log_probs_host = state.host.log_probs # type: ignore
cum_log_probs_host = state.host.cum_log_probs # type: ignore
generated_tokens = [[0]] * beam_width
log_probs = [[] for _ in range(beam_width)]
@ -3796,11 +3814,11 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
sequence_lengths_host[seq_slot, beam_idx].item() - request.py_prompt_len
)
end = begin + generated_length
generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist()
generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist() # type: ignore
# get the correct log probs for beam search
if request.py_return_log_probs:
cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item())
cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item()) # type: ignore
begin_log_probs_offset = (
request.prompt_len if request.sampling_config.beam_width == 1 else 0
@ -3809,7 +3827,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
log_probs[beam_idx].append(
{
token: Logprob(
logprob=log_probs_host[seq_slot, beam_idx][
logprob=log_probs_host[seq_slot, beam_idx][ # type: ignore
begin_log_probs_offset + current_token
].item(),
rank=1,

View File

@ -20,6 +20,7 @@ referring to types like LlmRequest.
import abc
import sys
from collections.abc import Hashable
from dataclasses import dataclass
from typing import Generic, Literal, Optional, Type, TypeAlias, TypeVar, cast
@ -289,9 +290,8 @@ def beam_search_sampling_batch(
beam_width_out: int,
beam_search_args: BeamSearchMetadata,
temperature: float | None,
generator: Optional[torch.Generator] = None,
return_probs: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Sample <beam_width> tokens for each request in parallel.
"""
@ -518,19 +518,18 @@ def sample(
beam_width_out=beam_width_out,
beam_search_args=group_metadata,
temperature=temperature,
generator=generator,
return_probs=return_probs,
)
return tokens, softmax, temperature
GenericStrategyKeyType = TypeVar("GenericStrategyKeyType")
GenericStrategyKeyType = TypeVar("GenericStrategyKeyType", bound=Hashable)
class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
@staticmethod
@abc.abstractmethod
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> GenericStrategyKeyType:
def strategy_grouping_key(strategy: Strategy) -> GenericStrategyKeyType:
raise NotImplementedError
@staticmethod
@ -552,6 +551,13 @@ class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
return_probs: bool,
group_metadata: StrategyMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]:
"""Sample grouped strategies.
Returns:
- Sampled tokens
- Processed probs (whenever return_probs=True)
- Temperature (used to compute processed _log_ probs)
"""
raise NotImplementedError
@ -560,7 +566,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
@override
@staticmethod
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE:
def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE:
return strategy
@override
@ -585,7 +591,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
generator: torch.Generator | None = None,
return_probs: bool,
group_metadata: StrategyMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, float | None]:
) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]:
if group_key[0] == "beam_search":
beam_width_in = group_key[1]
else:

View File

@ -20,7 +20,7 @@ referring to types like LlmRequest.
import abc
import sys
from typing import Optional, Type, TypeAlias, cast
from typing import Literal, Optional, Type, TypeAlias, cast
import flashinfer.sampling
import torch
@ -61,6 +61,9 @@ class _StrategyImpls:
def computes_probs(cls) -> bool:
pass
def get_temperature(self) -> torch.Tensor | None:
return getattr(self, "_temperature", None)
@abc.abstractmethod
def sample(
self,
@ -177,8 +180,8 @@ class _StrategyImpls:
class BeamSearchMixin(StrategyImpl):
def __init__(
self,
beam_width_in: torch.Tensor,
beam_width_out: torch.Tensor,
beam_width_in: int,
beam_width_out: int,
temperature: torch.Tensor,
):
self._beam_width_in = beam_width_in
@ -192,12 +195,8 @@ class _StrategyImpls:
) -> "_StrategyImpls.BeamSearchMixin":
assert all(strat[0] == "beam_search" for strat in strategies)
narrowed_strats = cast(list[BeamSearch], strategies)
beam_width_in = cls._make_tensor(
[strat[1] for strat in narrowed_strats], torch.int32, cuda_device
)
beam_width_out = cls._make_tensor(
[strat[2] for strat in narrowed_strats], torch.int32, cuda_device
)
(beam_width_in,) = set(strat[1] for strat in narrowed_strats)
(beam_width_out,) = set(strat[2] for strat in narrowed_strats)
temperature = cls._make_tensor(
[strat[3] or 1.0 for strat in narrowed_strats], torch.float32, cuda_device
)
@ -215,22 +214,15 @@ class _StrategyImpls:
assert group_metadata is not None and isinstance(group_metadata, BeamSearchMetadata), (
"BeamSearchMetadata is required for beam_search_sampling_batch"
)
assert torch.unique(self._beam_width_in).numel() == 1, (
"beam_width_in must be the same for all strategies"
)
assert torch.unique(self._beam_width_out).numel() == 1, (
"beam_width_out must be the same for all strategies"
)
logits = self._prepare_logits_with_temperature(
logits, group_logit_indices, self._temperature
)
# Convert from 1 temperature per request to 1 temperature per (request, beam)
temperature = self._temperature.repeat_interleave(self._beam_width_in)
logits = self._prepare_logits_with_temperature(logits, group_logit_indices, temperature)
return beam_search_sampling_batch(
logits,
beam_width_in=self._beam_width_in[0],
beam_width_out=self._beam_width_out[0],
beam_width_in=self._beam_width_in,
beam_width_out=self._beam_width_out,
beam_search_args=group_metadata,
temperature=None,
generator=generator,
return_probs=self.computes_probs(),
)
@ -648,84 +640,52 @@ class _StrategyImpls:
pass
def _create_beam_search_specialized_cls(
beam_width_in: torch.Tensor,
beam_width_out: torch.Tensor,
return_probs: bool,
) -> Type[_StrategyImpls.BeamSearchMixin]:
"""Create a class that implements BeamSearchMixin with static parameters for grouping."""
class BeamSearchSpecialized(
_StrategyImpls.BeamSearchWithProbs if return_probs else _StrategyImpls.BeamSearchSampleOnly
):
static_beam_width_in = beam_width_in
static_beam_width_out = beam_width_out
@override
def __hash__(self) -> int:
return hash((super(), self.static_beam_width_in, self.static_beam_width_out))
@override
def __eq__(self, other: object) -> bool:
return (
super().__eq__(other)
and self.static_beam_width_in == other.static_beam_width_in
and self.static_beam_width_out == other.static_beam_width_out
)
return BeamSearchSpecialized
_STRATEGY_KEY_TYPE: TypeAlias = (
Literal["temperature"]
| Literal["top_k"]
| Literal["top_p"]
| Literal["top_k_top_p"]
| Literal["greedy"]
| tuple[Literal["beam_search"], int, int]
)
class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpls.StrategyImpl]]):
class FlashInferGroupedStrategySampler(GroupedStrategySampler[_STRATEGY_KEY_TYPE]):
"""Implements batched sampling with FlashInfer.sampling kernels.
Note: Currently, FlashInfer.sampling appears to have limited CUDA graph
support, see https://github.com/flashinfer-ai/flashinfer/issues/978.
"""
STRATEGY_KEY_TYPE: TypeAlias = Type[_StrategyImpls.StrategyImpl]
STRATEGY_KEY_TYPE: TypeAlias = _STRATEGY_KEY_TYPE
@override
@staticmethod
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE:
if return_probs:
match strategy:
case ("top_k", _, _):
return _StrategyImpls.TopKWithProbs
case ("top_p", _, _):
return _StrategyImpls.TopPWithProbs
case ("top_k_top_p", _, _, _):
return _StrategyImpls.TopKTopPWithProbs
case ("temperature", _):
return _StrategyImpls.TemperatureOnlyWithProbs
case ("greedy", None):
return _StrategyImpls.GreedyWithProbs
case ("beam_search", beam_width_in, beam_width_out, _):
return _create_beam_search_specialized_cls(beam_width_in, beam_width_out, True)
else:
match strategy:
case ("top_p", _, _):
return _StrategyImpls.TopPSampleOnly
case ("top_k", _, _):
return _StrategyImpls.TopKSampleOnly
case ("top_k_top_p", _, _, _):
return _StrategyImpls.TopKTopPSampleOnly
case ("temperature", _):
return _StrategyImpls.TemperatureOnlySampleOnly
case ("greedy", None):
return _StrategyImpls.GreedySampleOnly
case ("beam_search", beam_width_in, beam_width_out, _):
return _create_beam_search_specialized_cls(beam_width_in, beam_width_out, False)
def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE:
match strategy:
case (
("top_k", _, _)
| ("top_p", _, _)
| ("top_k_top_p", _, _, _)
| ("temperature", _)
| ("greedy", None)
):
return strategy[0]
case ("beam_search", beam_width_in, beam_width_out, _):
return (strategy[0], beam_width_in, beam_width_out)
case _:
raise NotImplementedError("Unsupported strategy encountered")
@override
@staticmethod
def get_metadata_type_for_group(
strategy_key: STRATEGY_KEY_TYPE,
) -> Type[StrategyMetadata] | None:
if issubclass(strategy_key, _StrategyImpls.BeamSearchMixin):
return BeamSearchMetadata
else:
return None
match strategy_key:
case ("beam_search", _, _):
return BeamSearchMetadata
case _:
return None
@override
@staticmethod
@ -739,28 +699,52 @@ class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpl
return_probs: bool,
group_metadata: StrategyMetadata | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if hasattr(group_key, "static_beam_width_in"):
beam_width_in = group_key.static_beam_width_in
beam_width_in = 1
if return_probs:
match group_key:
case "top_k":
strategy_impl_cls = _StrategyImpls.TopKWithProbs
case "top_p":
strategy_impl_cls = _StrategyImpls.TopPWithProbs
case "top_k_top_p":
strategy_impl_cls = _StrategyImpls.TopKTopPWithProbs
case "temperature":
strategy_impl_cls = _StrategyImpls.TemperatureOnlyWithProbs
case "greedy":
strategy_impl_cls = _StrategyImpls.GreedyWithProbs
case ("beam_search", beam_width_in_key, _):
beam_width_in = beam_width_in_key
strategy_impl_cls = _StrategyImpls.BeamSearchWithProbs
case _:
raise NotImplementedError("Unsupported strategy key encountered")
else:
beam_width_in = 1
match group_key:
case "top_p":
strategy_impl_cls = _StrategyImpls.TopPSampleOnly
case "top_k":
strategy_impl_cls = _StrategyImpls.TopKSampleOnly
case "top_k_top_p":
strategy_impl_cls = _StrategyImpls.TopKTopPSampleOnly
case "temperature":
strategy_impl_cls = _StrategyImpls.TemperatureOnlySampleOnly
case "greedy":
strategy_impl_cls = _StrategyImpls.GreedySampleOnly
case ("beam_search", beam_width_in_key, _):
beam_width_in = beam_width_in_key
strategy_impl_cls = _StrategyImpls.BeamSearchSampleOnly
case _:
raise NotImplementedError("Unsupported strategy key encountered")
if group_logit_indices is None:
assert logits.size(0) == beam_width_in * len(strategies)
else:
assert group_logit_indices.size(0) == beam_width_in * len(strategies)
assert return_probs == group_key.computes_probs()
strategy_impl_cls = group_key
sampling_object = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device)
next_tokens, softmax = sampling_object.sample(
strategy_impl = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device)
next_tokens, softmax = strategy_impl.sample(
logits,
group_logit_indices=group_logit_indices,
generator=generator,
group_metadata=group_metadata,
)
temperature = (
sampling_object._temperature.unsqueeze(-1)
if sampling_object._temperature is not None
else None
)
return next_tokens, softmax, temperature
return next_tokens, softmax, strategy_impl.get_temperature()

View File

@ -436,7 +436,6 @@ def test_beam_search_sampling_batch_basic():
beam_width_out=beam_width,
beam_search_args=beam_search_args,
temperature=temperature,
generator=None,
return_probs=True,
)

View File

@ -35,7 +35,6 @@ import torch
from scipy.stats import power_divergence
from utils.util import assert_no_cuda_sync, force_ampere
from tensorrt_llm._torch.pyexecutor import sampling_utils_flashinfer
from tensorrt_llm._torch.pyexecutor.llm_request import convert_wordlist
from tensorrt_llm._torch.pyexecutor.sampler import (
GREEDY,
@ -1563,11 +1562,10 @@ class TestBatchedSampling:
return_probs: bool,
group_metadata: StrategyMetadata | None = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert issubclass(group_key, sampling_utils_flashinfer._StrategyImpls.StrategyImpl)
assert generator is sampler.get_generator(logits.device)
nonlocal flashinfer_keys_seen
assert group_key not in flashinfer_keys_seen
flashinfer_keys_seen.add(group_key)
assert (group_key, return_probs) not in flashinfer_keys_seen
flashinfer_keys_seen.add((group_key, return_probs))
return sample_grouped_strategies_orig(
group_key,
strategies,