mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-13 06:23:57 +08:00
[None][feat] Use list instead of torch tensor for new tokens in update requests (#7730)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
parent
6a36349964
commit
9f1d9b7b18
@ -340,13 +340,13 @@ def sample(strategy: Strategy,
|
||||
|
||||
|
||||
def add_token(request: LlmRequest,
|
||||
new_tokens: torch.Tensor,
|
||||
new_tokens: list[list[list[int]]],
|
||||
*,
|
||||
beam: int,
|
||||
step: int = 0) -> int:
|
||||
seq_slot = request.py_seq_slot
|
||||
assert seq_slot is not None
|
||||
new_token = int(new_tokens[step, seq_slot, beam])
|
||||
new_token = new_tokens[step][seq_slot][beam]
|
||||
request.add_new_token(new_token, beam)
|
||||
return new_token
|
||||
|
||||
@ -513,7 +513,7 @@ class TorchSampler(Sampler):
|
||||
for i in range(num_accepted):
|
||||
new_token = request.py_draft_tokens[i]
|
||||
new_tokens[i, request.seq_slot, self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
new_token = add_token(request, new_tokens, beam=self.BEAM, step=i)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
if stop:
|
||||
num_accepted = i + 1
|
||||
@ -522,14 +522,11 @@ class TorchSampler(Sampler):
|
||||
new_token = sample_rejected(draft_probs, target_probs, generator,
|
||||
num_accepted)
|
||||
new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token
|
||||
request.add_new_token(new_token, self.BEAM)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
else:
|
||||
new_token = add_token(request,
|
||||
new_tokens,
|
||||
beam=self.BEAM,
|
||||
step=num_accepted)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
new_token = add_token(request,
|
||||
new_tokens,
|
||||
beam=self.BEAM,
|
||||
step=num_accepted)
|
||||
stop = self._handle_stop_criteria(request, new_token)
|
||||
|
||||
return num_accepted
|
||||
|
||||
@ -545,7 +542,8 @@ class TorchSampler(Sampler):
|
||||
assert isinstance(state, SampleState)
|
||||
if state.sampler_event:
|
||||
state.sampler_event.synchronize()
|
||||
new_tokens = state.host.new_tokens
|
||||
|
||||
new_tokens = state.host.new_tokens.tolist()
|
||||
|
||||
for req in state.scheduled_requests.context_requests:
|
||||
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
|
||||
|
||||
@ -250,7 +250,7 @@ class MTPSampler(TorchSampler):
|
||||
assert isinstance(state, SampleStateMTP)
|
||||
|
||||
state.sampler_event.synchronize()
|
||||
new_tokens = state.host.new_tokens
|
||||
new_tokens = state.host.new_tokens.tolist()
|
||||
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
|
||||
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
|
||||
beam_idx = self.BEAM
|
||||
|
||||
Loading…
Reference in New Issue
Block a user