mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[https://nvbugs/5829097][fix] Re-init TRTLLM sampler to use sample stream in multi-stream cases. (#10918)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
9d65b8bf24
commit
9fcc93ea7b
@ -54,7 +54,7 @@ from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||
from .model_engine import ModelEngine
|
||||
from .resource_manager import ResourceManager
|
||||
from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState,
|
||||
SampleStateTensors)
|
||||
SampleStateTensors, TRTLLMSampler)
|
||||
from .scheduler import (RequestScheduler, ScheduledRequests,
|
||||
SerializableSchedulerOutput)
|
||||
|
||||
@ -371,9 +371,19 @@ class PyExecutor:
|
||||
self.send_schedule_handler = None
|
||||
self.pp_scheduler_max_retry_count = int(
|
||||
os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10))
|
||||
self.pp_multi_stream_sample = os.environ.get(
|
||||
"TRTLLM_PP_MULTI_STREAM_SAMPLE", "1") == "1"
|
||||
self.sample_stream = torch.cuda.Stream()
|
||||
self.start_sample_event = torch.cuda.Event()
|
||||
self.finish_sample_event = torch.cuda.Event()
|
||||
if (self.dist.pp_size > 1 and self.pp_multi_stream_sample
|
||||
and isinstance(self.sampler, TRTLLMSampler)):
|
||||
# TRTLLM sampler uses default stream for store and algorithms.
|
||||
# To enable multi-stream sampling, we need to re-initialize
|
||||
# the sampler store and algorithms on the sample stream.
|
||||
with torch.cuda.stream(self.sample_stream):
|
||||
self.sampler._initialize_store()
|
||||
self.sampler._instantiate_algorithms()
|
||||
|
||||
# Set of request IDs that are currently in flight across all micro batches.
|
||||
# The scheduler will avoid scheduling requests that are already in flight.
|
||||
@ -1216,8 +1226,7 @@ class PyExecutor:
|
||||
guided_decoder_failed_requests = self.guided_decoder.execute(
|
||||
batch_outputs['logits'])
|
||||
|
||||
if os.environ.get("TRTLLM_PP_MULTI_STREAM_SAMPLE",
|
||||
"1") == "1":
|
||||
if self.pp_multi_stream_sample:
|
||||
# Wait for the previous sample to finish.
|
||||
self.finish_sample_event.wait()
|
||||
# Copy the batch outputs as sampler inputs
|
||||
|
||||
Loading…
Reference in New Issue
Block a user