[TRTLLM-10030][perf] avoid syncs in beam search + other improvements (#11349)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-02-09 16:13:58 +01:00 committed by GitHub
parent 2b60cc181c
commit 196d94a419
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 175 additions and 147 deletions

View File

@ -766,6 +766,13 @@ class BeamHistory:
cum_logprobs: torch.Tensor | None = None
BeamHistoryBuilder: TypeAlias = Callable[[], BeamHistory | None]
"""Builder for BeamHistory.
Used to defer possibly unnecessary host-tensor construction until update_requests().
"""
@dataclass(kw_only=True)
class SamplingRequestsMetadata:
req_num_generated_tokens: torch.Tensor
@ -789,7 +796,7 @@ class SampleStateTensorsHostTorch(SampleStateTensors):
@dataclass(kw_only=True)
class SampleStateTorch(SampleState[SampleStateTensorsHostTorch, SampleStateTensors]):
beam_histories: list[BeamHistory | None] | None = None
beam_history_builders: list[BeamHistoryBuilder | None] | None = None
class AsyncWorkerMixin:
@ -1249,9 +1256,6 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
self,
token_tensor: torch.Tensor,
logprobs_tensor: torch.Tensor,
sampled_log_probs_indices: torch.Tensor | None,
sampled_log_probs_vals: torch.Tensor | None,
sampled_log_probs_rank: torch.Tensor | None,
) -> list[list[dict[int, Logprob]]]:
"""Convert the logprobs tensor to a list of lists of dictionaries of Logprob objects
@ -1260,9 +1264,6 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
args:
token_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs
logprobs_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs
sampled_log_probs_indices: torch.Tensor | None. Shape: num_tokens
sampled_log_probs_vals: torch.Tensor | None. Shape: num_tokens
sampled_log_probs_rank: torch.Tensor | None. Shape: num_tokens
output:
list[list[dict[int, Logprob]]]. Shape: (beam_width, num_tokens)
"""
@ -1274,38 +1275,13 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
token_log_probs: list[list[dict[int, Logprob]]] = []
token_list = token_tensor.tolist()
logprobs_list = logprobs_tensor.tolist()
sampled_log_probs_indices_list: list[int] | None = None
sampled_log_probs_vals_list: list[float] | None = None
sampled_log_probs_rank_list: list[int] | None = None
if sampled_log_probs_indices is not None:
sampled_log_probs_indices_list = sampled_log_probs_indices.tolist()
assert sampled_log_probs_vals is not None, "sampled_log_probs_vals must be provided"
assert sampled_log_probs_rank is not None, "sampled_log_probs_rank must be provided"
sampled_log_probs_vals_list = sampled_log_probs_vals.tolist()
sampled_log_probs_rank_list = sampled_log_probs_rank.tolist()
for beam_idx in range(token_tensor.shape[0]):
beam_token_log_probs: list[dict[int, Logprob]] = []
for step_idx, (topk_token, topk_logprob) in enumerate(
zip(token_list[beam_idx], logprobs_list[beam_idx])
):
for topk_token, topk_logprob in zip(token_list[beam_idx], logprobs_list[beam_idx]):
logprobs = {
token: Logprob(logprob=logprob, rank=rank + 1)
for rank, (token, logprob) in enumerate(zip(topk_token, topk_logprob))
}
if sampled_log_probs_indices is not None:
assert beam_idx == DEFAULT_BEAM_IDX, (
"beam search does not need to explicitly handle sampled log probs"
)
assert sampled_log_probs_indices_list is not None
assert sampled_log_probs_vals_list is not None
assert sampled_log_probs_rank_list is not None
if sampled_log_probs_indices_list[step_idx] not in logprobs:
logprobs[sampled_log_probs_indices_list[step_idx]] = Logprob(
logprob=sampled_log_probs_vals_list[step_idx],
rank=max(
token_tensor.shape[2] + 1, sampled_log_probs_rank_list[step_idx]
),
)
beam_token_log_probs.append(logprobs)
token_log_probs.append(beam_token_log_probs)
@ -1732,7 +1708,12 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
request, new_tokens_list=new_tokens_list, new_tokens_tensor=new_tokens_tensor
)
def _get_logprobs_from_request(self, request: LlmRequest) -> tuple[torch.Tensor, torch.Tensor]:
def _get_logprobs_from_request(
self,
request: LlmRequest,
pin_memory: bool = True,
preallocate_extra_steps: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Extract the logprobs from the request
Returns:
@ -1743,25 +1724,27 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
assert request.py_num_logprobs == 0, (
"Beam search only supports returning the sampled logprob per token"
)
logprobs_tensor = torch.empty(
logprobs_tensor_full = torch.empty(
(
request.sampling_config.beam_width,
num_generated_tokens,
num_generated_tokens + preallocate_extra_steps,
request.py_num_logprobs + 1,
),
device="cuda",
pin_memory=pin_memory,
dtype=torch.float32,
)
logprobs_indices_tensor = torch.empty(
logprobs_indices_tensor_full = torch.empty(
(
request.sampling_config.beam_width,
num_generated_tokens,
num_generated_tokens + preallocate_extra_steps,
request.py_num_logprobs + 1,
),
device="cuda",
pin_memory=pin_memory,
dtype=torch.int32,
)
if hasattr(request.py_result._log_probs, "log_probs"):
logprobs_tensor = logprobs_tensor_full[:, :-preallocate_extra_steps, :]
logprobs_indices_tensor = logprobs_indices_tensor_full[:, :-preallocate_extra_steps, :]
if logprobs_tensor.numel() > 0:
logprobs_list = request.py_result.log_probs
assert logprobs_list is not None
for beam_idx, beam_logprobs in enumerate(logprobs_list):
@ -1770,12 +1753,14 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
assert value.rank is not None
logprobs_tensor[beam_idx, token_idx, value.rank - 1] = value.logprob
logprobs_indices_tensor[beam_idx, token_idx, value.rank - 1] = key
return logprobs_tensor, logprobs_indices_tensor
return logprobs_tensor_full, logprobs_indices_tensor_full
def _create_beam_history(
def _prepare_beam_history(
self,
request: LlmRequest,
) -> BeamHistory | None:
*,
finish_reasons: torch.Tensor,
) -> BeamHistoryBuilder | None:
"""Correct the stored tokens for each beam and return it as a BeamHistory object.
Beam Search sampling only adds new tokens to the beam.
@ -1784,9 +1769,30 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
If logprobs are requested, the function also corrects the stored logprobs for each beam.
The function returns a BeamHistory object that contains the corrected tokens and logprobs for each beam.
Note: To defer the decision whether or not to skip BeamHistory construction until update_requests(), only
a builder (BeamHistoryBuilder) is returned here. The builder contains host tensors which are
being populated asynchronously. Hence, it can only be invoked after async D2H copies have completed,
e.g., after awaiting state.sampler_event in update_requests.
arguments:
request: The request to create the beam history for
finish_reasons: The first finish reason encountered for each beam of the request.
Shape: (max_tokens, max_beam_width)
"""
# Gather data used for skipping beam history processing
need_finalize_due_to_stop_words = self._check_stop_words_length(request)
if need_finalize_due_to_stop_words:
need_history = torch.tensor(True)
else:
should_stop = self._check_beam_search_stop_criteria(
request,
finish_reasons=finish_reasons,
)
need_history = should_stop
# enqueue async D2H copy
need_history = self._copy_to_host(need_history)
num_tokens = request.max_beam_num_tokens + 1 # last token is not yet added
prompt_length = request.py_prompt_len
num_generated_tokens = num_tokens - prompt_length
@ -1795,73 +1801,98 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
if num_generated_tokens == 0 or request.state == LlmRequestState.GENERATION_COMPLETE:
# early return if no tokens have been generated yet or the request is already finished
return None
assert self.store.cache_indirection is not None
assert self.store.original_tokens is not None
assert self.store.sampled_log_probs is not None
cache_indirection = self.store.cache_indirection[
request.py_seq_slot, :num_beams, prompt_length:num_tokens
]
current_path = self.store.original_tokens[
request.py_seq_slot, :num_beams, prompt_length:num_tokens
]
new_path = torch.zeros_like(current_path)
# initialize each beam with its own index
# enqueue async D2H copies
cache_indirection = self._copy_to_host(cache_indirection)
current_path = self._copy_to_host(current_path)
def _post_process_path() -> torch.Tensor:
# Gather the correct tokens for each beam
new_path = torch.zeros_like(current_path)
torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path)
return new_path
# Gather the correct tokens and logprobs for each beam
torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path)
if request.py_return_log_probs:
assert self.store.sampled_log_probs is not None
assert self.store.cum_log_probs is not None
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(request)
# concatenate the newly generated logprobs and newly
# generated tokens to the current logprobs and logprobs indices
current_logprobs = torch.cat(
[
current_logprobs,
self.store.sampled_log_probs[request.py_seq_slot, :num_beams].view(-1, 1, 1),
],
dim=1,
)
current_logprobs_indices = torch.cat(
[
current_logprobs_indices,
self.store.new_tokens[0, request.py_seq_slot, :num_beams].view(-1, 1, 1),
],
dim=1,
)
# Initialize the buffers to store the results
new_logprobs = torch.zeros_like(current_logprobs)
new_logprobs_indices = torch.zeros_like(current_logprobs_indices)
cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand(
-1, -1, current_logprobs.shape[2]
)
torch.gather(
input=current_logprobs,
dim=0,
index=cache_indirection_for_logprobs,
out=new_logprobs,
)
torch.gather(
input=current_logprobs_indices,
dim=0,
index=cache_indirection_for_logprobs,
out=new_logprobs_indices,
sampled_log_probs = self.store.sampled_log_probs[request.py_seq_slot, :num_beams].view(
-1, 1
)
sampled_logprobs_indices = self.store.new_tokens[
0, request.py_seq_slot, :num_beams
].view(-1, 1)
cum_logprobs = self.store.cum_log_probs[request.py_seq_slot, :num_beams]
# enqueue async D2H copies
sampled_log_probs = self._copy_to_host(sampled_log_probs)
sampled_logprobs_indices = self._copy_to_host(sampled_logprobs_indices)
cum_logprobs = self._copy_to_host(cum_logprobs)
def _maybe_postprocess_logprobs() -> tuple[
torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
]:
# Gather the correct logprobs for each beam
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(
request, preallocate_extra_steps=1
)
# concatenate the newly generated logprobs and newly
# generated tokens to the current logprobs and logprobs indices
current_logprobs[:, -1, :].copy_(sampled_log_probs)
current_logprobs_indices[:, -1, :].copy_(sampled_logprobs_indices)
# Initialize the buffers to store the results
new_logprobs = torch.zeros_like(current_logprobs)
new_logprobs_indices = torch.zeros_like(current_logprobs_indices)
cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand(
-1, -1, current_logprobs.shape[2]
)
torch.gather(
input=current_logprobs,
dim=0,
index=cache_indirection_for_logprobs,
out=new_logprobs,
)
torch.gather(
input=current_logprobs_indices,
dim=0,
index=cache_indirection_for_logprobs,
out=new_logprobs_indices,
)
return new_logprobs, new_logprobs_indices, cum_logprobs
else:
def _maybe_postprocess_logprobs() -> tuple[
torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
]:
return None, None, None
def _builder() -> BeamHistory | None:
if not need_history.item():
return None
new_path = _post_process_path()
new_logprobs, new_logprobs_indices, cum_logprobs = _maybe_postprocess_logprobs()
return BeamHistory(
tokens=new_path,
logprobs=new_logprobs,
logprobs_indices=new_logprobs_indices,
cum_logprobs=cum_logprobs,
)
else:
return BeamHistory(
tokens=new_path,
logprobs=None,
logprobs_indices=None,
cum_logprobs=None,
)
return _builder
def _finalize_beam(
self,
@ -1897,23 +1928,19 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
f"Beam_history.cum_logprobs.shape[0] should equal beam width: \
{beam_history.cum_logprobs.shape[0]} != {beam_width}"
)
valid_tokens = (beam_history.tokens != BEAM_SEARCH_PAD_TOKEN).sum(dim=-1)
valid_tokens = (beam_history.tokens != BEAM_SEARCH_PAD_TOKEN).sum(dim=-1).tolist()
gen_token_list = []
gen_log_probs_list = []
for beam_idx in range(beam_width):
gen_token_list.append(beam_history.tokens[beam_idx, : valid_tokens[beam_idx]].tolist())
beam_valid_tokens = valid_tokens[beam_idx]
gen_token_list.append(beam_history.tokens[beam_idx, :beam_valid_tokens].tolist())
if request.py_return_log_probs:
assert beam_history.logprobs_indices is not None
assert beam_history.logprobs is not None
gen_log_probs_list.append(
self._convert_logprobs_tensor_to_list(
beam_history.logprobs_indices[
beam_idx : beam_idx + 1, : valid_tokens[beam_idx]
],
beam_history.logprobs[beam_idx : beam_idx + 1, : valid_tokens[beam_idx]],
None,
None,
None,
beam_history.logprobs_indices[beam_idx : beam_idx + 1, :beam_valid_tokens],
beam_history.logprobs[beam_idx : beam_idx + 1, :beam_valid_tokens],
)[0]
)
request.set_generated_tokens(gen_token_list)
@ -1989,11 +2016,14 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
self,
request: LlmRequest,
finish_reasons: torch.Tensor,
) -> bool:
"""Check if the stop criteria is met for the request"""
) -> torch.Tensor:
"""Check if the stop criteria is met for the request.
Returns a boolean tensor of shape (), whose value is computed asynchronously.
"""
return (
finish_reasons[: request.sampling_config.beam_width] > 0
).sum().item() == request.sampling_config.beam_width # NB: This syncs
).sum() == request.sampling_config.beam_width
def _check_stop_words_length(self, request: LlmRequest) -> bool:
"""Check if the stop words length is greater than 1"""
@ -2006,23 +2036,20 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
return False
@nvtx_range("maybe_create_beam_histories")
def _maybe_create_beam_histories(
def _prepare_beam_histories(
self,
requests: list[LlmRequest],
finish_reasons: torch.Tensor,
beam_histories: list[BeamHistory | None],
) -> None:
"""Create the corrected tokens and logprobs for each beam of a request
) -> list[BeamHistoryBuilder | None]:
"""Create the corrected tokens and logprobs for each beam of a request.
This function creates a beam history object containing the corrected
tokens and logprobs for each beam of a request"""
for req_idx, req in enumerate(requests):
should_stop = self._check_beam_search_stop_criteria(
req, finish_reasons=finish_reasons[req.py_seq_slot]
)
need_finalize_due_to_stop_words = self._check_stop_words_length(req)
if should_stop or req.streaming or need_finalize_due_to_stop_words:
beam_histories[req_idx] = self._create_beam_history(req)
The builders returned by this function create a beam history object containing
the corrected tokens and logprobs for each beam of a request.
"""
return [
self._prepare_beam_history(req, finish_reasons=finish_reasons[req.py_seq_slot])
for req in requests
]
@override
@nvtx_range("update_requests")
@ -2040,22 +2067,31 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
finish_reasons = state.host.finish_reasons_list()
new_tokens_list = new_tokens.tolist()
beam_histories = state.beam_histories
logprobs_state_list: LogProbsStateList | None = None
if state.host.logprobs_state is not None:
logprobs_state_list = LogProbsStateList.from_logprobs_state(state.host.logprobs_state)
beam_history_builders = state.beam_history_builders
assert (beam_history_builders is not None) == self._use_beam_search
def _maybe_build_beam_history(req_idx: int) -> BeamHistory | None:
if (
beam_history_builders is not None
and (beam_history_builder := beam_history_builders[req_idx]) is not None
):
return beam_history_builder()
else:
return None
for req_idx, req in enumerate(state.scheduled_requests.context_requests):
if (
req.state == LlmRequestState.GENERATION_COMPLETE
or req.context_remaining_length != 0
):
continue
if beam_histories is not None and beam_histories[req_idx] is not None:
self._finalize_beam(
req,
cast(BeamHistory, beam_histories[req_idx]),
)
if (beam_history := _maybe_build_beam_history(req_idx)) is not None:
self._finalize_beam(req, beam_history)
else:
for beam_idx in range(req.sampling_config.beam_width):
add_token(req, new_tokens_list, beam_idx=beam_idx)
@ -2069,12 +2105,10 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
):
if req.state == LlmRequestState.GENERATION_COMPLETE:
continue
if req.sampling_config.beam_width > 1:
if beam_histories is not None and beam_histories[req_idx] is not None:
self._finalize_beam(
req,
cast(BeamHistory, beam_histories[req_idx]),
)
if (beam_history := _maybe_build_beam_history(req_idx)) is not None:
self._finalize_beam(req, beam_history)
else:
for beam_idx in range(req.sampling_config.beam_width):
# Beam search does not support speculative decoding.
@ -2083,7 +2117,6 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
self._handle_finish_reasons(req, state.host.finish_reasons, finish_reasons)
req.py_num_accepted_draft_tokens = 0
req.py_rewind_len = 0
else:
processed = 1
num_accepted = self.process_draft_tokens(
@ -2175,7 +2208,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
)
finish_reasons_host = self._copy_to_host(finish_reasons)
beam_histories: list[BeamHistory | None] = [None] * len(requests)
beam_history_builders = None
if self._use_beam_search:
assert first_finish_reasons is not None
assert seq_lens_host is not None, "seq_lens is required for beam search"
@ -2185,8 +2218,8 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
seq_lens = seq_lens_host.to(device="cuda", non_blocking=True)
first_finish_reasons_host = self._copy_to_host(self.store.first_finish_reasons)
self._update_original_tokens(seq_slots, seq_lens, new_tokens)
self._maybe_create_beam_histories(
requests, finish_reasons=first_finish_reasons, beam_histories=beam_histories
beam_history_builders = self._prepare_beam_histories(
requests, finish_reasons=first_finish_reasons
)
else:
first_finish_reasons_host = None
@ -2231,7 +2264,7 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
logprobs_state=logprobs_state,
),
sampler_event=sampler_event,
beam_histories=beam_histories,
beam_history_builders=beam_history_builders,
)
@staticmethod
@ -2885,12 +2918,11 @@ class TorchSampler(Sampler[SampleStateTorch], AsyncWorkerMixin):
if first_finish_reasons is not None:
# store the first stop reason for each beam of a seq_slot.
batched_first_finish_reasons = first_finish_reasons[seq_slots]
batched_first_finish_reasons = torch.where(
first_finish_reasons[seq_slots, ...] = torch.where(
batched_first_finish_reasons == FinishReason.NOT_FINISHED.value,
batched_finish_reasons,
batched_first_finish_reasons,
)
first_finish_reasons[seq_slots] = batched_first_finish_reasons
def _are_end_id(self, end_ids: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
return tokens == end_ids.view(1, -1, 1).expand(self.max_tokens, -1, self.max_beam_width)

View File

@ -635,9 +635,6 @@ def test_create_beam_history():
token_logprobs = sampler._convert_logprobs_tensor_to_list(
original_logprob_indices[:beam_width, :num_generated_tokens - 1],
original_logprobs[:beam_width, :num_generated_tokens - 1],
None,
None,
None,
)
request.py_result.set_log_probs(
token_logprobs,
@ -670,18 +667,18 @@ def test_create_beam_history():
num_generated_tokens -
1, 0]
# test
beam_history = sampler._create_beam_history(request)
beam_history_builder = sampler._prepare_beam_history(
request, finish_reasons=torch.ones((beam_width, ), dtype=torch.int))
torch.cuda.synchronize()
beam_history = beam_history_builder()
# expected selection:
# Currently beam history only contains the generated tokens, not the prompt tokens.
expected_tokens = torch.zeros(
(sampler.max_beam_width, num_generated_tokens),
dtype=torch.int32,
device=original_tokens.device)
(sampler.max_beam_width, num_generated_tokens), dtype=torch.int32)
expected_logprobs = torch.zeros(
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
dtype=torch.float32,
device=original_logprobs.device)
dtype=torch.float32)
for gen_idx in range(num_generated_tokens):
token_idx = prompt_len + gen_idx
expected_tokens[:, gen_idx] = original_tokens[
@ -694,10 +691,9 @@ def test_create_beam_history():
# test logprobs as well
torch.testing.assert_close(beam_history.logprobs[:beam_width],
expected_logprobs[:beam_width])
torch.testing.assert_close(beam_history.cum_logprobs[:beam_width],
original_cum_logprobs[seq_slot, :beam_width])
return
torch.testing.assert_close(
beam_history.cum_logprobs[:beam_width],
original_cum_logprobs[seq_slot, :beam_width].to("cpu"))
def test_finish_beams():