diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index 66f04c3bef..0660ee7972 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -596,10 +596,10 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]): beam_width_in = group_key[1] else: beam_width_in = 1 - if group_logit_indices is None: - assert logits.size(0) == beam_width_in * len(strategies) - else: + + if group_logit_indices is not None: logits = logits[group_logit_indices] + assert logits.size(0) == beam_width_in * len(strategies) assert all(strategy == group_key for strategy in strategies), "group must be consistent"