[None][fix] restore list[list[list[int]]] in add_token (#8502)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2025-10-21 04:34:57 +02:00 committed by GitHub
parent 85d5aa7763
commit 87eb5086fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 54 additions and 23 deletions

View File

@ -290,10 +290,13 @@ def _group_requests_by_strategy_key(
}
def add_token(request: LlmRequest, new_tokens: torch.Tensor, *, beam: int, step: int = 0) -> int:
def add_token(
request: LlmRequest, new_tokens: list[list[list[int]]], *, beam: int, step: int = 0
) -> int:
# NB: Accessing nested lists faster than torch.Tensor or numpy.ndarray
seq_slot = request.py_seq_slot
assert seq_slot is not None
new_token = cast(int, new_tokens[step][seq_slot][beam].item())
new_token = new_tokens[step][seq_slot][beam]
request.add_new_token(new_token, beam)
return new_token
@ -700,7 +703,7 @@ class TorchSampler(Sampler):
def _process_draft_tokens_greedy(
self,
request: LlmRequest,
new_tokens: torch.Tensor,
new_tokens: list[list[list[int]]],
) -> int:
new_token = add_token(request, new_tokens, beam=self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
@ -722,7 +725,8 @@ class TorchSampler(Sampler):
def _process_draft_tokens_tree(
self,
request: LlmRequest,
new_tokens: torch.Tensor,
new_tokens_tensor: torch.Tensor,
new_tokens_list: list[list[list[int]]],
spec_tree_manager: SpecTreeManager,
) -> int:
"""Tree verification for draft token tree based speculative decoding.
@ -757,7 +761,7 @@ class TorchSampler(Sampler):
# TODO: For the last layer of the dynamic tree, we need to resampling all the draft tokens.
cur_layer_num_nodes = sum(spec_tree_manager.get_top_k_list(cur_draft_layer_idx))
for i in range(cur_layer_num_nodes):
new_token = add_token(request, new_tokens, beam=0, step=i)
new_token = add_token(request, new_tokens_list, beam=0, step=i)
return 0
else:
# handle the target model request
@ -767,7 +771,9 @@ class TorchSampler(Sampler):
eagle_paths = spec_tree_manager.get_eagle_paths(seq_slot)
all_draft_tokens = request.py_draft_tokens # [max_total_draft_tokens]
all_target_tokens = new_tokens[:, seq_slot, :].squeeze(-1) # [max_total_draft_tokens]
all_target_tokens = new_tokens_tensor[:, seq_slot, :].squeeze(
-1
) # [max_total_draft_tokens]
assert all_target_tokens.shape[0] == spec_tree_manager.max_total_draft_tokens + 1
longest_accepted_len = 0
@ -800,13 +806,15 @@ class TorchSampler(Sampler):
if longest_accepted_len == 0:
# No draft tokens are accepted.
# Take the top-1 token of the first layer as the next new token.
new_token = add_token(request, new_tokens, beam=0, step=0)
new_token = add_token(request, new_tokens_list, beam=0, step=0)
return 0
else:
# Take the longest accepted path as the next new token.
num_accepted_draft_tokens = 0
for idx in eagle_paths[longest_match_path_idx][:longest_accepted_len]:
new_token = add_token(request, new_tokens, beam=0, step=cast(int, idx.item()))
new_token = add_token(
request, new_tokens_list, beam=0, step=cast(int, idx.item())
)
num_accepted_draft_tokens += 1
if self._handle_stop_criteria(request, new_token):
break
@ -876,8 +884,10 @@ class TorchSampler(Sampler):
def _process_draft_tokens_rejection_sampling(
self,
request: LlmRequest,
new_tokens: torch.Tensor,
new_tokens_list: list[list[list[int]]],
new_tokens_tensor: torch.Tensor,
) -> int:
assert request.py_draft_logits is not None
# FIXME: Passing a dummy vocab_size could result in unnecessary
# filtering of vocab_size logits, out of vocab_size in
# total. The 'sample' below should generally be avoided
@ -893,7 +903,9 @@ class TorchSampler(Sampler):
request.py_draft_logits,
generator=generator,
)
assert draft_probs is not None
target_probs = request.py_target_probs
assert target_probs is not None
d2t = getattr(request, "d2t", None)
if d2t is not None:
vocab_d = draft_probs.shape[-1]
@ -927,7 +939,7 @@ class TorchSampler(Sampler):
num_accepted = num_initially_accepted
for i in range(num_accepted):
new_token = request.py_draft_tokens[i]
new_tokens[i, request.seq_slot, self.BEAM] = new_token
new_tokens_tensor[i, request.seq_slot, self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
if stop:
@ -935,10 +947,10 @@ class TorchSampler(Sampler):
return num_accepted
if sample_last:
new_token = sample_rejected(draft_probs, target_probs, generator, num_accepted)
new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token
new_tokens_tensor[num_accepted, request.seq_slot, self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
else:
new_token = add_token(request, new_tokens, beam=self.BEAM, step=num_accepted)
new_token = add_token(request, new_tokens_list, beam=self.BEAM, step=num_accepted)
stop = self._handle_stop_criteria(request, new_token)
return num_accepted
@ -946,7 +958,8 @@ class TorchSampler(Sampler):
def process_draft_tokens(
self,
request: LlmRequest,
new_tokens: torch.Tensor,
new_tokens_tensor: torch.Tensor,
new_tokens_list: list[list[list[int]]],
resource_manager: Optional[ResourceManager] = None,
) -> int:
if (
@ -957,14 +970,19 @@ class TorchSampler(Sampler):
if spec_tree_manager is not None:
num_accepted = self._process_draft_tokens_tree(
request,
new_tokens=new_tokens,
new_tokens_tensor=new_tokens_tensor,
new_tokens_list=new_tokens_list,
spec_tree_manager=spec_tree_manager,
)
else:
num_accepted = self._process_draft_tokens_greedy(request, new_tokens=new_tokens)
num_accepted = self._process_draft_tokens_greedy(
request, new_tokens=new_tokens_list
)
return num_accepted
else:
return self._process_draft_tokens_rejection_sampling(request, new_tokens)
return self._process_draft_tokens_rejection_sampling(
request, new_tokens_list=new_tokens_list, new_tokens_tensor=new_tokens_tensor
)
@override
def update_requests(
@ -976,7 +994,9 @@ class TorchSampler(Sampler):
if state.sampler_event:
state.sampler_event.synchronize()
assert state.host is not None
new_tokens = state.host.new_tokens
new_tokens_list = new_tokens.tolist()
for req in state.scheduled_requests.context_requests:
if (
@ -984,7 +1004,7 @@ class TorchSampler(Sampler):
or req.context_remaining_length != 0
):
continue
new_token = add_token(req, new_tokens, beam=self.BEAM)
new_token = add_token(req, new_tokens_list, beam=self.BEAM)
self._handle_stop_criteria(req, new_token)
self.handle_logprobs(req, state, beam=self.BEAM, count=1)
req.py_decoding_iter += 1
@ -993,7 +1013,12 @@ class TorchSampler(Sampler):
if req.state == LlmRequestState.GENERATION_COMPLETE:
continue
processed = 1
num_accepted = self.process_draft_tokens(req, new_tokens, resource_manager)
num_accepted = self.process_draft_tokens(
req,
new_tokens_tensor=new_tokens,
new_tokens_list=new_tokens_list,
resource_manager=resource_manager,
)
if get_draft_token_length(req) > 0:
req.py_num_accepted_draft_tokens = num_accepted
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
@ -1911,7 +1936,7 @@ class TRTLLMSampler(Sampler):
state: SampleStateTRTLLM,
beam_width: int,
):
new_tokens_host = state.host.new_tokens
new_tokens_host = state.host.new_tokens.tolist()
finished_sum_host = state.host.finished_sum.tolist()
finish_reasons = state.host.finish_reasons.flatten().tolist()
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()

View File

@ -256,7 +256,7 @@ class MTPSampler(TorchSampler):
assert isinstance(state, SampleStateMTP)
state.sampler_event.synchronize()
new_tokens = state.host.new_tokens
new_tokens = state.host.new_tokens.tolist()
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
beam_idx = self.BEAM

View File

@ -45,9 +45,11 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree,
max_beam_width=beam_width,
))
input_new_tokens_list = input_new_tokens.tolist()
num_accepted_draft_tokens = torch_sampler._process_draft_tokens_tree(
request=input_request,
new_tokens=input_new_tokens,
new_tokens_tensor=input_new_tokens,
new_tokens_list=input_new_tokens_list,
spec_tree_manager=spec_tree_manager)
print(f"num_accepted_draft_tokens: {num_accepted_draft_tokens}")

View File

@ -1,4 +1,5 @@
import unittest
from typing import cast
import numpy as np
import torch
@ -24,8 +25,11 @@ def test_get_rejected_indices():
sampled_regular = []
for _ in range(num_iter):
draft_tokens = [
torch.multinomial(draft_probs, num_samples=1,
generator=generator).item()
cast(
int,
torch.multinomial(draft_probs,
num_samples=1,
generator=generator).item())
]
rejected_indices = get_rejected_indices(draft_probs, target_probs,
generator, draft_tokens)