mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8376][feat] top-p optimization (removes redundant softmax) (#9411)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
8da59103d6
commit
c5f52ab304
@ -171,27 +171,47 @@ def top_k_top_p_sampling_batch(
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
|
||||
# compute cumulative probability distribution of each sample
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
probs_sorted = torch.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = torch.cumsum(probs_sorted, dim=-1)
|
||||
|
||||
# get the location of top_p
|
||||
# NB: Currently selecting the smallest index with cumulative_probs > top_p.
|
||||
# NB: Currently selecting the smallest index with cumulative_probs >= top_p.
|
||||
# Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True).
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
||||
sorted_indices_to_remove[:, 0] = 0
|
||||
|
||||
# set the logits to -inf for token indices outside top_p
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
mask_to_remove = cumulative_probs >= top_p # at least one 'True' per row
|
||||
last_index_to_keep = torch.searchsorted(
|
||||
mask_to_remove.to(torch.int8, non_blocking=True),
|
||||
torch.ones((1,), dtype=torch.int8, device=mask_to_remove.device).expand(
|
||||
(mask_to_remove.size(0), 1)
|
||||
),
|
||||
right=False,
|
||||
out_int32=True,
|
||||
)
|
||||
mask_to_remove.scatter_(
|
||||
1,
|
||||
last_index_to_keep,
|
||||
torch.zeros((1,), dtype=torch.bool, device=mask_to_remove.device).expand_as(
|
||||
last_index_to_keep
|
||||
),
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, float("-inf"))
|
||||
|
||||
# compute probability distribution
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
# mask not selected probs
|
||||
probs_sorted.masked_fill_(mask_to_remove, 0.0)
|
||||
probs = torch.empty_like(probs_sorted)
|
||||
probs.scatter_(1, sorted_indices, probs_sorted)
|
||||
probs /= cumulative_probs[ # renormalize probs
|
||||
torch.arange(
|
||||
cumulative_probs.size(0), dtype=torch.int32, device=cumulative_probs.device
|
||||
), # needed for advanced indexing
|
||||
last_index_to_keep.squeeze(-1),
|
||||
].unsqueeze(-1)
|
||||
del logits # do not use, inconsistent with probs
|
||||
else:
|
||||
# compute probability distribution
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
# sample from the distribution and generate result of [batch_size, 1]
|
||||
next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1)
|
||||
return next_tokens, softmax
|
||||
next_tokens = torch.multinomial(probs, num_samples=1, generator=generator).squeeze(-1)
|
||||
return next_tokens, probs
|
||||
|
||||
|
||||
def greedy_search_sampling_batch(
|
||||
|
||||
@ -503,6 +503,11 @@ def device_sleep(
|
||||
def assert_no_cuda_sync(
|
||||
sync_timeout_s: float = 5, ) -> Generator[None, None, None]:
|
||||
"""Check that the function does not stream synchronize."""
|
||||
if int(os.environ.get("CUDA_LAUNCH_BLOCKING", 0)):
|
||||
print("CUDA_LAUNCH_BLOCKING set, skipping 'assert_no_cuda_sync'")
|
||||
yield None
|
||||
return
|
||||
|
||||
# NB: This implementation only assumes that the CUDA operations performed
|
||||
# in the guarded scope use the currently selected CUDA stream. This
|
||||
# should also cover custom Torch ops as well as non-Torch kernels.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user