mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[TRTLLM-10030][perf] avoid syncs in beam search + other improvements (#11349)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
2b60cc181c
commit
196d94a419
@ -766,6 +766,13 @@ class BeamHistory:
|
||||
cum_logprobs: torch.Tensor | None = None
|
||||
|
||||
|
||||
BeamHistoryBuilder: TypeAlias = Callable[[], BeamHistory | None]
|
||||
"""Builder for BeamHistory.
|
||||
|
||||
Used to defer possibly unnecessary host-tensor construction until update_requests().
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SamplingRequestsMetadata:
|
||||
req_num_generated_tokens: torch.Tensor
|
||||
@ -789,7 +796,7 @@ class SampleStateTensorsHostTorch(SampleStateTensors):
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleStateTorch(SampleState[SampleStateTensorsHostTorch, SampleStateTensors]):
|
||||
beam_histories: list[BeamHistory | None] | None = None
|
||||
beam_history_builders: list[BeamHistoryBuilder | None] | None = None
|
||||
|
||||
|
||||
class AsyncWorkerMixin:
|
||||
@ -1249,9 +1256,6 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
self,
|
||||
token_tensor: torch.Tensor,
|
||||
logprobs_tensor: torch.Tensor,
|
||||
sampled_log_probs_indices: torch.Tensor | None,
|
||||
sampled_log_probs_vals: torch.Tensor | None,
|
||||
sampled_log_probs_rank: torch.Tensor | None,
|
||||
) -> list[list[dict[int, Logprob]]]:
|
||||
"""Convert the logprobs tensor to a list of lists of dictionaries of Logprob objects
|
||||
|
||||
@ -1260,9 +1264,6 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
args:
|
||||
token_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs
|
||||
logprobs_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs
|
||||
sampled_log_probs_indices: torch.Tensor | None. Shape: num_tokens
|
||||
sampled_log_probs_vals: torch.Tensor | None. Shape: num_tokens
|
||||
sampled_log_probs_rank: torch.Tensor | None. Shape: num_tokens
|
||||
output:
|
||||
list[list[dict[int, Logprob]]]. Shape: (beam_width, num_tokens)
|
||||
"""
|
||||
@ -1274,38 +1275,13 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
token_log_probs: list[list[dict[int, Logprob]]] = []
|
||||
token_list = token_tensor.tolist()
|
||||
logprobs_list = logprobs_tensor.tolist()
|
||||
sampled_log_probs_indices_list: list[int] | None = None
|
||||
sampled_log_probs_vals_list: list[float] | None = None
|
||||
sampled_log_probs_rank_list: list[int] | None = None
|
||||
if sampled_log_probs_indices is not None:
|
||||
sampled_log_probs_indices_list = sampled_log_probs_indices.tolist()
|
||||
assert sampled_log_probs_vals is not None, "sampled_log_probs_vals must be provided"
|
||||
assert sampled_log_probs_rank is not None, "sampled_log_probs_rank must be provided"
|
||||
sampled_log_probs_vals_list = sampled_log_probs_vals.tolist()
|
||||
sampled_log_probs_rank_list = sampled_log_probs_rank.tolist()
|
||||
for beam_idx in range(token_tensor.shape[0]):
|
||||
beam_token_log_probs: list[dict[int, Logprob]] = []
|
||||
for step_idx, (topk_token, topk_logprob) in enumerate(
|
||||
zip(token_list[beam_idx], logprobs_list[beam_idx])
|
||||
):
|
||||
for topk_token, topk_logprob in zip(token_list[beam_idx], logprobs_list[beam_idx]):
|
||||
logprobs = {
|
||||
token: Logprob(logprob=logprob, rank=rank + 1)
|
||||
for rank, (token, logprob) in enumerate(zip(topk_token, topk_logprob))
|
||||
}
|
||||
if sampled_log_probs_indices is not None:
|
||||
assert beam_idx == DEFAULT_BEAM_IDX, (
|
||||
"beam search does not need to explicitly handle sampled log probs"
|
||||
)
|
||||
assert sampled_log_probs_indices_list is not None
|
||||
assert sampled_log_probs_vals_list is not None
|
||||
assert sampled_log_probs_rank_list is not None
|
||||
if sampled_log_probs_indices_list[step_idx] not in logprobs:
|
||||
logprobs[sampled_log_probs_indices_list[step_idx]] = Logprob(
|
||||
logprob=sampled_log_probs_vals_list[step_idx],
|
||||
rank=max(
|
||||
token_tensor.shape[2] + 1, sampled_log_probs_rank_list[step_idx]
|
||||
),
|
||||
)
|
||||
beam_token_log_probs.append(logprobs)
|
||||
token_log_probs.append(beam_token_log_probs)
|
||||
|
||||
@ -1732,7 +1708,12 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
request, new_tokens_list=new_tokens_list, new_tokens_tensor=new_tokens_tensor
|
||||
)
|
||||
|
||||
def _get_logprobs_from_request(self, request: LlmRequest) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_logprobs_from_request(
|
||||
self,
|
||||
request: LlmRequest,
|
||||
pin_memory: bool = True,
|
||||
preallocate_extra_steps: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Extract the logprobs from the request
|
||||
|
||||
Returns:
|
||||
@ -1743,25 +1724,27 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
assert request.py_num_logprobs == 0, (
|
||||
"Beam search only supports returning the sampled logprob per token"
|
||||
)
|
||||
logprobs_tensor = torch.empty(
|
||||
logprobs_tensor_full = torch.empty(
|
||||
(
|
||||
request.sampling_config.beam_width,
|
||||
num_generated_tokens,
|
||||
num_generated_tokens + preallocate_extra_steps,
|
||||
request.py_num_logprobs + 1,
|
||||
),
|
||||
device="cuda",
|
||||
pin_memory=pin_memory,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
logprobs_indices_tensor = torch.empty(
|
||||
logprobs_indices_tensor_full = torch.empty(
|
||||
(
|
||||
request.sampling_config.beam_width,
|
||||
num_generated_tokens,
|
||||
num_generated_tokens + preallocate_extra_steps,
|
||||
request.py_num_logprobs + 1,
|
||||
),
|
||||
device="cuda",
|
||||
pin_memory=pin_memory,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
if hasattr(request.py_result._log_probs, "log_probs"):
|
||||
logprobs_tensor = logprobs_tensor_full[:, :-preallocate_extra_steps, :]
|
||||
logprobs_indices_tensor = logprobs_indices_tensor_full[:, :-preallocate_extra_steps, :]
|
||||
if logprobs_tensor.numel() > 0:
|
||||
logprobs_list = request.py_result.log_probs
|
||||
assert logprobs_list is not None
|
||||
for beam_idx, beam_logprobs in enumerate(logprobs_list):
|
||||
@ -1770,12 +1753,14 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
assert value.rank is not None
|
||||
logprobs_tensor[beam_idx, token_idx, value.rank - 1] = value.logprob
|
||||
logprobs_indices_tensor[beam_idx, token_idx, value.rank - 1] = key
|
||||
return logprobs_tensor, logprobs_indices_tensor
|
||||
return logprobs_tensor_full, logprobs_indices_tensor_full
|
||||
|
||||
def _create_beam_history(
|
||||
def _prepare_beam_history(
|
||||
self,
|
||||
request: LlmRequest,
|
||||
) -> BeamHistory | None:
|
||||
*,
|
||||
finish_reasons: torch.Tensor,
|
||||
) -> BeamHistoryBuilder | None:
|
||||
"""Correct the stored tokens for each beam and return it as a BeamHistory object.
|
||||
|
||||
Beam Search sampling only adds new tokens to the beam.
|
||||
@ -1784,9 +1769,30 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
If logprobs are requested, the function also corrects the stored logprobs for each beam.
|
||||
The function returns a BeamHistory object that contains the corrected tokens and logprobs for each beam.
|
||||
|
||||
Note: To defer the decision whether or not to skip BeamHistory construction until update_requests(), only
|
||||
a builder (BeamHistoryBuilder) is returned here. The builder contains host tensors which are
|
||||
being populated asynchronously. Hence, it can only be invoked after async D2H copies have completed,
|
||||
e.g., after awaiting state.sampler_event in update_requests.
|
||||
|
||||
arguments:
|
||||
request: The request to create the beam history for
|
||||
finish_reasons: The first finish reason encountered for each beam of the request.
|
||||
Shape: (max_tokens, max_beam_width)
|
||||
"""
|
||||
|
||||
# Gather data used for skipping beam history processing
|
||||
need_finalize_due_to_stop_words = self._check_stop_words_length(request)
|
||||
if need_finalize_due_to_stop_words:
|
||||
need_history = torch.tensor(True)
|
||||
else:
|
||||
should_stop = self._check_beam_search_stop_criteria(
|
||||
request,
|
||||
finish_reasons=finish_reasons,
|
||||
)
|
||||
need_history = should_stop
|
||||
# enqueue async D2H copy
|
||||
need_history = self._copy_to_host(need_history)
|
||||
|
||||
num_tokens = request.max_beam_num_tokens + 1 # last token is not yet added
|
||||
prompt_length = request.py_prompt_len
|
||||
num_generated_tokens = num_tokens - prompt_length
|
||||
@ -1795,73 +1801,98 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
if num_generated_tokens == 0 or request.state == LlmRequestState.GENERATION_COMPLETE:
|
||||
# early return if no tokens have been generated yet or the request is already finished
|
||||
return None
|
||||
|
||||
assert self.store.cache_indirection is not None
|
||||
assert self.store.original_tokens is not None
|
||||
assert self.store.sampled_log_probs is not None
|
||||
cache_indirection = self.store.cache_indirection[
|
||||
request.py_seq_slot, :num_beams, prompt_length:num_tokens
|
||||
]
|
||||
current_path = self.store.original_tokens[
|
||||
request.py_seq_slot, :num_beams, prompt_length:num_tokens
|
||||
]
|
||||
new_path = torch.zeros_like(current_path)
|
||||
|
||||
# initialize each beam with its own index
|
||||
# enqueue async D2H copies
|
||||
cache_indirection = self._copy_to_host(cache_indirection)
|
||||
current_path = self._copy_to_host(current_path)
|
||||
|
||||
def _post_process_path() -> torch.Tensor:
|
||||
# Gather the correct tokens for each beam
|
||||
new_path = torch.zeros_like(current_path)
|
||||
torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path)
|
||||
return new_path
|
||||
|
||||
# Gather the correct tokens and logprobs for each beam
|
||||
torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path)
|
||||
if request.py_return_log_probs:
|
||||
assert self.store.sampled_log_probs is not None
|
||||
assert self.store.cum_log_probs is not None
|
||||
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(request)
|
||||
# concatenate the newly generated logprobs and newly
|
||||
# generated tokens to the current logprobs and logprobs indices
|
||||
current_logprobs = torch.cat(
|
||||
[
|
||||
current_logprobs,
|
||||
self.store.sampled_log_probs[request.py_seq_slot, :num_beams].view(-1, 1, 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
current_logprobs_indices = torch.cat(
|
||||
[
|
||||
current_logprobs_indices,
|
||||
self.store.new_tokens[0, request.py_seq_slot, :num_beams].view(-1, 1, 1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
# Initialize the buffers to store the results
|
||||
new_logprobs = torch.zeros_like(current_logprobs)
|
||||
new_logprobs_indices = torch.zeros_like(current_logprobs_indices)
|
||||
|
||||
cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand(
|
||||
-1, -1, current_logprobs.shape[2]
|
||||
)
|
||||
torch.gather(
|
||||
input=current_logprobs,
|
||||
dim=0,
|
||||
index=cache_indirection_for_logprobs,
|
||||
out=new_logprobs,
|
||||
)
|
||||
torch.gather(
|
||||
input=current_logprobs_indices,
|
||||
dim=0,
|
||||
index=cache_indirection_for_logprobs,
|
||||
out=new_logprobs_indices,
|
||||
sampled_log_probs = self.store.sampled_log_probs[request.py_seq_slot, :num_beams].view(
|
||||
-1, 1
|
||||
)
|
||||
sampled_logprobs_indices = self.store.new_tokens[
|
||||
0, request.py_seq_slot, :num_beams
|
||||
].view(-1, 1)
|
||||
cum_logprobs = self.store.cum_log_probs[request.py_seq_slot, :num_beams]
|
||||
|
||||
# enqueue async D2H copies
|
||||
sampled_log_probs = self._copy_to_host(sampled_log_probs)
|
||||
sampled_logprobs_indices = self._copy_to_host(sampled_logprobs_indices)
|
||||
cum_logprobs = self._copy_to_host(cum_logprobs)
|
||||
|
||||
def _maybe_postprocess_logprobs() -> tuple[
|
||||
torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
|
||||
]:
|
||||
# Gather the correct logprobs for each beam
|
||||
|
||||
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(
|
||||
request, preallocate_extra_steps=1
|
||||
)
|
||||
# concatenate the newly generated logprobs and newly
|
||||
# generated tokens to the current logprobs and logprobs indices
|
||||
current_logprobs[:, -1, :].copy_(sampled_log_probs)
|
||||
current_logprobs_indices[:, -1, :].copy_(sampled_logprobs_indices)
|
||||
|
||||
# Initialize the buffers to store the results
|
||||
new_logprobs = torch.zeros_like(current_logprobs)
|
||||
new_logprobs_indices = torch.zeros_like(current_logprobs_indices)
|
||||
|
||||
cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand(
|
||||
-1, -1, current_logprobs.shape[2]
|
||||
)
|
||||
torch.gather(
|
||||
input=current_logprobs,
|
||||
dim=0,
|
||||
index=cache_indirection_for_logprobs,
|
||||
out=new_logprobs,
|
||||
)
|
||||
torch.gather(
|
||||
input=current_logprobs_indices,
|
||||
dim=0,
|
||||
index=cache_indirection_for_logprobs,
|
||||
out=new_logprobs_indices,
|
||||
)
|
||||
return new_logprobs, new_logprobs_indices, cum_logprobs
|
||||
|
||||
else:
|
||||
|
||||
def _maybe_postprocess_logprobs() -> tuple[
|
||||
torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
|
||||
]:
|
||||
return None, None, None
|
||||
|
||||
def _builder() -> BeamHistory | None:
|
||||
if not need_history.item():
|
||||
return None
|
||||
|
||||
new_path = _post_process_path()
|
||||
new_logprobs, new_logprobs_indices, cum_logprobs = _maybe_postprocess_logprobs()
|
||||
|
||||
return BeamHistory(
|
||||
tokens=new_path,
|
||||
logprobs=new_logprobs,
|
||||
logprobs_indices=new_logprobs_indices,
|
||||
cum_logprobs=cum_logprobs,
|
||||
)
|
||||
else:
|
||||
return BeamHistory(
|
||||
tokens=new_path,
|
||||
logprobs=None,
|
||||
logprobs_indices=None,
|
||||
cum_logprobs=None,
|
||||
)
|
||||
|
||||
return _builder
|
||||
|
||||
def _finalize_beam(
|
||||
self,
|
||||
@ -1897,23 +1928,19 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
f"Beam_history.cum_logprobs.shape[0] should equal beam width: \
|
||||
{beam_history.cum_logprobs.shape[0]} != {beam_width}"
|
||||
)
|
||||
valid_tokens = (beam_history.tokens != BEAM_SEARCH_PAD_TOKEN).sum(dim=-1)
|
||||
valid_tokens = (beam_history.tokens != BEAM_SEARCH_PAD_TOKEN).sum(dim=-1).tolist()
|
||||
gen_token_list = []
|
||||
gen_log_probs_list = []
|
||||
for beam_idx in range(beam_width):
|
||||
gen_token_list.append(beam_history.tokens[beam_idx, : valid_tokens[beam_idx]].tolist())
|
||||
beam_valid_tokens = valid_tokens[beam_idx]
|
||||
gen_token_list.append(beam_history.tokens[beam_idx, :beam_valid_tokens].tolist())
|
||||
if request.py_return_log_probs:
|
||||
assert beam_history.logprobs_indices is not None
|
||||
assert beam_history.logprobs is not None
|
||||
gen_log_probs_list.append(
|
||||
self._convert_logprobs_tensor_to_list(
|
||||
beam_history.logprobs_indices[
|
||||
beam_idx : beam_idx + 1, : valid_tokens[beam_idx]
|
||||
],
|
||||
beam_history.logprobs[beam_idx : beam_idx + 1, : valid_tokens[beam_idx]],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
beam_history.logprobs_indices[beam_idx : beam_idx + 1, :beam_valid_tokens],
|
||||
beam_history.logprobs[beam_idx : beam_idx + 1, :beam_valid_tokens],
|
||||
)[0]
|
||||
)
|
||||
request.set_generated_tokens(gen_token_list)
|
||||
@ -1989,11 +2016,14 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
self,
|
||||
request: LlmRequest,
|
||||
finish_reasons: torch.Tensor,
|
||||
) -> bool:
|
||||
"""Check if the stop criteria is met for the request"""
|
||||
) -> torch.Tensor:
|
||||
"""Check if the stop criteria is met for the request.
|
||||
|
||||
Returns a boolean tensor of shape (), whose value is computed asynchronously.
|
||||
"""
|
||||
return (
|
||||
finish_reasons[: request.sampling_config.beam_width] > 0
|
||||
).sum().item() == request.sampling_config.beam_width # NB: This syncs
|
||||
).sum() == request.sampling_config.beam_width
|
||||
|
||||
def _check_stop_words_length(self, request: LlmRequest) -> bool:
|
||||
"""Check if the stop words length is greater than 1"""
|
||||
@ -2006,23 +2036,20 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
return False
|
||||
|
||||
@nvtx_range("maybe_create_beam_histories")
|
||||
def _maybe_create_beam_histories(
|
||||
def _prepare_beam_histories(
|
||||
self,
|
||||
requests: list[LlmRequest],
|
||||
finish_reasons: torch.Tensor,
|
||||
beam_histories: list[BeamHistory | None],
|
||||
) -> None:
|
||||
"""Create the corrected tokens and logprobs for each beam of a request
|
||||
) -> list[BeamHistoryBuilder | None]:
|
||||
"""Create the corrected tokens and logprobs for each beam of a request.
|
||||
|
||||
This function creates a beam history object containing the corrected
|
||||
tokens and logprobs for each beam of a request"""
|
||||
for req_idx, req in enumerate(requests):
|
||||
should_stop = self._check_beam_search_stop_criteria(
|
||||
req, finish_reasons=finish_reasons[req.py_seq_slot]
|
||||
)
|
||||
need_finalize_due_to_stop_words = self._check_stop_words_length(req)
|
||||
if should_stop or req.streaming or need_finalize_due_to_stop_words:
|
||||
beam_histories[req_idx] = self._create_beam_history(req)
|
||||
The builders returned by this function create a beam history object containing
|
||||
the corrected tokens and logprobs for each beam of a request.
|
||||
"""
|
||||
return [
|
||||
self._prepare_beam_history(req, finish_reasons=finish_reasons[req.py_seq_slot])
|
||||
for req in requests
|
||||
]
|
||||
|
||||
@override
|
||||
@nvtx_range("update_requests")
|
||||
@ -2040,22 +2067,31 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
finish_reasons = state.host.finish_reasons_list()
|
||||
|
||||
new_tokens_list = new_tokens.tolist()
|
||||
beam_histories = state.beam_histories
|
||||
|
||||
logprobs_state_list: LogProbsStateList | None = None
|
||||
if state.host.logprobs_state is not None:
|
||||
logprobs_state_list = LogProbsStateList.from_logprobs_state(state.host.logprobs_state)
|
||||
|
||||
beam_history_builders = state.beam_history_builders
|
||||
assert (beam_history_builders is not None) == self._use_beam_search
|
||||
|
||||
def _maybe_build_beam_history(req_idx: int) -> BeamHistory | None:
|
||||
if (
|
||||
beam_history_builders is not None
|
||||
and (beam_history_builder := beam_history_builders[req_idx]) is not None
|
||||
):
|
||||
return beam_history_builder()
|
||||
else:
|
||||
return None
|
||||
|
||||
for req_idx, req in enumerate(state.scheduled_requests.context_requests):
|
||||
if (
|
||||
req.state == LlmRequestState.GENERATION_COMPLETE
|
||||
or req.context_remaining_length != 0
|
||||
):
|
||||
continue
|
||||
if beam_histories is not None and beam_histories[req_idx] is not None:
|
||||
self._finalize_beam(
|
||||
req,
|
||||
cast(BeamHistory, beam_histories[req_idx]),
|
||||
)
|
||||
if (beam_history := _maybe_build_beam_history(req_idx)) is not None:
|
||||
self._finalize_beam(req, beam_history)
|
||||
else:
|
||||
for beam_idx in range(req.sampling_config.beam_width):
|
||||
add_token(req, new_tokens_list, beam_idx=beam_idx)
|
||||
@ -2069,12 +2105,10 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
):
|
||||
if req.state == LlmRequestState.GENERATION_COMPLETE:
|
||||
continue
|
||||
|
||||
if req.sampling_config.beam_width > 1:
|
||||
if beam_histories is not None and beam_histories[req_idx] is not None:
|
||||
self._finalize_beam(
|
||||
req,
|
||||
cast(BeamHistory, beam_histories[req_idx]),
|
||||
)
|
||||
if (beam_history := _maybe_build_beam_history(req_idx)) is not None:
|
||||
self._finalize_beam(req, beam_history)
|
||||
else:
|
||||
for beam_idx in range(req.sampling_config.beam_width):
|
||||
# Beam search does not support speculative decoding.
|
||||
@ -2083,7 +2117,6 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
self._handle_finish_reasons(req, state.host.finish_reasons, finish_reasons)
|
||||
req.py_num_accepted_draft_tokens = 0
|
||||
req.py_rewind_len = 0
|
||||
|
||||
else:
|
||||
processed = 1
|
||||
num_accepted = self.process_draft_tokens(
|
||||
@ -2175,7 +2208,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
)
|
||||
finish_reasons_host = self._copy_to_host(finish_reasons)
|
||||
|
||||
beam_histories: list[BeamHistory | None] = [None] * len(requests)
|
||||
beam_history_builders = None
|
||||
if self._use_beam_search:
|
||||
assert first_finish_reasons is not None
|
||||
assert seq_lens_host is not None, "seq_lens is required for beam search"
|
||||
@ -2185,8 +2218,8 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
seq_lens = seq_lens_host.to(device="cuda", non_blocking=True)
|
||||
first_finish_reasons_host = self._copy_to_host(self.store.first_finish_reasons)
|
||||
self._update_original_tokens(seq_slots, seq_lens, new_tokens)
|
||||
self._maybe_create_beam_histories(
|
||||
requests, finish_reasons=first_finish_reasons, beam_histories=beam_histories
|
||||
beam_history_builders = self._prepare_beam_histories(
|
||||
requests, finish_reasons=first_finish_reasons
|
||||
)
|
||||
else:
|
||||
first_finish_reasons_host = None
|
||||
@ -2231,7 +2264,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
logprobs_state=logprobs_state,
|
||||
),
|
||||
sampler_event=sampler_event,
|
||||
beam_histories=beam_histories,
|
||||
beam_history_builders=beam_history_builders,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -2885,12 +2918,11 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
|
||||
if first_finish_reasons is not None:
|
||||
# store the first stop reason for each beam of a seq_slot.
|
||||
batched_first_finish_reasons = first_finish_reasons[seq_slots]
|
||||
batched_first_finish_reasons = torch.where(
|
||||
first_finish_reasons[seq_slots, ...] = torch.where(
|
||||
batched_first_finish_reasons == FinishReason.NOT_FINISHED.value,
|
||||
batched_finish_reasons,
|
||||
batched_first_finish_reasons,
|
||||
)
|
||||
first_finish_reasons[seq_slots] = batched_first_finish_reasons
|
||||
|
||||
def _are_end_id(self, end_ids: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
||||
return tokens == end_ids.view(1, -1, 1).expand(self.max_tokens, -1, self.max_beam_width)
|
||||
|
||||
@ -635,9 +635,6 @@ def test_create_beam_history():
|
||||
token_logprobs = sampler._convert_logprobs_tensor_to_list(
|
||||
original_logprob_indices[:beam_width, :num_generated_tokens - 1],
|
||||
original_logprobs[:beam_width, :num_generated_tokens - 1],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
request.py_result.set_log_probs(
|
||||
token_logprobs,
|
||||
@ -670,18 +667,18 @@ def test_create_beam_history():
|
||||
num_generated_tokens -
|
||||
1, 0]
|
||||
# test
|
||||
beam_history = sampler._create_beam_history(request)
|
||||
beam_history_builder = sampler._prepare_beam_history(
|
||||
request, finish_reasons=torch.ones((beam_width, ), dtype=torch.int))
|
||||
torch.cuda.synchronize()
|
||||
beam_history = beam_history_builder()
|
||||
|
||||
# expected selection:
|
||||
# Currently beam history only contains the generated tokens, not the prompt tokens.
|
||||
expected_tokens = torch.zeros(
|
||||
(sampler.max_beam_width, num_generated_tokens),
|
||||
dtype=torch.int32,
|
||||
device=original_tokens.device)
|
||||
(sampler.max_beam_width, num_generated_tokens), dtype=torch.int32)
|
||||
expected_logprobs = torch.zeros(
|
||||
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
|
||||
dtype=torch.float32,
|
||||
device=original_logprobs.device)
|
||||
dtype=torch.float32)
|
||||
for gen_idx in range(num_generated_tokens):
|
||||
token_idx = prompt_len + gen_idx
|
||||
expected_tokens[:, gen_idx] = original_tokens[
|
||||
@ -694,10 +691,9 @@ def test_create_beam_history():
|
||||
# test logprobs as well
|
||||
torch.testing.assert_close(beam_history.logprobs[:beam_width],
|
||||
expected_logprobs[:beam_width])
|
||||
torch.testing.assert_close(beam_history.cum_logprobs[:beam_width],
|
||||
original_cum_logprobs[seq_slot, :beam_width])
|
||||
|
||||
return
|
||||
torch.testing.assert_close(
|
||||
beam_history.cum_logprobs[:beam_width],
|
||||
original_cum_logprobs[seq_slot, :beam_width].to("cpu"))
|
||||
|
||||
|
||||
def test_finish_beams():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user