diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 5dfb1382a7..876858bfec 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -929,6 +929,12 @@ class TorchSampler(Sampler, AsyncWorkerMixin): class Store: new_tokens: torch.Tensor """Shape: See cpp DecoderState.getAllNewTokens()""" + max_lengths_tensor: torch.Tensor + """Shape: batch_size + Usage: Stores the maximum lengths for each request""" + end_ids: torch.Tensor + """Shape: batch_size + Usage: Stores the end ids for each request""" finish_reasons: torch.Tensor """Shape: max_tokens, batch_size, beam_width Usage: Stores the currently estimated finish_reasons for each request""" @@ -974,6 +980,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin): # Tensors necessary for all sampling methods new_tokens = int_tensor(self.NEW_TOKENS_SHAPE) finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE) + max_lengths_tensor = int_tensor(self.max_num_sequences) + end_ids = int_tensor(self.max_num_sequences) # Only used for logprobs processing or beam search sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32) @@ -1007,6 +1015,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin): return self.Store( new_tokens=new_tokens, finish_reasons=finish_reasons, + max_lengths_tensor=max_lengths_tensor, + end_ids=end_ids, cache_indirection=cache_indirection, cache_indirection_buffer=cache_indirection_buffer, cum_log_probs=cum_log_probs, @@ -1072,6 +1082,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin): FinishReason.CANCELLED, ] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable` } + self._max_tokens_offset = torch.arange( + 1, self.max_tokens + 1, device="cuda", dtype=torch.int32 + ).view(-1, 1, 1) self._grouped_sampler_cls: Type[GroupedStrategySampler] if IS_FLASHINFER_AVAILABLE and not args.disable_flashinfer_sampling: @@ -1525,14 +1538,47 @@ class TorchSampler(Sampler, AsyncWorkerMixin): return num_accepted_draft_tokens - 1 - def setup_sampler_step(self, requests: ScheduledRequests): + def _is_new_request(self, request: LlmRequest) -> bool: + return ( + not request.is_finished + and not request.py_is_draft + and ( + (request.is_context_init_state and request.is_last_context_chunk) + or request.is_disagg_generation_transmission_complete + ) + ) + + @override + def setup_sampler_step(self, scheduled_requests: ScheduledRequests): """Setup the sampler step for the requests Args: requests: list[LlmRequest]. The requests to setup the sampler step for """ if self._use_beam_search: - self._prepare_beam_search(requests.all_requests()) + self._prepare_beam_search(scheduled_requests.all_requests()) + + seq_slots: list[int] = [] + max_lens: list[int] = [] + end_ids: list[int] = [] + for request in scheduled_requests.context_requests: + if self._is_new_request(request): + seq_slots.append(request.py_seq_slot) + max_lens.append( + min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens) + ) + end_ids.append(request.py_end_id if request.py_end_id is not None else -1) + if len(seq_slots) > 0: + full_list = [seq_slots, max_lens, end_ids] + # perform only a single copy + full_list_tensor = torch.tensor(full_list, device="cpu", dtype=torch.int32).to( + device="cuda", non_blocking=True + ) + seq_slots_tensor = full_list_tensor[0] + max_lens_tensor = full_list_tensor[1] + end_ids_tensor = full_list_tensor[2] + self.store.max_lengths_tensor[seq_slots_tensor] = max_lens_tensor + self.store.end_ids[seq_slots_tensor] = end_ids_tensor def _prepare_beam_search( self, @@ -1544,10 +1590,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): initialize/reset the buffers for the request """ for request in requests: - if not request.is_finished and ( - (request.is_context_init_state and request.is_last_context_chunk) - or request.is_disagg_generation_transmission_complete - ): + if self._is_new_request(request): if request.py_return_log_probs and request.py_num_logprobs > 1: raise ValueError("Beam search does not support multiple logprobs") self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0) @@ -2083,13 +2126,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin): dtype=torch.int64, # for index_fill_ pin_memory=True, ) - # necessary for beam search - seq_lens_host = ( - torch.tensor( - [r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True - ) - if self._use_beam_search - else None + # necessary for beam search and max_length checks + seq_lens_host = torch.tensor( + [r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True ) new_tokens_host = self._process_requests( scheduled_requests, @@ -2102,12 +2141,14 @@ class TorchSampler(Sampler, AsyncWorkerMixin): finish_reasons = self.store.finish_reasons seq_slots = seq_slots_host.to(device="cuda", non_blocking=True) + seq_lens = seq_lens_host.to(device="cuda", non_blocking=True) first_finish_reasons = self.store.first_finish_reasons if self._use_beam_search else None self._write_finish_reasons( requests, finish_reasons=finish_reasons, seq_slots=seq_slots, + seq_lens=seq_lens, new_tokens=new_tokens, first_finish_reasons=first_finish_reasons, predecessor_beams=self.store.predecessor_beams, @@ -2760,6 +2801,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin): *, finish_reasons: torch.Tensor, seq_slots: torch.Tensor, + seq_lens: torch.Tensor, new_tokens: torch.Tensor, first_finish_reasons: torch.Tensor | None = None, predecessor_beams: torch.Tensor | None = None, @@ -2775,7 +2817,11 @@ class TorchSampler(Sampler, AsyncWorkerMixin): new_tokens: a buffer containing the newly generated tokens. Shape: (max_tokens, max_batch_size, max_beam_width) """ - tokens = new_tokens[:, seq_slots.to(device=new_tokens.device, non_blocking=True)] + + # Seq Slots should be on the same device as new_tokens + assert seq_slots.device == new_tokens.device + assert seq_lens.device == new_tokens.device + tokens = new_tokens[:, seq_slots] # we need to fill with NOT_FINISHED so we can differentiate between previous requests that had the same seq slot finish_reasons.index_fill_(1, seq_slots, FinishReason.NOT_FINISHED.value) @@ -2801,12 +2847,12 @@ class TorchSampler(Sampler, AsyncWorkerMixin): ) batched_finish_reasons = torch.where( - self._are_max_length(requests), + self._are_max_length(seq_lens, self.store.max_lengths_tensor[seq_slots]), self._reason_tensors[FinishReason.LENGTH], batched_finish_reasons, ) batched_finish_reasons = torch.where( - self._are_end_id(requests, tokens), + self._are_end_id(self.store.end_ids[seq_slots], tokens), self._reason_tensors[FinishReason.END_ID], batched_finish_reasons, ) @@ -2822,57 +2868,26 @@ class TorchSampler(Sampler, AsyncWorkerMixin): ) first_finish_reasons[seq_slots] = batched_first_finish_reasons - def _are_end_id(self, requests: list[LlmRequest], tokens: torch.Tensor) -> torch.Tensor: - end_ids_tensor = ( - torch.tensor( - [ - ([req.py_end_id if req.py_end_id is not None else -1] * self.max_beam_width) - for req in requests - ] - * self.max_tokens, - pin_memory=True, - dtype=tokens.dtype, - ) - .view(self.max_tokens, len(requests), self.max_beam_width) - .to(device="cuda", non_blocking=True) - ) - return tokens == end_ids_tensor + def _are_end_id(self, end_ids: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: + return tokens == end_ids.view(1, -1, 1).expand(self.max_tokens, -1, self.max_beam_width) - def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor: + def _are_max_length(self, seq_lens: torch.Tensor, max_seq_lens: torch.Tensor) -> torch.Tensor: """Checks which sequences are at or beyond the max length Args: - requests: the requests to check the max length of - + seq_lens: the sequence lengths of the requests to check the max length of + max_seq_lens: the maximum sequence lengths of the requests to check the max length of Returns: A tensor of shape (max_tokens, len(requests), max_beam_width) where each element is True if the sequence is at or beyond the max length, False otherwise """ - lengths_tensor = torch.tensor( - [ - [ - [ - (req.get_num_tokens(beam_idx) + num_tokens) - for beam_idx in range(self.max_beam_width) - ] - for req in requests - ] - for num_tokens in range(1, self.max_tokens + 1) - ] + lengths_tensor = (seq_lens.view(1, -1, 1) + self._max_tokens_offset).expand( + self.max_tokens, -1, self.max_beam_width ) - max_lengths_tensor = torch.tensor( - [ - ( - [min(req.py_max_new_tokens + req.orig_prompt_len, self.max_seq_len)] - * self.max_beam_width - ) - for req in requests - ] - * self.max_tokens - ).view(self.max_tokens, len(requests), self.max_beam_width) - return ( - (lengths_tensor >= max_lengths_tensor).pin_memory().to(device="cuda", non_blocking=True) + max_lengths_tensor = max_seq_lens.view(1, -1, 1).expand( + self.max_tokens, -1, self.max_beam_width ) + return lengths_tensor >= max_lengths_tensor _PAD_ID = -1 """Pad with negative, doesn't matter what""" diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 845455d229..85ad553a50 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -224,11 +224,18 @@ class MTPSampler(TorchSampler): next_draft_tokens: torch.Tensor new_tokens_lens: torch.Tensor max_total_draft_tokens: torch.Tensor - finish_reasons: None = None # Necessary to satisfy the interface of TorchSampler.Store + # Necessary to satisfy the interface of TorchSampler.Store + finish_reasons: None = None + end_ids: None = None + max_lengths_tensor: None = None def __post_init__(self): pass # finish_reasons has no size to compare against new_tokens in MTPSampler + def setup_sampler_step(self, scheduled_requests: ScheduledRequests): + # MTPSampler does not need to setup additional buffers before the sampler step + pass + def __init__(self, args: TorchSampler.Args, *, nextn: int): self.mapping = None self.draft_len = nextn diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index dbfef3f6ad..03fd014323 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -701,16 +701,40 @@ class RequestCase: seq_slots = torch.tensor( [req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64 ) + seq_lens = torch.tensor( + [req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda" + ) new_tokens = torch.tensor( [req.new_tokens for req in requests], dtype=torch.int32, device="cuda" ).T sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens + max_seq_lens = torch.tensor( + [ + min( + sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens + ) + for req in requests + ], + dtype=torch.int32, + device="cuda", + ) + end_ids = torch.tensor( + [ + req.request.py_end_id if req.request.py_end_id is not None else -1 + for req in requests + ], + dtype=torch.int32, + device="cuda", + ) + sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens + sampler.store.end_ids[seq_slots] = end_ids def run(): sampler._write_finish_reasons( [req.request for req in requests], finish_reasons=sampler.store.finish_reasons, new_tokens=sampler.store.new_tokens, + seq_lens=seq_lens, seq_slots=seq_slots, )