[https://nvbugs/5622938][feat] Run sample_async on extra stream. (#10215)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2026-01-09 18:15:18 +08:00 committed by GitHub
parent 78bb245554
commit 80f261ea36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -250,6 +250,9 @@ class PyExecutor:
self.send_schedule_handler = None
self.pp_scheduler_max_retry_count = int(
os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10))
self.sample_stream = torch.cuda.Stream()
self.start_sample_event = torch.cuda.Event()
self.finish_sample_event = torch.cuda.Event()
# Set of request IDs that are currently in flight across all micro batches.
# The scheduler will avoid scheduling requests that are already in flight.
@ -1068,8 +1071,25 @@ class PyExecutor:
guided_decoder_failed_requests = self.guided_decoder.execute(
batch_outputs['logits'])
sample_state = self._sample_async(
scheduled_batch, batch_outputs)
if os.environ.get("TRTLLM_PP_MULTI_STREAM_SAMPLE",
"1") == "1":
# Wait for the previous sample to finish.
self.finish_sample_event.wait()
# Copy the batch outputs as sampler inputs
# to avoid next forward step overwriting them.
batch_outputs_copy = {
name: tensor.clone()
for name, tensor in batch_outputs.items()
}
self.start_sample_event.record()
with torch.cuda.stream(self.sample_stream):
self.start_sample_event.wait()
sample_state = self._sample_async(
scheduled_batch, batch_outputs_copy)
self.finish_sample_event.record()
else:
sample_state = self._sample_async(
scheduled_batch, batch_outputs)
assert sample_state is not None, "Sampling failed"
# Handle guided decoder errors after _sample_async to avoid state conflicts.