From cc3f8e643149e47ea25f9419215c96b0b4b68d16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20C=C3=A1mpora?= <961215+dcampora@users.noreply.github.com> Date: Wed, 21 May 2025 08:21:39 +0200 Subject: [PATCH] fix: Fix trtllm sampler beam width bug (#4507) * Fix TRTLLMSampler. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Added type hint. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --------- Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 08b32c15d1..c93a849ff6 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Iterable from dataclasses import dataclass import torch @@ -561,7 +562,7 @@ class TRTLLMSampler(Sampler): self.model_config, self.world_config, self.decoding_config, requests, self.store["buffer_manager"], self.logits_datatype, self.store["decoder_input_buffers"], - self.algs.decoder.decoder_state, self.beam_width, + self.algs.decoder.decoder_state, self.beam_width(requests), self.store["cuda_stream"]) if len(decoder_requests): @@ -578,15 +579,15 @@ class TRTLLMSampler(Sampler): decoder_requests) @staticmethod - def beam_width(scheduled_requests: ScheduledRequests) -> int: - for req in scheduled_requests.all_requests: + def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: + for req in scheduled_requests: return req.sampling_config.beam_width - raise ValueError("No beam width found") + return 0 def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs) -> SampleStateTRTLLM: batch_size = scheduled_requests.batch_size - beam_width = self.beam_width(scheduled_requests) + beam_width = self.beam_width(scheduled_requests.all_requests) logits = model_outputs["logits"].reshape((batch_size, beam_width, -1)) @@ -659,7 +660,7 @@ class TRTLLMSampler(Sampler): scheduled_requests = state.scheduled_requests assert scheduled_requests.batch_size > 0 - beam_width = self.beam_width(scheduled_requests) + beam_width = self.beam_width(scheduled_requests.all_requests) sampler_event = state.sampler_event if sampler_event: