From 87eb5086fb04f6da0330a6523df724c1da58c6db Mon Sep 17 00:00:00 2001 From: mpikulski <206748156+ixlmar@users.noreply.github.com> Date: Tue, 21 Oct 2025 04:34:57 +0200 Subject: [PATCH] [None][fix] restore list[list[list[int]]] in add_token (#8502) Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 63 +++++++++++++------ tensorrt_llm/_torch/speculative/mtp.py | 2 +- .../test_draft_token_tree_verification.py | 4 +- .../test_torch_rejection_sampling.py | 8 ++- 4 files changed, 54 insertions(+), 23 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 0c37b2856c..b52ca5f459 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -290,10 +290,13 @@ def _group_requests_by_strategy_key( } -def add_token(request: LlmRequest, new_tokens: torch.Tensor, *, beam: int, step: int = 0) -> int: +def add_token( + request: LlmRequest, new_tokens: list[list[list[int]]], *, beam: int, step: int = 0 +) -> int: + # NB: Accessing nested lists faster than torch.Tensor or numpy.ndarray seq_slot = request.py_seq_slot assert seq_slot is not None - new_token = cast(int, new_tokens[step][seq_slot][beam].item()) + new_token = new_tokens[step][seq_slot][beam] request.add_new_token(new_token, beam) return new_token @@ -700,7 +703,7 @@ class TorchSampler(Sampler): def _process_draft_tokens_greedy( self, request: LlmRequest, - new_tokens: torch.Tensor, + new_tokens: list[list[list[int]]], ) -> int: new_token = add_token(request, new_tokens, beam=self.BEAM) stop = self._handle_stop_criteria(request, new_token) @@ -722,7 +725,8 @@ class TorchSampler(Sampler): def _process_draft_tokens_tree( self, request: LlmRequest, - new_tokens: torch.Tensor, + new_tokens_tensor: torch.Tensor, + new_tokens_list: list[list[list[int]]], spec_tree_manager: SpecTreeManager, ) -> int: """Tree verification for draft token tree based speculative decoding. @@ -757,7 +761,7 @@ class TorchSampler(Sampler): # TODO: For the last layer of the dynamic tree, we need to resampling all the draft tokens. cur_layer_num_nodes = sum(spec_tree_manager.get_top_k_list(cur_draft_layer_idx)) for i in range(cur_layer_num_nodes): - new_token = add_token(request, new_tokens, beam=0, step=i) + new_token = add_token(request, new_tokens_list, beam=0, step=i) return 0 else: # handle the target model request @@ -767,7 +771,9 @@ class TorchSampler(Sampler): eagle_paths = spec_tree_manager.get_eagle_paths(seq_slot) all_draft_tokens = request.py_draft_tokens # [max_total_draft_tokens] - all_target_tokens = new_tokens[:, seq_slot, :].squeeze(-1) # [max_total_draft_tokens] + all_target_tokens = new_tokens_tensor[:, seq_slot, :].squeeze( + -1 + ) # [max_total_draft_tokens] assert all_target_tokens.shape[0] == spec_tree_manager.max_total_draft_tokens + 1 longest_accepted_len = 0 @@ -800,13 +806,15 @@ class TorchSampler(Sampler): if longest_accepted_len == 0: # No draft tokens are accepted. # Take the top-1 token of the first layer as the next new token. - new_token = add_token(request, new_tokens, beam=0, step=0) + new_token = add_token(request, new_tokens_list, beam=0, step=0) return 0 else: # Take the longest accepted path as the next new token. num_accepted_draft_tokens = 0 for idx in eagle_paths[longest_match_path_idx][:longest_accepted_len]: - new_token = add_token(request, new_tokens, beam=0, step=cast(int, idx.item())) + new_token = add_token( + request, new_tokens_list, beam=0, step=cast(int, idx.item()) + ) num_accepted_draft_tokens += 1 if self._handle_stop_criteria(request, new_token): break @@ -876,8 +884,10 @@ class TorchSampler(Sampler): def _process_draft_tokens_rejection_sampling( self, request: LlmRequest, - new_tokens: torch.Tensor, + new_tokens_list: list[list[list[int]]], + new_tokens_tensor: torch.Tensor, ) -> int: + assert request.py_draft_logits is not None # FIXME: Passing a dummy vocab_size could result in unnecessary # filtering of vocab_size logits, out of vocab_size in # total. The 'sample' below should generally be avoided @@ -893,7 +903,9 @@ class TorchSampler(Sampler): request.py_draft_logits, generator=generator, ) + assert draft_probs is not None target_probs = request.py_target_probs + assert target_probs is not None d2t = getattr(request, "d2t", None) if d2t is not None: vocab_d = draft_probs.shape[-1] @@ -927,7 +939,7 @@ class TorchSampler(Sampler): num_accepted = num_initially_accepted for i in range(num_accepted): new_token = request.py_draft_tokens[i] - new_tokens[i, request.seq_slot, self.BEAM] = new_token + new_tokens_tensor[i, request.seq_slot, self.BEAM] = new_token request.add_new_token(new_token, self.BEAM) stop = self._handle_stop_criteria(request, new_token) if stop: @@ -935,10 +947,10 @@ class TorchSampler(Sampler): return num_accepted if sample_last: new_token = sample_rejected(draft_probs, target_probs, generator, num_accepted) - new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token + new_tokens_tensor[num_accepted, request.seq_slot, self.BEAM] = new_token request.add_new_token(new_token, self.BEAM) else: - new_token = add_token(request, new_tokens, beam=self.BEAM, step=num_accepted) + new_token = add_token(request, new_tokens_list, beam=self.BEAM, step=num_accepted) stop = self._handle_stop_criteria(request, new_token) return num_accepted @@ -946,7 +958,8 @@ class TorchSampler(Sampler): def process_draft_tokens( self, request: LlmRequest, - new_tokens: torch.Tensor, + new_tokens_tensor: torch.Tensor, + new_tokens_list: list[list[list[int]]], resource_manager: Optional[ResourceManager] = None, ) -> int: if ( @@ -957,14 +970,19 @@ class TorchSampler(Sampler): if spec_tree_manager is not None: num_accepted = self._process_draft_tokens_tree( request, - new_tokens=new_tokens, + new_tokens_tensor=new_tokens_tensor, + new_tokens_list=new_tokens_list, spec_tree_manager=spec_tree_manager, ) else: - num_accepted = self._process_draft_tokens_greedy(request, new_tokens=new_tokens) + num_accepted = self._process_draft_tokens_greedy( + request, new_tokens=new_tokens_list + ) return num_accepted else: - return self._process_draft_tokens_rejection_sampling(request, new_tokens) + return self._process_draft_tokens_rejection_sampling( + request, new_tokens_list=new_tokens_list, new_tokens_tensor=new_tokens_tensor + ) @override def update_requests( @@ -976,7 +994,9 @@ class TorchSampler(Sampler): if state.sampler_event: state.sampler_event.synchronize() + assert state.host is not None new_tokens = state.host.new_tokens + new_tokens_list = new_tokens.tolist() for req in state.scheduled_requests.context_requests: if ( @@ -984,7 +1004,7 @@ class TorchSampler(Sampler): or req.context_remaining_length != 0 ): continue - new_token = add_token(req, new_tokens, beam=self.BEAM) + new_token = add_token(req, new_tokens_list, beam=self.BEAM) self._handle_stop_criteria(req, new_token) self.handle_logprobs(req, state, beam=self.BEAM, count=1) req.py_decoding_iter += 1 @@ -993,7 +1013,12 @@ class TorchSampler(Sampler): if req.state == LlmRequestState.GENERATION_COMPLETE: continue processed = 1 - num_accepted = self.process_draft_tokens(req, new_tokens, resource_manager) + num_accepted = self.process_draft_tokens( + req, + new_tokens_tensor=new_tokens, + new_tokens_list=new_tokens_list, + resource_manager=resource_manager, + ) if get_draft_token_length(req) > 0: req.py_num_accepted_draft_tokens = num_accepted req.py_rewind_len = req.py_draft_pages_allocated - num_accepted @@ -1911,7 +1936,7 @@ class TRTLLMSampler(Sampler): state: SampleStateTRTLLM, beam_width: int, ): - new_tokens_host = state.host.new_tokens + new_tokens_host = state.host.new_tokens.tolist() finished_sum_host = state.host.finished_sum.tolist() finish_reasons = state.host.finish_reasons.flatten().tolist() sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist() diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 991fb67275..29fbac2aaa 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -256,7 +256,7 @@ class MTPSampler(TorchSampler): assert isinstance(state, SampleStateMTP) state.sampler_event.synchronize() - new_tokens = state.host.new_tokens + new_tokens = state.host.new_tokens.tolist() new_tokens_lens_list = state.host.new_tokens_lens.tolist() next_draft_tokens_list = state.host.next_draft_tokens.tolist() beam_idx = self.BEAM diff --git a/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py b/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py index bfeab8c427..d84aeef548 100644 --- a/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py +++ b/tests/unittest/_torch/speculative/test_draft_token_tree_verification.py @@ -45,9 +45,11 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree, max_beam_width=beam_width, )) + input_new_tokens_list = input_new_tokens.tolist() num_accepted_draft_tokens = torch_sampler._process_draft_tokens_tree( request=input_request, - new_tokens=input_new_tokens, + new_tokens_tensor=input_new_tokens, + new_tokens_list=input_new_tokens_list, spec_tree_manager=spec_tree_manager) print(f"num_accepted_draft_tokens: {num_accepted_draft_tokens}") diff --git a/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py b/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py index 6891d3f9dd..e28d00dd69 100644 --- a/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py +++ b/tests/unittest/_torch/speculative/test_torch_rejection_sampling.py @@ -1,4 +1,5 @@ import unittest +from typing import cast import numpy as np import torch @@ -24,8 +25,11 @@ def test_get_rejected_indices(): sampled_regular = [] for _ in range(num_iter): draft_tokens = [ - torch.multinomial(draft_probs, num_samples=1, - generator=generator).item() + cast( + int, + torch.multinomial(draft_probs, + num_samples=1, + generator=generator).item()) ] rejected_indices = get_rejected_indices(draft_probs, target_probs, generator, draft_tokens)