mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5622938][feat] Run sample_async on extra stream. (#10215)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
78bb245554
commit
80f261ea36
@ -250,6 +250,9 @@ class PyExecutor:
|
||||
self.send_schedule_handler = None
|
||||
self.pp_scheduler_max_retry_count = int(
|
||||
os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10))
|
||||
self.sample_stream = torch.cuda.Stream()
|
||||
self.start_sample_event = torch.cuda.Event()
|
||||
self.finish_sample_event = torch.cuda.Event()
|
||||
|
||||
# Set of request IDs that are currently in flight across all micro batches.
|
||||
# The scheduler will avoid scheduling requests that are already in flight.
|
||||
@ -1068,8 +1071,25 @@ class PyExecutor:
|
||||
guided_decoder_failed_requests = self.guided_decoder.execute(
|
||||
batch_outputs['logits'])
|
||||
|
||||
sample_state = self._sample_async(
|
||||
scheduled_batch, batch_outputs)
|
||||
if os.environ.get("TRTLLM_PP_MULTI_STREAM_SAMPLE",
|
||||
"1") == "1":
|
||||
# Wait for the previous sample to finish.
|
||||
self.finish_sample_event.wait()
|
||||
# Copy the batch outputs as sampler inputs
|
||||
# to avoid next forward step overwriting them.
|
||||
batch_outputs_copy = {
|
||||
name: tensor.clone()
|
||||
for name, tensor in batch_outputs.items()
|
||||
}
|
||||
self.start_sample_event.record()
|
||||
with torch.cuda.stream(self.sample_stream):
|
||||
self.start_sample_event.wait()
|
||||
sample_state = self._sample_async(
|
||||
scheduled_batch, batch_outputs_copy)
|
||||
self.finish_sample_event.record()
|
||||
else:
|
||||
sample_state = self._sample_async(
|
||||
scheduled_batch, batch_outputs)
|
||||
assert sample_state is not None, "Sampling failed"
|
||||
|
||||
# Handle guided decoder errors after _sample_async to avoid state conflicts.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user