diff --git a/tests/unittest/_torch/sampler/test_beam_search.py b/tests/unittest/_torch/sampler/test_beam_search.py index 169fb0a6de..b987dcfb0b 100644 --- a/tests/unittest/_torch/sampler/test_beam_search.py +++ b/tests/unittest/_torch/sampler/test_beam_search.py @@ -12,21 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc +import os import pathlib as _pl -from typing import Any +from contextlib import contextmanager, nullcontext +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Callable, Generator, cast import pytest import torch from test_beam_search_util import (BeamSearchTestOutput, DummyConfigLoader, DummyWeightLoader, get_expected_outputs) from utils.llm_data import llm_models_root -from utils.util import assert_no_cuda_sync, force_ampere +from utils.util import assert_no_cuda_sync, force_ampere, run_test_with_warmup -from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm import LLM, SamplingParams, TorchLlmArgs from tensorrt_llm._torch.models.checkpoints import HfCheckpointLoader from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, SamplingConfig) -from tensorrt_llm._torch.pyexecutor.sampler import BeamHistory, TorchSampler +from tensorrt_llm._torch.pyexecutor.sampler import (BeamHistory, + SampleStateTorch, + TorchSampler) from tensorrt_llm._torch.pyexecutor.sampling_utils import ( BEAM_SEARCH_PAD_TOKEN, BeamSearchMetadata, FinishReason, beam_search_sampling_batch) @@ -68,43 +75,63 @@ def model_kwargs(fixed_params, sampling_information) -> dict[str, Any]: ) -def _build_llm(fixed_params, input_prompts, model_kwargs): +@pytest.fixture(scope="module", params=[False, True]) +def with_cuda_graph_and_overlap(request): + return request.param + + +def _build_llm(fixed_params, input_prompts, llm_kwargs): return LLM( - **model_kwargs, - kv_cache_config=KvCacheConfig(max_tokens=10000), + **llm_kwargs, + kv_cache_config=KvCacheConfig( + max_tokens=10000, # type: ignore + ), max_batch_size=fixed_params["max_beam_width"] * len( input_prompts ), # use small batch size to prevent large buffers from possibly hiding wrong data accesses. max_seq_len=32, max_beam_width=fixed_params["max_beam_width"], - disable_overlap_scheduler=True, - cuda_graph_config=None, ) -@pytest.fixture(scope="module") -def llm(fixed_params, input_prompts, model_kwargs): - llm = _build_llm(fixed_params, input_prompts, model_kwargs) - yield llm - llm.shutdown() +@contextmanager +def _single_process_context(): + os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1" + try: + yield + finally: + del os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] +# NB: It is important that all tests instantiating 'LLM' with +# TLLM_WORKER_USE_SINGLE_PROCESS=1 (i.e., single_process=True below) +# use this fixture. Otherwise, more than one such 'LLM' object +# could be alive at any given point in time and this has been +# found to result in corruption of the cache_indirection tensors. @pytest.fixture(scope="module") -def llm_cuda_graph(fixed_params, input_prompts, model_kwargs): - llm = LLM( - **model_kwargs, - kv_cache_config=KvCacheConfig(max_tokens=10000), - max_batch_size=fixed_params["max_beam_width"] * len( - input_prompts - ), # use small batch size to prevent large buffers from possibly hiding wrong data accesses. - max_seq_len=32, - max_beam_width=fixed_params["max_beam_width"], - disable_overlap_scheduler=False, - cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8], - enable_padding=True), - ) - yield llm - llm.shutdown() +def llm(fixed_params, input_prompts, model_kwargs, single_process: bool, + with_cuda_graph_and_overlap: bool): + gc.collect( + 2) # force destruction of any other LLM instances (cf. comment above) + with _single_process_context() if single_process else nullcontext(): + llm = _build_llm( + fixed_params, + input_prompts, + llm_kwargs=( + (dict( + disable_overlap_scheduler=True, + cuda_graph_config=None, + ) if not with_cuda_graph_and_overlap else dict( + disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8], + enable_padding=True), + )) + | + deepcopy( # LLM.shutdown resets checkpoint_loader.config_loader + model_kwargs)), + ) + with llm: + yield llm def check_generation_logits(beam: CompletionOutput, @@ -125,6 +152,7 @@ def check_generation_logits(beam: CompletionOutput, def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams, valid_tokens: int | None) -> None: """Check if the logprobs have the correct shape""" + assert beam.logprobs is not None if sampling_params.logprobs is not None: generated_tokens = valid_tokens if valid_tokens is not None else sampling_params.max_tokens assert len( @@ -132,6 +160,7 @@ def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams, ) == generated_tokens, f"expected {generated_tokens} logprobs, but got {len(beam.logprobs)}" log_sum = 0.0 for logprob_dict in (beam.logprobs): + assert isinstance(logprob_dict, dict) for logprob_value in logprob_dict.values(): log_sum += logprob_value.logprob assert log_sum == beam.cumulative_logprob, f"expected {beam.cumulative_logprob} logprob, but got {log_sum}" @@ -145,6 +174,7 @@ def check_cache_indirection(beam: CompletionOutput, prompt_length: int, beam_idx: int, valid_tokens: int | None) -> None: """Check if the cache indirection seen by the model is the same as the expected cache indirection""" + assert beam.additional_generation_outputs is not None cache_indirection = beam.additional_generation_outputs["cache_indirection"] assert cache_indirection is not None, "cache indirection should not be None" assert cache_indirection.shape[ @@ -207,6 +237,11 @@ def check_context_logits(output: GenerationResult, assert output.context_logits is None, "context logits should be None" +@pytest.fixture(scope="module", params=[False, True]) +def single_process(request) -> bool: + return request.param + + def validate_output(output: GenerationResult, input_prompt: list[int], sampling_params: SamplingParams) -> None: """Perform several checks on the output of a single prompt""" @@ -226,11 +261,63 @@ def validate_output(output: GenerationResult, input_prompt: list[int], def validate_outputs(llm: LLM, input_prompts: list[list[int]], - sampling_params: SamplingParams) -> None: + sampling_params: SamplingParams, + monkeypatch: pytest.MonkeyPatch, + check_no_sync: bool) -> None: """Generate outputs for a list of prompts and validate the outputs""" - outputs = llm.generate(input_prompts, sampling_params=sampling_params) - num_prompts = len(input_prompts) + outputs = llm.generate(deepcopy(input_prompts), + sampling_params=deepcopy(sampling_params)) + + if check_no_sync: + del outputs # treat previous .generate as warmup, ignore results + + with monkeypatch.context() as patcher: + sample_async_orig = TorchSampler.sample_async + update_requests_orig = TorchSampler.update_requests + + _sample_async_hook_called = False + _update_requests_hook_called = False + + def _sample_async_hook(*args, **kwargs): + nonlocal _sample_async_hook_called + _sample_async_hook_called = True + + with assert_no_cuda_sync(): + return sample_async_orig(*args, **kwargs) + + def _update_requests_hook(self, state: SampleStateTorch, *args, + **kwargs): + nonlocal _update_requests_hook_called + _update_requests_hook_called = True + + # await sampling event outside sync-check (because this syncs) + sampler_event = state.sampler_event + if sampler_event: + sampler_event.synchronize() + + with assert_no_cuda_sync(): + state.sampler_event = None + try: + return update_requests_orig(self, state, *args, + **kwargs) + finally: + state.sampler_event = sampler_event + + # Intercept sampler methods to check that they do not sync (requires + # TLLM_WORKER_USE_SINGLE_PROCESS). + patcher.setattr(TorchSampler, "sample_async", _sample_async_hook) + patcher.setattr(TorchSampler, "update_requests", + _update_requests_hook) + + outputs = llm.generate(deepcopy(input_prompts), + sampling_params=deepcopy(sampling_params)) + + assert _sample_async_hook_called + assert _update_requests_hook_called + + num_prompts = len(input_prompts) + assert isinstance(outputs, list) assert len( outputs ) == num_prompts, f"expected {num_prompts} outputs, but got {len(outputs)}" @@ -247,73 +334,45 @@ def validate_outputs(llm: LLM, input_prompts: list[list[int]], @pytest.mark.parametrize("gather_generation_logits", [True, False]) @pytest.mark.parametrize("gather_context_logits", [True, False]) @pytest.mark.parametrize("num_output_beams", [1, 2]) -@pytest.mark.parametrize("num_prompts", [1, 2]) +@pytest.mark.parametrize("num_prompts", [1, 2, 3]) +@pytest.mark.parametrize("stop_token_ids", [[15], None]) @pytest.mark.threadleak(enabled=False) def test_beam_search_e2e( - gather_context_logits: bool, - gather_generation_logits: bool, - return_log_probs: bool, - num_output_beams: int, - num_prompts: int, - llm, - fixed_params, - input_prompts, -) -> None: - if return_log_probs and num_prompts > 1 and llm.args.sampler_type == "TRTLLMSampler": - pytest.skip( - "Beam search currently does not support return_log_probs with multiple prompts" - ) - if return_log_probs and llm.args.sampler_type == "TRTLLMSampler": - pytest.skip( - "Beam search on TRTLLMSampler does not correctly handle log_probs if called multiple times" - ) - - # create sampling parameters - # additional_model_outputs is used to gather the cache indirection from the model. - sampling_params = SamplingParams( - max_tokens=fixed_params["max_tokens"], - n=num_output_beams, - best_of=fixed_params["max_beam_width"], - use_beam_search=True, - return_context_logits=gather_context_logits, - return_generation_logits=gather_generation_logits, - logprobs=return_log_probs, - end_id=-1, - additional_model_outputs=["cache_indirection"], - ) - validate_outputs(llm, input_prompts[:num_prompts], sampling_params) - - -@pytest.mark.parametrize("return_log_probs", [True, False]) -@pytest.mark.parametrize("gather_generation_logits", [True, False]) -@pytest.mark.parametrize("gather_context_logits", [True, False]) -@pytest.mark.parametrize("num_output_beams", [1, 2]) -@pytest.mark.parametrize("num_prompts", [1, 2, 3]) -@pytest.mark.parametrize("stop_token_ids", [[15], None]) -@pytest.mark.threadleak(enabled=False) -def test_beam_search_e2e_cuda_graph_and_overlap( gather_context_logits: bool, gather_generation_logits: bool, return_log_probs: bool, num_output_beams: int, num_prompts: int, stop_token_ids: list[int] | None, - llm_cuda_graph, + llm: LLM, fixed_params, input_prompts, + single_process: bool, + monkeypatch: pytest.MonkeyPatch, ) -> None: - if return_log_probs and num_prompts > 1 and llm_cuda_graph.args.sampler_type == "TRTLLMSampler": + llm_args = cast(TorchLlmArgs, llm.args) + check_no_sync = single_process # single_process only used for sync check + if check_no_sync and cast(TorchLlmArgs, + llm.args).sampler_type != "TorchSampler": + pytest.skip("Sync check only supported for TorchSampler") + if check_no_sync and (cast(TorchLlmArgs, llm.args).sampler_type + == "TorchSampler") and stop_token_ids is not None: + # FIXME: Fix TorchSampler._are_stop_words + pytest.skip("Stop word handling in TorchSampler syncs") + + if return_log_probs and num_prompts > 1 and llm_args.sampler_type == "TRTLLMSampler": pytest.skip( "Beam search currently does not support return_log_probs with multiple prompts" ) - if return_log_probs and llm_cuda_graph.args.sampler_type == "TRTLLMSampler": + if return_log_probs and llm_args.sampler_type == "TRTLLMSampler": pytest.skip( "Beam search on TRTLLMSampler does not correctly handle log_probs if called multiple times" ) - if stop_token_ids is not None and llm_cuda_graph.args.sampler_type == "TRTLLMSampler": + if stop_token_ids is not None and llm_args.sampler_type == "TRTLLMSampler": pytest.skip( "Beam search on TRTLLMSampler does not correctly handle stop_token_ids" ) + # create sampling parameters # additional_model_outputs is used to gather the cache indirection from the model. sampling_params = SamplingParams( @@ -328,8 +387,13 @@ def test_beam_search_e2e_cuda_graph_and_overlap( stop_token_ids=stop_token_ids, additional_model_outputs=["cache_indirection"], ) - validate_outputs(llm_cuda_graph, input_prompts[:num_prompts], - sampling_params) + validate_outputs( + llm, + input_prompts[:num_prompts], + sampling_params, + check_no_sync=check_no_sync, + monkeypatch=monkeypatch, + ) ########################################################################### @@ -440,6 +504,7 @@ def test_beam_search_sampling_batch_basic(): ) # Validate output shapes + assert softmax is not None expected_tokens_shape = (batch_size, beam_width) assert next_tokens.shape == expected_tokens_shape, ( f"Expected shape {expected_tokens_shape}, got {next_tokens.shape}") @@ -564,9 +629,11 @@ def create_default_sampler(test_params: GeneralTestParams) -> TorchSampler: assert max_seq_len > test_params.seq_len, "Max sequence length must be greater than sequence length" assert max_batch_size > test_params.batch_size, "Max batch size must be greater than batch size" assert max_batch_size > test_params.seq_slot, "Max batch size must be greater than sequence slot" + assert sampler.store.cache_indirection is not None assert sampler.store.cache_indirection.shape == ( max_batch_size, max_beam_width, max_seq_len), "Cache indirection shape mismatch" + assert sampler.store.original_tokens is not None assert sampler.store.original_tokens.shape == ( max_batch_size, max_beam_width, max_seq_len), "Original tokens shape mismatch" @@ -579,121 +646,154 @@ def test_create_beam_history(): This test verifies that beam history is correctly reconstructed by following the cache_indirection backwards to obtain the correct token sequence. """ - test_params = GeneralTestParams() - request = create_default_request(test_params) - sampler = create_default_sampler(test_params) - # Extract parameters from the test parameters - beam_width = test_params.beam_width - prompt_len = test_params.prompt_len - num_generated_tokens = test_params.num_generated_tokens - seq_slot = test_params.seq_slot - vocab_size = test_params.vocab_size - num_logprobs = test_params.num_logprobs + 1 - cache_indirection = sampler.store.cache_indirection - original_tokens = sampler.store.original_tokens - original_logprobs = torch.zeros( - (beam_width, num_generated_tokens, num_logprobs), - dtype=torch.float32, - device=original_tokens.device) - original_logprob_indices = torch.zeros( - (beam_width, num_generated_tokens, num_logprobs), - dtype=torch.int32, - device=original_tokens.device) - original_cum_logprobs = sampler.store.cum_log_probs + @contextmanager + def _uut_provider( + is_warmup: bool) -> Generator[Callable[[], None], None, None]: + test_params = GeneralTestParams() + request = create_default_request(test_params) + sampler = create_default_sampler(test_params) - # Fill the request with some random tokens that will be overwritten by the beam search sampling - # Beam history is created before add_token is called - request.set_generated_tokens( - torch.randint(0, - vocab_size, (beam_width, num_generated_tokens - 1), - dtype=torch.int32).tolist()) - # random fill - torch.manual_seed(42) - original_tokens[seq_slot, :beam_width, prompt_len:prompt_len + - num_generated_tokens] = torch.randint( - 0, - beam_width, (beam_width, num_generated_tokens), - dtype=torch.int32) - assert original_tokens.sum( - ) > 0, "Original tokens must not only contain zeros. Otherwise change the seed." + # Extract parameters from the test parameters + beam_width = test_params.beam_width + prompt_len = test_params.prompt_len + num_generated_tokens = test_params.num_generated_tokens + seq_slot = test_params.seq_slot + vocab_size = test_params.vocab_size + num_logprobs = test_params.num_logprobs + 1 + cache_indirection = sampler.store.cache_indirection + assert cache_indirection is not None + original_tokens = sampler.store.original_tokens + assert original_tokens is not None + original_logprobs = torch.zeros( + (beam_width, num_generated_tokens, num_logprobs), + dtype=torch.float32, + device=original_tokens.device) + original_logprob_indices = torch.zeros( + (beam_width, num_generated_tokens, num_logprobs), + dtype=torch.int32, + device=original_tokens.device) + original_cum_logprobs = sampler.store.cum_log_probs + assert original_cum_logprobs is not None - original_logprobs[:beam_width] = torch.randn( - (beam_width, num_generated_tokens, original_logprobs.shape[-1]), - dtype=torch.float32) - original_logprob_indices[:beam_width] = torch.randint( - 0, - vocab_size, - (beam_width, num_generated_tokens, original_logprobs.shape[-1]), - dtype=torch.int32) - assert (original_logprobs != 0).sum( - ) > 0, "Original log probs must not only contain zeros. Otherwise change the seed." - assert (original_logprob_indices).sum( - ) > 0, "Original log prob indices must not only contain zeros. Otherwise change the seed." + # Fill the request with some random tokens that will be overwritten by the beam search sampling + # Beam history is created before add_token is called + request.set_generated_tokens( + torch.randint(0, + vocab_size, (beam_width, num_generated_tokens - 1), + dtype=torch.int32).tolist()) + # random fill + torch.manual_seed(42) + original_tokens[seq_slot, :beam_width, prompt_len:prompt_len + + num_generated_tokens] = torch.randint( + 0, + beam_width, (beam_width, num_generated_tokens), + dtype=torch.int32) + assert original_tokens.sum( + ) > 0, "Original tokens must not only contain zeros. Otherwise change the seed." - # set the logprobs in the request: - token_logprobs = sampler._convert_logprobs_tensor_to_list( - original_logprob_indices[:beam_width, :num_generated_tokens - 1], - original_logprobs[:beam_width, :num_generated_tokens - 1], - ) - request.py_result.set_log_probs( - token_logprobs, - cum_log_probs=torch.zeros_like( - original_cum_logprobs[seq_slot, :beam_width]).tolist()) + original_logprobs[:beam_width] = torch.randn( + (beam_width, num_generated_tokens, original_logprobs.shape[-1]), + dtype=torch.float32) + original_logprob_indices[:beam_width] = torch.randint( + 0, + vocab_size, + (beam_width, num_generated_tokens, original_logprobs.shape[-1]), + dtype=torch.int32) + assert (original_logprobs != 0).sum( + ) > 0, "Original log probs must not only contain zeros. Otherwise change the seed." + assert (original_logprob_indices).sum( + ) > 0, "Original log prob indices must not only contain zeros. Otherwise change the seed." - original_cum_logprobs[seq_slot, :beam_width] = torch.randn( - (beam_width, ), dtype=torch.float32) - assert (original_cum_logprobs != 0).sum( - ) > 0, "Original cumulative log probs must not only contain zeros. Otherwise change the seed." + # set the logprobs in the request: + token_logprobs = sampler._convert_logprobs_tensor_to_list( + original_logprob_indices[:beam_width, :num_generated_tokens - 1], + original_logprobs[:beam_width, :num_generated_tokens - 1], + ) + request.py_result.set_log_probs( + token_logprobs, + cum_log_probs=torch.zeros_like( + original_cum_logprobs[seq_slot, :beam_width]).tolist()) - cache_indirection[seq_slot, :beam_width, prompt_len:prompt_len + - num_generated_tokens] = torch.randint( - 0, - beam_width, (beam_width, num_generated_tokens), - dtype=torch.int32) - assert cache_indirection[ - seq_slot, :beam_width, - prompt_len:prompt_len + num_generated_tokens].sum( - ) > 0, "Deterministic offsets must not only contain zeros. Otherwise change the seed." + original_cum_logprobs[seq_slot, :beam_width] = torch.randn( + (beam_width, ), dtype=torch.float32) + assert (original_cum_logprobs != 0).sum( + ) > 0, "Original cumulative log probs must not only contain zeros. Otherwise change the seed." - # set the new log probs and tokens for the beam search sampling - sampler.store.sampled_log_probs[ - seq_slot, :beam_width] = original_logprobs[:beam_width, - num_generated_tokens - 1, - 0:1] - sampler.store.new_tokens[ - 0, - seq_slot, :beam_width] = original_logprob_indices[:beam_width, - num_generated_tokens - - 1, 0] - # test - beam_history_builder = sampler._prepare_beam_history( - request, finish_reasons=torch.ones((beam_width, ), dtype=torch.int)) - torch.cuda.synchronize() - beam_history = beam_history_builder() + cache_indirection[seq_slot, :beam_width, prompt_len:prompt_len + + num_generated_tokens] = torch.randint( + 0, + beam_width, (beam_width, num_generated_tokens), + dtype=torch.int32) + assert cache_indirection[ + seq_slot, :beam_width, + prompt_len:prompt_len + num_generated_tokens].sum( + ) > 0, "Deterministic offsets must not only contain zeros. Otherwise change the seed." - # expected selection: - # Currently beam history only contains the generated tokens, not the prompt tokens. - expected_tokens = torch.zeros( - (sampler.max_beam_width, num_generated_tokens), dtype=torch.int32) - expected_logprobs = torch.zeros( - (beam_width, num_generated_tokens, original_logprobs.shape[-1]), - dtype=torch.float32) - for gen_idx in range(num_generated_tokens): - token_idx = prompt_len + gen_idx - expected_tokens[:, gen_idx] = original_tokens[ - seq_slot, cache_indirection[seq_slot, :, token_idx], token_idx] - expected_logprobs[:, gen_idx] = original_logprobs[cache_indirection[ - seq_slot, :beam_width, token_idx], gen_idx] + # set the new log probs and tokens for the beam search sampling + assert sampler.store.sampled_log_probs is not None + sampler.store.sampled_log_probs[ + seq_slot, :beam_width] = original_logprobs[:beam_width, + num_generated_tokens - 1, + 0:1] + sampler.store.new_tokens[ + 0, seq_slot, : + beam_width] = original_logprob_indices[:beam_width, + num_generated_tokens - 1, 0] - torch.testing.assert_close(beam_history.tokens[:beam_width], - expected_tokens[:beam_width]) - # test logprobs as well - torch.testing.assert_close(beam_history.logprobs[:beam_width], - expected_logprobs[:beam_width]) - torch.testing.assert_close( - beam_history.cum_logprobs[:beam_width], - original_cum_logprobs[seq_slot, :beam_width].to("cpu")) + @dataclass + class UutResult: + beam_history_builder: Callable[[], BeamHistory | None] | None + + @dataclass + class UutResultWrapper: + result: UutResult | None = None + + res = UutResultWrapper() + + # test + def _uut(res=res): + res.result = UutResult( + beam_history_builder=sampler._prepare_beam_history( + request, + finish_reasons=torch.ones((beam_width, ), dtype=torch.int), + ), ) + + yield _uut + + torch.cuda.synchronize() + assert res.result is not None + beam_history_builder = res.result.beam_history_builder + assert beam_history_builder is not None + beam_history = beam_history_builder() + assert beam_history is not None + + # expected selection: + # Currently beam history only contains the generated tokens, not the prompt tokens. + expected_tokens = torch.zeros( + (sampler.max_beam_width, num_generated_tokens), dtype=torch.int32) + expected_logprobs = torch.zeros( + (beam_width, num_generated_tokens, original_logprobs.shape[-1]), + dtype=torch.float32) + for gen_idx in range(num_generated_tokens): + token_idx = prompt_len + gen_idx + expected_tokens[:, gen_idx] = original_tokens[ + seq_slot, cache_indirection[seq_slot, :, token_idx], token_idx] + expected_logprobs[:, gen_idx] = original_logprobs[cache_indirection[ + seq_slot, :beam_width, token_idx], gen_idx] + + torch.testing.assert_close(beam_history.tokens[:beam_width], + expected_tokens[:beam_width]) + # test logprobs as well + assert beam_history.logprobs is not None + torch.testing.assert_close(beam_history.logprobs[:beam_width], + expected_logprobs[:beam_width]) + assert beam_history.cum_logprobs is not None + torch.testing.assert_close( + beam_history.cum_logprobs[:beam_width], + original_cum_logprobs[seq_slot, :beam_width].to("cpu")) + + run_test_with_warmup(_uut_provider, max_sync_s=1) def test_finish_beams(): @@ -702,81 +802,94 @@ def test_finish_beams(): This test verifies that beams are correctly finalized. """ - test_params = GeneralTestParams() - beam_width = test_params.beam_width - num_generated_tokens = test_params.num_generated_tokens - test_params.seq_len - end_id = test_params.end_id - batch_size = test_params.batch_size - vocab_size = test_params.vocab_size - num_logprobs = 1 - request = create_default_request(test_params) - sampler = create_default_sampler(test_params) - store_device = sampler.store.cache_indirection.device + @contextmanager + def _uut_provider( + is_warmup: bool) -> Generator[Callable[[], None], None, None]: + test_params = GeneralTestParams() + beam_width = test_params.beam_width + num_generated_tokens = test_params.num_generated_tokens + end_id = test_params.end_id + batch_size = test_params.batch_size + vocab_size = test_params.vocab_size + num_logprobs = 1 + request = create_default_request(test_params) + sampler = create_default_sampler(test_params) + assert sampler.store.cache_indirection is not None - request.set_generated_tokens( - torch.randint(0, - vocab_size, (beam_width, num_generated_tokens), - dtype=torch.int32).tolist()) + request.set_generated_tokens( + torch.randint(0, + vocab_size, (beam_width, num_generated_tokens), + dtype=torch.int32).tolist()) - torch.manual_seed(42) - # Do not keep end_id tokens in the tensor. This would interfere with the test. - tokens = torch.randint( - 0, - end_id, (batch_size, sampler.max_beam_width, num_generated_tokens), - dtype=torch.int32, - device=store_device) - logprobs = torch.randn((batch_size, sampler.max_beam_width, - num_generated_tokens, num_logprobs), - dtype=torch.float32, - device=store_device) - cum_logprobs = logprobs[..., 0].sum(dim=-1) + torch.manual_seed(42) + # Do not keep end_id tokens in the tensor. This would interfere with the test. + tokens = torch.randint( + 0, + end_id, (batch_size, sampler.max_beam_width, num_generated_tokens), + dtype=torch.int32) + logprobs = torch.randn((batch_size, sampler.max_beam_width, + num_generated_tokens, num_logprobs), + dtype=torch.float32) + cum_logprobs = logprobs[..., 0].sum(dim=-1) - # assert that the buffers are different from zero. Otherwise the test may pass if the function does not work. - assert tokens.sum( - ) > 0, "Tokens must not only contain zeros. Otherwise change the seed." - assert torch.any(logprobs != 0) and torch.any( - cum_logprobs != 0 - ), "Log probs and cumulative log probs must not only contain zeros. Otherwise change the seed." + # assert that the buffers are different from zero. Otherwise the test may pass if the function does not work. + assert tokens.sum( + ) > 0, "Tokens must not only contain zeros. Otherwise change the seed." + assert torch.any(logprobs != 0) and torch.any( + cum_logprobs != 0 + ), "Log probs and cumulative log probs must not only contain zeros. Otherwise change the seed." - tokens[batch_size - 1, 0, num_generated_tokens // - 2:] = BEAM_SEARCH_PAD_TOKEN # simulate early finished beam + tokens[batch_size - 1, 0, num_generated_tokens // + 2:] = BEAM_SEARCH_PAD_TOKEN # simulate early finished beam - for batch_idx in range(batch_size): - beam_history = BeamHistory( - tokens=tokens[batch_idx, :beam_width], - logprobs=logprobs[batch_idx, :beam_width], - cum_logprobs=cum_logprobs[batch_idx, :beam_width]) - request.py_return_log_probs = False prompt_len = request.py_prompt_len - if batch_idx < batch_size - 1: - # requests are not finished yet - sampler._finalize_beam(request, beam_history) - final_tokens = torch.tensor(request.get_tokens(), - device=store_device, - dtype=torch.int32)[:, prompt_len:] - torch.testing.assert_close(final_tokens, - tokens[batch_idx, :beam_width]) - # Test the case where end_ids are present in the output - else: - sampler._finalize_beam(request, beam_history) + token_history = [] - # Given input for beam 0: [ token, token, ..., token, BEAM_SEARCH_PAD_TOKEN, BEAM_SEARCH_PAD_TOKEN, ..., BEAM_SEARCH_PAD_TOKEN] - # Expected output for beam 0: [ token, token, ..., token] - final_tokens_1p = torch.tensor(request.get_tokens()[1:], - device=store_device, - dtype=torch.int32)[:, prompt_len:] - final_tokens_0 = torch.tensor(request.get_tokens()[0], - device=store_device, - dtype=torch.int32)[prompt_len:] - torch.testing.assert_close(final_tokens_1p, tokens[batch_idx, - 1:beam_width]) - torch.testing.assert_close(final_tokens_0.shape[0], - num_generated_tokens // 2) - torch.testing.assert_close( - final_tokens_0, tokens[batch_idx, - 0, :num_generated_tokens // 2]) + # test + def _uut(): + nonlocal token_history + + for batch_idx in range(batch_size): + beam_history = BeamHistory( + tokens=tokens[batch_idx, :beam_width], + logprobs=logprobs[batch_idx, :beam_width], + cum_logprobs=cum_logprobs[batch_idx, :beam_width]) + request.py_return_log_probs = False + + sampler._finalize_beam(request, beam_history) + + token_history.append(deepcopy(request.get_tokens())) + + yield _uut + + for batch_idx in range(batch_size): + batch_final_tokens = token_history[batch_idx] + + if batch_idx < batch_size - 1: + # requests are not finished yet + final_tokens = torch.tensor(batch_final_tokens, + dtype=torch.int32)[:, prompt_len:] + torch.testing.assert_close(final_tokens, + tokens[batch_idx, :beam_width]) + # Test the case where end_ids are present in the output + else: + # Given input for beam 0: [ token, token, ..., token, BEAM_SEARCH_PAD_TOKEN, BEAM_SEARCH_PAD_TOKEN, ..., BEAM_SEARCH_PAD_TOKEN] + # Expected output for beam 0: [ token, token, ..., token] + final_tokens_1p = torch.tensor(batch_final_tokens[1:], + dtype=torch.int32)[:, + prompt_len:] + final_tokens_0 = torch.tensor(batch_final_tokens[0], + dtype=torch.int32)[prompt_len:] + torch.testing.assert_close(final_tokens_1p, + tokens[batch_idx, 1:beam_width]) + torch.testing.assert_close(final_tokens_0.shape[0], + num_generated_tokens // 2) + torch.testing.assert_close( + final_tokens_0, tokens[batch_idx, + 0, :num_generated_tokens // 2]) + + run_test_with_warmup(_uut_provider, max_sync_s=1) @force_ampere # Save H100 resource diff --git a/tests/unittest/_torch/sampler/test_beam_search_util.py b/tests/unittest/_torch/sampler/test_beam_search_util.py index 1fc0239fb3..877c7043eb 100644 --- a/tests/unittest/_torch/sampler/test_beam_search_util.py +++ b/tests/unittest/_torch/sampler/test_beam_search_util.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, cast import torch from transformers.configuration_utils import PretrainedConfig +from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader @@ -49,6 +50,7 @@ class DummyConfig(PretrainedConfig): class DummyModel(torch.nn.Module): def __init__(self, model_config: ModelConfig): super().__init__() + assert model_config.pretrained_config is not None self.dtype = model_config.pretrained_config.torch_dtype self.model_config = model_config @@ -67,9 +69,11 @@ class DummyModel(torch.nn.Module): attn_metadata: AttentionMetadata, return_context_logits: bool = False, **kwargs, - ) -> torch.Tensor: + ) -> dict[str, torch.Tensor]: num_batch_tokens = input_ids.size(0) + assert self.config is not None + assert attn_metadata.seq_lens_cuda is not None vocab_size = self.config.vocab_size last_tokens = ( torch.cumsum( @@ -98,18 +102,18 @@ class DummyModel(torch.nn.Module): num_context_requests = attn_metadata.num_contexts # each beam has its own attn_metadata.seq_lens_cuda entry - num_generation_requests = ( - last_tokens.shape[0] - num_context_requests - ) // attn_metadata.beam_width + beam_width = cast(TrtllmAttentionMetadata, attn_metadata).beam_width + num_generation_requests = (last_tokens.shape[0] - num_context_requests) // beam_width num_requests = num_generation_requests + num_context_requests # return cache indirection, as additional model output. # each sequence should only return a 1D cache indirection tensor + assert attn_metadata.cache_indirection is not None context_cache_indirection = attn_metadata.cache_indirection[:num_context_requests, 0] generation_cache_indirection = attn_metadata.cache_indirection[ num_context_requests:num_requests ].view( - num_generation_requests * attn_metadata.beam_width, + num_generation_requests * beam_width, attn_metadata.cache_indirection.shape[-1], ) return { @@ -130,7 +134,7 @@ class DummyModel(torch.nn.Module): @register_checkpoint_weight_loader("DUMMY_FORMAT") class DummyWeightLoader(BaseWeightLoader): - def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]: + def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]: # type: ignore """Load weights from your dummy format. Args: checkpoint_dir: Directory containing checkpoint files diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index 5aa6a92eda..59d4ccc316 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -15,25 +15,14 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from itertools import product -from typing import ( - Callable, - ContextManager, - Final, - Generator, - Optional, - Protocol, - Type, - TypeVar, - Union, - cast, -) +from typing import Callable, ContextManager, Final, Generator, Optional, Type, TypeVar, Union, cast import flashinfer.sampling import numpy as np import pytest import torch from scipy.stats import power_divergence -from utils.util import assert_no_cuda_sync, force_ampere +from utils.util import UutProvider, assert_no_cuda_sync, force_ampere, run_test_with_warmup from tensorrt_llm._torch.pyexecutor.llm_request import convert_wordlist from tensorrt_llm._torch.pyexecutor.sampler import ( @@ -363,53 +352,6 @@ class TestStrategySelection: assert torch_sampler.should_provide_draft_probs(request) == (not is_greedy) -class UutProvider(Protocol): - def __call__(self, is_warmup: bool) -> ContextManager[Callable[[], None]]: ... - - -def _run_test_with_warmup( - uut_provider: UutProvider, - warmup_sizes_bytes: tuple[int] = (4 * 2**30,), - *, - max_sync_s: Optional[float], -): - """Run UUT including setup and warmup. - - This is mainly used to check that the UUT does not CUDA device sync. Thus, - given that PyTorch's caching memory allocator can device sync when it runs - out of cached GPU memory segments, the warmup allocates some GPU memory. - - The warmup also runs the test once. This avoids issues with things like lazy loading - of device code. The UUT provider can use the 'is_warmup' argument to adapt its - behavior to the warmup and final test runs. - - If max_sync_s is provided, this helper checks that the UUT does not device sync, - assuming that the sync (CPU) part of the code takes no longer than max_sync_s - seconds to complete. - - It is the user's responsibility to ensure that the amount of submitted work - does not exceed the CUDA driver/device queue capacity, which would make - the execution appear synchronous. - """ - with torch.cuda.Stream(): - with uut_provider(is_warmup=True) as uut: - bufs = [] - for warmup_size in warmup_sizes_bytes: - bufs.append( - torch.ones(warmup_size, device=torch.cuda.current_device(), dtype=torch.int8) - ) - del bufs - uut() - - with uut_provider(is_warmup=False) as uut: - with ( - assert_no_cuda_sync(sync_timeout_s=max_sync_s) - if max_sync_s is not None - else nullcontext() - ): - uut() - - @force_ampere @pytest.mark.parametrize( "draft_len, with_ctx, with_gen", @@ -630,7 +572,7 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool) torch.testing.assert_close(res.result.req_offsets.to("cpu"), expected_req_offsets) torch.testing.assert_close(res.result.selected_logits.to("cpu"), expected_logits) - _run_test_with_warmup(_test_runner, max_sync_s=0.3) + run_test_with_warmup(_test_runner, max_sync_s=0.3) class TestFinishReasons: @@ -852,7 +794,7 @@ class TestFinishReasons: ] ) - _run_test_with_warmup(uut_provider, max_sync_s=0.5) + run_test_with_warmup(uut_provider, max_sync_s=0.5) @classmethod def test_are_stop_words_isnt_called_when_no_stop_words(cls, monkeypatch: pytest.MonkeyPatch): @@ -879,13 +821,13 @@ class TestFinishReasons: ], extra_context=lambda: raising_stop_words_ctx(True), ) - _run_test_with_warmup(uut_provider_with_stop_words, max_sync_s=0.5) + run_test_with_warmup(uut_provider_with_stop_words, max_sync_s=0.5) uut_provider_without_stop_words = cls.RequestCase.build( [cls.RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[cls.NOT_FINISHED])], extra_context=lambda: raising_stop_words_ctx(False), ) - _run_test_with_warmup(uut_provider_without_stop_words, max_sync_s=0.5) + run_test_with_warmup(uut_provider_without_stop_words, max_sync_s=0.5) class TestBatchedSampling: @@ -1532,7 +1474,7 @@ class TestBatchedSampling: logit_offset += steps - _run_test_with_warmup( + run_test_with_warmup( _uut_provider, max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample ) @@ -2247,7 +2189,7 @@ class TestBatchedSampling: num_samples=num_samples, ) - _run_test_with_warmup( + run_test_with_warmup( _uut_provider, max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample ) @@ -2443,4 +2385,4 @@ class TestBatchedSampling: torch.testing.assert_close(new_tokens_host[:steps, seq_slot], req_tokens.cpu()) input_offset += steps - _run_test_with_warmup(_uut_provider, max_sync_s=0.2) + run_test_with_warmup(_uut_provider, max_sync_s=0.2) diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index 2c720a7328..5bfbd0fff0 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -18,11 +18,11 @@ import math import os import time import unittest -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from dataclasses import dataclass from difflib import SequenceMatcher from pathlib import Path -from typing import Any, Generator +from typing import Any, Callable, ContextManager, Generator, Protocol import psutil import pynvml @@ -520,6 +520,12 @@ def device_sleep( time.sleep(spin_s) +class UutProvider(Protocol): + + def __call__(self, is_warmup: bool) -> ContextManager[Callable[[], None]]: + ... + + @contextmanager def assert_no_cuda_sync( sync_timeout_s: float = 5, ) -> Generator[None, None, None]: @@ -563,6 +569,47 @@ def assert_no_cuda_sync( scope_finished_event.synchronize() +def run_test_with_warmup( + uut_provider: UutProvider, + warmup_sizes_bytes: tuple[int] = (4 * 2**30, ), + *, + max_sync_s: float | None, +): + """Run UUT including setup and warmup. + + This is mainly used to check that the UUT does not CUDA device sync. Thus, + given that PyTorch's caching memory allocator can device sync when it runs + out of cached GPU memory segments, the warmup allocates some GPU memory. + + The warmup also runs the test once. This avoids issues with things like lazy loading + of device code. The UUT provider can use the 'is_warmup' argument to adapt its + behavior to the warmup and final test runs. + + If max_sync_s is provided, this helper checks that the UUT does not device sync, + assuming that the sync (CPU) part of the code takes no longer than max_sync_s + seconds to complete. + + It is the user's responsibility to ensure that the amount of submitted work + does not exceed the CUDA driver/device queue capacity, which would make + the execution appear synchronous. + """ + with torch.cuda.Stream(): + with uut_provider(is_warmup=True) as uut: + bufs = [] + for warmup_size in warmup_sizes_bytes: + bufs.append( + torch.ones(warmup_size, + device=torch.cuda.current_device(), + dtype=torch.int8)) + del bufs + uut() + + with uut_provider(is_warmup=False) as uut: + with (assert_no_cuda_sync(sync_timeout_s=max_sync_s) + if max_sync_s is not None else nullcontext()): + uut() + + _pynvmlInited = False