mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[TRTLLM-10030][chore] improve assert in sampler (#11475)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
b67dcd8fef
commit
37c53425c1
@ -596,10 +596,10 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
|
||||
beam_width_in = group_key[1]
|
||||
else:
|
||||
beam_width_in = 1
|
||||
if group_logit_indices is None:
|
||||
assert logits.size(0) == beam_width_in * len(strategies)
|
||||
else:
|
||||
|
||||
if group_logit_indices is not None:
|
||||
logits = logits[group_logit_indices]
|
||||
assert logits.size(0) == beam_width_in * len(strategies)
|
||||
|
||||
assert all(strategy == group_key for strategy in strategies), "group must be consistent"
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user