[TRTLLM-8376][feat] top-p optimization (removes redundant softmax) (#9411)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2025-11-25 18:46:48 +01:00 committed by GitHub
parent 8da59103d6
commit c5f52ab304
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 14 deletions

View File

@ -171,27 +171,47 @@ def top_k_top_p_sampling_batch(
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
# compute cumulative probability distribution of each sample
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
probs_sorted = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(probs_sorted, dim=-1)
# get the location of top_p
# NB: Currently selecting the smallest index with cumulative_probs > top_p.
# NB: Currently selecting the smallest index with cumulative_probs >= top_p.
# Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True).
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
# set the logits to -inf for token indices outside top_p
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
mask_to_remove = cumulative_probs >= top_p # at least one 'True' per row
last_index_to_keep = torch.searchsorted(
mask_to_remove.to(torch.int8, non_blocking=True),
torch.ones((1,), dtype=torch.int8, device=mask_to_remove.device).expand(
(mask_to_remove.size(0), 1)
),
right=False,
out_int32=True,
)
mask_to_remove.scatter_(
1,
last_index_to_keep,
torch.zeros((1,), dtype=torch.bool, device=mask_to_remove.device).expand_as(
last_index_to_keep
),
)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
# compute probability distribution
softmax = torch.softmax(logits, dim=-1)
# mask not selected probs
probs_sorted.masked_fill_(mask_to_remove, 0.0)
probs = torch.empty_like(probs_sorted)
probs.scatter_(1, sorted_indices, probs_sorted)
probs /= cumulative_probs[ # renormalize probs
torch.arange(
cumulative_probs.size(0), dtype=torch.int32, device=cumulative_probs.device
), # needed for advanced indexing
last_index_to_keep.squeeze(-1),
].unsqueeze(-1)
del logits # do not use, inconsistent with probs
else:
# compute probability distribution
probs = torch.softmax(logits, dim=-1)
# sample from the distribution and generate result of [batch_size, 1]
next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1)
return next_tokens, softmax
next_tokens = torch.multinomial(probs, num_samples=1, generator=generator).squeeze(-1)
return next_tokens, probs
def greedy_search_sampling_batch(

View File

@ -503,6 +503,11 @@ def device_sleep(
def assert_no_cuda_sync(
sync_timeout_s: float = 5, ) -> Generator[None, None, None]:
"""Check that the function does not stream synchronize."""
if int(os.environ.get("CUDA_LAUNCH_BLOCKING", 0)):
print("CUDA_LAUNCH_BLOCKING set, skipping 'assert_no_cuda_sync'")
yield None
return
# NB: This implementation only assumes that the CUDA operations performed
# in the guarded scope use the currently selected CUDA stream. This
# should also cover custom Torch ops as well as non-Torch kernels.