fix: redrafter sampling (#3278)

* Fix redrafter sampling

Signed-off-by: Ivan Sorokin <isorokin@nvidia.com>

* Rename redrafter bream search var

Signed-off-by: Ivan Sorokin <isorokin@nvidia.com>

* Remove _beam_search_candidates_v0

Signed-off-by: Ivan Sorokin <isorokin@nvidia.com>

* Remove unused import

Signed-off-by: Ivan Sorokin <isorokin@nvidia.com>

---------

Signed-off-by: Ivan Sorokin <isorokin@nvidia.com>
This commit is contained in:
Ivan Sorokin 2025-04-08 02:49:32 +03:00 committed by GitHub
parent ba019a43d6
commit d40fce474a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,3 @@
import warnings
from typing import Tuple
import numpy as np
@ -423,7 +422,7 @@ def _beam_search_candidates(prompt_state: Tensor, init_token: Tensor,
state_shape = shape(context, cast_to_dtype=INT_DTYPE_STR) # [bs, nb, H]
state = expand(expand_dims(constant_to_tensor_(0.0, dtype=dtype), [0, 1]),
state_shape) # [bs, nb, H]
logits_token_in_beam = None
log_p_token_in_beam = None
candidate_length = beam_length - 1
for _ in range(candidate_length):
state = (
@ -464,126 +463,22 @@ def _beam_search_candidates(prompt_state: Tensor, init_token: Tensor,
state = _gather_beams(state, gather_indices, batch_size,
num_beams) # [bs, nb, H]
cur_logits_token_in_beam = unsqueeze(
_gather_beams(logits_new_token, gather_indices, batch_size,
cur_log_p_token_in_beam = unsqueeze(
_gather_beams(log_p_new_token, gather_indices, batch_size,
num_beams), 2) # [bs, nb, 1, V]
if logits_token_in_beam is None: # first iteration
logits_token_in_beam = cur_logits_token_in_beam
if log_p_token_in_beam is None: # first iteration
log_p_token_in_beam = cur_log_p_token_in_beam
else:
logits_token_in_beam = concat(
log_p_token_in_beam = concat(
[
_gather_beams(logits_token_in_beam, gather_indices,
_gather_beams(log_p_token_in_beam, gather_indices,
batch_size,
num_beams), # prev_top_logits [bs, nb, i, V]
cur_logits_token_in_beam
cur_log_p_token_in_beam
],
dim=2) # [bs, nb, i+1, V]
last_tokens = top_token_ids # [bs, nb]
return beams, logits_token_in_beam
def _beam_search_candidates_v0(x: Tensor, init_token: Tensor,
embedding: Embedding, drafter: Module,
num_beams: int, beam_length: int,
is_rnn: bool) -> Tuple[Tensor, Tensor]:
'''
x: [bs, H]
init_token: [bs]
Returns:
draft_tokens: (batch, num_beams, beam_length)
Draft tokens are appended at after the true token.
draft_probs: (batch, num_beams, beam_length - 1, vocab_size)
Probabilities for the draft_tokens.
'''
warnings.warn(
"This version of beam search is deprecated and will be removed in the future."
)
NEG_INF = -50000.0
batch_size = shape(x, 0, INT_DTYPE_STR)
vocab_size = embedding.num_embeddings
scores = constant(
numpy_array([0.0] + [NEG_INF] * (num_beams - 1),
trt_dtype=x.dtype)) # [nb]
scores = expand(unsqueeze(scores, 0), concat([batch_size,
num_beams])) # [bs, nb]
x = _add_decoding_dim(x, num_beams) # [bs, H] => [bs, nb, H]
if init_token.ndim() == 1:
init_token = unsqueeze(init_token, -1) # [bs] => [bs, 1]
draft_tokens = _add_decoding_dim(
init_token, num_beams=num_beams) # [bs, 1] => [bs, nb, 1]
last_tokens = squeeze(draft_tokens, -1) # [bs, nb]
emb_shape = shape(x, cast_to_dtype=INT_DTYPE_STR) # [bs, nb, H]
prev_embs = expand(
expand_dims(constant_to_tensor_(0.0, dtype=x.dtype), [0, 1]),
emb_shape) # [bs, nb, H]
draft_probs = None
candidate_length = beam_length - 1
assert candidate_length > 0
for i in range(candidate_length):
# gather embeddings
cur_embs = embedding(last_tokens)
if is_rnn:
prev_embs = drafter.rnn_embed(cur_embs,
None if i == 0 else prev_embs)
else:
prev_embs = cur_embs + prev_embs # [bs, nb, H]
h = concat([x, prev_embs], dim=-1) # [bs, nb, 2H]
# run drafter
new_flat_logits = drafter(_flatten_decoding_dim(
h)) # [bs, nb, 2H] => [bs*nb, 2H] => [bs*nb, V]
new_flat_log_probs = log_softmax(new_flat_logits, dim=-1) # [bs*nb, V]
# compute probabilities and flatten the beams for topk
candidate_log_probs = _unflatten_decoding_dim(
new_flat_log_probs, num_beams) # [bs*nb, V] => [bs, nb, V]
log_probs = candidate_log_probs + unsqueeze(scores, 2) # [bs, nb, V]
flat_log_probs = view(log_probs,
concat([batch_size, num_beams * vocab_size
])) # [bs, nb, V] => [bs, nb*V]
# get topk choices from all beams
topk_log_probs, topk_indices = topk(
flat_log_probs, k=num_beams,
dim=-1) # [bs, nb*V] => [bs, nb], [bs, nb]
# topk_ids = modulo(topk_indices, vocab_size) # [bs. nb]
topk_beam_indices = floordiv(topk_indices, vocab_size) # [bs, nb]
topk_ids = topk_indices - (topk_beam_indices * vocab_size
) # topk_indices % vocab_size, [bs. nb]
# get the common indices to gather beams
batch_size = shape(x, 0, INT_DTYPE_STR)
gather_indices = _get_indices_for_gather_beams(batch_size,
topk_beam_indices,
num_beams)
# update running draft_tokens, embeddings, and logits
topk_tokens = _gather_beams(draft_tokens, gather_indices, batch_size,
num_beams) # [bs, nb] OR [bs, nb, 1]
if topk_tokens.ndim() == 2:
topk_tokens = unsqueeze(topk_tokens, -1) # [bs, nb, 1]
new_tokens = unsqueeze(topk_ids, -1)
draft_tokens = concat([topk_tokens, new_tokens], -1) # [bs, nb, 1+_+1]
scores = topk_log_probs # save the new scores
prev_embs = _gather_beams(prev_embs, gather_indices, batch_size,
num_beams) # [bs, nb, H]
topk_probs = _gather_beams(candidate_log_probs, gather_indices,
batch_size, num_beams) # [bs, nb, V]
topk_probs = unsqueeze(topk_probs, 2) # [bs, nb, V] => [bs, nb, 1, V]
if draft_probs is None:
draft_probs = topk_probs # [bs, nb, 1, V]
else:
draft_probs = _gather_beams(draft_probs, gather_indices, batch_size,
num_beams) # [bs, nb, _, V]
draft_probs = concat([draft_probs, topk_probs],
dim=2) # => [bs, nb, _+1, V]
# mark the new draft_tokens as last for the next iteration
last_tokens = topk_ids # [bs, nb]
return draft_tokens, draft_probs # draft_tokens [bs, nb, bl], draft_logits [bs, nb, bl-1, V]
return beams, log_p_token_in_beam
def _top_1_logits(logits: Tensor, NINF=-50000.0) -> Tensor: