fix: Fix trtllm sampler beam width bug (#4507)

* Fix TRTLLMSampler.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

* Added type hint.

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>

---------

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Daniel Cámpora 2025-05-21 08:21:39 +02:00 committed by GitHub
parent ff0f37bcf8
commit cc3f8e6431
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
import torch
@ -561,7 +562,7 @@ class TRTLLMSampler(Sampler):
self.model_config, self.world_config, self.decoding_config,
requests, self.store["buffer_manager"], self.logits_datatype,
self.store["decoder_input_buffers"],
self.algs.decoder.decoder_state, self.beam_width,
self.algs.decoder.decoder_state, self.beam_width(requests),
self.store["cuda_stream"])
if len(decoder_requests):
@ -578,15 +579,15 @@ class TRTLLMSampler(Sampler):
decoder_requests)
@staticmethod
def beam_width(scheduled_requests: ScheduledRequests) -> int:
for req in scheduled_requests.all_requests:
def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int:
for req in scheduled_requests:
return req.sampling_config.beam_width
raise ValueError("No beam width found")
return 0
def sample_async(self, scheduled_requests: ScheduledRequests,
model_outputs) -> SampleStateTRTLLM:
batch_size = scheduled_requests.batch_size
beam_width = self.beam_width(scheduled_requests)
beam_width = self.beam_width(scheduled_requests.all_requests)
logits = model_outputs["logits"].reshape((batch_size, beam_width, -1))
@ -659,7 +660,7 @@ class TRTLLMSampler(Sampler):
scheduled_requests = state.scheduled_requests
assert scheduled_requests.batch_size > 0
beam_width = self.beam_width(scheduled_requests)
beam_width = self.beam_width(scheduled_requests.all_requests)
sampler_event = state.sampler_event
if sampler_event: