diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 876858bfec..ac29abe979 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -21,7 +21,7 @@ from concurrent import futures from dataclasses import dataclass from functools import cached_property from itertools import repeat -from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeAlias, TypeVar, cast import numpy as np import torch @@ -264,42 +264,11 @@ class RequestGroupValue: need_processed_logprobs: torch.Tensor need_raw_logprobs: torch.Tensor - def __iter__(self): - return iter( - ( - self.indices, - self.strategies, - self.speculation_needs_probs_indices, - self.need_processed_logprobs, - self.need_raw_logprobs, - ) - ) - - def __len__(self): - return 5 - @dataclass(kw_only=True, frozen=True) class RequestGroupValueWithMetadata(RequestGroupValue): metadata: StrategyMetadata | None - @override - def __iter__(self): - return iter( - ( - self.indices, - self.strategies, - self.speculation_needs_probs_indices, - self.need_processed_logprobs, - self.need_raw_logprobs, - self.metadata, - ) - ) - - @override - def __len__(self): - return 6 - class EarlyStopWithMMResult(Sampler): """ @@ -307,7 +276,7 @@ class EarlyStopWithMMResult(Sampler): """ @override - def sample_async( + def sample_async( # type: ignore self, scheduled_requests: ScheduledRequests, model_outputs, @@ -322,7 +291,7 @@ class EarlyStopWithMMResult(Sampler): return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data) @override - def update_requests( + def update_requests( # type: ignore self, state: SampleStateWithMMResult, resource_manager: Optional[ResourceManager] = None, @@ -341,9 +310,9 @@ class EarlyStopWithMMResult(Sampler): request.state = LlmRequestState.GENERATION_COMPLETE # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) - if len(mm_embedding) != sum(request.multimodal_lengths): + if len(mm_embedding) != sum(request.multimodal_lengths): # type: ignore raise ValueError( - f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}" + f"mm_embedding shape mismatch: {len(mm_embedding)} != {sum(request.multimodal_lengths)}" # type: ignore ) request.py_result.append_mm_embeddings(mm_embedding) @@ -385,7 +354,7 @@ def _get_max_beam_width(request: LlmRequest) -> int: sampling_config = request.sampling_config max_beam_width = sampling_config.beam_width if sampling_config.beam_width_array is not None: - max_beam_width = max(max_beam_width, sampling_config.beam_width_array.max()) + max_beam_width = max(max_beam_width, sampling_config.beam_width_array.max()) # type: ignore return max_beam_width @@ -416,23 +385,24 @@ def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy: # We try to cache the resolved strategy on the request object, as it's not cheap enough to # resolve it on every iteration. if hasattr(request, "py_sampling_strategy"): - return request.py_sampling_strategy + return request.py_sampling_strategy # type: ignore params = _request_get_sampling_params(request) sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) if _request_sampling_params_cachable(params): - request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) + request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) # type: ignore return sampling_strategy def _group_requests_by_strategy_key( requests: Iterable[LlmRequest], *, - strategy_to_key: Callable[[Strategy, bool], GenericStrategyKeyType], + strategy_to_key: Callable[[Strategy], GenericStrategyKeyType], pin_memory: bool = False, vocab_size: int, ) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue]: - # NB: Client code relies on request indices in returned torch.Tensor being sorted. + # NB: Client code relies on request indices in returned torch.Tensor being ordered + # by ascending "requests" index. RequestGroupValueBuilder = namedtuple( "RequestGroupValueBuilder", [ @@ -444,8 +414,8 @@ def _group_requests_by_strategy_key( ], ) - group_dict: dict[RequestGroupKey, RequestGroupValueBuilder] = defaultdict( - lambda: RequestGroupValueBuilder([], [], [], [], []) + group_dict: dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValueBuilder] = ( + defaultdict(lambda: RequestGroupValueBuilder([], [], [], [], [])) ) for req_index, req in enumerate(requests): @@ -460,7 +430,7 @@ def _group_requests_by_strategy_key( ) need_raw_logprobs = req.py_logprobs_mode == LogprobMode.RAW and req.return_log_probs needs_probs = speculation_needs_probs or need_processed_logprobs - strategy_key = strategy_to_key(strategy, needs_probs) + strategy_key = strategy_to_key(strategy) group_dict_entry = group_dict[ RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs) ] @@ -768,7 +738,7 @@ DEFAULT_BEAM_IDX = 0 # Step index to use when no speculative decoding is used but a step index is required DEFAULT_STEP_IDX = 0 -FinishReasonsList = list[list[int]] +FinishReasonsList: TypeAlias = list[list[list[int]]] @dataclass(kw_only=True) @@ -797,7 +767,7 @@ class SamplingRequestsMetadata: @dataclass(kw_only=True) class SampleStateTensorsHostTorch(SampleStateTensors): finish_reasons: torch.Tensor - first_finish_reasons: torch.Tensor + first_finish_reasons: torch.Tensor | None logprobs_state: LogProbsState | None = None def finish_reasons_list(self) -> FinishReasonsList: @@ -808,7 +778,7 @@ class SampleStateTensorsHostTorch(SampleStateTensors): @dataclass(kw_only=True) class SampleStateTorch(SampleState): - host: SampleStateTensorsHostTorch + host: SampleStateTensorsHostTorch # type: ignore beam_histories: list[BeamHistory | None] | None = None @@ -827,7 +797,7 @@ class AsyncWorkerMixin: def _async_worker_init(self, enable_async_worker: bool): self._enable_async_worker = enable_async_worker self._async_worker = None - self._async_worker_futures: list[futures.Future[any]] = [] + self._async_worker_futures: list[futures.Future[Any]] = [] def async_worker_enabled(self): return getattr(self, "_enable_async_worker", False) @@ -853,6 +823,7 @@ class AsyncWorkerMixin: def async_worker_stop(self): assert self.async_worker_enabled() if self._async_worker_active(): + assert self._async_worker is not None self._async_worker.shutdown(wait=True) self._async_worker = None @@ -888,6 +859,7 @@ class AsyncWorkerMixin: copy_ready.record() # Submit the copy to the async worker thread + assert self._async_worker is not None result = self._async_worker.submit( self._async_copy_to_host, copy_ready, dest, src_snapshot ) @@ -980,8 +952,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin): # Tensors necessary for all sampling methods new_tokens = int_tensor(self.NEW_TOKENS_SHAPE) finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE) - max_lengths_tensor = int_tensor(self.max_num_sequences) - end_ids = int_tensor(self.max_num_sequences) + max_lengths_tensor = int_tensor((self.max_num_sequences,)) + end_ids = int_tensor((self.max_num_sequences,)) # Only used for logprobs processing or beam search sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32) @@ -1236,6 +1208,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): != FinishReason.NOT_FINISHED.value ).sum() == request.sampling_config.beam_width: request.state = LlmRequestState.GENERATION_COMPLETE + assert request.py_seq_slot is not None for beam_idx in range(request.sampling_config.beam_width): request.set_finished_reason( FinishReason( @@ -1255,6 +1228,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): When using streaming a requests tokens may be altered, leading to wrong results when called multiple times. Store the original tokens in a separate buffer to use them as a consistent basis when updating the tokens in a request.""" + assert self.store.original_tokens is not None assert new_tokens.device == self.store.original_tokens.device, ( "new_tokens and original_tokens must be on the same device" ) @@ -1311,6 +1285,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin): assert beam_idx == DEFAULT_BEAM_IDX, ( "beam search does not need to explicitly handle sampled log probs" ) + assert sampled_log_probs_indices_list is not None + assert sampled_log_probs_vals_list is not None + assert sampled_log_probs_rank_list is not None if sampled_log_probs_indices_list[step_idx] not in logprobs: logprobs[sampled_log_probs_indices_list[step_idx]] = Logprob( logprob=sampled_log_probs_vals_list[step_idx], @@ -1388,6 +1365,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): beam_width = request.sampling_config.beam_width assert request.py_num_logprobs is not None, "request.py_num_logprobs must be provided" assert logprobs_state_list is not None, "logprobs_state_list must be provided" + assert request.py_seq_slot is not None token_log_probs = self._store_logprobs_list_to_request( logprobs_state_list, request.py_seq_slot, beam_width, count, request.py_num_logprobs ) @@ -1396,6 +1374,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): def finish_if_reason( self, request: LlmRequest, finish_reasons: FinishReasonsList, *, step: int, beam_idx: int ) -> bool: + assert request.py_seq_slot is not None reason = FinishReason(finish_reasons[request.py_seq_slot][step][beam_idx]) valid_reasons = {FinishReason.END_ID, FinishReason.LENGTH, FinishReason.STOP_WORDS} if reason in valid_reasons: @@ -1548,6 +1527,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): ) ) + @override @override def setup_sampler_step(self, scheduled_requests: ScheduledRequests): """Setup the sampler step for the requests @@ -1563,6 +1543,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): end_ids: list[int] = [] for request in scheduled_requests.context_requests: if self._is_new_request(request): + assert request.py_seq_slot is not None seq_slots.append(request.py_seq_slot) max_lens.append( min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens) @@ -1593,6 +1574,13 @@ class TorchSampler(Sampler, AsyncWorkerMixin): if self._is_new_request(request): if request.py_return_log_probs and request.py_num_logprobs > 1: raise ValueError("Beam search does not support multiple logprobs") + assert self.store.cache_indirection is not None + assert self.store.cum_log_probs is not None + assert self.store.sampled_log_probs is not None + assert self.store.sampled_log_prob_ranks is not None + assert self.store.predecessor_beams is not None + assert self.store.first_finish_reasons is not None + assert self.store.original_tokens is not None self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0) self.store.cum_log_probs[request.py_seq_slot].fill_(0) self.store.sampled_log_probs[request.py_seq_slot].fill_(0) @@ -1765,9 +1753,11 @@ class TorchSampler(Sampler, AsyncWorkerMixin): ) if hasattr(request.py_result._log_probs, "log_probs"): logprobs_list = request.py_result.log_probs + assert logprobs_list is not None for beam_idx, beam_logprobs in enumerate(logprobs_list): for token_idx, token_logprobs in enumerate(beam_logprobs): for key, value in token_logprobs.items(): + assert value.rank is not None logprobs_tensor[beam_idx, token_idx, value.rank - 1] = value.logprob logprobs_indices_tensor[beam_idx, token_idx, value.rank - 1] = key return logprobs_tensor, logprobs_indices_tensor @@ -1795,6 +1785,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin): if num_generated_tokens == 0 or request.state == LlmRequestState.GENERATION_COMPLETE: # early return if no tokens have been generated yet or the request is already finished return None + assert self.store.cache_indirection is not None + assert self.store.original_tokens is not None + assert self.store.sampled_log_probs is not None cache_indirection = self.store.cache_indirection[ request.py_seq_slot, :num_beams, prompt_length:num_tokens ] @@ -1802,7 +1795,13 @@ class TorchSampler(Sampler, AsyncWorkerMixin): request.py_seq_slot, :num_beams, prompt_length:num_tokens ] new_path = torch.zeros_like(current_path) + + # initialize each beam with its own index + + # Gather the correct tokens and logprobs for each beam + torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path) if request.py_return_log_probs: + assert self.store.cum_log_probs is not None current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(request) # concatenate the newly generated logprobs and newly # generated tokens to the current logprobs and logprobs indices @@ -1823,11 +1822,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): # Initialize the buffers to store the results new_logprobs = torch.zeros_like(current_logprobs) new_logprobs_indices = torch.zeros_like(current_logprobs_indices) - # initialize each beam with its own index - # Gather the correct tokens and logprobs for each beam - torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path) - if request.py_return_log_probs: cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand( -1, -1, current_logprobs.shape[2] ) @@ -1877,6 +1872,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin): {beam_history.tokens.shape[0]} != {beam_width}" ) if request.py_return_log_probs: + assert beam_history.logprobs is not None + assert beam_history.logprobs_indices is not None + assert beam_history.cum_logprobs is not None assert beam_history.logprobs.shape[0] == beam_width, ( f"Beam_history.logprobs.shape[0] should equal beam width: \ {beam_history.logprobs.shape[0]} != {beam_width}" @@ -1895,6 +1893,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin): for beam_idx in range(beam_width): gen_token_list.append(beam_history.tokens[beam_idx, : valid_tokens[beam_idx]].tolist()) if request.py_return_log_probs: + assert beam_history.logprobs_indices is not None + assert beam_history.logprobs is not None gen_log_probs_list.append( self._convert_logprobs_tensor_to_list( beam_history.logprobs_indices[ @@ -1910,6 +1910,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): if request.py_return_log_probs: # cum_log_probs will not change when padding with end tokens. # Therefore, we do not need to correct it + assert beam_history.cum_logprobs is not None request.py_result.set_log_probs( gen_log_probs_list, cum_log_probs=beam_history.cum_logprobs.tolist() ) @@ -1920,7 +1921,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin): grouped_requests: dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue], seq_slots: torch.Tensor, seq_lens: torch.Tensor | None, - get_metadata_type_for_group_fn: Callable[[GenericStrategyKeyType], Type[StrategyMetadata]], + get_metadata_type_for_group_fn: Callable[ + [GenericStrategyKeyType], Type[StrategyMetadata] | None + ], ) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValueWithMetadata]: grouped_requests_with_metadata: dict[ RequestGroupKey[GenericStrategyKeyType], RequestGroupValueWithMetadata @@ -1929,6 +1932,12 @@ class TorchSampler(Sampler, AsyncWorkerMixin): metadata_type = get_metadata_type_for_group_fn(key.strategy_key) if metadata_type is BeamSearchMetadata: assert seq_lens is not None, "seq_lens is required for beam search" + assert self.store.cache_indirection is not None + assert self.store.cache_indirection_buffer is not None + assert self.store.cum_log_probs is not None + assert self.store.sampled_log_probs is not None + assert self.store.first_finish_reasons is not None + assert self.store.predecessor_beams is not None metadata = BeamSearchMetadata( cache_indirection=self.store.cache_indirection, cache_indirection_buffer=self.store.cache_indirection_buffer, @@ -1982,7 +1991,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): _, cumsum = request.py_stop_words_list if -1 in cumsum: cumsum = cumsum[: cumsum.index(-1)] - longest_stop_word_len = np.max(np.diff(cumsum, prepend=0), initial=0) + longest_stop_word_len = np.max(np.diff(cumsum, prepend=0), initial=0).item() return longest_stop_word_len > 1 return False @@ -2010,9 +2019,10 @@ class TorchSampler(Sampler, AsyncWorkerMixin): @torch.inference_mode() def update_requests( self, - state: SampleStateTorch, + state: Sampler.SampleState, resource_manager: Optional[ResourceManager] = None, ) -> None: + state = cast(SampleStateTorch, state) assert isinstance(state, SampleStateTorch) if state.sampler_event: state.sampler_event.synchronize() @@ -2036,7 +2046,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): if beam_histories is not None and beam_histories[req_idx] is not None: self._finalize_beam( req, - beam_histories[req_idx], + cast(BeamHistory, beam_histories[req_idx]), ) else: for beam_idx in range(req.sampling_config.beam_width): @@ -2055,7 +2065,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): if beam_histories is not None and beam_histories[req_idx] is not None: self._finalize_beam( req, - beam_histories[req_idx], + cast(BeamHistory, beam_histories[req_idx]), ) else: for beam_idx in range(req.sampling_config.beam_width): @@ -2100,6 +2110,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin): self.max_tokens, self.max_topk_logprobs, ) + assert self.store.topk_vals is not None + assert self.store.topk_indices is not None self.store.topk_vals.resize_(self.TOPK_LOGPROBS_SHAPE) self.store.topk_indices.resize_(self.TOPK_LOGPROBS_SHAPE) @@ -2155,8 +2167,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin): ) finish_reasons_host = self._copy_to_host(finish_reasons) - beam_histories = [None] * len(requests) + beam_histories: list[BeamHistory | None] = [None] * len(requests) if self._use_beam_search: + assert first_finish_reasons is not None assert seq_lens_host is not None, "seq_lens is required for beam search" assert self.store.first_finish_reasons is not None, ( "first_finish_reasons must be provided" @@ -2167,6 +2180,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin): self._maybe_create_beam_histories( requests, finish_reasons=first_finish_reasons, beam_histories=beam_histories ) + else: + first_finish_reasons_host = None # copy logprobs to host logprobs_state: LogProbsState | None = None @@ -2204,9 +2219,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): host=SampleStateTensorsHostTorch( new_tokens=new_tokens_host, finish_reasons=finish_reasons_host, - first_finish_reasons=None - if not self._use_beam_search - else first_finish_reasons_host, + first_finish_reasons=first_finish_reasons_host, logprobs_state=logprobs_state, ), sampler_event=sampler_event, @@ -2393,14 +2406,19 @@ class TorchSampler(Sampler, AsyncWorkerMixin): ) batch_req_idx_offset_start = 0 batch_next_tokens_offset_start = 0 - for (strategy_key, needs_probs), ( - group_req_indices, - group_strategies, - group_speculation_needs_probs_indices, - group_need_processed_logprobs, - group_need_raw_logprobs, - group_metadata, - ) in grouped_requests_with_metadata.items(): + for ( + strategy_key, + needs_probs, + ), group_val_with_metadata in grouped_requests_with_metadata.items(): + group_req_indices = group_val_with_metadata.indices + group_strategies = group_val_with_metadata.strategies + group_speculation_needs_probs_indices = ( + group_val_with_metadata.speculation_needs_probs_indices + ) + group_need_processed_logprobs = group_val_with_metadata.need_processed_logprobs + group_need_raw_logprobs = group_val_with_metadata.need_raw_logprobs + group_metadata = group_val_with_metadata.metadata + # group_req_indices: Indices of 'requests' entries having the same sampling # strategy, ordered ascending. batch_req_idx_offset_end = batch_req_idx_offset_start + group_req_indices.size(0) @@ -2420,22 +2438,19 @@ class TorchSampler(Sampler, AsyncWorkerMixin): # indices for accessing logits within the current group group_logit_indexer = _PackedStepIndexer( num_steps=req_num_generated_tokens[group_req_indices], - max_steps=req_num_generated_tokens.max() * self.max_beam_width, + max_steps=cast( + int, req_num_generated_tokens.max().item() * self.max_beam_width + ), ) - logit_indices_for_processed_logprobs_cuda = ( - None - if not any_request_needs_processed_logprobs - else group_logit_indexer[need_processed_logprobs_indices].to( - logits_cuda.device, non_blocking=True - ) - ) - logit_indices_for_raw_logprobs_cuda = ( - None - if not any_request_needs_raw_logprobs - else group_logit_indexer[need_raw_logprobs_indices].to( - logits_cuda.device, non_blocking=True - ) - ) + logit_indices_for_processed_logprobs_cuda = group_logit_indexer[ + need_processed_logprobs_indices + ].to(logits_cuda.device, non_blocking=True) + logit_indices_for_raw_logprobs_cuda = group_logit_indexer[ + need_raw_logprobs_indices + ].to(logits_cuda.device, non_blocking=True) + else: + logit_indices_for_processed_logprobs_cuda = None + logit_indices_for_raw_logprobs_cuda = None group_logits_cuda_indices = logits_cuda_indexer[group_req_indices] # NB: Assuming that group_req_indices are sorted @@ -2518,13 +2533,13 @@ class TorchSampler(Sampler, AsyncWorkerMixin): processed_logits_cuda = torch.where( current_softmax_cuda > 0, current_logits_cuda, float("-inf") ) - if group_temperature_cuda is not None: - if isinstance(group_temperature_cuda, torch.Tensor): - processed_logits_cuda /= group_temperature_cuda[ - logit_indices_for_processed_logprobs_cuda - ] - else: - processed_logits_cuda /= group_temperature_cuda + temperature_for_processed_logprobs = group_temperature_cuda + if isinstance(temperature_for_processed_logprobs, torch.Tensor): + temperature_for_processed_logprobs = cast(torch.Tensor, group_temperature_cuda)[ + logit_indices_for_processed_logprobs_cuda + ].unsqueeze(-1) + if temperature_for_processed_logprobs is not None: + processed_logits_cuda /= temperature_for_processed_logprobs logit_indices_for_processed_logprobs_cuda += batch_next_tokens_offset_start batch_logits_for_logprobs_cuda[logit_indices_for_processed_logprobs_cuda] = ( processed_logits_cuda @@ -2768,6 +2783,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int: max_stop_word_len = 0 for req in requests: + assert req.py_stop_words_list is not None _, cumsum = req.py_stop_words_list if -1 in cumsum: cumsum = cumsum[: cumsum.index(-1)] @@ -2971,13 +2987,14 @@ class TorchSampler(Sampler, AsyncWorkerMixin): padded_tokens = self._padded_old_tokens(requests, tokens, predecessor_beams) for request_idx, request in enumerate(requests): + assert request.py_stop_words_list is not None swl, ends = request.py_stop_words_list if -1 in ends: ends = ends[: ends.index(-1)] lens = np.diff(ends, prepend=0) max_len = np.max(lens) - words = torch.zeros(len(lens), max_len, dtype=torch.int32, pin_memory=True) + words = torch.zeros(len(lens), max_len.item(), dtype=torch.int32, pin_memory=True) for step, (start, length) in enumerate(zip([0] + ends, lens)): words[step, :length] = torch.tensor(swl[start : start + length], dtype=torch.int32) words_device = words.to("cuda", non_blocking=True) @@ -3119,6 +3136,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): ).scatter_(dim=0, index=padded_indices_cuda, src=sampled_rank_cuda) if self._use_beam_search: + assert self.store.sampled_log_prob_indices is not None local_group_req_indices_with_beam_search = torch.tensor( [ req_id @@ -3186,8 +3204,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin): logits_cuda = self._apply_min_length_penalty( logits_cuda, requests, - sampling_requests_metadata.req_num_steps, - sampling_requests_metadata.req_num_beams, + sampling_requests_metadata.req_num_steps.tolist(), + sampling_requests_metadata.req_num_beams.tolist(), ) # Fast path for greedy sampling @@ -3303,7 +3321,7 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors): class SampleStateTRTLLM(SampleState): finalize_events: dict[str, CudaEvent] | None = None """`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`""" - host: Optional[SampleStateTensorsHostTRTLLM] = None + host: Optional[SampleStateTensorsHostTRTLLM] = None # type: ignore class TRTLLMSampler(Sampler, AsyncWorkerMixin): @@ -3337,7 +3355,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): self.logits_datatype = DataType.FLOAT self.decoding_mode = decoding_mode self.decoding_config = decoding_config if decoding_config else DecodingConfig(decoding_mode) - max_attn_window = kv_cache_config.max_attention_window + max_attn_window = kv_cache_config.max_attention_window # type: ignore self.max_seq_len = max_seq_len self.max_attention_window = ( max(max_attn_window) if max_attn_window is not None else max_seq_len @@ -3416,8 +3434,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): def _instantiate_algorithms(self): self.algs = Algorithms() - self.algs.decoder = GptDecoderBatched(stream=self.store["torch_stream"]) - self.algs.decoder.setup( + self.algs.decoder = GptDecoderBatched(stream=self.store["torch_stream"]) # type: ignore + self.algs.decoder.setup( # type: ignore mode=self.decoding_mode, max_num_sequences=self.max_num_sequences, max_beam_width=self.max_beam_width, @@ -3425,18 +3443,18 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): model_config=self.model_config, world_config=self.world_config, ) - self.algs.create_new_decoder_requests = CreateNewDecoderRequests( + self.algs.create_new_decoder_requests = CreateNewDecoderRequests( # type: ignore speculative_decoding_fast_logits=False, is_leader_in_orch_mode=False, is_normalize_log_probs=False, ) - self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput() + self.algs.make_decoding_batch_input_output = MakeDecodingBatchInputOutput() # type: ignore @torch.inference_mode() @nvtx_range("setup_sampler_step") - def setup_sampler_step(self, requests): + def setup_sampler_step(self, requests): # type: ignore batch_slots, sampling_configs, lookahead_prompt, lookahead_algo_configs = ( - self.algs.create_new_decoder_requests( + self.algs.create_new_decoder_requests( # type: ignore self.model_config, self.world_config, self.decoding_config, @@ -3445,7 +3463,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): self.store["decoder_input_buffers"][self.micro_batch_idx], self.store["decoder_state"], self.store["cuda_stream"], - self.algs.decoder.decoder_stream, + self.algs.decoder.decoder_stream, # type: ignore self.max_seq_len, self.beam_width(requests.context_requests), ) @@ -3454,7 +3472,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): local_batch_size = len(batch_slots) if local_batch_size > 0: sampling_config = make_sampling_config(sampling_configs) - self.algs.decoder.underlying_decoder().setup( + self.algs.decoder.underlying_decoder().setup( # type: ignore sampling_config, local_batch_size, batch_slots, @@ -3468,9 +3486,9 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): batch_size = len(adp) if batch_size == 0: return - config = make_sampling_config([r.sampling_config for r in adp]) + config = make_sampling_config([r.sampling_config for r in adp]) # type: ignore slots = torch.tensor([r.py_seq_slot for r in adp], dtype=torch.int32) - self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) + self.algs.decoder.underlying_decoder().setup(config, batch_size, slots) # type: ignore def get_cache_indirection(self) -> torch.Tensor | None: return self.store["decoder_state"].cache_indirection_output @@ -3519,7 +3537,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): self.store["buffer_manager"], ) - self.algs.decoder.forward_async( + self.algs.decoder.forward_async( # type: ignore self.store["decoder_state"], self.store["decoder_input_buffers"][self.micro_batch_idx], ) @@ -3574,7 +3592,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): @torch.inference_mode() @override - def update_requests( + def update_requests( # type: ignore self, state: SampleStateTRTLLM, resource_manager: Optional[ResourceManager] = None, @@ -3598,8 +3616,8 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): @nvtx_range("update_requests_single_beam_single_step") def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM): """Specialization of update_requests for single beam and single step""" - sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() - finish_reasons = state.host.finish_reasons.flatten().tolist() + sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore + finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore reqs = [ r for r in state.scheduled_requests.context_requests if not r.is_context_init_state @@ -3618,7 +3636,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): seq_slots = [] seq_slots_need_log_probs = [] for request in reqs: - if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0): + if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0): # type: ignore continue reqs_with_new_tokens.append(request) @@ -3628,19 +3646,19 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): seq_slots_need_log_probs.append(request.py_seq_slot) # [maxTokensPerStep, batchSize, maxBeamWidth] - new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist() + new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist() # type: ignore add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0) # Log probs - if state.host.log_probs is not None: + if state.host.log_probs is not None: # type: ignore # [batchSize, maxBeamWidth] - seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1 + seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1 # type: ignore # [batchSize, maxBeamWidth, maxSequenceLength] - log_probs_host = state.host.log_probs[ + log_probs_host = state.host.log_probs[ # type: ignore seq_slots_need_log_probs, 0, seq_last_idx ].tolist() # [batchSize, maxBeamWidth] - cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist() + cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist() # type: ignore log_probs_idx = 0 for request, new_token in zip(reqs_with_new_tokens, new_tokens): @@ -3659,7 +3677,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): for request in reqs: request.py_decoding_iter += 1 - finished_state = FinishedState(finish_reasons[request.py_seq_slot]) + finished_state = FinishedState(finish_reasons[request.py_seq_slot]) # type: ignore if finished_state.is_finished: request.state = LlmRequestState.GENERATION_COMPLETE finish_reason = finished_state.to_finish_reason() @@ -3672,14 +3690,14 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): state: SampleStateTRTLLM, beam_width: int, ): - new_tokens_host = state.host.new_tokens.tolist() - finished_sum_host = state.host.finished_sum.tolist() - finish_reasons = state.host.finish_reasons.flatten().tolist() - sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() + new_tokens_host = state.host.new_tokens.tolist() # type: ignore + finished_sum_host = state.host.finished_sum.tolist() # type: ignore + finish_reasons = state.host.finish_reasons.flatten().tolist() # type: ignore + sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() # type: ignore cum_log_probs_host = ( - state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None + state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None # type: ignore ) - log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None + log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None # type: ignore finalize_events = state.finalize_events reqs = [ @@ -3700,7 +3718,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): cum_log_probs = [] for beam_idx in range(beam_width): - seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx] + seq_len = sequence_lengths_host_data[seq_slot * beam_width + beam_idx] # type: ignore num_new_tokens[beam_idx] = min( num_generated_tokens, seq_len - request.get_num_tokens(beam_idx) ) @@ -3709,7 +3727,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): new_token = add_token(request, new_tokens_host, beam_idx=beam_idx, step=step) if request.py_return_log_probs: - assert state.host.log_probs is not None + assert state.host.log_probs is not None # type: ignore # NOTE: Log probs with drafting has not been tested yet. begin_log_probs_offset = ( request.prompt_len if request.sampling_config.beam_width == 1 else 0 @@ -3720,7 +3738,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): log_probs[beam_idx].append( { new_token: Logprob( - logprob=log_probs_host[seq_slot][beam_idx][ + logprob=log_probs_host[seq_slot][beam_idx][ # type: ignore begin_log_probs_offset + current_token ], rank=1, @@ -3729,9 +3747,9 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): ) if request.py_return_log_probs: - cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx]) + cum_log_probs.append(cum_log_probs_host[seq_slot][beam_idx]) # type: ignore - finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx]) + finished_state = FinishedState(finish_reasons[seq_slot * beam_width + beam_idx]) # type: ignore if finished_state.is_finished: finish_reason = finished_state.to_finish_reason() request.set_finished_reason(finish_reason, beam_idx) @@ -3748,7 +3766,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): if request.state != LlmRequestState.GENERATION_COMPLETE: request.py_decoding_iter += 1 - if finished_sum_host[seq_slot] == beam_width: + if finished_sum_host[seq_slot] == beam_width: # type: ignore request.state = LlmRequestState.GENERATION_COMPLETE for request in reqs: if finalize_events is not None and request.request_id in finalize_events: @@ -3761,7 +3779,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): ): """Finalizes the request. This is necessary for beam search.""" seq_slot = request.py_seq_slot - event = self.algs.decoder.finalize( + event = self.algs.decoder.finalize( # type: ignore self.store["decoder_state"], seq_slot, request.sampling_config, streaming ) return event @@ -3775,15 +3793,15 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): beam_width = request.sampling_config.beam_width # synchronize on the finalize event before continuing the post processing. # should be unnecessary, as already wait for the sampler event in update_requests - state.finalize_events[request.request_id].synchronize() + state.finalize_events[request.request_id].synchronize() # type: ignore # Get these values again, as they might have changed during the finalize step - output_ids_host = state.host.gathered_ids - sequence_lengths_host = state.host.sequence_lengths + output_ids_host = state.host.gathered_ids # type: ignore + sequence_lengths_host = state.host.sequence_lengths # type: ignore if request.py_return_log_probs: - log_probs_host = state.host.log_probs - cum_log_probs_host = state.host.cum_log_probs + log_probs_host = state.host.log_probs # type: ignore + cum_log_probs_host = state.host.cum_log_probs # type: ignore generated_tokens = [[0]] * beam_width log_probs = [[] for _ in range(beam_width)] @@ -3796,11 +3814,11 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): sequence_lengths_host[seq_slot, beam_idx].item() - request.py_prompt_len ) end = begin + generated_length - generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist() + generated_tokens[beam_idx] = output_ids_host[seq_slot, beam_idx, begin:end].tolist() # type: ignore # get the correct log probs for beam search if request.py_return_log_probs: - cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item()) + cum_log_probs.append(cum_log_probs_host[seq_slot, beam_idx].item()) # type: ignore begin_log_probs_offset = ( request.prompt_len if request.sampling_config.beam_width == 1 else 0 @@ -3809,7 +3827,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin): log_probs[beam_idx].append( { token: Logprob( - logprob=log_probs_host[seq_slot, beam_idx][ + logprob=log_probs_host[seq_slot, beam_idx][ # type: ignore begin_log_probs_offset + current_token ].item(), rank=1, diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index f2156c07aa..66f04c3bef 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -20,6 +20,7 @@ referring to types like LlmRequest. import abc import sys +from collections.abc import Hashable from dataclasses import dataclass from typing import Generic, Literal, Optional, Type, TypeAlias, TypeVar, cast @@ -289,9 +290,8 @@ def beam_search_sampling_batch( beam_width_out: int, beam_search_args: BeamSearchMetadata, temperature: float | None, - generator: Optional[torch.Generator] = None, return_probs: bool = True, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Sample tokens for each request in parallel. """ @@ -518,19 +518,18 @@ def sample( beam_width_out=beam_width_out, beam_search_args=group_metadata, temperature=temperature, - generator=generator, return_probs=return_probs, ) return tokens, softmax, temperature -GenericStrategyKeyType = TypeVar("GenericStrategyKeyType") +GenericStrategyKeyType = TypeVar("GenericStrategyKeyType", bound=Hashable) class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC): @staticmethod @abc.abstractmethod - def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> GenericStrategyKeyType: + def strategy_grouping_key(strategy: Strategy) -> GenericStrategyKeyType: raise NotImplementedError @staticmethod @@ -552,6 +551,13 @@ class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC): return_probs: bool, group_metadata: StrategyMetadata | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]: + """Sample grouped strategies. + + Returns: + - Sampled tokens + - Processed probs (whenever return_probs=True) + - Temperature (used to compute processed _log_ probs) + """ raise NotImplementedError @@ -560,7 +566,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]): @override @staticmethod - def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE: + def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE: return strategy @override @@ -585,7 +591,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]): generator: torch.Generator | None = None, return_probs: bool, group_metadata: StrategyMetadata | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None, float | None]: + ) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]: if group_key[0] == "beam_search": beam_width_in = group_key[1] else: diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py index 75d5edb40e..77cdcf0a38 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py @@ -20,7 +20,7 @@ referring to types like LlmRequest. import abc import sys -from typing import Optional, Type, TypeAlias, cast +from typing import Literal, Optional, Type, TypeAlias, cast import flashinfer.sampling import torch @@ -61,6 +61,9 @@ class _StrategyImpls: def computes_probs(cls) -> bool: pass + def get_temperature(self) -> torch.Tensor | None: + return getattr(self, "_temperature", None) + @abc.abstractmethod def sample( self, @@ -177,8 +180,8 @@ class _StrategyImpls: class BeamSearchMixin(StrategyImpl): def __init__( self, - beam_width_in: torch.Tensor, - beam_width_out: torch.Tensor, + beam_width_in: int, + beam_width_out: int, temperature: torch.Tensor, ): self._beam_width_in = beam_width_in @@ -192,12 +195,8 @@ class _StrategyImpls: ) -> "_StrategyImpls.BeamSearchMixin": assert all(strat[0] == "beam_search" for strat in strategies) narrowed_strats = cast(list[BeamSearch], strategies) - beam_width_in = cls._make_tensor( - [strat[1] for strat in narrowed_strats], torch.int32, cuda_device - ) - beam_width_out = cls._make_tensor( - [strat[2] for strat in narrowed_strats], torch.int32, cuda_device - ) + (beam_width_in,) = set(strat[1] for strat in narrowed_strats) + (beam_width_out,) = set(strat[2] for strat in narrowed_strats) temperature = cls._make_tensor( [strat[3] or 1.0 for strat in narrowed_strats], torch.float32, cuda_device ) @@ -215,22 +214,15 @@ class _StrategyImpls: assert group_metadata is not None and isinstance(group_metadata, BeamSearchMetadata), ( "BeamSearchMetadata is required for beam_search_sampling_batch" ) - assert torch.unique(self._beam_width_in).numel() == 1, ( - "beam_width_in must be the same for all strategies" - ) - assert torch.unique(self._beam_width_out).numel() == 1, ( - "beam_width_out must be the same for all strategies" - ) - logits = self._prepare_logits_with_temperature( - logits, group_logit_indices, self._temperature - ) + # Convert from 1 temperature per request to 1 temperature per (request, beam) + temperature = self._temperature.repeat_interleave(self._beam_width_in) + logits = self._prepare_logits_with_temperature(logits, group_logit_indices, temperature) return beam_search_sampling_batch( logits, - beam_width_in=self._beam_width_in[0], - beam_width_out=self._beam_width_out[0], + beam_width_in=self._beam_width_in, + beam_width_out=self._beam_width_out, beam_search_args=group_metadata, temperature=None, - generator=generator, return_probs=self.computes_probs(), ) @@ -648,84 +640,52 @@ class _StrategyImpls: pass -def _create_beam_search_specialized_cls( - beam_width_in: torch.Tensor, - beam_width_out: torch.Tensor, - return_probs: bool, -) -> Type[_StrategyImpls.BeamSearchMixin]: - """Create a class that implements BeamSearchMixin with static parameters for grouping.""" - - class BeamSearchSpecialized( - _StrategyImpls.BeamSearchWithProbs if return_probs else _StrategyImpls.BeamSearchSampleOnly - ): - static_beam_width_in = beam_width_in - static_beam_width_out = beam_width_out - - @override - def __hash__(self) -> int: - return hash((super(), self.static_beam_width_in, self.static_beam_width_out)) - - @override - def __eq__(self, other: object) -> bool: - return ( - super().__eq__(other) - and self.static_beam_width_in == other.static_beam_width_in - and self.static_beam_width_out == other.static_beam_width_out - ) - - return BeamSearchSpecialized +_STRATEGY_KEY_TYPE: TypeAlias = ( + Literal["temperature"] + | Literal["top_k"] + | Literal["top_p"] + | Literal["top_k_top_p"] + | Literal["greedy"] + | tuple[Literal["beam_search"], int, int] +) -class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpls.StrategyImpl]]): +class FlashInferGroupedStrategySampler(GroupedStrategySampler[_STRATEGY_KEY_TYPE]): """Implements batched sampling with FlashInfer.sampling kernels. Note: Currently, FlashInfer.sampling appears to have limited CUDA graph support, see https://github.com/flashinfer-ai/flashinfer/issues/978. """ - STRATEGY_KEY_TYPE: TypeAlias = Type[_StrategyImpls.StrategyImpl] + STRATEGY_KEY_TYPE: TypeAlias = _STRATEGY_KEY_TYPE @override @staticmethod - def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE: - if return_probs: - match strategy: - case ("top_k", _, _): - return _StrategyImpls.TopKWithProbs - case ("top_p", _, _): - return _StrategyImpls.TopPWithProbs - case ("top_k_top_p", _, _, _): - return _StrategyImpls.TopKTopPWithProbs - case ("temperature", _): - return _StrategyImpls.TemperatureOnlyWithProbs - case ("greedy", None): - return _StrategyImpls.GreedyWithProbs - case ("beam_search", beam_width_in, beam_width_out, _): - return _create_beam_search_specialized_cls(beam_width_in, beam_width_out, True) - else: - match strategy: - case ("top_p", _, _): - return _StrategyImpls.TopPSampleOnly - case ("top_k", _, _): - return _StrategyImpls.TopKSampleOnly - case ("top_k_top_p", _, _, _): - return _StrategyImpls.TopKTopPSampleOnly - case ("temperature", _): - return _StrategyImpls.TemperatureOnlySampleOnly - case ("greedy", None): - return _StrategyImpls.GreedySampleOnly - case ("beam_search", beam_width_in, beam_width_out, _): - return _create_beam_search_specialized_cls(beam_width_in, beam_width_out, False) + def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE: + match strategy: + case ( + ("top_k", _, _) + | ("top_p", _, _) + | ("top_k_top_p", _, _, _) + | ("temperature", _) + | ("greedy", None) + ): + return strategy[0] + case ("beam_search", beam_width_in, beam_width_out, _): + return (strategy[0], beam_width_in, beam_width_out) + case _: + raise NotImplementedError("Unsupported strategy encountered") @override @staticmethod def get_metadata_type_for_group( strategy_key: STRATEGY_KEY_TYPE, ) -> Type[StrategyMetadata] | None: - if issubclass(strategy_key, _StrategyImpls.BeamSearchMixin): - return BeamSearchMetadata - else: - return None + match strategy_key: + case ("beam_search", _, _): + return BeamSearchMetadata + case _: + return None @override @staticmethod @@ -739,28 +699,52 @@ class FlashInferGroupedStrategySampler(GroupedStrategySampler[Type[_StrategyImpl return_probs: bool, group_metadata: StrategyMetadata | None = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - if hasattr(group_key, "static_beam_width_in"): - beam_width_in = group_key.static_beam_width_in + beam_width_in = 1 + if return_probs: + match group_key: + case "top_k": + strategy_impl_cls = _StrategyImpls.TopKWithProbs + case "top_p": + strategy_impl_cls = _StrategyImpls.TopPWithProbs + case "top_k_top_p": + strategy_impl_cls = _StrategyImpls.TopKTopPWithProbs + case "temperature": + strategy_impl_cls = _StrategyImpls.TemperatureOnlyWithProbs + case "greedy": + strategy_impl_cls = _StrategyImpls.GreedyWithProbs + case ("beam_search", beam_width_in_key, _): + beam_width_in = beam_width_in_key + strategy_impl_cls = _StrategyImpls.BeamSearchWithProbs + case _: + raise NotImplementedError("Unsupported strategy key encountered") else: - beam_width_in = 1 + match group_key: + case "top_p": + strategy_impl_cls = _StrategyImpls.TopPSampleOnly + case "top_k": + strategy_impl_cls = _StrategyImpls.TopKSampleOnly + case "top_k_top_p": + strategy_impl_cls = _StrategyImpls.TopKTopPSampleOnly + case "temperature": + strategy_impl_cls = _StrategyImpls.TemperatureOnlySampleOnly + case "greedy": + strategy_impl_cls = _StrategyImpls.GreedySampleOnly + case ("beam_search", beam_width_in_key, _): + beam_width_in = beam_width_in_key + strategy_impl_cls = _StrategyImpls.BeamSearchSampleOnly + case _: + raise NotImplementedError("Unsupported strategy key encountered") + if group_logit_indices is None: assert logits.size(0) == beam_width_in * len(strategies) else: assert group_logit_indices.size(0) == beam_width_in * len(strategies) - assert return_probs == group_key.computes_probs() - - strategy_impl_cls = group_key - sampling_object = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device) - next_tokens, softmax = sampling_object.sample( + strategy_impl = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device) + next_tokens, softmax = strategy_impl.sample( logits, group_logit_indices=group_logit_indices, generator=generator, group_metadata=group_metadata, ) - temperature = ( - sampling_object._temperature.unsqueeze(-1) - if sampling_object._temperature is not None - else None - ) - return next_tokens, softmax, temperature + return next_tokens, softmax, strategy_impl.get_temperature() diff --git a/tests/unittest/_torch/sampler/test_beam_search.py b/tests/unittest/_torch/sampler/test_beam_search.py index f4b1b09da3..1489b3c564 100644 --- a/tests/unittest/_torch/sampler/test_beam_search.py +++ b/tests/unittest/_torch/sampler/test_beam_search.py @@ -436,7 +436,6 @@ def test_beam_search_sampling_batch_basic(): beam_width_out=beam_width, beam_search_args=beam_search_args, temperature=temperature, - generator=None, return_probs=True, ) diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index 03fd014323..61c79e77a6 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -35,7 +35,6 @@ import torch from scipy.stats import power_divergence from utils.util import assert_no_cuda_sync, force_ampere -from tensorrt_llm._torch.pyexecutor import sampling_utils_flashinfer from tensorrt_llm._torch.pyexecutor.llm_request import convert_wordlist from tensorrt_llm._torch.pyexecutor.sampler import ( GREEDY, @@ -1563,11 +1562,10 @@ class TestBatchedSampling: return_probs: bool, group_metadata: StrategyMetadata | None = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert issubclass(group_key, sampling_utils_flashinfer._StrategyImpls.StrategyImpl) assert generator is sampler.get_generator(logits.device) nonlocal flashinfer_keys_seen - assert group_key not in flashinfer_keys_seen - flashinfer_keys_seen.add(group_key) + assert (group_key, return_probs) not in flashinfer_keys_seen + flashinfer_keys_seen.add((group_key, return_probs)) return sample_grouped_strategies_orig( group_key, strategies,