fix: reshape token_ids for lp in torch backend (#4239)

reshape token_ids

Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
This commit is contained in:
Erin 2025-05-12 17:43:47 -07:00 committed by GitHub
parent 035d915fea
commit 4becf32360
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2035,6 +2035,7 @@ class PyTorchModelEngine(ModelEngine):
# Reshape to align w/ the shape used in the TRT backend,
# so the same logit processors can be used across both backends.
logits_row = logits_row.view(1, 1, -1)
token_ids = [token_ids]
for lp in logits_processors:
lp_params = inspect.signature(lp).parameters