mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-10030][perf] beam search (remove GPU sync + fix batching + refactor) (#11276)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
e483c7263d
commit
719e82c429
@ -21,7 +21,7 @@ from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from itertools import repeat
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeAlias, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -264,42 +264,11 @@ class RequestGroupValue:
|
||||
need_processed_logprobs: torch.Tensor
|
||||
need_raw_logprobs: torch.Tensor
|
||||
|
||||
def __iter__(self):
|
||||
return iter(
|
||||
(
|
||||
self.indices,
|
||||
self.strategies,
|
||||
self.speculation_needs_probs_indices,
|
||||
self.need_processed_logprobs,
|
||||
self.need_raw_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return 5
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class RequestGroupValueWithMetadata(RequestGroupValue):
|
||||
metadata: StrategyMetadata | None
|
||||
|
||||
@override
|
||||
def __iter__(self):
|
||||
return iter(
|
||||
(
|
||||
self.indices,
|
||||
self.strategies,
|
||||
self.speculation_needs_probs_indices,
|
||||
self.need_processed_logprobs,
|
||||
self.need_raw_logprobs,
|
||||
self.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def __len__(self):
|
||||
return 6
|
||||
|
||||
|
||||
class EarlyStopWithMMResult(Sampler):
|
||||
"""
|
||||
@ -307,7 +276,7 @@ class EarlyStopWithMMResult(Sampler):
|
||||
"""
|
||||
|
||||
@override
|
||||
def sample_async(
|
||||
def sample_async( # type: ignore
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
model_outputs,
|
||||
@ -322,7 +291,7 @@ class EarlyStopWithMMResult(Sampler):
|
||||
return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data)
|
||||
|
||||
@override
|
||||
def update_requests(
|
||||
def update_requests( # type: ignore
|
||||
self,
|
||||
state: SampleStateWithMMResult,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
@ -341,9 +310,9 @@ class EarlyStopWithMMResult(Sampler):
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
# NOTE: This is a hack: set finish reason manually and set the beam 0
|
||||
request.set_finished_reason(FinishReason.LENGTH, 0)
|
||||
if len(mm_embedding) != sum(request.multimodal_lengths):
|
||||
if len(mm_embedding) != sum(request.multimodal_lengths): # type: ignore
|
||||
raise ValueError(
|
||||
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}"
|
||||
f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}" # type: ignore
|
||||
)
|
||||
|
||||
request.py_result.append_mm_embeddings(mm_embedding)
|
||||
@ -385,7 +354,7 @@ def _get_max_beam_width(request: LlmRequest) -> int:
|
||||
sampling_config = request.sampling_config
|
||||
max_beam_width = sampling_config.beam_width
|
||||
if sampling_config.beam_width_array is not None:
|
||||
max_beam_width = max(max_beam_width, sampling_config.beam_width_array.max())
|
||||
max_beam_width = max(max_beam_width, sampling_config.beam_width_array.max()) # type: ignore
|
||||
return max_beam_width
|
||||
|
||||
|
||||
@ -416,23 +385,24 @@ def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
|
||||
# We try to cache the resolved strategy on the request object, as it's not cheap enough to
|
||||
# resolve it on every iteration.
|
||||
if hasattr(request, "py_sampling_strategy"):
|
||||
return request.py_sampling_strategy
|
||||
return request.py_sampling_strategy # type: ignore
|
||||
|
||||
params = _request_get_sampling_params(request)
|
||||
sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
|
||||
if _request_sampling_params_cachable(params):
|
||||
request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
|
||||
request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) # type: ignore
|
||||
return sampling_strategy
|
||||
|
||||
|
||||
def _group_requests_by_strategy_key(
|
||||
requests: Iterable[LlmRequest],
|
||||
*,
|
||||
strategy_to_key: Callable[[Strategy, bool], GenericStrategyKeyType],
|
||||
strategy_to_key: Callable[[Strategy], GenericStrategyKeyType],
|
||||
pin_memory: bool = False,
|
||||
vocab_size: int,
|
||||
) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue]:
|
||||
# NB: Client code relies on request indices in returned torch.Tensor being sorted.
|
||||
# NB: Client code relies on request indices in returned torch.Tensor being ordered
|
||||
# by ascending "requests" index.
|
||||
RequestGroupValueBuilder = namedtuple(
|
||||
"RequestGroupValueBuilder",
|
||||
[
|
||||
@ -444,8 +414,8 @@ def _group_requests_by_strategy_key(
|
||||
],
|
||||
)
|
||||
|
||||
group_dict: dict[RequestGroupKey, RequestGroupValueBuilder] = defaultdict(
|
||||
lambda: RequestGroupValueBuilder([], [], [], [], [])
|
||||
group_dict: dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValueBuilder] = (
|
||||
defaultdict(lambda: RequestGroupValueBuilder([], [], [], [], []))
|
||||
)
|
||||
|
||||
for req_index, req in enumerate(requests):
|
||||
@ -460,7 +430,7 @@ def _group_requests_by_strategy_key(
|
||||
)
|
||||
need_raw_logprobs = req.py_logprobs_mode == LogprobMode.RAW and req.return_log_probs
|
||||
needs_probs = speculation_needs_probs or need_processed_logprobs
|
||||
strategy_key = strategy_to_key(strategy, needs_probs)
|
||||
strategy_key = strategy_to_key(strategy)
|
||||
group_dict_entry = group_dict[
|
||||
RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs)
|
||||
]
|
||||
@ -768,7 +738,7 @@ DEFAULT_BEAM_IDX = 0
|
||||
# Step index to use when no speculative decoding is used but a step index is required
|
||||
DEFAULT_STEP_IDX = 0
|
||||
|
||||
FinishReasonsList = list[list[int]]
|
||||
FinishReasonsList: TypeAlias = list[list[list[int]]]
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -797,7 +767,7 @@ class SamplingRequestsMetadata:
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTensorsHostTorch(SampleStateTensors):
|
||||
finish_reasons: torch.Tensor
|
||||
first_finish_reasons: torch.Tensor
|
||||
first_finish_reasons: torch.Tensor | None
|
||||
logprobs_state: LogProbsState | None = None
|
||||
|
||||
def finish_reasons_list(self) -> FinishReasonsList:
|
||||
@ -808,7 +778,7 @@ class SampleStateTensorsHostTorch(SampleStateTensors):
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTorch(SampleState):
|
||||
host: SampleStateTensorsHostTorch
|
||||
host: SampleStateTensorsHostTorch # type: ignore
|
||||
beam_histories: list[BeamHistory | None] | None = None
|
||||
|
||||
|
||||
@ -827,7 +797,7 @@ class AsyncWorkerMixin:
|
||||
def _async_worker_init(self, enable_async_worker: bool):
|
||||
self._enable_async_worker = enable_async_worker
|
||||
self._async_worker = None
|
||||
self._async_worker_futures: list[futures.Future[any]] = []
|
||||
self._async_worker_futures: list[futures.Future[Any]] = []
|
||||
|
||||
def async_worker_enabled(self):
|
||||
return getattr(self, "_enable_async_worker", False)
|
||||
@ -853,6 +823,7 @@ class AsyncWorkerMixin:
|
||||
def async_worker_stop(self):
|
||||
assert self.async_worker_enabled()
|
||||
if self._async_worker_active():
|
||||
assert self._async_worker is not None
|
||||
self._async_worker.shutdown(wait=True)
|
||||
self._async_worker = None
|
||||
|
||||
@ -888,6 +859,7 @@ class AsyncWorkerMixin:
|
||||
copy_ready.record()
|
||||
|
||||
# Submit the copy to the async worker thread
|
||||
assert self._async_worker is not None
|
||||
result = self._async_worker.submit(
|
||||
self._async_copy_to_host, copy_ready, dest, src_snapshot
|
||||
)
|
||||
@ -980,8 +952,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
# Tensors necessary for all sampling methods
|
||||
new_tokens = int_tensor(self.NEW_TOKENS_SHAPE)
|
||||
finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE)
|
||||
max_lengths_tensor = int_tensor(self.max_num_sequences)
|
||||
end_ids = int_tensor(self.max_num_sequences)
|
||||
max_lengths_tensor = int_tensor((self.max_num_sequences,))
|
||||
end_ids = int_tensor((self.max_num_sequences,))
|
||||
|
||||
# Only used for logprobs processing or beam search
|
||||
sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32)
|
||||
@ -1236,6 +1208,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
!= FinishReason.NOT_FINISHED.value
|
||||
).sum() == request.sampling_config.beam_width:
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
assert request.py_seq_slot is not None
|
||||
for beam_idx in range(request.sampling_config.beam_width):
|
||||
request.set_finished_reason(
|
||||
FinishReason(
|
||||
@ -1255,6 +1228,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
When using streaming a requests tokens may be altered, leading to wrong results when called multiple times.
|
||||
Store the original tokens in a separate buffer to use them as a consistent basis
|
||||
when updating the tokens in a request."""
|
||||
assert self.store.original_tokens is not None
|
||||
assert new_tokens.device == self.store.original_tokens.device, (
|
||||
"new_tokens and original_tokens must be on the same device"
|
||||
)
|
||||
@ -1311,6 +1285,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
assert beam_idx == DEFAULT_BEAM_IDX, (
|
||||
"beam search does not need to explicitly handle sampled log probs"
|
||||
)
|
||||
assert sampled_log_probs_indices_list is not None
|
||||
assert sampled_log_probs_vals_list is not None
|
||||
assert sampled_log_probs_rank_list is not None
|
||||
if sampled_log_probs_indices_list[step_idx] not in logprobs:
|
||||
logprobs[sampled_log_probs_indices_list[step_idx]] = Logprob(
|
||||
logprob=sampled_log_probs_vals_list[step_idx],
|
||||
@ -1388,6 +1365,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
beam_width = request.sampling_config.beam_width
|
||||
assert request.py_num_logprobs is not None, "request.py_num_logprobs must be provided"
|
||||
assert logprobs_state_list is not None, "logprobs_state_list must be provided"
|
||||
assert request.py_seq_slot is not None
|
||||
token_log_probs = self._store_logprobs_list_to_request(
|
||||
logprobs_state_list, request.py_seq_slot, beam_width, count, request.py_num_logprobs
|
||||
)
|
||||
@ -1396,6 +1374,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
def finish_if_reason(
|
||||
self, request: LlmRequest, finish_reasons: FinishReasonsList, *, step: int, beam_idx: int
|
||||
) -> bool:
|
||||
assert request.py_seq_slot is not None
|
||||
reason = FinishReason(finish_reasons[request.py_seq_slot][step][beam_idx])
|
||||
valid_reasons = {FinishReason.END_ID, FinishReason.LENGTH, FinishReason.STOP_WORDS}
|
||||
if reason in valid_reasons:
|
||||
@ -1548,6 +1527,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
@override
|
||||
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
|
||||
"""Setup the sampler step for the requests
|
||||
@ -1563,6 +1543,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
end_ids: list[int] = []
|
||||
for request in scheduled_requests.context_requests:
|
||||
if self._is_new_request(request):
|
||||
assert request.py_seq_slot is not None
|
||||
seq_slots.append(request.py_seq_slot)
|
||||
max_lens.append(
|
||||
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
|
||||
@ -1593,6 +1574,13 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
if self._is_new_request(request):
|
||||
if request.py_return_log_probs and request.py_num_logprobs > 1:
|
||||
raise ValueError("Beam search does not support multiple logprobs")
|
||||
assert self.store.cache_indirection is not None
|
||||
assert self.store.cum_log_probs is not None
|
||||
assert self.store.sampled_log_probs is not None
|
||||
assert self.store.sampled_log_prob_ranks is not None
|
||||
assert self.store.predecessor_beams is not None
|
||||
assert self.store.first_finish_reasons is not None
|
||||
assert self.store.original_tokens is not None
|
||||
self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0)
|
||||
self.store.cum_log_probs[request.py_seq_slot].fill_(0)
|
||||
self.store.sampled_log_probs[request.py_seq_slot].fill_(0)
|
||||
@ -1765,9 +1753,11 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
)
|
||||
if hasattr(request.py_result._log_probs, "log_probs"):
|
||||
logprobs_list = request.py_result.log_probs
|
||||
assert logprobs_list is not None
|
||||
for beam_idx, beam_logprobs in enumerate(logprobs_list):
|
||||
for token_idx, token_logprobs in enumerate(beam_logprobs):
|
||||
for key, value in token_logprobs.items():
|
||||
assert value.rank is not None
|
||||
logprobs_tensor[beam_idx, token_idx, value.rank - 1] = value.logprob
|
||||
logprobs_indices_tensor[beam_idx, token_idx, value.rank - 1] = key
|
||||
return logprobs_tensor, logprobs_indices_tensor
|
||||
@ -1795,6 +1785,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
if num_generated_tokens == 0 or request.state == LlmRequestState.GENERATION_COMPLETE:
|
||||
# early return if no tokens have been generated yet or the request is already finished
|
||||
return None
|
||||
assert self.store.cache_indirection is not None
|
||||
assert self.store.original_tokens is not None
|
||||
assert self.store.sampled_log_probs is not None
|
||||
cache_indirection = self.store.cache_indirection[
|
||||
request.py_seq_slot, :num_beams, prompt_length:num_tokens
|
||||
]
|
||||
@ -1802,7 +1795,13 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
request.py_seq_slot, :num_beams, prompt_length:num_tokens
|
||||
]
|
||||
new_path = torch.zeros_like(current_path)
|
||||
|
||||
# initialize each beam with its own index
|
||||
|
||||
# Gather the correct tokens and logprobs for each beam
|
||||
torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path)
|
||||
if request.py_return_log_probs:
|
||||
assert self.store.cum_log_probs is not None
|
||||
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(request)
|
||||
# concatenate the newly generated logprobs and newly
|
||||
# generated tokens to the current logprobs and logprobs indices
|
||||
@ -1823,11 +1822,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
# Initialize the buffers to store the results
|
||||
new_logprobs = torch.zeros_like(current_logprobs)
|
||||
new_logprobs_indices = torch.zeros_like(current_logprobs_indices)
|
||||
# initialize each beam with its own index
|
||||
|
||||
# Gather the correct tokens and logprobs for each beam
|
||||
torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path)
|
||||
if request.py_return_log_probs:
|
||||
cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand(
|
||||
-1, -1, current_logprobs.shape[2]
|
||||
)
|
||||
@ -1877,6 +1872,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
{beam_history.tokens.shape[0]} != {beam_width}"
|
||||
)
|
||||
if request.py_return_log_probs:
|
||||
assert beam_history.logprobs is not None
|
||||
assert beam_history.logprobs_indices is not None
|
||||
assert beam_history.cum_logprobs is not None
|
||||
assert beam_history.logprobs.shape[0] == beam_width, (
|
||||
f"Beam_history.logprobs.shape[0] should equal beam width: \
|
||||
{beam_history.logprobs.shape[0]} != {beam_width}"
|
||||
@ -1895,6 +1893,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
for beam_idx in range(beam_width):
|
||||
gen_token_list.append(beam_history.tokens[beam_idx, : valid_tokens[beam_idx]].tolist())
|
||||
if request.py_return_log_probs:
|
||||
assert beam_history.logprobs_indices is not None
|
||||
assert beam_history.logprobs is not None
|
||||
gen_log_probs_list.append(
|
||||
self._convert_logprobs_tensor_to_list(
|
||||
beam_history.logprobs_indices[
|
||||
@ -1910,6 +1910,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
if request.py_return_log_probs:
|
||||
# cum_log_probs will not change when padding with end tokens.
|
||||
# Therefore, we do not need to correct it
|
||||
assert beam_history.cum_logprobs is not None
|
||||
request.py_result.set_log_probs(
|
||||
gen_log_probs_list, cum_log_probs=beam_history.cum_logprobs.tolist()
|
||||
)
|
||||
@ -1920,7 +1921,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
grouped_requests: dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue],
|
||||
seq_slots: torch.Tensor,
|
||||
seq_lens: torch.Tensor | None,
|
||||
get_metadata_type_for_group_fn: Callable[[GenericStrategyKeyType], Type[StrategyMetadata]],
|
||||
get_metadata_type_for_group_fn: Callable[
|
||||
[GenericStrategyKeyType], Type[StrategyMetadata] | None
|
||||
],
|
||||
) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValueWithMetadata]:
|
||||
grouped_requests_with_metadata: dict[
|
||||
RequestGroupKey[GenericStrategyKeyType], RequestGroupValueWithMetadata
|
||||
@ -1929,6 +1932,12 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
metadata_type = get_metadata_type_for_group_fn(key.strategy_key)
|
||||
if metadata_type is BeamSearchMetadata:
|
||||
assert seq_lens is not None, "seq_lens is required for beam search"
|
||||
assert self.store.cache_indirection is not None
|
||||
assert self.store.cache_indirection_buffer is not None
|
||||
assert self.store.cum_log_probs is not None
|
||||
assert self.store.sampled_log_probs is not None
|
||||
assert self.store.first_finish_reasons is not None
|
||||
assert self.store.predecessor_beams is not None
|
||||
metadata = BeamSearchMetadata(
|
||||
cache_indirection=self.store.cache_indirection,
|
||||
cache_indirection_buffer=self.store.cache_indirection_buffer,
|
||||
@ -1982,7 +1991,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
_, cumsum = request.py_stop_words_list
|
||||
if -1 in cumsum:
|
||||
cumsum = cumsum[: cumsum.index(-1)]
|
||||
longest_stop_word_len = np.max(np.diff(cumsum, prepend=0), initial=0)
|
||||
longest_stop_word_len = np.max(np.diff(cumsum, prepend=0), initial=0).item()
|
||||
return longest_stop_word_len > 1
|
||||
return False
|
||||
|
||||
@ -2010,9 +2019,10 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
@torch.inference_mode()
|
||||
def update_requests(
|
||||
self,
|
||||
state: SampleStateTorch,
|
||||
state: Sampler.SampleState,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> None:
|
||||
state = cast(SampleStateTorch, state)
|
||||
assert isinstance(state, SampleStateTorch)
|
||||
if state.sampler_event:
|
||||
state.sampler_event.synchronize()
|
||||
@ -2036,7 +2046,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
if beam_histories is not None and beam_histories[req_idx] is not None:
|
||||
self._finalize_beam(
|
||||
req,
|
||||
beam_histories[req_idx],
|
||||
cast(BeamHistory, beam_histories[req_idx]),
|
||||
)
|
||||
else:
|
||||
for beam_idx in range(req.sampling_config.beam_width):
|
||||
@ -2055,7 +2065,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
if beam_histories is not None and beam_histories[req_idx] is not None:
|
||||
self._finalize_beam(
|
||||
req,
|
||||
beam_histories[req_idx],
|
||||
cast(BeamHistory, beam_histories[req_idx]),
|
||||
)
|
||||
else:
|
||||
for beam_idx in range(req.sampling_config.beam_width):
|
||||
@ -2100,6 +2110,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
self.max_tokens,
|
||||
self.max_topk_logprobs,
|
||||
)
|
||||
assert self.store.topk_vals is not None
|
||||
assert self.store.topk_indices is not None
|
||||
self.store.topk_vals.resize_(self.TOPK_LOGPROBS_SHAPE)
|
||||
self.store.topk_indices.resize_(self.TOPK_LOGPROBS_SHAPE)
|
||||
|
||||
@ -2155,8 +2167,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
)
|
||||
finish_reasons_host = self._copy_to_host(finish_reasons)
|
||||
|
||||
beam_histories = [None] * len(requests)
|
||||
beam_histories: list[BeamHistory | None] = [None] * len(requests)
|
||||
if self._use_beam_search:
|
||||
assert first_finish_reasons is not None
|
||||
assert seq_lens_host is not None, "seq_lens is required for beam search"
|
||||
assert self.store.first_finish_reasons is not None, (
|
||||
"first_finish_reasons must be provided"
|
||||
@ -2167,6 +2180,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
self._maybe_create_beam_histories(
|
||||
requests, finish_reasons=first_finish_reasons, beam_histories=beam_histories
|
||||
)
|
||||
else:
|
||||
first_finish_reasons_host = None
|
||||
|
||||
# copy logprobs to host
|
||||
logprobs_state: LogProbsState | None = None
|
||||
@ -2204,9 +2219,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
host=SampleStateTensorsHostTorch(
|
||||
new_tokens=new_tokens_host,
|
||||
finish_reasons=finish_reasons_host,
|
||||
first_finish_reasons=None
|
||||
if not self._use_beam_search
|
||||
else first_finish_reasons_host,
|
||||
first_finish_reasons=first_finish_reasons_host,
|
||||
logprobs_state=logprobs_state,
|
||||
),
|
||||
sampler_event=sampler_event,
|
||||
@ -2393,14 +2406,19 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
)
|
||||
batch_req_idx_offset_start = 0
|
||||
batch_next_tokens_offset_start = 0
|
||||
for (strategy_key, needs_probs), (
|
||||
group_req_indices,
|
||||
group_strategies,
|
||||
group_speculation_needs_probs_indices,
|
||||
group_need_processed_logprobs,
|
||||
group_need_raw_logprobs,
|
||||
group_metadata,
|
||||
) in grouped_requests_with_metadata.items():
|
||||
for (
|
||||
strategy_key,
|
||||
needs_probs,
|
||||
), group_val_with_metadata in grouped_requests_with_metadata.items():
|
||||
group_req_indices = group_val_with_metadata.indices
|
||||
group_strategies = group_val_with_metadata.strategies
|
||||
group_speculation_needs_probs_indices = (
|
||||
group_val_with_metadata.speculation_needs_probs_indices
|
||||
)
|
||||
group_need_processed_logprobs = group_val_with_metadata.need_processed_logprobs
|
||||
group_need_raw_logprobs = group_val_with_metadata.need_raw_logprobs
|
||||
group_metadata = group_val_with_metadata.metadata
|
||||
|
||||
# group_req_indices: Indices of 'requests' entries having the same sampling
|
||||
# strategy, ordered ascending.
|
||||
batch_req_idx_offset_end = batch_req_idx_offset_start + group_req_indices.size(0)
|
||||
@ -2420,22 +2438,19 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
# indices for accessing logits within the current group
|
||||
group_logit_indexer = _PackedStepIndexer(
|
||||
num_steps=req_num_generated_tokens[group_req_indices],
|
||||
max_steps=req_num_generated_tokens.max() * self.max_beam_width,
|
||||
max_steps=cast(
|
||||
int, req_num_generated_tokens.max().item() * self.max_beam_width
|
||||
),
|
||||
)
|
||||
logit_indices_for_processed_logprobs_cuda = (
|
||||
None
|
||||
if not any_request_needs_processed_logprobs
|
||||
else group_logit_indexer[need_processed_logprobs_indices].to(
|
||||
logits_cuda.device, non_blocking=True
|
||||
)
|
||||
)
|
||||
logit_indices_for_raw_logprobs_cuda = (
|
||||
None
|
||||
if not any_request_needs_raw_logprobs
|
||||
else group_logit_indexer[need_raw_logprobs_indices].to(
|
||||
logits_cuda.device, non_blocking=True
|
||||
)
|
||||
)
|
||||
logit_indices_for_processed_logprobs_cuda = group_logit_indexer[
|
||||
need_processed_logprobs_indices
|
||||
].to(logits_cuda.device, non_blocking=True)
|
||||
logit_indices_for_raw_logprobs_cuda = group_logit_indexer[
|
||||
need_raw_logprobs_indices
|
||||
].to(logits_cuda.device, non_blocking=True)
|
||||
else:
|
||||
logit_indices_for_processed_logprobs_cuda = None
|
||||
logit_indices_for_raw_logprobs_cuda = None
|
||||
|
||||
group_logits_cuda_indices = logits_cuda_indexer[group_req_indices]
|
||||
# NB: Assuming that group_req_indices are sorted
|
||||
@ -2518,13 +2533,13 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
processed_logits_cuda = torch.where(
|
||||
current_softmax_cuda > 0, current_logits_cuda, float("-inf")
|
||||
)
|
||||
if group_temperature_cuda is not None:
|
||||
if isinstance(group_temperature_cuda, torch.Tensor):
|
||||
processed_logits_cuda /= group_temperature_cuda[
|
||||
logit_indices_for_processed_logprobs_cuda
|
||||
]
|
||||
else:
|
||||
processed_logits_cuda /= group_temperature_cuda
|
||||
temperature_for_processed_logprobs = group_temperature_cuda
|
||||
if isinstance(temperature_for_processed_logprobs, torch.Tensor):
|
||||
temperature_for_processed_logprobs = cast(torch.Tensor, group_temperature_cuda)[
|
||||
logit_indices_for_processed_logprobs_cuda
|
||||
].unsqueeze(-1)
|
||||
if temperature_for_processed_logprobs is not None:
|
||||
processed_logits_cuda /= temperature_for_processed_logprobs
|
||||
logit_indices_for_processed_logprobs_cuda += batch_next_tokens_offset_start
|
||||
batch_logits_for_logprobs_cuda[logit_indices_for_processed_logprobs_cuda] = (
|
||||
processed_logits_cuda
|
||||
@ -2768,6 +2783,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
|
||||
max_stop_word_len = 0
|
||||
for req in requests:
|
||||
assert req.py_stop_words_list is not None
|
||||
_, cumsum = req.py_stop_words_list
|
||||
if -1 in cumsum:
|
||||
cumsum = cumsum[: cumsum.index(-1)]
|
||||
@ -2971,13 +2987,14 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
padded_tokens = self._padded_old_tokens(requests, tokens, predecessor_beams)
|
||||
|
||||
for request_idx, request in enumerate(requests):
|
||||
assert request.py_stop_words_list is not None
|
||||
swl, ends = request.py_stop_words_list
|
||||
if -1 in ends:
|
||||
ends = ends[: ends.index(-1)]
|
||||
lens = np.diff(ends, prepend=0)
|
||||
max_len = np.max(lens)
|
||||
|
||||
words = torch.zeros(len(lens), max_len, dtype=torch.int32, pin_memory=True)
|
||||
words = torch.zeros(len(lens), max_len.item(), dtype=torch.int32, pin_memory=True)
|
||||
for step, (start, length) in enumerate(zip([0] + ends, lens)):
|
||||
words[step, :length] = torch.tensor(swl[start : start + length], dtype=torch.int32)
|
||||
words_device = words.to("cuda", non_blocking=True)
|
||||
@ -3119,6 +3136,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
).scatter_(dim=0, index=padded_indices_cuda, src=sampled_rank_cuda)
|
||||
|
||||
if self._use_beam_search:
|
||||
assert self.store.sampled_log_prob_indices is not None
|
||||
local_group_req_indices_with_beam_search = torch.tensor(
|
||||
[
|
||||
req_id
|
||||
@ -3186,8 +3204,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
logits_cuda = self._apply_min_length_penalty(
|
||||
logits_cuda,
|
||||
requests,
|
||||
sampling_requests_metadata.req_num_steps,
|
||||
sampling_requests_metadata.req_num_beams,
|
||||
sampling_requests_metadata.req_num_steps.tolist(),
|
||||
sampling_requests_metadata.req_num_beams.tolist(),
|
||||
)
|
||||
|
||||
# Fast path for greedy sampling
|
||||
@ -3303,7 +3321,7 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
|
||||
class SampleStateTRTLLM(SampleState):
|
||||
finalize_events: dict[str, CudaEvent] | None = None
|
||||
"""`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`"""
|
||||
host: Optional[SampleStateTensorsHostTRTLLM] = None
|
||||
host: Optional[SampleStateTensorsHostTRTLLM] = None # type: ignore
|
||||
|
||||
|
||||
class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
@ -3337,7 +3355,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
self.logits_datatype = DataType.FLOAT
|
||||
self.decoding_mode = decoding_mode
|
||||
self.decoding_config = decoding_config if decoding_config else DecodingConfig(decoding_mode)
|
||||
max_attn_window = kv_cache_config.max_attention_window
|
||||
max_attn_window = kv_cache_config.max_attention_window # type: ignore
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_attention_window = (
|
||||
max(max_attn_window) if max_attn_window is not None else max_seq_len
|
||||
@ -3416,8 +3434,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
|
||||
def _instantiate_algorithms(self):
|
||||
self.algs = Algorithms()
|
||||
self.algs.decoder = GptDecoderBatched(stream=self.store["torch_stream"])
|
||||
self.algs.decoder.setup(
|
||||
self.algs.decoder = GptDecoderBatched(stream=self.store["torch_stream"]) # type: ignore
|
||||
self.algs.decoder.setup( # type: ignore
|
||||
mode=self.decoding_mode,
|
||||
max_num_sequences=self.max_num_sequences,
|
||||
max_beam_width=self.max_beam_width,
|
||||
@ -3425,18 +3443,18 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
model_config=self.model_config,
|
||||
world_config=self.world_config,
|
||||
)
|
||||
self.algs.create_new_decoder_requests = CreateNewDecoderRequests(
|
||||
self.algs.create_new_decoder_requests = CreateNewDecoderRequests( # type: ignore
|
||||
speculative_decoding_fast_logits=False,
|
||||
is_leader_in_orch_mode=False,
|
||||
is_normalize_log_probs=False,
|
||||
)
|
||||
self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput()
|
||||
self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput() # type: ignore
|
||||
|
||||
@torch.inference_mode()
|
||||
@nvtx_range("setup_sampler_step")
|
||||
def setup_sampler_step(self, requests):
|
||||
def setup_sampler_step(self, requests): # type: ignore
|
||||
batch_slots, sampling_configs, lookahead_prompt, lookahead_algo_configs = (
|
||||
self.algs.create_new_decoder_requests(
|
||||
self.algs.create_new_decoder_requests( # type: ignore
|
||||
self.model_config,
|
||||
self.world_config,
|
||||
self.decoding_config,
|
||||
@ -3445,7 +3463,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
self.store["decoder_input_buffers"][self.micro_batch_idx],
|
||||
self.store["decoder_state"],
|
||||
self.store["cuda_stream"],
|
||||
self.algs.decoder.decoder_stream,
|
||||
self.algs.decoder.decoder_stream, # type: ignore
|
||||
self.max_seq_len,
|
||||
self.beam_width(requests.context_requests),
|
||||
)
|
||||
@ -3454,7 +3472,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
local_batch_size = len(batch_slots)
|
||||
if local_batch_size > 0:
|
||||
sampling_config = make_sampling_config(sampling_configs)
|
||||
self.algs.decoder.underlying_decoder().setup(
|
||||
self.algs.decoder.underlying_decoder().setup( # type: ignore
|
||||
sampling_config,
|
||||
local_batch_size,
|
||||
batch_slots,
|
||||
@ -3468,9 +3486,9 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
batch_size = len(adp)
|
||||
if batch_size == 0:
|
||||
return
|
||||
config = make_sampling_config([r.sampling_config for r in adp])
|
||||
config = make_sampling_config([r.sampling_config for r in adp]) # type: ignore
|
||||
slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32)
|
||||
self.algs.decoder.underlying_decoder().setup(config, batch_size, slots)
|
||||
self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) # type: ignore
|
||||
|
||||
def get_cache_indirection(self) -> torch.Tensor | None:
|
||||
return self.store["decoder_state"].cache_indirection_output
|
||||
@ -3519,7 +3537,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
self.store["buffer_manager"],
|
||||
)
|
||||
|
||||
self.algs.decoder.forward_async(
|
||||
self.algs.decoder.forward_async( # type: ignore
|
||||
self.store["decoder_state"],
|
||||
self.store["decoder_input_buffers"][self.micro_batch_idx],
|
||||
)
|
||||
@ -3574,7 +3592,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
|
||||
@torch.inference_mode()
|
||||
@override
|
||||
def update_requests(
|
||||
def update_requests( # type: ignore
|
||||
self,
|
||||
state: SampleStateTRTLLM,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
@ -3598,8 +3616,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
@nvtx_range("update_requests_single_beam_single_step")
|
||||
def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
|
||||
"""Specialization of update_requests for single beam and single step"""
|
||||
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
|
||||
finish_reasons = state.host.finish_reasons.flatten().tolist()
|
||||
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore
|
||||
finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore
|
||||
|
||||
reqs = [
|
||||
r for r in state.scheduled_requests.context_requests if not r.is_context_init_state
|
||||
@ -3618,7 +3636,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
seq_slots = []
|
||||
seq_slots_need_log_probs = []
|
||||
for request in reqs:
|
||||
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0):
|
||||
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0): # type: ignore
|
||||
continue
|
||||
|
||||
reqs_with_new_tokens.append(request)
|
||||
@ -3628,19 +3646,19 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
seq_slots_need_log_probs.append(request.py_seq_slot)
|
||||
|
||||
# [maxTokensPerStep, batchSize, maxBeamWidth]
|
||||
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist()
|
||||
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist() # type: ignore
|
||||
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)
|
||||
|
||||
# Log probs
|
||||
if state.host.log_probs is not None:
|
||||
if state.host.log_probs is not None: # type: ignore
|
||||
# [batchSize, maxBeamWidth]
|
||||
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1
|
||||
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1 # type: ignore
|
||||
# [batchSize, maxBeamWidth, maxSequenceLength]
|
||||
log_probs_host = state.host.log_probs[
|
||||
log_probs_host = state.host.log_probs[ # type: ignore
|
||||
seq_slots_need_log_probs, 0, seq_last_idx
|
||||
].tolist()
|
||||
# [batchSize, maxBeamWidth]
|
||||
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist()
|
||||
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist() # type: ignore
|
||||
|
||||
log_probs_idx = 0
|
||||
for request, new_token in zip(reqs_with_new_tokens, new_tokens):
|
||||
@ -3659,7 +3677,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
|
||||
for request in reqs:
|
||||
request.py_decoding_iter += 1
|
||||
finished_state = FinishedState(finish_reasons[request.py_seq_slot])
|
||||
finished_state = FinishedState(finish_reasons[request.py_seq_slot]) # type: ignore
|
||||
if finished_state.is_finished:
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
finish_reason = finished_state.to_finish_reason()
|
||||
@ -3672,14 +3690,14 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
state: SampleStateTRTLLM,
|
||||
beam_width: int,
|
||||
):
|
||||
new_tokens_host = state.host.new_tokens.tolist()
|
||||
finished_sum_host = state.host.finished_sum.tolist()
|
||||
finish_reasons = state.host.finish_reasons.flatten().tolist()
|
||||
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
|
||||
new_tokens_host = state.host.new_tokens.tolist() # type: ignore
|
||||
finished_sum_host = state.host.finished_sum.tolist() # type: ignore
|
||||
finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore
|
||||
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore
|
||||
cum_log_probs_host = (
|
||||
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
|
||||
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None # type: ignore
|
||||
)
|
||||
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None
|
||||
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None # type: ignore
|
||||
finalize_events = state.finalize_events
|
||||
|
||||
reqs = [
|
||||
@ -3700,7 +3718,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
cum_log_probs = []
|
||||
|
||||
for beam_idx in range(beam_width):
|
||||
seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx]
|
||||
seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx] # type: ignore
|
||||
num_new_tokens[beam_idx] = min(
|
||||
num_generated_tokens, seq_len - request.get_num_tokens(beam_idx)
|
||||
)
|
||||
@ -3709,7 +3727,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
new_token = add_token(request, new_tokens_host, beam_idx=beam_idx, step=step)
|
||||
|
||||
if request.py_return_log_probs:
|
||||
assert state.host.log_probs is not None
|
||||
assert state.host.log_probs is not None # type: ignore
|
||||
# NOTE: Log probs with drafting has not been tested yet.
|
||||
begin_log_probs_offset = (
|
||||
request.prompt_len if request.sampling_config.beam_width == 1 else 0
|
||||
@ -3720,7 +3738,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
log_probs[beam_idx].append(
|
||||
{
|
||||
new_token: Logprob(
|
||||
logprob=log_probs_host[seq_slot][beam_idx][
|
||||
logprob=log_probs_host[seq_slot][beam_idx][ # type: ignore
|
||||
begin_log_probs_offset + current_token
|
||||
],
|
||||
rank=1,
|
||||
@ -3729,9 +3747,9 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
)
|
||||
|
||||
if request.py_return_log_probs:
|
||||
cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx])
|
||||
cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx]) # type: ignore
|
||||
|
||||
finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx])
|
||||
finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx]) # type: ignore
|
||||
if finished_state.is_finished:
|
||||
finish_reason = finished_state.to_finish_reason()
|
||||
request.set_finished_reason(finish_reason, beam_idx)
|
||||
@ -3748,7 +3766,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
request.py_decoding_iter += 1
|
||||
|
||||
if finished_sum_host[seq_slot] == beam_width:
|
||||
if finished_sum_host[seq_slot] == beam_width: # type: ignore
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
for request in reqs:
|
||||
if finalize_events is not None and request.request_id in finalize_events:
|
||||
@ -3761,7 +3779,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
):
|
||||
"""Finalizes the request. This is necessary for beam search."""
|
||||
seq_slot = request.py_seq_slot
|
||||
event = self.algs.decoder.finalize(
|
||||
event = self.algs.decoder.finalize( # type: ignore
|
||||
self.store["decoder_state"], seq_slot, request.sampling_config, streaming
|
||||
)
|
||||
return event
|
||||
@ -3775,15 +3793,15 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
beam_width = request.sampling_config.beam_width
|
||||
# synchronize on the finalize event before continuing the post processing.
|
||||
# should be unnecessary, as already wait for the sampler event in update_requests
|
||||
state.finalize_events[request.request_id].synchronize()
|
||||
state.finalize_events[request.request_id].synchronize() # type: ignore
|
||||
|
||||
# Get these values again, as they might have changed during the finalize step
|
||||
output_ids_host = state.host.gathered_ids
|
||||
sequence_lengths_host = state.host.sequence_lengths
|
||||
output_ids_host = state.host.gathered_ids # type: ignore
|
||||
sequence_lengths_host = state.host.sequence_lengths # type: ignore
|
||||
|
||||
if request.py_return_log_probs:
|
||||
log_probs_host = state.host.log_probs
|
||||
cum_log_probs_host = state.host.cum_log_probs
|
||||
log_probs_host = state.host.log_probs # type: ignore
|
||||
cum_log_probs_host = state.host.cum_log_probs # type: ignore
|
||||
|
||||
generated_tokens = [[0]] * beam_width
|
||||
log_probs = [[] for _ in range(beam_width)]
|
||||
@ -3796,11 +3814,11 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
sequence_lengths_host[seq_slot, beam_idx].item() - request.py_prompt_len
|
||||
)
|
||||
end = begin + generated_length
|
||||
generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist()
|
||||
generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist() # type: ignore
|
||||
|
||||
# get the correct log probs for beam search
|
||||
if request.py_return_log_probs:
|
||||
cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item())
|
||||
cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item()) # type: ignore
|
||||
|
||||
begin_log_probs_offset = (
|
||||
request.prompt_len if request.sampling_config.beam_width == 1 else 0
|
||||
@ -3809,7 +3827,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
|
||||
log_probs[beam_idx].append(
|
||||
{
|
||||
token: Logprob(
|
||||
logprob=log_probs_host[seq_slot, beam_idx][
|
||||
logprob=log_probs_host[seq_slot, beam_idx][ # type: ignore
|
||||
begin_log_probs_offset + current_token
|
||||
].item(),
|
||||
rank=1,
|
||||
|
||||
@ -20,6 +20,7 @@ referring to types like LlmRequest.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Literal, Optional, Type, TypeAlias, TypeVar, cast
|
||||
|
||||
@ -289,9 +290,8 @@ def beam_search_sampling_batch(
|
||||
beam_width_out: int,
|
||||
beam_search_args: BeamSearchMetadata,
|
||||
temperature: float | None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_probs: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Sample <beam_width> tokens for each request in parallel.
|
||||
"""
|
||||
@ -518,19 +518,18 @@ def sample(
|
||||
beam_width_out=beam_width_out,
|
||||
beam_search_args=group_metadata,
|
||||
temperature=temperature,
|
||||
generator=generator,
|
||||
return_probs=return_probs,
|
||||
)
|
||||
return tokens, softmax, temperature
|
||||
|
||||
|
||||
GenericStrategyKeyType = TypeVar("GenericStrategyKeyType")
|
||||
GenericStrategyKeyType = TypeVar("GenericStrategyKeyType", bound=Hashable)
|
||||
|
||||
|
||||
class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> GenericStrategyKeyType:
|
||||
def strategy_grouping_key(strategy: Strategy) -> GenericStrategyKeyType:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -552,6 +551,13 @@ class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
|
||||
return_probs: bool,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]:
|
||||
"""Sample grouped strategies.
|
||||
|
||||
Returns:
|
||||
- Sampled tokens
|
||||
- Processed probs (whenever return_probs=True)
|
||||
- Temperature (used to compute processed _log_ probs)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -560,7 +566,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE:
|
||||
def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE:
|
||||
return strategy
|
||||
|
||||
@override
|
||||
@ -585,7 +591,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
|
||||
generator: torch.Generator | None = None,
|
||||
return_probs: bool,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, float | None]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]:
|
||||
if group_key[0] == "beam_search":
|
||||
beam_width_in = group_key[1]
|
||||
else:
|
||||
|
||||
@ -20,7 +20,7 @@ referring to types like LlmRequest.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
from typing import Optional, Type, TypeAlias, cast
|
||||
from typing import Literal, Optional, Type, TypeAlias, cast
|
||||
|
||||
import flashinfer.sampling
|
||||
import torch
|
||||
@ -61,6 +61,9 @@ class _StrategyImpls:
|
||||
def computes_probs(cls) -> bool:
|
||||
pass
|
||||
|
||||
def get_temperature(self) -> torch.Tensor | None:
|
||||
return getattr(self, "_temperature", None)
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(
|
||||
self,
|
||||
@ -177,8 +180,8 @@ class _StrategyImpls:
|
||||
class BeamSearchMixin(StrategyImpl):
|
||||
def __init__(
|
||||
self,
|
||||
beam_width_in: torch.Tensor,
|
||||
beam_width_out: torch.Tensor,
|
||||
beam_width_in: int,
|
||||
beam_width_out: int,
|
||||
temperature: torch.Tensor,
|
||||
):
|
||||
self._beam_width_in = beam_width_in
|
||||
@ -192,12 +195,8 @@ class _StrategyImpls:
|
||||
) -> "_StrategyImpls.BeamSearchMixin":
|
||||
assert all(strat[0] == "beam_search" for strat in strategies)
|
||||
narrowed_strats = cast(list[BeamSearch], strategies)
|
||||
beam_width_in = cls._make_tensor(
|
||||
[strat[1] for strat in narrowed_strats], torch.int32, cuda_device
|
||||
)
|
||||
beam_width_out = cls._make_tensor(
|
||||
[strat[2] for strat in narrowed_strats], torch.int32, cuda_device
|
||||
)
|
||||
(beam_width_in,) = set(strat[1] for strat in narrowed_strats)
|
||||
(beam_width_out,) = set(strat[2] for strat in narrowed_strats)
|
||||
temperature = cls._make_tensor(
|
||||
[strat[3] or 1.0 for strat in narrowed_strats], torch.float32, cuda_device
|
||||
)
|
||||
@ -215,22 +214,15 @@ class _StrategyImpls:
|
||||
assert group_metadata is not None and isinstance(group_metadata, BeamSearchMetadata), (
|
||||
"BeamSearchMetadata is required for beam_search_sampling_batch"
|
||||
)
|
||||
assert torch.unique(self._beam_width_in).numel() == 1, (
|
||||
"beam_width_in must be the same for all strategies"
|
||||
)
|
||||
assert torch.unique(self._beam_width_out).numel() == 1, (
|
||||
"beam_width_out must be the same for all strategies"
|
||||
)
|
||||
logits = self._prepare_logits_with_temperature(
|
||||
logits, group_logit_indices, self._temperature
|
||||
)
|
||||
# Convert from 1 temperature per request to 1 temperature per (request, beam)
|
||||
temperature = self._temperature.repeat_interleave(self._beam_width_in)
|
||||
logits = self._prepare_logits_with_temperature(logits, group_logit_indices, temperature)
|
||||
return beam_search_sampling_batch(
|
||||
logits,
|
||||
beam_width_in=self._beam_width_in[0],
|
||||
beam_width_out=self._beam_width_out[0],
|
||||
beam_width_in=self._beam_width_in,
|
||||
beam_width_out=self._beam_width_out,
|
||||
beam_search_args=group_metadata,
|
||||
temperature=None,
|
||||
generator=generator,
|
||||
return_probs=self.computes_probs(),
|
||||
)
|
||||
|
||||
@ -648,84 +640,52 @@ class _StrategyImpls:
|
||||
pass
|
||||
|
||||
|
||||
def _create_beam_search_specialized_cls(
|
||||
beam_width_in: torch.Tensor,
|
||||
beam_width_out: torch.Tensor,
|
||||
return_probs: bool,
|
||||
) -> Type[_StrategyImpls.BeamSearchMixin]:
|
||||
"""Create a class that implements BeamSearchMixin with static parameters for grouping."""
|
||||
|
||||
class BeamSearchSpecialized(
|
||||
_StrategyImpls.BeamSearchWithProbs if return_probs else _StrategyImpls.BeamSearchSampleOnly
|
||||
):
|
||||
static_beam_width_in = beam_width_in
|
||||
static_beam_width_out = beam_width_out
|
||||
|
||||
@override
|
||||
def __hash__(self) -> int:
|
||||
return hash((super(), self.static_beam_width_in, self.static_beam_width_out))
|
||||
|
||||
@override
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return (
|
||||
super().__eq__(other)
|
||||
and self.static_beam_width_in == other.static_beam_width_in
|
||||
and self.static_beam_width_out == other.static_beam_width_out
|
||||
)
|
||||
|
||||
return BeamSearchSpecialized
|
||||
_STRATEGY_KEY_TYPE: TypeAlias = (
|
||||
Literal["temperature"]
|
||||
| Literal["top_k"]
|
||||
| Literal["top_p"]
|
||||
| Literal["top_k_top_p"]
|
||||
| Literal["greedy"]
|
||||
| tuple[Literal["beam_search"], int, int]
|
||||
)
|
||||
|
||||
|
||||
class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpls.StrategyImpl]]):
|
||||
class FlashInferGroupedStrategySampler(GroupedStrategySampler[_STRATEGY_KEY_TYPE]):
|
||||
"""Implements batched sampling with FlashInfer.sampling kernels.
|
||||
|
||||
Note: Currently, FlashInfer.sampling appears to have limited CUDA graph
|
||||
support, see https://github.com/flashinfer-ai/flashinfer/issues/978.
|
||||
"""
|
||||
|
||||
STRATEGY_KEY_TYPE: TypeAlias = Type[_StrategyImpls.StrategyImpl]
|
||||
STRATEGY_KEY_TYPE: TypeAlias = _STRATEGY_KEY_TYPE
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE:
|
||||
if return_probs:
|
||||
match strategy:
|
||||
case ("top_k", _, _):
|
||||
return _StrategyImpls.TopKWithProbs
|
||||
case ("top_p", _, _):
|
||||
return _StrategyImpls.TopPWithProbs
|
||||
case ("top_k_top_p", _, _, _):
|
||||
return _StrategyImpls.TopKTopPWithProbs
|
||||
case ("temperature", _):
|
||||
return _StrategyImpls.TemperatureOnlyWithProbs
|
||||
case ("greedy", None):
|
||||
return _StrategyImpls.GreedyWithProbs
|
||||
case ("beam_search", beam_width_in, beam_width_out, _):
|
||||
return _create_beam_search_specialized_cls(beam_width_in, beam_width_out, True)
|
||||
else:
|
||||
match strategy:
|
||||
case ("top_p", _, _):
|
||||
return _StrategyImpls.TopPSampleOnly
|
||||
case ("top_k", _, _):
|
||||
return _StrategyImpls.TopKSampleOnly
|
||||
case ("top_k_top_p", _, _, _):
|
||||
return _StrategyImpls.TopKTopPSampleOnly
|
||||
case ("temperature", _):
|
||||
return _StrategyImpls.TemperatureOnlySampleOnly
|
||||
case ("greedy", None):
|
||||
return _StrategyImpls.GreedySampleOnly
|
||||
case ("beam_search", beam_width_in, beam_width_out, _):
|
||||
return _create_beam_search_specialized_cls(beam_width_in, beam_width_out, False)
|
||||
def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE:
|
||||
match strategy:
|
||||
case (
|
||||
("top_k", _, _)
|
||||
| ("top_p", _, _)
|
||||
| ("top_k_top_p", _, _, _)
|
||||
| ("temperature", _)
|
||||
| ("greedy", None)
|
||||
):
|
||||
return strategy[0]
|
||||
case ("beam_search", beam_width_in, beam_width_out, _):
|
||||
return (strategy[0], beam_width_in, beam_width_out)
|
||||
case _:
|
||||
raise NotImplementedError("Unsupported strategy encountered")
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def get_metadata_type_for_group(
|
||||
strategy_key: STRATEGY_KEY_TYPE,
|
||||
) -> Type[StrategyMetadata] | None:
|
||||
if issubclass(strategy_key, _StrategyImpls.BeamSearchMixin):
|
||||
return BeamSearchMetadata
|
||||
else:
|
||||
return None
|
||||
match strategy_key:
|
||||
case ("beam_search", _, _):
|
||||
return BeamSearchMetadata
|
||||
case _:
|
||||
return None
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@ -739,28 +699,52 @@ class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpl
|
||||
return_probs: bool,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if hasattr(group_key, "static_beam_width_in"):
|
||||
beam_width_in = group_key.static_beam_width_in
|
||||
beam_width_in = 1
|
||||
if return_probs:
|
||||
match group_key:
|
||||
case "top_k":
|
||||
strategy_impl_cls = _StrategyImpls.TopKWithProbs
|
||||
case "top_p":
|
||||
strategy_impl_cls = _StrategyImpls.TopPWithProbs
|
||||
case "top_k_top_p":
|
||||
strategy_impl_cls = _StrategyImpls.TopKTopPWithProbs
|
||||
case "temperature":
|
||||
strategy_impl_cls = _StrategyImpls.TemperatureOnlyWithProbs
|
||||
case "greedy":
|
||||
strategy_impl_cls = _StrategyImpls.GreedyWithProbs
|
||||
case ("beam_search", beam_width_in_key, _):
|
||||
beam_width_in = beam_width_in_key
|
||||
strategy_impl_cls = _StrategyImpls.BeamSearchWithProbs
|
||||
case _:
|
||||
raise NotImplementedError("Unsupported strategy key encountered")
|
||||
else:
|
||||
beam_width_in = 1
|
||||
match group_key:
|
||||
case "top_p":
|
||||
strategy_impl_cls = _StrategyImpls.TopPSampleOnly
|
||||
case "top_k":
|
||||
strategy_impl_cls = _StrategyImpls.TopKSampleOnly
|
||||
case "top_k_top_p":
|
||||
strategy_impl_cls = _StrategyImpls.TopKTopPSampleOnly
|
||||
case "temperature":
|
||||
strategy_impl_cls = _StrategyImpls.TemperatureOnlySampleOnly
|
||||
case "greedy":
|
||||
strategy_impl_cls = _StrategyImpls.GreedySampleOnly
|
||||
case ("beam_search", beam_width_in_key, _):
|
||||
beam_width_in = beam_width_in_key
|
||||
strategy_impl_cls = _StrategyImpls.BeamSearchSampleOnly
|
||||
case _:
|
||||
raise NotImplementedError("Unsupported strategy key encountered")
|
||||
|
||||
if group_logit_indices is None:
|
||||
assert logits.size(0) == beam_width_in * len(strategies)
|
||||
else:
|
||||
assert group_logit_indices.size(0) == beam_width_in * len(strategies)
|
||||
|
||||
assert return_probs == group_key.computes_probs()
|
||||
|
||||
strategy_impl_cls = group_key
|
||||
sampling_object = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device)
|
||||
next_tokens, softmax = sampling_object.sample(
|
||||
strategy_impl = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device)
|
||||
next_tokens, softmax = strategy_impl.sample(
|
||||
logits,
|
||||
group_logit_indices=group_logit_indices,
|
||||
generator=generator,
|
||||
group_metadata=group_metadata,
|
||||
)
|
||||
temperature = (
|
||||
sampling_object._temperature.unsqueeze(-1)
|
||||
if sampling_object._temperature is not None
|
||||
else None
|
||||
)
|
||||
return next_tokens, softmax, temperature
|
||||
return next_tokens, softmax, strategy_impl.get_temperature()
|
||||
|
||||
@ -436,7 +436,6 @@ def test_beam_search_sampling_batch_basic():
|
||||
beam_width_out=beam_width,
|
||||
beam_search_args=beam_search_args,
|
||||
temperature=temperature,
|
||||
generator=None,
|
||||
return_probs=True,
|
||||
)
|
||||
|
||||
|
||||
@ -35,7 +35,6 @@ import torch
|
||||
from scipy.stats import power_divergence
|
||||
from utils.util import assert_no_cuda_sync, force_ampere
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor import sampling_utils_flashinfer
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import convert_wordlist
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import (
|
||||
GREEDY,
|
||||
@ -1563,11 +1562,10 @@ class TestBatchedSampling:
|
||||
return_probs: bool,
|
||||
group_metadata: StrategyMetadata | None = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert issubclass(group_key, sampling_utils_flashinfer._StrategyImpls.StrategyImpl)
|
||||
assert generator is sampler.get_generator(logits.device)
|
||||
nonlocal flashinfer_keys_seen
|
||||
assert group_key not in flashinfer_keys_seen
|
||||
flashinfer_keys_seen.add(group_key)
|
||||
assert (group_key, return_probs) not in flashinfer_keys_seen
|
||||
flashinfer_keys_seen.add((group_key, return_probs))
|
||||
return sample_grouped_strategies_orig(
|
||||
group_key,
|
||||
strategies,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user