mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][fix] restore list[list[list[int]]] in add_token (#8502)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
85d5aa7763
commit
87eb5086fb
@ -290,10 +290,13 @@ def _group_requests_by_strategy_key(
|
||||
}
|
||||
|
||||
|
||||
def add_token(request: LlmRequest, new_tokens: torch.Tensor, *, beam: int, step: int = 0) -> int:
|
||||
def add_token(
|
||||
request: LlmRequest, new_tokens: list[list[list[int]]], *, beam: int, step: int = 0
|
||||
) -> int:
|
||||
# NB: Accessing nested lists faster than torch.Tensor or numpy.ndarray
|
||||
seq_slot = request.py_seq_slot
|
||||
assert seq_slot is not None
|
||||
new_token = cast(int, new_tokens[step][seq_slot][beam].item())
|
||||
new_token = new_tokens[step][seq_slot][beam]
|
||||
request.add_new_token(new_token, beam)
|
||||
return new_token
|
||||
|
||||
@ -700,7 +703,7 @@ class TorchSampler(Sampler):
|
||||
def _process_draft_tokens_greedy(
|
||||
self,
|
||||
request: LlmRequest,
|
||||
new_tokens: torch.Tensor,
|
||||
new_tokens: list[list[list[int]]],
|
||||
) -> int:
|
||||
new_token = add_token(request, new_tokens, beam=self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
@ -722,7 +725,8 @@ class TorchSampler(Sampler):
|
||||
def _process_draft_tokens_tree(
|
||||
self,
|
||||
request: LlmRequest,
|
||||
new_tokens: torch.Tensor,
|
||||
new_tokens_tensor: torch.Tensor,
|
||||
new_tokens_list: list[list[list[int]]],
|
||||
spec_tree_manager: SpecTreeManager,
|
||||
) -> int:
|
||||
"""Tree verification for draft token tree based speculative decoding.
|
||||
@ -757,7 +761,7 @@ class TorchSampler(Sampler):
|
||||
# TODO: For the last layer of the dynamic tree, we need to resampling all the draft tokens.
|
||||
cur_layer_num_nodes = sum(spec_tree_manager.get_top_k_list(cur_draft_layer_idx))
|
||||
for i in range(cur_layer_num_nodes):
|
||||
new_token = add_token(request, new_tokens, beam=0, step=i)
|
||||
new_token = add_token(request, new_tokens_list, beam=0, step=i)
|
||||
return 0
|
||||
else:
|
||||
# handle the target model request
|
||||
@ -767,7 +771,9 @@ class TorchSampler(Sampler):
|
||||
eagle_paths = spec_tree_manager.get_eagle_paths(seq_slot)
|
||||
|
||||
all_draft_tokens = request.py_draft_tokens # [max_total_draft_tokens]
|
||||
all_target_tokens = new_tokens[:, seq_slot, :].squeeze(-1) # [max_total_draft_tokens]
|
||||
all_target_tokens = new_tokens_tensor[:, seq_slot, :].squeeze(
|
||||
-1
|
||||
) # [max_total_draft_tokens]
|
||||
assert all_target_tokens.shape[0] == spec_tree_manager.max_total_draft_tokens + 1
|
||||
|
||||
longest_accepted_len = 0
|
||||
@ -800,13 +806,15 @@ class TorchSampler(Sampler):
|
||||
if longest_accepted_len == 0:
|
||||
# No draft tokens are accepted.
|
||||
# Take the top-1 token of the first layer as the next new token.
|
||||
new_token = add_token(request, new_tokens, beam=0, step=0)
|
||||
new_token = add_token(request, new_tokens_list, beam=0, step=0)
|
||||
return 0
|
||||
else:
|
||||
# Take the longest accepted path as the next new token.
|
||||
num_accepted_draft_tokens = 0
|
||||
for idx in eagle_paths[longest_match_path_idx][:longest_accepted_len]:
|
||||
new_token = add_token(request, new_tokens, beam=0, step=cast(int, idx.item()))
|
||||
new_token = add_token(
|
||||
request, new_tokens_list, beam=0, step=cast(int, idx.item())
|
||||
)
|
||||
num_accepted_draft_tokens += 1
|
||||
if self._handle_stop_criteria(request, new_token):
|
||||
break
|
||||
@ -876,8 +884,10 @@ class TorchSampler(Sampler):
|
||||
def _process_draft_tokens_rejection_sampling(
|
||||
self,
|
||||
request: LlmRequest,
|
||||
new_tokens: torch.Tensor,
|
||||
new_tokens_list: list[list[list[int]]],
|
||||
new_tokens_tensor: torch.Tensor,
|
||||
) -> int:
|
||||
assert request.py_draft_logits is not None
|
||||
# FIXME: Passing a dummy vocab_size could result in unnecessary
|
||||
# filtering of vocab_size logits, out of vocab_size in
|
||||
# total. The 'sample' below should generally be avoided
|
||||
@ -893,7 +903,9 @@ class TorchSampler(Sampler):
|
||||
request.py_draft_logits,
|
||||
generator=generator,
|
||||
)
|
||||
assert draft_probs is not None
|
||||
target_probs = request.py_target_probs
|
||||
assert target_probs is not None
|
||||
d2t = getattr(request, "d2t", None)
|
||||
if d2t is not None:
|
||||
vocab_d = draft_probs.shape[-1]
|
||||
@ -927,7 +939,7 @@ class TorchSampler(Sampler):
|
||||
num_accepted = num_initially_accepted
|
||||
for i in range(num_accepted):
|
||||
new_token = request.py_draft_tokens[i]
|
||||
new_tokens[i, request.seq_slot, self.BEAM] = new_token
|
||||
new_tokens_tensor[i, request.seq_slot, self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
if stop:
|
||||
@ -935,10 +947,10 @@ class TorchSampler(Sampler):
|
||||
return num_accepted
|
||||
if sample_last:
|
||||
new_token = sample_rejected(draft_probs, target_probs, generator, num_accepted)
|
||||
new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token
|
||||
new_tokens_tensor[num_accepted, request.seq_slot, self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
else:
|
||||
new_token = add_token(request, new_tokens, beam=self.BEAM, step=num_accepted)
|
||||
new_token = add_token(request, new_tokens_list, beam=self.BEAM, step=num_accepted)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
|
||||
return num_accepted
|
||||
@ -946,7 +958,8 @@ class TorchSampler(Sampler):
|
||||
def process_draft_tokens(
|
||||
self,
|
||||
request: LlmRequest,
|
||||
new_tokens: torch.Tensor,
|
||||
new_tokens_tensor: torch.Tensor,
|
||||
new_tokens_list: list[list[list[int]]],
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> int:
|
||||
if (
|
||||
@ -957,14 +970,19 @@ class TorchSampler(Sampler):
|
||||
if spec_tree_manager is not None:
|
||||
num_accepted = self._process_draft_tokens_tree(
|
||||
request,
|
||||
new_tokens=new_tokens,
|
||||
new_tokens_tensor=new_tokens_tensor,
|
||||
new_tokens_list=new_tokens_list,
|
||||
spec_tree_manager=spec_tree_manager,
|
||||
)
|
||||
else:
|
||||
num_accepted = self._process_draft_tokens_greedy(request, new_tokens=new_tokens)
|
||||
num_accepted = self._process_draft_tokens_greedy(
|
||||
request, new_tokens=new_tokens_list
|
||||
)
|
||||
return num_accepted
|
||||
else:
|
||||
return self._process_draft_tokens_rejection_sampling(request, new_tokens)
|
||||
return self._process_draft_tokens_rejection_sampling(
|
||||
request, new_tokens_list=new_tokens_list, new_tokens_tensor=new_tokens_tensor
|
||||
)
|
||||
|
||||
@override
|
||||
def update_requests(
|
||||
@ -976,7 +994,9 @@ class TorchSampler(Sampler):
|
||||
if state.sampler_event:
|
||||
state.sampler_event.synchronize()
|
||||
|
||||
assert state.host is not None
|
||||
new_tokens = state.host.new_tokens
|
||||
new_tokens_list = new_tokens.tolist()
|
||||
|
||||
for req in state.scheduled_requests.context_requests:
|
||||
if (
|
||||
@ -984,7 +1004,7 @@ class TorchSampler(Sampler):
|
||||
or req.context_remaining_length != 0
|
||||
):
|
||||
continue
|
||||
new_token = add_token(req, new_tokens, beam=self.BEAM)
|
||||
new_token = add_token(req, new_tokens_list, beam=self.BEAM)
|
||||
self._handle_stop_criteria(req, new_token)
|
||||
self.handle_logprobs(req, state, beam=self.BEAM, count=1)
|
||||
req.py_decoding_iter += 1
|
||||
@ -993,7 +1013,12 @@ class TorchSampler(Sampler):
|
||||
if req.state == LlmRequestState.GENERATION_COMPLETE:
|
||||
continue
|
||||
processed = 1
|
||||
num_accepted = self.process_draft_tokens(req, new_tokens, resource_manager)
|
||||
num_accepted = self.process_draft_tokens(
|
||||
req,
|
||||
new_tokens_tensor=new_tokens,
|
||||
new_tokens_list=new_tokens_list,
|
||||
resource_manager=resource_manager,
|
||||
)
|
||||
if get_draft_token_length(req) > 0:
|
||||
req.py_num_accepted_draft_tokens = num_accepted
|
||||
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
|
||||
@ -1911,7 +1936,7 @@ class TRTLLMSampler(Sampler):
|
||||
state: SampleStateTRTLLM,
|
||||
beam_width: int,
|
||||
):
|
||||
new_tokens_host = state.host.new_tokens
|
||||
new_tokens_host = state.host.new_tokens.tolist()
|
||||
finished_sum_host = state.host.finished_sum.tolist()
|
||||
finish_reasons = state.host.finish_reasons.flatten().tolist()
|
||||
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
|
||||
|
||||
@ -256,7 +256,7 @@ class MTPSampler(TorchSampler):
|
||||
assert isinstance(state, SampleStateMTP)
|
||||
|
||||
state.sampler_event.synchronize()
|
||||
new_tokens = state.host.new_tokens
|
||||
new_tokens = state.host.new_tokens.tolist()
|
||||
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
|
||||
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
|
||||
beam_idx = self.BEAM
|
||||
|
||||
@ -45,9 +45,11 @@ def run_test(eagle_model_dir, max_seq_len, beam_width, use_dynamic_tree,
|
||||
max_beam_width=beam_width,
|
||||
))
|
||||
|
||||
input_new_tokens_list = input_new_tokens.tolist()
|
||||
num_accepted_draft_tokens = torch_sampler._process_draft_tokens_tree(
|
||||
request=input_request,
|
||||
new_tokens=input_new_tokens,
|
||||
new_tokens_tensor=input_new_tokens,
|
||||
new_tokens_list=input_new_tokens_list,
|
||||
spec_tree_manager=spec_tree_manager)
|
||||
|
||||
print(f"num_accepted_draft_tokens: {num_accepted_draft_tokens}")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -24,8 +25,11 @@ def test_get_rejected_indices():
|
||||
sampled_regular = []
|
||||
for _ in range(num_iter):
|
||||
draft_tokens = [
|
||||
torch.multinomial(draft_probs, num_samples=1,
|
||||
generator=generator).item()
|
||||
cast(
|
||||
int,
|
||||
torch.multinomial(draft_probs,
|
||||
num_samples=1,
|
||||
generator=generator).item())
|
||||
]
|
||||
rejected_indices = get_rejected_indices(draft_probs, target_probs,
|
||||
generator, draft_tokens)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user