diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index 9f64b58771..3557b5c038 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -373,7 +373,8 @@ class UutProvider(Protocol): def _run_test_with_warmup( uut_provider: UutProvider, warmup_sizes_bytes: tuple[int] = (4 * 2**30,), - max_sync_s: Optional[float] = None, + *, + max_sync_s: Optional[float], ): """Run UUT including setup and warmup. @@ -635,231 +636,259 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool) _run_test_with_warmup(_test_runner, max_sync_s=0.3) -MAX_NUM_SEQUENCES = 128 -NOT_FINISHED = FinishReason.NOT_FINISHED -STOP_WORDS = FinishReason.STOP_WORDS -END_ID = FinishReason.END_ID -LENGTH = FinishReason.LENGTH -BEAM = 0 +class TestFinishReasons: + NOT_FINISHED = FinishReason.NOT_FINISHED + STOP_WORDS = FinishReason.STOP_WORDS + END_ID = FinishReason.END_ID + LENGTH = FinishReason.LENGTH + class RequestCase: + MAX_NEW_TOKENS = 10 + MAX_NUM_SEQUENCES = 128 + seq_slots = torch.randperm(MAX_NUM_SEQUENCES).tolist() + BEAM = 0 -class RequestCase: - MAX_NEW_TOKENS = 10 - seq_slots = torch.randperm(MAX_NUM_SEQUENCES).tolist() + def __init__( + self, + *, + prompt: list[int], + new_tokens: list[int], + finish_reasons: list[FinishReason], + max_new_tokens: int = MAX_NEW_TOKENS, + end_id: Optional[int] = None, + num_draft_tokens: int | None = None, + stop_words_list: Optional[list[list[int]]] = None, + ): + seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES + self.prompt = prompt + if num_draft_tokens is None: + num_draft_tokens = len(new_tokens) - 1 + self.request = LlmRequest( + request_id=seq_slot, + seq_slot=seq_slot, + input_tokens=prompt, + max_new_tokens=max_new_tokens, + stop_words_list=convert_wordlist(stop_words_list) + if stop_words_list is not None + else None, + end_id=end_id, + sampling_config=SamplingConfig(), + is_streaming=False, + draft_tokens=new_tokens[:num_draft_tokens], + ) + assert len(new_tokens) == len(finish_reasons) + self.new_tokens = new_tokens + self.finish_reasons = finish_reasons - def __init__( - self, - *, - prompt: list[int], - new_tokens: list[int], - finish_reasons: list[FinishReason], - max_new_tokens: int = MAX_NEW_TOKENS, - end_id: Optional[int] = None, - num_draft_tokens: int | None = None, - stop_words_list: Optional[list[list[int]]] = None, - ): - seq_slot = self.seq_slots.pop() # random seq slot in MAX_NUM_SEQUENCES - self.prompt = prompt - if num_draft_tokens is None: - num_draft_tokens = len(new_tokens) - 1 - self.request = LlmRequest( - request_id=seq_slot, - seq_slot=seq_slot, - input_tokens=prompt, - max_new_tokens=max_new_tokens, - stop_words_list=convert_wordlist(stop_words_list) - if stop_words_list is not None - else None, - end_id=end_id, - sampling_config=SamplingConfig(), - is_streaming=False, - draft_tokens=new_tokens[:num_draft_tokens], - ) - assert len(new_tokens) == len(finish_reasons) - self.new_tokens = new_tokens - self.finish_reasons = finish_reasons + def __repr__(self): + return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, \ + {self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})" - def __repr__(self): - return f"RequestCase({self.prompt=}, {self.new_tokens=}, {self.finish_reasons=}, \ - {self.request.max_new_tokens=}, {self.request.end_id=}, {self.request.stop_words_list=})" - - @staticmethod - def setup(requests: list["RequestCase"]): - max_tokens = set(len(req.new_tokens) for req in requests) - assert len(max_tokens) == 1 - max_draft_len = max_tokens.pop() - 1 - sampler_args = TorchSampler.Args( - max_seq_len=20, - max_draft_len=max_draft_len, - max_total_draft_tokens=max_draft_len, - # Fill with many more max requests than below, - # so we can test that write_finish_reasons uses seq_slots correctly - max_num_sequences=MAX_NUM_SEQUENCES, - max_beam_width=1, - disable_overlap_scheduler=False, - ) - sampler = TorchSampler(args=sampler_args) - - # fill with garbage value so we can observe that finish reasons are filled - # with NOT_FINISHED before we write to them. - sampler.store.finish_reasons.fill_(205) - seq_slots = torch.tensor( - [req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64 - ) - seq_lens = torch.tensor( - [req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda" - ) - new_tokens = torch.tensor( - [req.new_tokens for req in requests], dtype=torch.int32, device="cuda" - ).T - sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens - max_seq_lens = torch.tensor( - [ - min( - sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens + @classmethod + def build( + cls, + requests: list["TestFinishReasons.RequestCase"], + *, + check_no_cuda_sync: bool = True, + extra_context: Callable[[], ContextManager] | None = None, + ) -> UutProvider: + @contextmanager + def _uut_provider(is_warmup: bool) -> Generator[Callable[[], None], None, None]: + max_tokens = set(len(req.new_tokens) for req in requests) + assert len(max_tokens) == 1 + max_draft_len = max_tokens.pop() - 1 + sampler_args = TorchSampler.Args( + max_seq_len=20, + max_draft_len=max_draft_len, + max_total_draft_tokens=max_draft_len, + # Fill with many more max requests than below, + # so we can test that write_finish_reasons uses seq_slots correctly + max_num_sequences=cls.MAX_NUM_SEQUENCES, + max_beam_width=1, + disable_overlap_scheduler=False, ) - for req in requests - ], - dtype=torch.int32, - device="cuda", - ) - end_ids = torch.tensor( + sampler = TorchSampler(args=sampler_args) + + # fill with garbage value so we can observe that finish reasons are filled + # with NOT_FINISHED before we write to them. + sampler.store.finish_reasons.fill_(205) + seq_slots = torch.tensor( + [req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64 + ) + seq_lens = torch.tensor( + [req.request.max_beam_num_tokens for req in requests], + dtype=torch.int32, + device="cuda", + ) + new_tokens = torch.tensor( + [req.new_tokens for req in requests], dtype=torch.int32, device="cuda" + ).T + sampler.store.new_tokens[:, seq_slots, cls.BEAM] = new_tokens + max_seq_lens = torch.tensor( + [ + min( + sampler.max_seq_len, + req.request.orig_prompt_len + req.request.py_max_new_tokens, + ) + for req in requests + ], + dtype=torch.int32, + device="cuda", + ) + end_ids = torch.tensor( + [ + req.request.py_end_id if req.request.py_end_id is not None else -1 + for req in requests + ], + dtype=torch.int32, + device="cuda", + ) + sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens + sampler.store.end_ids[seq_slots] = end_ids + + def _uut(): + with extra_context() if extra_context is not None else nullcontext(): + sampler._write_finish_reasons( + [req.request for req in requests], + finish_reasons=sampler.store.finish_reasons, + new_tokens=sampler.store.new_tokens, + seq_lens=seq_lens, + seq_slots=seq_slots, + ) + + yield _uut + + reasons = sampler.store.finish_reasons[:, seq_slots, cls.BEAM].T.tolist() + + for actual, request in zip(reasons, requests, strict=True): + expected = request.finish_reasons + msg = f"actual={[FinishReason(reason) for reason in actual]} != expected={expected}\nFor {request}" + assert actual == [reason.value for reason in expected], msg + + return _uut_provider + + @classmethod + def test_write_finish_reasons(cls): + """We don't really care about the finish reason past the first infraction, because we're not going to use it, + although in some instance it is written anyway.""" + uut_provider = cls.RequestCase.build( [ - req.request.py_end_id if req.request.py_end_id is not None else -1 - for req in requests - ], - dtype=torch.int32, - device="cuda", + cls.RequestCase( + prompt=[13, 14], + new_tokens=[60, 61, 62], + # We pre-fill the finish reasons with NOT_FINISHED. + finish_reasons=[cls.NOT_FINISHED, cls.NOT_FINISHED, cls.NOT_FINISHED], + ), + cls.RequestCase( + prompt=[7, 8, 6], + stop_words_list=[[12, 13]], + new_tokens=[12, 13, 60], + finish_reasons=[cls.NOT_FINISHED, cls.STOP_WORDS, cls.NOT_FINISHED], + ), + cls.RequestCase( + prompt=[7, 8, 6], + stop_words_list=[[12, 13]], + new_tokens=[60, 12, 13], + # The request has stop words, but no draft is created + # Tokens at indices greater than 0 should be ignored + num_draft_tokens=0, + finish_reasons=[cls.NOT_FINISHED, cls.NOT_FINISHED, cls.NOT_FINISHED], + ), + cls.RequestCase( + prompt=[1, 2, 3, 4], + end_id=99, + new_tokens=[55, 99, 58], + finish_reasons=[cls.NOT_FINISHED, cls.END_ID, cls.NOT_FINISHED], + ), + cls.RequestCase( + prompt=[4, 5, 6], + max_new_tokens=2, + new_tokens=[56, 57, 59], + # The LENGTH check happens to not have an early exit + finish_reasons=[cls.NOT_FINISHED, cls.LENGTH, cls.LENGTH], + ), + cls.RequestCase( + prompt=[1, 12], + stop_words_list=[[12, 13], [14, 15]], + new_tokens=[13, 14, 15], + # We don't use early exit to avoid stream synchronization for stop words + finish_reasons=[cls.STOP_WORDS, cls.NOT_FINISHED, cls.STOP_WORDS], + ), + cls.RequestCase( + prompt=[1, 12], + stop_words_list=[ + [12, 13, 14, 15], + [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + ], + new_tokens=[13, 14, 15], + # Stop words of different lengths are handled correctly with respect to padding of stop words + # and tokens + finish_reasons=[cls.NOT_FINISHED, cls.NOT_FINISHED, cls.STOP_WORDS], + ), + cls.RequestCase( + prompt=[1], + max_new_tokens=2, + end_id=99, + stop_words_list=[[1, 12]], + new_tokens=[12, 99, 63], + # Different infractions are written to different places as + # we don't have an early exit between infractions + finish_reasons=[cls.STOP_WORDS, cls.END_ID, cls.LENGTH], + ), + cls.RequestCase( + prompt=[1, 12, 56, 67, 68, 234, 678], + stop_words_list=[[12, 56, 67, 68, 234, 678, 129, 182]], + new_tokens=[129, 182, 600], + # Notice the offending stop sequence is concatenated, as we lookback + finish_reasons=[cls.NOT_FINISHED, cls.STOP_WORDS, cls.NOT_FINISHED], + ), + cls.RequestCase( + prompt=[1, 12], + end_id=99, + max_new_tokens=1, + stop_words_list=[[1, 12, 99]], + new_tokens=[99, 100, 101], + # The latest infraction check overrides the earlier infraction checks, + # hence the first finish_reason is END_ID + finish_reasons=[cls.END_ID, cls.LENGTH, cls.LENGTH], + ), + ] ) - sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens - sampler.store.end_ids[seq_slots] = end_ids - def run(): - sampler._write_finish_reasons( - [req.request for req in requests], - finish_reasons=sampler.store.finish_reasons, - new_tokens=sampler.store.new_tokens, - seq_lens=seq_lens, - seq_slots=seq_slots, - ) + _run_test_with_warmup(uut_provider, max_sync_s=0.5) - reasons = sampler.store.finish_reasons[:, seq_slots, BEAM].T.tolist() + @classmethod + def test_are_stop_words_isnt_called_when_no_stop_words(cls, monkeypatch: pytest.MonkeyPatch): + """We don't want to call are_stop_words when there are no stop words because it's expensive""" - for actual, request in zip(reasons, requests, strict=True): - expected = request.finish_reasons - msg = f"actual={[FinishReason(reason) for reason in actual]} != expected={expected}\nFor {request}" - assert actual == [reason.value for reason in expected], msg + def stop_words_that_raises(*args, **kwargs): + raise AssertionError - return run, sampler + @contextmanager + def raising_stop_words_ctx(expect_raise: bool) -> Generator[None, None, None]: + with monkeypatch.context() as patch_ctx: + patch_ctx.setattr(TorchSampler, "_are_stop_words", stop_words_that_raises) + with pytest.raises(AssertionError) if expect_raise else nullcontext(): + yield + uut_provider_with_stop_words = cls.RequestCase.build( + [ + cls.RequestCase( + prompt=[1], + stop_words_list=[[1]], + new_tokens=[4], + finish_reasons=[cls.NOT_FINISHED], + ) + ], + extra_context=lambda: raising_stop_words_ctx(True), + ) + _run_test_with_warmup(uut_provider_with_stop_words, max_sync_s=0.5) -def test_write_finish_reasons(): - """We don't really care about the finish reason past the first infraction, because we're not going to use it, - although in some instance it is written anyway.""" - run, _ = RequestCase.setup( - [ - RequestCase( - prompt=[13, 14], - new_tokens=[60, 61, 62], - # We pre-fill the finish reasons with NOT_FINISHED. - finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED], - ), - RequestCase( - prompt=[7, 8, 6], - stop_words_list=[[12, 13]], - new_tokens=[12, 13, 60], - finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED], - ), - RequestCase( - prompt=[7, 8, 6], - stop_words_list=[[12, 13]], - new_tokens=[60, 12, 13], - # The request has stop words, but no draft is created - # Tokens at indices greater than 0 should be ignored - num_draft_tokens=0, - finish_reasons=[NOT_FINISHED, NOT_FINISHED, NOT_FINISHED], - ), - RequestCase( - prompt=[1, 2, 3, 4], - end_id=99, - new_tokens=[55, 99, 58], - finish_reasons=[NOT_FINISHED, END_ID, NOT_FINISHED], - ), - RequestCase( - prompt=[4, 5, 6], - max_new_tokens=2, - new_tokens=[56, 57, 59], - # The LENGTH check happens to not have an early exit - finish_reasons=[NOT_FINISHED, LENGTH, LENGTH], - ), - RequestCase( - prompt=[1, 12], - stop_words_list=[[12, 13], [14, 15]], - new_tokens=[13, 14, 15], - # We don't use early exit to avoid stream synchronization for stop words - finish_reasons=[STOP_WORDS, NOT_FINISHED, STOP_WORDS], - ), - RequestCase( - prompt=[1, 12], - stop_words_list=[[12, 13, 14, 15], [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]], - new_tokens=[13, 14, 15], - # Stop words of different lengths are handled correctly with respect to padding of stop words and tokens - finish_reasons=[NOT_FINISHED, NOT_FINISHED, STOP_WORDS], - ), - RequestCase( - prompt=[1], - max_new_tokens=2, - end_id=99, - stop_words_list=[[1, 12]], - new_tokens=[12, 99, 63], - # Different infractions are written to different places as - # we don't have an early exit between infractions - finish_reasons=[STOP_WORDS, END_ID, LENGTH], - ), - RequestCase( - prompt=[1, 12, 56, 67, 68, 234, 678], - stop_words_list=[[12, 56, 67, 68, 234, 678, 129, 182]], - new_tokens=[129, 182, 600], - # Notice the offending stop sequence is concatenated, as we lookback - finish_reasons=[NOT_FINISHED, STOP_WORDS, NOT_FINISHED], - ), - RequestCase( - prompt=[1, 12], - end_id=99, - max_new_tokens=1, - stop_words_list=[[1, 12, 99]], - new_tokens=[99, 100, 101], - # The latest infraction check overrides the earlier infraction checks, - # hence the first finish_reason is END_ID - finish_reasons=[END_ID, LENGTH, LENGTH], - ), - ] - ) - run() - - -def test_are_stop_words_isnt_called_when_no_stop_words(): - """We don't want to call are_stop_words when there are no stop words because it's expensive""" - - def stop_words_that_raises(*args, **kwargs): - raise AssertionError - - run_with_stop_words, sampler = RequestCase.setup( - [ - RequestCase( - prompt=[1], stop_words_list=[[1]], new_tokens=[4], finish_reasons=[NOT_FINISHED] - ) - ] - ) - sampler._are_stop_words = stop_words_that_raises - with pytest.raises(AssertionError): - run_with_stop_words() - - run_without_stop_words, sampler = RequestCase.setup( - [RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[NOT_FINISHED])] - ) - sampler._are_stop_words = stop_words_that_raises - _ = run_without_stop_words() + uut_provider_without_stop_words = cls.RequestCase.build( + [cls.RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[cls.NOT_FINISHED])], + extra_context=lambda: raising_stop_words_ctx(False), + ) + _run_test_with_warmup(uut_provider_without_stop_words, max_sync_s=0.5) class TestBatchedSampling: @@ -1485,7 +1514,10 @@ class TestBatchedSampling: logit_offset += steps - _run_test_with_warmup(_uut_provider) + _run_test_with_warmup( + _uut_provider, + max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample + ) def _compute_probs( self, @@ -2197,7 +2229,10 @@ class TestBatchedSampling: num_samples=num_samples, ) - _run_test_with_warmup(_uut_provider) + _run_test_with_warmup( + _uut_provider, + max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample + ) @staticmethod def _build_seq_slot_assignments() -> list[tuple[list[int], int, str]]: