[TRTLLM-10030][chore] refactor finish reasons tests (#11445)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-02-12 08:32:50 +01:00 committed by GitHub
parent 3c1323442b
commit d0f3c412ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -373,7 +373,8 @@ class UutProvider(Protocol):
def _run_test_with_warmup(
uut_provider: UutProvider,
warmup_sizes_bytes: tuple[int] = (4 * 2**30,),
max_sync_s: Optional[float] = None,
*,
max_sync_s: Optional[float],
):
"""Run UUT including setup and warmup.
@ -635,231 +636,259 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool)
_run_test_with_warmup(_test_runner, max_sync_s=0.3)
MAX_NUM_SEQUENCES = 128
NOT_FINISHED = FinishReason.NOT_FINISHED
STOP_WORDS = FinishReason.STOP_WORDS
END_ID = FinishReason.END_ID
LENGTH = FinishReason.LENGTH
BEAM = 0
class TestFinishReasons:
NOT_FINISHED = FinishReason.NOT_FINISHED
STOP_WORDS = FinishReason.STOP_WORDS
END_ID = FinishReason.END_ID
LENGTH = FinishReason.LENGTH
class RequestCase:
MAX_NEW_TOKENS = 10
MAX_NUM_SEQUENCES = 128
seq_slots = torch.randperm(MAX_NUM_SEQUENCES).tolist()
BEAM = 0
class RequestCase:
MAX_NEW_TOKENS = 10
seq_slots = torch.randperm(MAX_NUM_SEQUENCES).tolist()
def __init__(
self,
*,
prompt: list[int],
new_tokens: list[int],
finish_reasons: list[FinishReason],
max_new_tokens: int = MAX_NEW_TOKENS,
end_id: Optional[int] = None,
num_draft_tokens: int | None = None,
stop_words_list: Optional[list[list[int]]] = None,
):
seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES
self.prompt = prompt
if num_draft_tokens is None:
num_draft_tokens = len(new_tokens) - 1
self.request = LlmRequest(
request_id=seq_slot,
seq_slot=seq_slot,
input_tokens=prompt,
max_new_tokens=max_new_tokens,
stop_words_list=convert_wordlist(stop_words_list)
if stop_words_list is not None
else None,
end_id=end_id,
sampling_config=SamplingConfig(),
is_streaming=False,
draft_tokens=new_tokens[:num_draft_tokens],
)
assert len(new_tokens) == len(finish_reasons)
self.new_tokens = new_tokens
self.finish_reasons = finish_reasons
def __init__(
self,
*,
prompt: list[int],
new_tokens: list[int],
finish_reasons: list[FinishReason],
max_new_tokens: int = MAX_NEW_TOKENS,
end_id: Optional[int] = None,
num_draft_tokens: int | None = None,
stop_words_list: Optional[list[list[int]]] = None,
):
seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES
self.prompt = prompt
if num_draft_tokens is None:
num_draft_tokens = len(new_tokens) - 1
self.request = LlmRequest(
request_id=seq_slot,
seq_slot=seq_slot,
input_tokens=prompt,
max_new_tokens=max_new_tokens,
stop_words_list=convert_wordlist(stop_words_list)
if stop_words_list is not None
else None,
end_id=end_id,
sampling_config=SamplingConfig(),
is_streaming=False,
draft_tokens=new_tokens[:num_draft_tokens],
)
assert len(new_tokens) == len(finish_reasons)
self.new_tokens = new_tokens
self.finish_reasons = finish_reasons
def __repr__(self):
return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, \
{self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})"
def __repr__(self):
return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, \
{self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})"
@staticmethod
def setup(requests: list["RequestCase"]):
max_tokens = set(len(req.new_tokens) for req in requests)
assert len(max_tokens) == 1
max_draft_len = max_tokens.pop() - 1
sampler_args = TorchSampler.Args(
max_seq_len=20,
max_draft_len=max_draft_len,
max_total_draft_tokens=max_draft_len,
# Fill with many more max requests than below,
# so we can test that write_finish_reasons uses seq_slots correctly
max_num_sequences=MAX_NUM_SEQUENCES,
max_beam_width=1,
disable_overlap_scheduler=False,
)
sampler = TorchSampler(args=sampler_args)
# fill with garbage value so we can observe that finish reasons are filled
# with NOT_FINISHED before we write to them.
sampler.store.finish_reasons.fill_(205)
seq_slots = torch.tensor(
[req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64
)
seq_lens = torch.tensor(
[req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda"
)
new_tokens = torch.tensor(
[req.new_tokens for req in requests], dtype=torch.int32, device="cuda"
).T
sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens
max_seq_lens = torch.tensor(
[
min(
sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens
@classmethod
def build(
cls,
requests: list["TestFinishReasons.RequestCase"],
*,
check_no_cuda_sync: bool = True,
extra_context: Callable[[], ContextManager] | None = None,
) -> UutProvider:
@contextmanager
def _uut_provider(is_warmup: bool) -> Generator[Callable[[], None], None, None]:
max_tokens = set(len(req.new_tokens) for req in requests)
assert len(max_tokens) == 1
max_draft_len = max_tokens.pop() - 1
sampler_args = TorchSampler.Args(
max_seq_len=20,
max_draft_len=max_draft_len,
max_total_draft_tokens=max_draft_len,
# Fill with many more max requests than below,
# so we can test that write_finish_reasons uses seq_slots correctly
max_num_sequences=cls.MAX_NUM_SEQUENCES,
max_beam_width=1,
disable_overlap_scheduler=False,
)
for req in requests
],
dtype=torch.int32,
device="cuda",
)
end_ids = torch.tensor(
sampler = TorchSampler(args=sampler_args)
# fill with garbage value so we can observe that finish reasons are filled
# with NOT_FINISHED before we write to them.
sampler.store.finish_reasons.fill_(205)
seq_slots = torch.tensor(
[req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64
)
seq_lens = torch.tensor(
[req.request.max_beam_num_tokens for req in requests],
dtype=torch.int32,
device="cuda",
)
new_tokens = torch.tensor(
[req.new_tokens for req in requests], dtype=torch.int32, device="cuda"
).T
sampler.store.new_tokens[:, seq_slots, cls.BEAM] = new_tokens
max_seq_lens = torch.tensor(
[
min(
sampler.max_seq_len,
req.request.orig_prompt_len + req.request.py_max_new_tokens,
)
for req in requests
],
dtype=torch.int32,
device="cuda",
)
end_ids = torch.tensor(
[
req.request.py_end_id if req.request.py_end_id is not None else -1
for req in requests
],
dtype=torch.int32,
device="cuda",
)
sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens
sampler.store.end_ids[seq_slots] = end_ids
def _uut():
with extra_context() if extra_context is not None else nullcontext():
sampler._write_finish_reasons(
[req.request for req in requests],
finish_reasons=sampler.store.finish_reasons,
new_tokens=sampler.store.new_tokens,
seq_lens=seq_lens,
seq_slots=seq_slots,
)
yield _uut
reasons = sampler.store.finish_reasons[:, seq_slots, cls.BEAM].T.tolist()
for actual, request in zip(reasons, requests, strict=True):
expected = request.finish_reasons
msg = f"actual={[FinishReason(reason) for reason in actual]} != expected={expected}\nFor {request}"
assert actual == [reason.value for reason in expected], msg
return _uut_provider
@classmethod
def test_write_finish_reasons(cls):
"""We don't really care about the finish reason past the first infraction, because we're not going to use it,
although in some instance it is written anyway."""
uut_provider = cls.RequestCase.build(
[
req.request.py_end_id if req.request.py_end_id is not None else -1
for req in requests
],
dtype=torch.int32,
device="cuda",
cls.RequestCase(
prompt=[13, 14],
new_tokens=[60, 61, 62],
# We pre-fill the finish reasons with NOT_FINISHED.
finish_reasons=[cls.NOT_FINISHED, cls.NOT_FINISHED, cls.NOT_FINISHED],
),
cls.RequestCase(
prompt=[7, 8, 6],
stop_words_list=[[12, 13]],
new_tokens=[12, 13, 60],
finish_reasons=[cls.NOT_FINISHED, cls.STOP_WORDS, cls.NOT_FINISHED],
),
cls.RequestCase(
prompt=[7, 8, 6],
stop_words_list=[[12, 13]],
new_tokens=[60, 12, 13],
# The request has stop words, but no draft is created
# Tokens at indices greater than 0 should be ignored
num_draft_tokens=0,
finish_reasons=[cls.NOT_FINISHED, cls.NOT_FINISHED, cls.NOT_FINISHED],
),
cls.RequestCase(
prompt=[1, 2, 3, 4],
end_id=99,
new_tokens=[55, 99, 58],
finish_reasons=[cls.NOT_FINISHED, cls.END_ID, cls.NOT_FINISHED],
),
cls.RequestCase(
prompt=[4, 5, 6],
max_new_tokens=2,
new_tokens=[56, 57, 59],
# The LENGTH check happens to not have an early exit
finish_reasons=[cls.NOT_FINISHED, cls.LENGTH, cls.LENGTH],
),
cls.RequestCase(
prompt=[1, 12],
stop_words_list=[[12, 13], [14, 15]],
new_tokens=[13, 14, 15],
# We don't use early exit to avoid stream synchronization for stop words
finish_reasons=[cls.STOP_WORDS, cls.NOT_FINISHED, cls.STOP_WORDS],
),
cls.RequestCase(
prompt=[1, 12],
stop_words_list=[
[12, 13, 14, 15],
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
],
new_tokens=[13, 14, 15],
# Stop words of different lengths are handled correctly with respect to padding of stop words
# and tokens
finish_reasons=[cls.NOT_FINISHED, cls.NOT_FINISHED, cls.STOP_WORDS],
),
cls.RequestCase(
prompt=[1],
max_new_tokens=2,
end_id=99,
stop_words_list=[[1, 12]],
new_tokens=[12, 99, 63],
# Different infractions are written to different places as
# we don't have an early exit between infractions
finish_reasons=[cls.STOP_WORDS, cls.END_ID, cls.LENGTH],
),
cls.RequestCase(
prompt=[1, 12, 56, 67, 68, 234, 678],
stop_words_list=[[12, 56, 67, 68, 234, 678, 129, 182]],
new_tokens=[129, 182, 600],
# Notice the offending stop sequence is concatenated, as we lookback
finish_reasons=[cls.NOT_FINISHED, cls.STOP_WORDS, cls.NOT_FINISHED],
),
cls.RequestCase(
prompt=[1, 12],
end_id=99,
max_new_tokens=1,
stop_words_list=[[1, 12, 99]],
new_tokens=[99, 100, 101],
# The latest infraction check overrides the earlier infraction checks,
# hence the first finish_reason is END_ID
finish_reasons=[cls.END_ID, cls.LENGTH, cls.LENGTH],
),
]
)
sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens
sampler.store.end_ids[seq_slots] = end_ids
def run():
sampler._write_finish_reasons(
[req.request for req in requests],
finish_reasons=sampler.store.finish_reasons,
new_tokens=sampler.store.new_tokens,
seq_lens=seq_lens,
seq_slots=seq_slots,
)
_run_test_with_warmup(uut_provider, max_sync_s=0.5)
reasons = sampler.store.finish_reasons[:, seq_slots, BEAM].T.tolist()
@classmethod
def test_are_stop_words_isnt_called_when_no_stop_words(cls, monkeypatch: pytest.MonkeyPatch):
"""We don't want to call are_stop_words when there are no stop words because it's expensive"""
for actual, request in zip(reasons, requests, strict=True):
expected = request.finish_reasons
msg = f"actual={[FinishReason(reason) for reason in actual]} != expected={expected}\nFor {request}"
assert actual == [reason.value for reason in expected], msg
def stop_words_that_raises(*args, **kwargs):
raise AssertionError
return run, sampler
@contextmanager
def raising_stop_words_ctx(expect_raise: bool) -> Generator[None, None, None]:
with monkeypatch.context() as patch_ctx:
patch_ctx.setattr(TorchSampler, "_are_stop_words", stop_words_that_raises)
with pytest.raises(AssertionError) if expect_raise else nullcontext():
yield
uut_provider_with_stop_words = cls.RequestCase.build(
[
cls.RequestCase(
prompt=[1],
stop_words_list=[[1]],
new_tokens=[4],
finish_reasons=[cls.NOT_FINISHED],
)
],
extra_context=lambda: raising_stop_words_ctx(True),
)
_run_test_with_warmup(uut_provider_with_stop_words, max_sync_s=0.5)
def test_write_finish_reasons():
"""We don't really care about the finish reason past the first infraction, because we're not going to use it,
although in some instance it is written anyway."""
run, _ = RequestCase.setup(
[
RequestCase(
prompt=[13, 14],
new_tokens=[60, 61, 62],
# We pre-fill the finish reasons with NOT_FINISHED.
finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED],
),
RequestCase(
prompt=[7, 8, 6],
stop_words_list=[[12, 13]],
new_tokens=[12, 13, 60],
finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED],
),
RequestCase(
prompt=[7, 8, 6],
stop_words_list=[[12, 13]],
new_tokens=[60, 12, 13],
# The request has stop words, but no draft is created
# Tokens at indices greater than 0 should be ignored
num_draft_tokens=0,
finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED],
),
RequestCase(
prompt=[1, 2, 3, 4],
end_id=99,
new_tokens=[55, 99, 58],
finish_reasons=[NOT_FINISHED, END_ID, NOT_FINISHED],
),
RequestCase(
prompt=[4, 5, 6],
max_new_tokens=2,
new_tokens=[56, 57, 59],
# The LENGTH check happens to not have an early exit
finish_reasons=[NOT_FINISHED, LENGTH, LENGTH],
),
RequestCase(
prompt=[1, 12],
stop_words_list=[[12, 13], [14, 15]],
new_tokens=[13, 14, 15],
# We don't use early exit to avoid stream synchronization for stop words
finish_reasons=[STOP_WORDS, NOT_FINISHED, STOP_WORDS],
),
RequestCase(
prompt=[1, 12],
stop_words_list=[[12, 13, 14, 15], [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]],
new_tokens=[13, 14, 15],
# Stop words of different lengths are handled correctly with respect to padding of stop words and tokens
finish_reasons=[NOT_FINISHED, NOT_FINISHED, STOP_WORDS],
),
RequestCase(
prompt=[1],
max_new_tokens=2,
end_id=99,
stop_words_list=[[1, 12]],
new_tokens=[12, 99, 63],
# Different infractions are written to different places as
# we don't have an early exit between infractions
finish_reasons=[STOP_WORDS, END_ID, LENGTH],
),
RequestCase(
prompt=[1, 12, 56, 67, 68, 234, 678],
stop_words_list=[[12, 56, 67, 68, 234, 678, 129, 182]],
new_tokens=[129, 182, 600],
# Notice the offending stop sequence is concatenated, as we lookback
finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED],
),
RequestCase(
prompt=[1, 12],
end_id=99,
max_new_tokens=1,
stop_words_list=[[1, 12, 99]],
new_tokens=[99, 100, 101],
# The latest infraction check overrides the earlier infraction checks,
# hence the first finish_reason is END_ID
finish_reasons=[END_ID, LENGTH, LENGTH],
),
]
)
run()
def test_are_stop_words_isnt_called_when_no_stop_words():
"""We don't want to call are_stop_words when there are no stop words because it's expensive"""
def stop_words_that_raises(*args, **kwargs):
raise AssertionError
run_with_stop_words, sampler = RequestCase.setup(
[
RequestCase(
prompt=[1], stop_words_list=[[1]], new_tokens=[4], finish_reasons=[NOT_FINISHED]
)
]
)
sampler._are_stop_words = stop_words_that_raises
with pytest.raises(AssertionError):
run_with_stop_words()
run_without_stop_words, sampler = RequestCase.setup(
[RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[NOT_FINISHED])]
)
sampler._are_stop_words = stop_words_that_raises
_ = run_without_stop_words()
uut_provider_without_stop_words = cls.RequestCase.build(
[cls.RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[cls.NOT_FINISHED])],
extra_context=lambda: raising_stop_words_ctx(False),
)
_run_test_with_warmup(uut_provider_without_stop_words, max_sync_s=0.5)
class TestBatchedSampling:
@ -1485,7 +1514,10 @@ class TestBatchedSampling:
logit_offset += steps
_run_test_with_warmup(_uut_provider)
_run_test_with_warmup(
_uut_provider,
max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample
)
def _compute_probs(
self,
@ -2197,7 +2229,10 @@ class TestBatchedSampling:
num_samples=num_samples,
)
_run_test_with_warmup(_uut_provider)
_run_test_with_warmup(
_uut_provider,
max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample
)
@staticmethod
def _build_seq_slot_assignments() -> list[tuple[list[int], int, str]]: