mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-10312][perf] Improve performance of _write_finish_reasons in TorchSampler (#10459)
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
This commit is contained in:
parent
80dd6e70c6
commit
7d31532850
@ -929,6 +929,12 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
class Store:
|
||||
new_tokens: torch.Tensor
|
||||
"""Shape: See cpp DecoderState.getAllNewTokens()"""
|
||||
max_lengths_tensor: torch.Tensor
|
||||
"""Shape: batch_size
|
||||
Usage: Stores the maximum lengths for each request"""
|
||||
end_ids: torch.Tensor
|
||||
"""Shape: batch_size
|
||||
Usage: Stores the end ids for each request"""
|
||||
finish_reasons: torch.Tensor
|
||||
"""Shape: max_tokens, batch_size, beam_width
|
||||
Usage: Stores the currently estimated finish_reasons for each request"""
|
||||
@ -974,6 +980,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
# Tensors necessary for all sampling methods
|
||||
new_tokens = int_tensor(self.NEW_TOKENS_SHAPE)
|
||||
finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE)
|
||||
max_lengths_tensor = int_tensor(self.max_num_sequences)
|
||||
end_ids = int_tensor(self.max_num_sequences)
|
||||
|
||||
# Only used for logprobs processing or beam search
|
||||
sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32)
|
||||
@ -1007,6 +1015,8 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
return self.Store(
|
||||
new_tokens=new_tokens,
|
||||
finish_reasons=finish_reasons,
|
||||
max_lengths_tensor=max_lengths_tensor,
|
||||
end_ids=end_ids,
|
||||
cache_indirection=cache_indirection,
|
||||
cache_indirection_buffer=cache_indirection_buffer,
|
||||
cum_log_probs=cum_log_probs,
|
||||
@ -1072,6 +1082,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
FinishReason.CANCELLED,
|
||||
] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable`
|
||||
}
|
||||
self._max_tokens_offset = torch.arange(
|
||||
1, self.max_tokens + 1, device="cuda", dtype=torch.int32
|
||||
).view(-1, 1, 1)
|
||||
|
||||
self._grouped_sampler_cls: Type[GroupedStrategySampler]
|
||||
if IS_FLASHINFER_AVAILABLE and not args.disable_flashinfer_sampling:
|
||||
@ -1525,14 +1538,47 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
|
||||
return num_accepted_draft_tokens - 1
|
||||
|
||||
def setup_sampler_step(self, requests: ScheduledRequests):
|
||||
def _is_new_request(self, request: LlmRequest) -> bool:
|
||||
return (
|
||||
not request.is_finished
|
||||
and not request.py_is_draft
|
||||
and (
|
||||
(request.is_context_init_state and request.is_last_context_chunk)
|
||||
or request.is_disagg_generation_transmission_complete
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
|
||||
"""Setup the sampler step for the requests
|
||||
|
||||
Args:
|
||||
requests: list[LlmRequest]. The requests to setup the sampler step for
|
||||
"""
|
||||
if self._use_beam_search:
|
||||
self._prepare_beam_search(requests.all_requests())
|
||||
self._prepare_beam_search(scheduled_requests.all_requests())
|
||||
|
||||
seq_slots: list[int] = []
|
||||
max_lens: list[int] = []
|
||||
end_ids: list[int] = []
|
||||
for request in scheduled_requests.context_requests:
|
||||
if self._is_new_request(request):
|
||||
seq_slots.append(request.py_seq_slot)
|
||||
max_lens.append(
|
||||
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
|
||||
)
|
||||
end_ids.append(request.py_end_id if request.py_end_id is not None else -1)
|
||||
if len(seq_slots) > 0:
|
||||
full_list = [seq_slots, max_lens, end_ids]
|
||||
# perform only a single copy
|
||||
full_list_tensor = torch.tensor(full_list, device="cpu", dtype=torch.int32).to(
|
||||
device="cuda", non_blocking=True
|
||||
)
|
||||
seq_slots_tensor = full_list_tensor[0]
|
||||
max_lens_tensor = full_list_tensor[1]
|
||||
end_ids_tensor = full_list_tensor[2]
|
||||
self.store.max_lengths_tensor[seq_slots_tensor] = max_lens_tensor
|
||||
self.store.end_ids[seq_slots_tensor] = end_ids_tensor
|
||||
|
||||
def _prepare_beam_search(
|
||||
self,
|
||||
@ -1544,10 +1590,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
initialize/reset the buffers for the request
|
||||
"""
|
||||
for request in requests:
|
||||
if not request.is_finished and (
|
||||
(request.is_context_init_state and request.is_last_context_chunk)
|
||||
or request.is_disagg_generation_transmission_complete
|
||||
):
|
||||
if self._is_new_request(request):
|
||||
if request.py_return_log_probs and request.py_num_logprobs > 1:
|
||||
raise ValueError("Beam search does not support multiple logprobs")
|
||||
self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0)
|
||||
@ -2083,13 +2126,9 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
dtype=torch.int64, # for index_fill_
|
||||
pin_memory=True,
|
||||
)
|
||||
# necessary for beam search
|
||||
seq_lens_host = (
|
||||
torch.tensor(
|
||||
[r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True
|
||||
)
|
||||
if self._use_beam_search
|
||||
else None
|
||||
# necessary for beam search and max_length checks
|
||||
seq_lens_host = torch.tensor(
|
||||
[r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True
|
||||
)
|
||||
new_tokens_host = self._process_requests(
|
||||
scheduled_requests,
|
||||
@ -2102,12 +2141,14 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
|
||||
finish_reasons = self.store.finish_reasons
|
||||
seq_slots = seq_slots_host.to(device="cuda", non_blocking=True)
|
||||
seq_lens = seq_lens_host.to(device="cuda", non_blocking=True)
|
||||
first_finish_reasons = self.store.first_finish_reasons if self._use_beam_search else None
|
||||
|
||||
self._write_finish_reasons(
|
||||
requests,
|
||||
finish_reasons=finish_reasons,
|
||||
seq_slots=seq_slots,
|
||||
seq_lens=seq_lens,
|
||||
new_tokens=new_tokens,
|
||||
first_finish_reasons=first_finish_reasons,
|
||||
predecessor_beams=self.store.predecessor_beams,
|
||||
@ -2760,6 +2801,7 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
*,
|
||||
finish_reasons: torch.Tensor,
|
||||
seq_slots: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
new_tokens: torch.Tensor,
|
||||
first_finish_reasons: torch.Tensor | None = None,
|
||||
predecessor_beams: torch.Tensor | None = None,
|
||||
@ -2775,7 +2817,11 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
new_tokens: a buffer containing the newly generated tokens.
|
||||
Shape: (max_tokens, max_batch_size, max_beam_width)
|
||||
"""
|
||||
tokens = new_tokens[:, seq_slots.to(device=new_tokens.device, non_blocking=True)]
|
||||
|
||||
# Seq Slots should be on the same device as new_tokens
|
||||
assert seq_slots.device == new_tokens.device
|
||||
assert seq_lens.device == new_tokens.device
|
||||
tokens = new_tokens[:, seq_slots]
|
||||
|
||||
# we need to fill with NOT_FINISHED so we can differentiate between previous requests that had the same seq slot
|
||||
finish_reasons.index_fill_(1, seq_slots, FinishReason.NOT_FINISHED.value)
|
||||
@ -2801,12 +2847,12 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
)
|
||||
|
||||
batched_finish_reasons = torch.where(
|
||||
self._are_max_length(requests),
|
||||
self._are_max_length(seq_lens, self.store.max_lengths_tensor[seq_slots]),
|
||||
self._reason_tensors[FinishReason.LENGTH],
|
||||
batched_finish_reasons,
|
||||
)
|
||||
batched_finish_reasons = torch.where(
|
||||
self._are_end_id(requests, tokens),
|
||||
self._are_end_id(self.store.end_ids[seq_slots], tokens),
|
||||
self._reason_tensors[FinishReason.END_ID],
|
||||
batched_finish_reasons,
|
||||
)
|
||||
@ -2822,57 +2868,26 @@ class TorchSampler(Sampler, AsyncWorkerMixin):
|
||||
)
|
||||
first_finish_reasons[seq_slots] = batched_first_finish_reasons
|
||||
|
||||
def _are_end_id(self, requests: list[LlmRequest], tokens: torch.Tensor) -> torch.Tensor:
|
||||
end_ids_tensor = (
|
||||
torch.tensor(
|
||||
[
|
||||
([req.py_end_id if req.py_end_id is not None else -1] * self.max_beam_width)
|
||||
for req in requests
|
||||
]
|
||||
* self.max_tokens,
|
||||
pin_memory=True,
|
||||
dtype=tokens.dtype,
|
||||
)
|
||||
.view(self.max_tokens, len(requests), self.max_beam_width)
|
||||
.to(device="cuda", non_blocking=True)
|
||||
)
|
||||
return tokens == end_ids_tensor
|
||||
def _are_end_id(self, end_ids: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
|
||||
return tokens == end_ids.view(1, -1, 1).expand(self.max_tokens, -1, self.max_beam_width)
|
||||
|
||||
def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor:
|
||||
def _are_max_length(self, seq_lens: torch.Tensor, max_seq_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""Checks which sequences are at or beyond the max length
|
||||
|
||||
Args:
|
||||
requests: the requests to check the max length of
|
||||
|
||||
seq_lens: the sequence lengths of the requests to check the max length of
|
||||
max_seq_lens: the maximum sequence lengths of the requests to check the max length of
|
||||
Returns:
|
||||
A tensor of shape (max_tokens, len(requests), max_beam_width)
|
||||
where each element is True if the sequence is at or beyond the max length, False otherwise
|
||||
"""
|
||||
lengths_tensor = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
(req.get_num_tokens(beam_idx) + num_tokens)
|
||||
for beam_idx in range(self.max_beam_width)
|
||||
]
|
||||
for req in requests
|
||||
]
|
||||
for num_tokens in range(1, self.max_tokens + 1)
|
||||
]
|
||||
lengths_tensor = (seq_lens.view(1, -1, 1) + self._max_tokens_offset).expand(
|
||||
self.max_tokens, -1, self.max_beam_width
|
||||
)
|
||||
max_lengths_tensor = torch.tensor(
|
||||
[
|
||||
(
|
||||
[min(req.py_max_new_tokens + req.orig_prompt_len, self.max_seq_len)]
|
||||
* self.max_beam_width
|
||||
)
|
||||
for req in requests
|
||||
]
|
||||
* self.max_tokens
|
||||
).view(self.max_tokens, len(requests), self.max_beam_width)
|
||||
return (
|
||||
(lengths_tensor >= max_lengths_tensor).pin_memory().to(device="cuda", non_blocking=True)
|
||||
max_lengths_tensor = max_seq_lens.view(1, -1, 1).expand(
|
||||
self.max_tokens, -1, self.max_beam_width
|
||||
)
|
||||
return lengths_tensor >= max_lengths_tensor
|
||||
|
||||
_PAD_ID = -1
|
||||
"""Pad with negative, doesn't matter what"""
|
||||
|
||||
@ -224,11 +224,18 @@ class MTPSampler(TorchSampler):
|
||||
next_draft_tokens: torch.Tensor
|
||||
new_tokens_lens: torch.Tensor
|
||||
max_total_draft_tokens: torch.Tensor
|
||||
finish_reasons: None = None # Necessary to satisfy the interface of TorchSampler.Store
|
||||
# Necessary to satisfy the interface of TorchSampler.Store
|
||||
finish_reasons: None = None
|
||||
end_ids: None = None
|
||||
max_lengths_tensor: None = None
|
||||
|
||||
def __post_init__(self):
|
||||
pass # finish_reasons has no size to compare against new_tokens in MTPSampler
|
||||
|
||||
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
|
||||
# MTPSampler does not need to setup additional buffers before the sampler step
|
||||
pass
|
||||
|
||||
def __init__(self, args: TorchSampler.Args, *, nextn: int):
|
||||
self.mapping = None
|
||||
self.draft_len = nextn
|
||||
|
||||
@ -701,16 +701,40 @@ class RequestCase:
|
||||
seq_slots = torch.tensor(
|
||||
[req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64
|
||||
)
|
||||
seq_lens = torch.tensor(
|
||||
[req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
new_tokens = torch.tensor(
|
||||
[req.new_tokens for req in requests], dtype=torch.int32, device="cuda"
|
||||
).T
|
||||
sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens
|
||||
max_seq_lens = torch.tensor(
|
||||
[
|
||||
min(
|
||||
sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens
|
||||
)
|
||||
for req in requests
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
end_ids = torch.tensor(
|
||||
[
|
||||
req.request.py_end_id if req.request.py_end_id is not None else -1
|
||||
for req in requests
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens
|
||||
sampler.store.end_ids[seq_slots] = end_ids
|
||||
|
||||
def run():
|
||||
sampler._write_finish_reasons(
|
||||
[req.request for req in requests],
|
||||
finish_reasons=sampler.store.finish_reasons,
|
||||
new_tokens=sampler.store.new_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_slots=seq_slots,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user