From 9f1d9b7b185304ec309d99402c31ac18e718025a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20C=C3=A1mpora?= <961215+dcampora@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:40:08 +0200 Subject: [PATCH] [None][feat] Use list instead of torch tensor for new tokens in update requests (#7730) Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 22 ++++++++++------------ tensorrt_llm/_torch/speculative/mtp.py | 2 +- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 6e6f57bc21..6e58d5f80a 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -340,13 +340,13 @@ def sample(strategy: Strategy, def add_token(request: LlmRequest, - new_tokens: torch.Tensor, + new_tokens: list[list[list[int]]], *, beam: int, step: int = 0) -> int: seq_slot = request.py_seq_slot assert seq_slot is not None - new_token = int(new_tokens[step, seq_slot, beam]) + new_token = new_tokens[step][seq_slot][beam] request.add_new_token(new_token, beam) return new_token @@ -513,7 +513,7 @@ class TorchSampler(Sampler): for i in range(num_accepted): new_token = request.py_draft_tokens[i] new_tokens[i, request.seq_slot, self.BEAM] = new_token - request.add_new_token(new_token, self.BEAM) + new_token = add_token(request, new_tokens, beam=self.BEAM, step=i) stop = self._handle_stop_criteria(request, new_token) if stop: num_accepted = i + 1 @@ -522,14 +522,11 @@ class TorchSampler(Sampler): new_token = sample_rejected(draft_probs, target_probs, generator, num_accepted) new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token - request.add_new_token(new_token, self.BEAM) - stop = self._handle_stop_criteria(request, new_token) - else: - new_token = add_token(request, - new_tokens, - beam=self.BEAM, - step=num_accepted) - stop = self._handle_stop_criteria(request, new_token) + new_token = add_token(request, + new_tokens, + beam=self.BEAM, + step=num_accepted) + stop = self._handle_stop_criteria(request, new_token) return num_accepted @@ -545,7 +542,8 @@ class TorchSampler(Sampler): assert isinstance(state, SampleState) if state.sampler_event: state.sampler_event.synchronize() - new_tokens = state.host.new_tokens + + new_tokens = state.host.new_tokens.tolist() for req in state.scheduled_requests.context_requests: if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0: diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 05859cb277..1262b8d502 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -250,7 +250,7 @@ class MTPSampler(TorchSampler): assert isinstance(state, SampleStateMTP) state.sampler_event.synchronize() - new_tokens = state.host.new_tokens + new_tokens = state.host.new_tokens.tolist() new_tokens_lens_list = state.host.new_tokens_lens.tolist() next_draft_tokens_list = state.host.next_draft_tokens.tolist() beam_idx = self.BEAM