diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 08b32c15d1..c93a849ff6 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Iterable from dataclasses import dataclass import torch @@ -561,7 +562,7 @@ class TRTLLMSampler(Sampler): self.model_config, self.world_config, self.decoding_config, requests, self.store["buffer_manager"], self.logits_datatype, self.store["decoder_input_buffers"], - self.algs.decoder.decoder_state, self.beam_width, + self.algs.decoder.decoder_state, self.beam_width(requests), self.store["cuda_stream"]) if len(decoder_requests): @@ -578,15 +579,15 @@ class TRTLLMSampler(Sampler): decoder_requests) @staticmethod - def beam_width(scheduled_requests: ScheduledRequests) -> int: - for req in scheduled_requests.all_requests: + def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: + for req in scheduled_requests: return req.sampling_config.beam_width - raise ValueError("No beam width found") + return 0 def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs) -> SampleStateTRTLLM: batch_size = scheduled_requests.batch_size - beam_width = self.beam_width(scheduled_requests) + beam_width = self.beam_width(scheduled_requests.all_requests) logits = model_outputs["logits"].reshape((batch_size, beam_width, -1)) @@ -659,7 +660,7 @@ class TRTLLMSampler(Sampler): scheduled_requests = state.scheduled_requests assert scheduled_requests.batch_size > 0 - beam_width = self.beam_width(scheduled_requests) + beam_width = self.beam_width(scheduled_requests.all_requests) sampler_event = state.sampler_event if sampler_event: