[None][feat] Use list instead of torch tensor for new tokens in update requests (#7730)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Daniel Cámpora 2025-09-23 16:40:08 +02:00 committed by GitHub
parent 6a36349964
commit 9f1d9b7b18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 13 deletions

View File

@ -340,13 +340,13 @@ def sample(strategy: Strategy,
def add_token(request: LlmRequest,
new_tokens: torch.Tensor,
new_tokens: list[list[list[int]]],
*,
beam: int,
step: int = 0) -> int:
seq_slot = request.py_seq_slot
assert seq_slot is not None
new_token = int(new_tokens[step, seq_slot, beam])
new_token = new_tokens[step][seq_slot][beam]
request.add_new_token(new_token, beam)
return new_token
@ -513,7 +513,7 @@ class TorchSampler(Sampler):
for i in range(num_accepted):
new_token = request.py_draft_tokens[i]
new_tokens[i, request.seq_slot, self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
new_token = add_token(request, new_tokens, beam=self.BEAM, step=i)
stop = self._handle_stop_criteria(request, new_token)
if stop:
num_accepted = i + 1
@ -522,14 +522,11 @@ class TorchSampler(Sampler):
new_token = sample_rejected(draft_probs, target_probs, generator,
num_accepted)
new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token
request.add_new_token(new_token, self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
else:
new_token = add_token(request,
new_tokens,
beam=self.BEAM,
step=num_accepted)
stop = self._handle_stop_criteria(request, new_token)
new_token = add_token(request,
new_tokens,
beam=self.BEAM,
step=num_accepted)
stop = self._handle_stop_criteria(request, new_token)
return num_accepted
@ -545,7 +542,8 @@ class TorchSampler(Sampler):
assert isinstance(state, SampleState)
if state.sampler_event:
state.sampler_event.synchronize()
new_tokens = state.host.new_tokens
new_tokens = state.host.new_tokens.tolist()
for req in state.scheduled_requests.context_requests:
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:

View File

@ -250,7 +250,7 @@ class MTPSampler(TorchSampler):
assert isinstance(state, SampleStateMTP)
state.sampler_event.synchronize()
new_tokens = state.host.new_tokens
new_tokens = state.host.new_tokens.tolist()
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
beam_idx = self.BEAM