From 9fcc93ea7bbec1b921e1dd2f6681c5bd36c2158f Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Sat, 24 Jan 2026 14:04:10 +0800 Subject: [PATCH] [https://nvbugs/5829097][fix] Re-init TRTLLM sampler to use sample stream in multi-stream cases. (#10918) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index eaf6e05098..39f62285e1 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -54,7 +54,7 @@ from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, from .model_engine import ModelEngine from .resource_manager import ResourceManager from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState, - SampleStateTensors) + SampleStateTensors, TRTLLMSampler) from .scheduler import (RequestScheduler, ScheduledRequests, SerializableSchedulerOutput) @@ -371,9 +371,19 @@ class PyExecutor: self.send_schedule_handler = None self.pp_scheduler_max_retry_count = int( os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10)) + self.pp_multi_stream_sample = os.environ.get( + "TRTLLM_PP_MULTI_STREAM_SAMPLE", "1") == "1" self.sample_stream = torch.cuda.Stream() self.start_sample_event = torch.cuda.Event() self.finish_sample_event = torch.cuda.Event() + if (self.dist.pp_size > 1 and self.pp_multi_stream_sample + and isinstance(self.sampler, TRTLLMSampler)): + # TRTLLM sampler uses default stream for store and algorithms. + # To enable multi-stream sampling, we need to re-initialize + # the sampler store and algorithms on the sample stream. + with torch.cuda.stream(self.sample_stream): + self.sampler._initialize_store() + self.sampler._instantiate_algorithms() # Set of request IDs that are currently in flight across all micro batches. # The scheduler will avoid scheduling requests that are already in flight. @@ -1216,8 +1226,7 @@ class PyExecutor: guided_decoder_failed_requests = self.guided_decoder.execute( batch_outputs['logits']) - if os.environ.get("TRTLLM_PP_MULTI_STREAM_SAMPLE", - "1") == "1": + if self.pp_multi_stream_sample: # Wait for the previous sample to finish. self.finish_sample_event.wait() # Copy the batch outputs as sampler inputs