[https://nvbugs/5829097][fix] Re-init TRTLLM sampler to use sample stream in multi-stream cases. (#10918)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2026-01-24 14:04:10 +08:00 committed by GitHub
parent 9d65b8bf24
commit 9fcc93ea7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -54,7 +54,7 @@ from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
from .model_engine import ModelEngine
from .resource_manager import ResourceManager
from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState,
SampleStateTensors)
SampleStateTensors, TRTLLMSampler)
from .scheduler import (RequestScheduler, ScheduledRequests,
SerializableSchedulerOutput)
@ -371,9 +371,19 @@ class PyExecutor:
self.send_schedule_handler = None
self.pp_scheduler_max_retry_count = int(
os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10))
self.pp_multi_stream_sample = os.environ.get(
"TRTLLM_PP_MULTI_STREAM_SAMPLE", "1") == "1"
self.sample_stream = torch.cuda.Stream()
self.start_sample_event = torch.cuda.Event()
self.finish_sample_event = torch.cuda.Event()
if (self.dist.pp_size > 1 and self.pp_multi_stream_sample
and isinstance(self.sampler, TRTLLMSampler)):
# TRTLLM sampler uses default stream for store and algorithms.
# To enable multi-stream sampling, we need to re-initialize
# the sampler store and algorithms on the sample stream.
with torch.cuda.stream(self.sample_stream):
self.sampler._initialize_store()
self.sampler._instantiate_algorithms()
# Set of request IDs that are currently in flight across all micro batches.
# The scheduler will avoid scheduling requests that are already in flight.
@ -1216,8 +1226,7 @@ class PyExecutor:
guided_decoder_failed_requests = self.guided_decoder.execute(
batch_outputs['logits'])
if os.environ.get("TRTLLM_PP_MULTI_STREAM_SAMPLE",
"1") == "1":
if self.pp_multi_stream_sample:
# Wait for the previous sample to finish.
self.finish_sample_event.wait()
# Copy the batch outputs as sampler inputs