[TRTLLM-10030][chore] improve assert in sampler (#11475)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-02-13 14:54:28 +01:00 committed by GitHub
parent b67dcd8fef
commit 37c53425c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -596,10 +596,10 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
beam_width_in = group_key[1]
else:
beam_width_in = 1
if group_logit_indices is None:
assert logits.size(0) == beam_width_in * len(strategies)
else:
if group_logit_indices is not None:
logits = logits[group_logit_indices]
assert logits.size(0) == beam_width_in * len(strategies)
assert all(strategy == group_key for strategy in strategies), "group must be consistent"