mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-10030][chore] refactor finish reasons tests (#11445)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
3c1323442b
commit
d0f3c412ff
@ -373,7 +373,8 @@ class UutProvider(Protocol):
|
||||
def _run_test_with_warmup(
|
||||
uut_provider: UutProvider,
|
||||
warmup_sizes_bytes: tuple[int] = (4 * 2**30,),
|
||||
max_sync_s: Optional[float] = None,
|
||||
*,
|
||||
max_sync_s: Optional[float],
|
||||
):
|
||||
"""Run UUT including setup and warmup.
|
||||
|
||||
@ -635,231 +636,259 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool)
|
||||
_run_test_with_warmup(_test_runner, max_sync_s=0.3)
|
||||
|
||||
|
||||
MAX_NUM_SEQUENCES = 128
|
||||
NOT_FINISHED = FinishReason.NOT_FINISHED
|
||||
STOP_WORDS = FinishReason.STOP_WORDS
|
||||
END_ID = FinishReason.END_ID
|
||||
LENGTH = FinishReason.LENGTH
|
||||
BEAM = 0
|
||||
class TestFinishReasons:
|
||||
NOT_FINISHED = FinishReason.NOT_FINISHED
|
||||
STOP_WORDS = FinishReason.STOP_WORDS
|
||||
END_ID = FinishReason.END_ID
|
||||
LENGTH = FinishReason.LENGTH
|
||||
|
||||
class RequestCase:
|
||||
MAX_NEW_TOKENS = 10
|
||||
MAX_NUM_SEQUENCES = 128
|
||||
seq_slots = torch.randperm(MAX_NUM_SEQUENCES).tolist()
|
||||
BEAM = 0
|
||||
|
||||
class RequestCase:
|
||||
MAX_NEW_TOKENS = 10
|
||||
seq_slots = torch.randperm(MAX_NUM_SEQUENCES).tolist()
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
prompt: list[int],
|
||||
new_tokens: list[int],
|
||||
finish_reasons: list[FinishReason],
|
||||
max_new_tokens: int = MAX_NEW_TOKENS,
|
||||
end_id: Optional[int] = None,
|
||||
num_draft_tokens: int | None = None,
|
||||
stop_words_list: Optional[list[list[int]]] = None,
|
||||
):
|
||||
seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES
|
||||
self.prompt = prompt
|
||||
if num_draft_tokens is None:
|
||||
num_draft_tokens = len(new_tokens) - 1
|
||||
self.request = LlmRequest(
|
||||
request_id=seq_slot,
|
||||
seq_slot=seq_slot,
|
||||
input_tokens=prompt,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_words_list=convert_wordlist(stop_words_list)
|
||||
if stop_words_list is not None
|
||||
else None,
|
||||
end_id=end_id,
|
||||
sampling_config=SamplingConfig(),
|
||||
is_streaming=False,
|
||||
draft_tokens=new_tokens[:num_draft_tokens],
|
||||
)
|
||||
assert len(new_tokens) == len(finish_reasons)
|
||||
self.new_tokens = new_tokens
|
||||
self.finish_reasons = finish_reasons
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
prompt: list[int],
|
||||
new_tokens: list[int],
|
||||
finish_reasons: list[FinishReason],
|
||||
max_new_tokens: int = MAX_NEW_TOKENS,
|
||||
end_id: Optional[int] = None,
|
||||
num_draft_tokens: int | None = None,
|
||||
stop_words_list: Optional[list[list[int]]] = None,
|
||||
):
|
||||
seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES
|
||||
self.prompt = prompt
|
||||
if num_draft_tokens is None:
|
||||
num_draft_tokens = len(new_tokens) - 1
|
||||
self.request = LlmRequest(
|
||||
request_id=seq_slot,
|
||||
seq_slot=seq_slot,
|
||||
input_tokens=prompt,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop_words_list=convert_wordlist(stop_words_list)
|
||||
if stop_words_list is not None
|
||||
else None,
|
||||
end_id=end_id,
|
||||
sampling_config=SamplingConfig(),
|
||||
is_streaming=False,
|
||||
draft_tokens=new_tokens[:num_draft_tokens],
|
||||
)
|
||||
assert len(new_tokens) == len(finish_reasons)
|
||||
self.new_tokens = new_tokens
|
||||
self.finish_reasons = finish_reasons
|
||||
def __repr__(self):
|
||||
return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, \
|
||||
{self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, \
|
||||
{self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})"
|
||||
|
||||
@staticmethod
|
||||
def setup(requests: list["RequestCase"]):
|
||||
max_tokens = set(len(req.new_tokens) for req in requests)
|
||||
assert len(max_tokens) == 1
|
||||
max_draft_len = max_tokens.pop() - 1
|
||||
sampler_args = TorchSampler.Args(
|
||||
max_seq_len=20,
|
||||
max_draft_len=max_draft_len,
|
||||
max_total_draft_tokens=max_draft_len,
|
||||
# Fill with many more max requests than below,
|
||||
# so we can test that write_finish_reasons uses seq_slots correctly
|
||||
max_num_sequences=MAX_NUM_SEQUENCES,
|
||||
max_beam_width=1,
|
||||
disable_overlap_scheduler=False,
|
||||
)
|
||||
sampler = TorchSampler(args=sampler_args)
|
||||
|
||||
# fill with garbage value so we can observe that finish reasons are filled
|
||||
# with NOT_FINISHED before we write to them.
|
||||
sampler.store.finish_reasons.fill_(205)
|
||||
seq_slots = torch.tensor(
|
||||
[req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64
|
||||
)
|
||||
seq_lens = torch.tensor(
|
||||
[req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
new_tokens = torch.tensor(
|
||||
[req.new_tokens for req in requests], dtype=torch.int32, device="cuda"
|
||||
).T
|
||||
sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens
|
||||
max_seq_lens = torch.tensor(
|
||||
[
|
||||
min(
|
||||
sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
requests: list["TestFinishReasons.RequestCase"],
|
||||
*,
|
||||
check_no_cuda_sync: bool = True,
|
||||
extra_context: Callable[[], ContextManager] | None = None,
|
||||
) -> UutProvider:
|
||||
@contextmanager
|
||||
def _uut_provider(is_warmup: bool) -> Generator[Callable[[], None], None, None]:
|
||||
max_tokens = set(len(req.new_tokens) for req in requests)
|
||||
assert len(max_tokens) == 1
|
||||
max_draft_len = max_tokens.pop() - 1
|
||||
sampler_args = TorchSampler.Args(
|
||||
max_seq_len=20,
|
||||
max_draft_len=max_draft_len,
|
||||
max_total_draft_tokens=max_draft_len,
|
||||
# Fill with many more max requests than below,
|
||||
# so we can test that write_finish_reasons uses seq_slots correctly
|
||||
max_num_sequences=cls.MAX_NUM_SEQUENCES,
|
||||
max_beam_width=1,
|
||||
disable_overlap_scheduler=False,
|
||||
)
|
||||
for req in requests
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
end_ids = torch.tensor(
|
||||
sampler = TorchSampler(args=sampler_args)
|
||||
|
||||
# fill with garbage value so we can observe that finish reasons are filled
|
||||
# with NOT_FINISHED before we write to them.
|
||||
sampler.store.finish_reasons.fill_(205)
|
||||
seq_slots = torch.tensor(
|
||||
[req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64
|
||||
)
|
||||
seq_lens = torch.tensor(
|
||||
[req.request.max_beam_num_tokens for req in requests],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
new_tokens = torch.tensor(
|
||||
[req.new_tokens for req in requests], dtype=torch.int32, device="cuda"
|
||||
).T
|
||||
sampler.store.new_tokens[:, seq_slots, cls.BEAM] = new_tokens
|
||||
max_seq_lens = torch.tensor(
|
||||
[
|
||||
min(
|
||||
sampler.max_seq_len,
|
||||
req.request.orig_prompt_len + req.request.py_max_new_tokens,
|
||||
)
|
||||
for req in requests
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
end_ids = torch.tensor(
|
||||
[
|
||||
req.request.py_end_id if req.request.py_end_id is not None else -1
|
||||
for req in requests
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens
|
||||
sampler.store.end_ids[seq_slots] = end_ids
|
||||
|
||||
def _uut():
|
||||
with extra_context() if extra_context is not None else nullcontext():
|
||||
sampler._write_finish_reasons(
|
||||
[req.request for req in requests],
|
||||
finish_reasons=sampler.store.finish_reasons,
|
||||
new_tokens=sampler.store.new_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_slots=seq_slots,
|
||||
)
|
||||
|
||||
yield _uut
|
||||
|
||||
reasons = sampler.store.finish_reasons[:, seq_slots, cls.BEAM].T.tolist()
|
||||
|
||||
for actual, request in zip(reasons, requests, strict=True):
|
||||
expected = request.finish_reasons
|
||||
msg = f"actual={[FinishReason(reason) for reason in actual]} != expected={expected}\nFor {request}"
|
||||
assert actual == [reason.value for reason in expected], msg
|
||||
|
||||
return _uut_provider
|
||||
|
||||
@classmethod
|
||||
def test_write_finish_reasons(cls):
|
||||
"""We don't really care about the finish reason past the first infraction, because we're not going to use it,
|
||||
although in some instance it is written anyway."""
|
||||
uut_provider = cls.RequestCase.build(
|
||||
[
|
||||
req.request.py_end_id if req.request.py_end_id is not None else -1
|
||||
for req in requests
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
cls.RequestCase(
|
||||
prompt=[13, 14],
|
||||
new_tokens=[60, 61, 62],
|
||||
# We pre-fill the finish reasons with NOT_FINISHED.
|
||||
finish_reasons=[cls.NOT_FINISHED, cls.NOT_FINISHED, cls.NOT_FINISHED],
|
||||
),
|
||||
cls.RequestCase(
|
||||
prompt=[7, 8, 6],
|
||||
stop_words_list=[[12, 13]],
|
||||
new_tokens=[12, 13, 60],
|
||||
finish_reasons=[cls.NOT_FINISHED, cls.STOP_WORDS, cls.NOT_FINISHED],
|
||||
),
|
||||
cls.RequestCase(
|
||||
prompt=[7, 8, 6],
|
||||
stop_words_list=[[12, 13]],
|
||||
new_tokens=[60, 12, 13],
|
||||
# The request has stop words, but no draft is created
|
||||
# Tokens at indices greater than 0 should be ignored
|
||||
num_draft_tokens=0,
|
||||
finish_reasons=[cls.NOT_FINISHED, cls.NOT_FINISHED, cls.NOT_FINISHED],
|
||||
),
|
||||
cls.RequestCase(
|
||||
prompt=[1, 2, 3, 4],
|
||||
end_id=99,
|
||||
new_tokens=[55, 99, 58],
|
||||
finish_reasons=[cls.NOT_FINISHED, cls.END_ID, cls.NOT_FINISHED],
|
||||
),
|
||||
cls.RequestCase(
|
||||
prompt=[4, 5, 6],
|
||||
max_new_tokens=2,
|
||||
new_tokens=[56, 57, 59],
|
||||
# The LENGTH check happens to not have an early exit
|
||||
finish_reasons=[cls.NOT_FINISHED, cls.LENGTH, cls.LENGTH],
|
||||
),
|
||||
cls.RequestCase(
|
||||
prompt=[1, 12],
|
||||
stop_words_list=[[12, 13], [14, 15]],
|
||||
new_tokens=[13, 14, 15],
|
||||
# We don't use early exit to avoid stream synchronization for stop words
|
||||
finish_reasons=[cls.STOP_WORDS, cls.NOT_FINISHED, cls.STOP_WORDS],
|
||||
),
|
||||
cls.RequestCase(
|
||||
prompt=[1, 12],
|
||||
stop_words_list=[
|
||||
[12, 13, 14, 15],
|
||||
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
|
||||
],
|
||||
new_tokens=[13, 14, 15],
|
||||
# Stop words of different lengths are handled correctly with respect to padding of stop words
|
||||
# and tokens
|
||||
finish_reasons=[cls.NOT_FINISHED, cls.NOT_FINISHED, cls.STOP_WORDS],
|
||||
),
|
||||
cls.RequestCase(
|
||||
prompt=[1],
|
||||
max_new_tokens=2,
|
||||
end_id=99,
|
||||
stop_words_list=[[1, 12]],
|
||||
new_tokens=[12, 99, 63],
|
||||
# Different infractions are written to different places as
|
||||
# we don't have an early exit between infractions
|
||||
finish_reasons=[cls.STOP_WORDS, cls.END_ID, cls.LENGTH],
|
||||
),
|
||||
cls.RequestCase(
|
||||
prompt=[1, 12, 56, 67, 68, 234, 678],
|
||||
stop_words_list=[[12, 56, 67, 68, 234, 678, 129, 182]],
|
||||
new_tokens=[129, 182, 600],
|
||||
# Notice the offending stop sequence is concatenated, as we lookback
|
||||
finish_reasons=[cls.NOT_FINISHED, cls.STOP_WORDS, cls.NOT_FINISHED],
|
||||
),
|
||||
cls.RequestCase(
|
||||
prompt=[1, 12],
|
||||
end_id=99,
|
||||
max_new_tokens=1,
|
||||
stop_words_list=[[1, 12, 99]],
|
||||
new_tokens=[99, 100, 101],
|
||||
# The latest infraction check overrides the earlier infraction checks,
|
||||
# hence the first finish_reason is END_ID
|
||||
finish_reasons=[cls.END_ID, cls.LENGTH, cls.LENGTH],
|
||||
),
|
||||
]
|
||||
)
|
||||
sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens
|
||||
sampler.store.end_ids[seq_slots] = end_ids
|
||||
|
||||
def run():
|
||||
sampler._write_finish_reasons(
|
||||
[req.request for req in requests],
|
||||
finish_reasons=sampler.store.finish_reasons,
|
||||
new_tokens=sampler.store.new_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_slots=seq_slots,
|
||||
)
|
||||
_run_test_with_warmup(uut_provider, max_sync_s=0.5)
|
||||
|
||||
reasons = sampler.store.finish_reasons[:, seq_slots, BEAM].T.tolist()
|
||||
@classmethod
|
||||
def test_are_stop_words_isnt_called_when_no_stop_words(cls, monkeypatch: pytest.MonkeyPatch):
|
||||
"""We don't want to call are_stop_words when there are no stop words because it's expensive"""
|
||||
|
||||
for actual, request in zip(reasons, requests, strict=True):
|
||||
expected = request.finish_reasons
|
||||
msg = f"actual={[FinishReason(reason) for reason in actual]} != expected={expected}\nFor {request}"
|
||||
assert actual == [reason.value for reason in expected], msg
|
||||
def stop_words_that_raises(*args, **kwargs):
|
||||
raise AssertionError
|
||||
|
||||
return run, sampler
|
||||
@contextmanager
|
||||
def raising_stop_words_ctx(expect_raise: bool) -> Generator[None, None, None]:
|
||||
with monkeypatch.context() as patch_ctx:
|
||||
patch_ctx.setattr(TorchSampler, "_are_stop_words", stop_words_that_raises)
|
||||
with pytest.raises(AssertionError) if expect_raise else nullcontext():
|
||||
yield
|
||||
|
||||
uut_provider_with_stop_words = cls.RequestCase.build(
|
||||
[
|
||||
cls.RequestCase(
|
||||
prompt=[1],
|
||||
stop_words_list=[[1]],
|
||||
new_tokens=[4],
|
||||
finish_reasons=[cls.NOT_FINISHED],
|
||||
)
|
||||
],
|
||||
extra_context=lambda: raising_stop_words_ctx(True),
|
||||
)
|
||||
_run_test_with_warmup(uut_provider_with_stop_words, max_sync_s=0.5)
|
||||
|
||||
def test_write_finish_reasons():
|
||||
"""We don't really care about the finish reason past the first infraction, because we're not going to use it,
|
||||
although in some instance it is written anyway."""
|
||||
run, _ = RequestCase.setup(
|
||||
[
|
||||
RequestCase(
|
||||
prompt=[13, 14],
|
||||
new_tokens=[60, 61, 62],
|
||||
# We pre-fill the finish reasons with NOT_FINISHED.
|
||||
finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[7, 8, 6],
|
||||
stop_words_list=[[12, 13]],
|
||||
new_tokens=[12, 13, 60],
|
||||
finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[7, 8, 6],
|
||||
stop_words_list=[[12, 13]],
|
||||
new_tokens=[60, 12, 13],
|
||||
# The request has stop words, but no draft is created
|
||||
# Tokens at indices greater than 0 should be ignored
|
||||
num_draft_tokens=0,
|
||||
finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1, 2, 3, 4],
|
||||
end_id=99,
|
||||
new_tokens=[55, 99, 58],
|
||||
finish_reasons=[NOT_FINISHED, END_ID, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[4, 5, 6],
|
||||
max_new_tokens=2,
|
||||
new_tokens=[56, 57, 59],
|
||||
# The LENGTH check happens to not have an early exit
|
||||
finish_reasons=[NOT_FINISHED, LENGTH, LENGTH],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1, 12],
|
||||
stop_words_list=[[12, 13], [14, 15]],
|
||||
new_tokens=[13, 14, 15],
|
||||
# We don't use early exit to avoid stream synchronization for stop words
|
||||
finish_reasons=[STOP_WORDS, NOT_FINISHED, STOP_WORDS],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1, 12],
|
||||
stop_words_list=[[12, 13, 14, 15], [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]],
|
||||
new_tokens=[13, 14, 15],
|
||||
# Stop words of different lengths are handled correctly with respect to padding of stop words and tokens
|
||||
finish_reasons=[NOT_FINISHED, NOT_FINISHED, STOP_WORDS],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1],
|
||||
max_new_tokens=2,
|
||||
end_id=99,
|
||||
stop_words_list=[[1, 12]],
|
||||
new_tokens=[12, 99, 63],
|
||||
# Different infractions are written to different places as
|
||||
# we don't have an early exit between infractions
|
||||
finish_reasons=[STOP_WORDS, END_ID, LENGTH],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1, 12, 56, 67, 68, 234, 678],
|
||||
stop_words_list=[[12, 56, 67, 68, 234, 678, 129, 182]],
|
||||
new_tokens=[129, 182, 600],
|
||||
# Notice the offending stop sequence is concatenated, as we lookback
|
||||
finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED],
|
||||
),
|
||||
RequestCase(
|
||||
prompt=[1, 12],
|
||||
end_id=99,
|
||||
max_new_tokens=1,
|
||||
stop_words_list=[[1, 12, 99]],
|
||||
new_tokens=[99, 100, 101],
|
||||
# The latest infraction check overrides the earlier infraction checks,
|
||||
# hence the first finish_reason is END_ID
|
||||
finish_reasons=[END_ID, LENGTH, LENGTH],
|
||||
),
|
||||
]
|
||||
)
|
||||
run()
|
||||
|
||||
|
||||
def test_are_stop_words_isnt_called_when_no_stop_words():
|
||||
"""We don't want to call are_stop_words when there are no stop words because it's expensive"""
|
||||
|
||||
def stop_words_that_raises(*args, **kwargs):
|
||||
raise AssertionError
|
||||
|
||||
run_with_stop_words, sampler = RequestCase.setup(
|
||||
[
|
||||
RequestCase(
|
||||
prompt=[1], stop_words_list=[[1]], new_tokens=[4], finish_reasons=[NOT_FINISHED]
|
||||
)
|
||||
]
|
||||
)
|
||||
sampler._are_stop_words = stop_words_that_raises
|
||||
with pytest.raises(AssertionError):
|
||||
run_with_stop_words()
|
||||
|
||||
run_without_stop_words, sampler = RequestCase.setup(
|
||||
[RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[NOT_FINISHED])]
|
||||
)
|
||||
sampler._are_stop_words = stop_words_that_raises
|
||||
_ = run_without_stop_words()
|
||||
uut_provider_without_stop_words = cls.RequestCase.build(
|
||||
[cls.RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[cls.NOT_FINISHED])],
|
||||
extra_context=lambda: raising_stop_words_ctx(False),
|
||||
)
|
||||
_run_test_with_warmup(uut_provider_without_stop_words, max_sync_s=0.5)
|
||||
|
||||
|
||||
class TestBatchedSampling:
|
||||
@ -1485,7 +1514,10 @@ class TestBatchedSampling:
|
||||
|
||||
logit_offset += steps
|
||||
|
||||
_run_test_with_warmup(_uut_provider)
|
||||
_run_test_with_warmup(
|
||||
_uut_provider,
|
||||
max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample
|
||||
)
|
||||
|
||||
def _compute_probs(
|
||||
self,
|
||||
@ -2197,7 +2229,10 @@ class TestBatchedSampling:
|
||||
num_samples=num_samples,
|
||||
)
|
||||
|
||||
_run_test_with_warmup(_uut_provider)
|
||||
_run_test_with_warmup(
|
||||
_uut_provider,
|
||||
max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_seq_slot_assignments() -> list[tuple[list[int], int, str]]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user