mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Fix trtllm sampler beam width bug (#4507)
* Fix TRTLLMSampler. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> * Added type hint. Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> --------- Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
parent
ff0f37bcf8
commit
cc3f8e6431
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
@ -561,7 +562,7 @@ class TRTLLMSampler(Sampler):
|
||||
self.model_config, self.world_config, self.decoding_config,
|
||||
requests, self.store["buffer_manager"], self.logits_datatype,
|
||||
self.store["decoder_input_buffers"],
|
||||
self.algs.decoder.decoder_state, self.beam_width,
|
||||
self.algs.decoder.decoder_state, self.beam_width(requests),
|
||||
self.store["cuda_stream"])
|
||||
|
||||
if len(decoder_requests):
|
||||
@ -578,15 +579,15 @@ class TRTLLMSampler(Sampler):
|
||||
decoder_requests)
|
||||
|
||||
@staticmethod
|
||||
def beam_width(scheduled_requests: ScheduledRequests) -> int:
|
||||
for req in scheduled_requests.all_requests:
|
||||
def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int:
|
||||
for req in scheduled_requests:
|
||||
return req.sampling_config.beam_width
|
||||
raise ValueError("No beam width found")
|
||||
return 0
|
||||
|
||||
def sample_async(self, scheduled_requests: ScheduledRequests,
|
||||
model_outputs) -> SampleStateTRTLLM:
|
||||
batch_size = scheduled_requests.batch_size
|
||||
beam_width = self.beam_width(scheduled_requests)
|
||||
beam_width = self.beam_width(scheduled_requests.all_requests)
|
||||
|
||||
logits = model_outputs["logits"].reshape((batch_size, beam_width, -1))
|
||||
|
||||
@ -659,7 +660,7 @@ class TRTLLMSampler(Sampler):
|
||||
|
||||
scheduled_requests = state.scheduled_requests
|
||||
assert scheduled_requests.batch_size > 0
|
||||
beam_width = self.beam_width(scheduled_requests)
|
||||
beam_width = self.beam_width(scheduled_requests.all_requests)
|
||||
sampler_event = state.sampler_event
|
||||
|
||||
if sampler_event:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user