mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
fix: redrafter sampling (#3278)
* Fix redrafter sampling Signed-off-by: Ivan Sorokin <isorokin@nvidia.com> * Rename redrafter bream search var Signed-off-by: Ivan Sorokin <isorokin@nvidia.com> * Remove _beam_search_candidates_v0 Signed-off-by: Ivan Sorokin <isorokin@nvidia.com> * Remove unused import Signed-off-by: Ivan Sorokin <isorokin@nvidia.com> --------- Signed-off-by: Ivan Sorokin <isorokin@nvidia.com>
This commit is contained in:
parent
ba019a43d6
commit
d40fce474a
@ -1,4 +1,3 @@
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -423,7 +422,7 @@ def _beam_search_candidates(prompt_state: Tensor, init_token: Tensor,
|
||||
state_shape = shape(context, cast_to_dtype=INT_DTYPE_STR) # [bs, nb, H]
|
||||
state = expand(expand_dims(constant_to_tensor_(0.0, dtype=dtype), [0, 1]),
|
||||
state_shape) # [bs, nb, H]
|
||||
logits_token_in_beam = None
|
||||
log_p_token_in_beam = None
|
||||
candidate_length = beam_length - 1
|
||||
for _ in range(candidate_length):
|
||||
state = (
|
||||
@ -464,126 +463,22 @@ def _beam_search_candidates(prompt_state: Tensor, init_token: Tensor,
|
||||
state = _gather_beams(state, gather_indices, batch_size,
|
||||
num_beams) # [bs, nb, H]
|
||||
|
||||
cur_logits_token_in_beam = unsqueeze(
|
||||
_gather_beams(logits_new_token, gather_indices, batch_size,
|
||||
cur_log_p_token_in_beam = unsqueeze(
|
||||
_gather_beams(log_p_new_token, gather_indices, batch_size,
|
||||
num_beams), 2) # [bs, nb, 1, V]
|
||||
if logits_token_in_beam is None: # first iteration
|
||||
logits_token_in_beam = cur_logits_token_in_beam
|
||||
if log_p_token_in_beam is None: # first iteration
|
||||
log_p_token_in_beam = cur_log_p_token_in_beam
|
||||
else:
|
||||
logits_token_in_beam = concat(
|
||||
log_p_token_in_beam = concat(
|
||||
[
|
||||
_gather_beams(logits_token_in_beam, gather_indices,
|
||||
_gather_beams(log_p_token_in_beam, gather_indices,
|
||||
batch_size,
|
||||
num_beams), # prev_top_logits [bs, nb, i, V]
|
||||
cur_logits_token_in_beam
|
||||
cur_log_p_token_in_beam
|
||||
],
|
||||
dim=2) # [bs, nb, i+1, V]
|
||||
last_tokens = top_token_ids # [bs, nb]
|
||||
return beams, logits_token_in_beam
|
||||
|
||||
|
||||
def _beam_search_candidates_v0(x: Tensor, init_token: Tensor,
|
||||
embedding: Embedding, drafter: Module,
|
||||
num_beams: int, beam_length: int,
|
||||
is_rnn: bool) -> Tuple[Tensor, Tensor]:
|
||||
'''
|
||||
x: [bs, H]
|
||||
init_token: [bs]
|
||||
|
||||
Returns:
|
||||
draft_tokens: (batch, num_beams, beam_length)
|
||||
Draft tokens are appended at after the true token.
|
||||
|
||||
draft_probs: (batch, num_beams, beam_length - 1, vocab_size)
|
||||
Probabilities for the draft_tokens.
|
||||
'''
|
||||
warnings.warn(
|
||||
"This version of beam search is deprecated and will be removed in the future."
|
||||
)
|
||||
NEG_INF = -50000.0
|
||||
batch_size = shape(x, 0, INT_DTYPE_STR)
|
||||
vocab_size = embedding.num_embeddings
|
||||
scores = constant(
|
||||
numpy_array([0.0] + [NEG_INF] * (num_beams - 1),
|
||||
trt_dtype=x.dtype)) # [nb]
|
||||
scores = expand(unsqueeze(scores, 0), concat([batch_size,
|
||||
num_beams])) # [bs, nb]
|
||||
x = _add_decoding_dim(x, num_beams) # [bs, H] => [bs, nb, H]
|
||||
if init_token.ndim() == 1:
|
||||
init_token = unsqueeze(init_token, -1) # [bs] => [bs, 1]
|
||||
draft_tokens = _add_decoding_dim(
|
||||
init_token, num_beams=num_beams) # [bs, 1] => [bs, nb, 1]
|
||||
last_tokens = squeeze(draft_tokens, -1) # [bs, nb]
|
||||
emb_shape = shape(x, cast_to_dtype=INT_DTYPE_STR) # [bs, nb, H]
|
||||
prev_embs = expand(
|
||||
expand_dims(constant_to_tensor_(0.0, dtype=x.dtype), [0, 1]),
|
||||
emb_shape) # [bs, nb, H]
|
||||
draft_probs = None
|
||||
candidate_length = beam_length - 1
|
||||
assert candidate_length > 0
|
||||
for i in range(candidate_length):
|
||||
# gather embeddings
|
||||
cur_embs = embedding(last_tokens)
|
||||
if is_rnn:
|
||||
prev_embs = drafter.rnn_embed(cur_embs,
|
||||
None if i == 0 else prev_embs)
|
||||
else:
|
||||
prev_embs = cur_embs + prev_embs # [bs, nb, H]
|
||||
h = concat([x, prev_embs], dim=-1) # [bs, nb, 2H]
|
||||
|
||||
# run drafter
|
||||
new_flat_logits = drafter(_flatten_decoding_dim(
|
||||
h)) # [bs, nb, 2H] => [bs*nb, 2H] => [bs*nb, V]
|
||||
new_flat_log_probs = log_softmax(new_flat_logits, dim=-1) # [bs*nb, V]
|
||||
|
||||
# compute probabilities and flatten the beams for topk
|
||||
candidate_log_probs = _unflatten_decoding_dim(
|
||||
new_flat_log_probs, num_beams) # [bs*nb, V] => [bs, nb, V]
|
||||
log_probs = candidate_log_probs + unsqueeze(scores, 2) # [bs, nb, V]
|
||||
flat_log_probs = view(log_probs,
|
||||
concat([batch_size, num_beams * vocab_size
|
||||
])) # [bs, nb, V] => [bs, nb*V]
|
||||
|
||||
# get topk choices from all beams
|
||||
topk_log_probs, topk_indices = topk(
|
||||
flat_log_probs, k=num_beams,
|
||||
dim=-1) # [bs, nb*V] => [bs, nb], [bs, nb]
|
||||
|
||||
# topk_ids = modulo(topk_indices, vocab_size) # [bs. nb]
|
||||
topk_beam_indices = floordiv(topk_indices, vocab_size) # [bs, nb]
|
||||
topk_ids = topk_indices - (topk_beam_indices * vocab_size
|
||||
) # topk_indices % vocab_size, [bs. nb]
|
||||
|
||||
# get the common indices to gather beams
|
||||
batch_size = shape(x, 0, INT_DTYPE_STR)
|
||||
gather_indices = _get_indices_for_gather_beams(batch_size,
|
||||
topk_beam_indices,
|
||||
num_beams)
|
||||
|
||||
# update running draft_tokens, embeddings, and logits
|
||||
topk_tokens = _gather_beams(draft_tokens, gather_indices, batch_size,
|
||||
num_beams) # [bs, nb] OR [bs, nb, 1]
|
||||
if topk_tokens.ndim() == 2:
|
||||
topk_tokens = unsqueeze(topk_tokens, -1) # [bs, nb, 1]
|
||||
new_tokens = unsqueeze(topk_ids, -1)
|
||||
draft_tokens = concat([topk_tokens, new_tokens], -1) # [bs, nb, 1+_+1]
|
||||
scores = topk_log_probs # save the new scores
|
||||
prev_embs = _gather_beams(prev_embs, gather_indices, batch_size,
|
||||
num_beams) # [bs, nb, H]
|
||||
topk_probs = _gather_beams(candidate_log_probs, gather_indices,
|
||||
batch_size, num_beams) # [bs, nb, V]
|
||||
topk_probs = unsqueeze(topk_probs, 2) # [bs, nb, V] => [bs, nb, 1, V]
|
||||
if draft_probs is None:
|
||||
draft_probs = topk_probs # [bs, nb, 1, V]
|
||||
else:
|
||||
draft_probs = _gather_beams(draft_probs, gather_indices, batch_size,
|
||||
num_beams) # [bs, nb, _, V]
|
||||
draft_probs = concat([draft_probs, topk_probs],
|
||||
dim=2) # => [bs, nb, _+1, V]
|
||||
|
||||
# mark the new draft_tokens as last for the next iteration
|
||||
last_tokens = topk_ids # [bs, nb]
|
||||
return draft_tokens, draft_probs # draft_tokens [bs, nb, bl], draft_logits [bs, nb, bl-1, V]
|
||||
return beams, log_p_token_in_beam
|
||||
|
||||
|
||||
def _top_1_logits(logits: Tensor, NINF=-50000.0) -> Tensor:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user