[TRTLLM-10030][test] ensure that TorchSampler does not sync (#11508)

Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
mpikulski 2026-02-16 13:10:40 +01:00 committed by GitHub
parent d72f8098fe
commit 08c7103fc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 436 additions and 330 deletions

View File

@ -12,21 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import pathlib as _pl
from typing import Any
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Callable, Generator, cast
import pytest
import torch
from test_beam_search_util import (BeamSearchTestOutput, DummyConfigLoader,
DummyWeightLoader, get_expected_outputs)
from utils.llm_data import llm_models_root
from utils.util import assert_no_cuda_sync, force_ampere
from utils.util import assert_no_cuda_sync, force_ampere, run_test_with_warmup
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm import LLM, SamplingParams, TorchLlmArgs
from tensorrt_llm._torch.models.checkpoints import HfCheckpointLoader
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
SamplingConfig)
from tensorrt_llm._torch.pyexecutor.sampler import BeamHistory, TorchSampler
from tensorrt_llm._torch.pyexecutor.sampler import (BeamHistory,
SampleStateTorch,
TorchSampler)
from tensorrt_llm._torch.pyexecutor.sampling_utils import (
BEAM_SEARCH_PAD_TOKEN, BeamSearchMetadata, FinishReason,
beam_search_sampling_batch)
@ -68,43 +75,63 @@ def model_kwargs(fixed_params, sampling_information) -> dict[str, Any]:
)
def _build_llm(fixed_params, input_prompts, model_kwargs):
@pytest.fixture(scope="module", params=[False, True])
def with_cuda_graph_and_overlap(request):
return request.param
def _build_llm(fixed_params, input_prompts, llm_kwargs):
return LLM(
**model_kwargs,
kv_cache_config=KvCacheConfig(max_tokens=10000),
**llm_kwargs,
kv_cache_config=KvCacheConfig(
max_tokens=10000, # type: ignore
),
max_batch_size=fixed_params["max_beam_width"] * len(
input_prompts
), # use small batch size to prevent large buffers from possibly hiding wrong data accesses.
max_seq_len=32,
max_beam_width=fixed_params["max_beam_width"],
disable_overlap_scheduler=True,
cuda_graph_config=None,
)
@pytest.fixture(scope="module")
def llm(fixed_params, input_prompts, model_kwargs):
llm = _build_llm(fixed_params, input_prompts, model_kwargs)
yield llm
llm.shutdown()
@contextmanager
def _single_process_context():
os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1"
try:
yield
finally:
del os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"]
# NB: It is important that all tests instantiating 'LLM' with
# TLLM_WORKER_USE_SINGLE_PROCESS=1 (i.e., single_process=True below)
# use this fixture. Otherwise, more than one such 'LLM' object
# could be alive at any given point in time and this has been
# found to result in corruption of the cache_indirection tensors.
@pytest.fixture(scope="module")
def llm_cuda_graph(fixed_params, input_prompts, model_kwargs):
llm = LLM(
**model_kwargs,
kv_cache_config=KvCacheConfig(max_tokens=10000),
max_batch_size=fixed_params["max_beam_width"] * len(
input_prompts
), # use small batch size to prevent large buffers from possibly hiding wrong data accesses.
max_seq_len=32,
max_beam_width=fixed_params["max_beam_width"],
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
enable_padding=True),
)
yield llm
llm.shutdown()
def llm(fixed_params, input_prompts, model_kwargs, single_process: bool,
with_cuda_graph_and_overlap: bool):
gc.collect(
2) # force destruction of any other LLM instances (cf. comment above)
with _single_process_context() if single_process else nullcontext():
llm = _build_llm(
fixed_params,
input_prompts,
llm_kwargs=(
(dict(
disable_overlap_scheduler=True,
cuda_graph_config=None,
) if not with_cuda_graph_and_overlap else dict(
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
enable_padding=True),
))
|
deepcopy( # LLM.shutdown resets checkpoint_loader.config_loader
model_kwargs)),
)
with llm:
yield llm
def check_generation_logits(beam: CompletionOutput,
@ -125,6 +152,7 @@ def check_generation_logits(beam: CompletionOutput,
def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams,
valid_tokens: int | None) -> None:
"""Check if the logprobs have the correct shape"""
assert beam.logprobs is not None
if sampling_params.logprobs is not None:
generated_tokens = valid_tokens if valid_tokens is not None else sampling_params.max_tokens
assert len(
@ -132,6 +160,7 @@ def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams,
) == generated_tokens, f"expected {generated_tokens} logprobs, but got {len(beam.logprobs)}"
log_sum = 0.0
for logprob_dict in (beam.logprobs):
assert isinstance(logprob_dict, dict)
for logprob_value in logprob_dict.values():
log_sum += logprob_value.logprob
assert log_sum == beam.cumulative_logprob, f"expected {beam.cumulative_logprob} logprob, but got {log_sum}"
@ -145,6 +174,7 @@ def check_cache_indirection(beam: CompletionOutput,
prompt_length: int, beam_idx: int,
valid_tokens: int | None) -> None:
"""Check if the cache indirection seen by the model is the same as the expected cache indirection"""
assert beam.additional_generation_outputs is not None
cache_indirection = beam.additional_generation_outputs["cache_indirection"]
assert cache_indirection is not None, "cache indirection should not be None"
assert cache_indirection.shape[
@ -207,6 +237,11 @@ def check_context_logits(output: GenerationResult,
assert output.context_logits is None, "context logits should be None"
@pytest.fixture(scope="module", params=[False, True])
def single_process(request) -> bool:
return request.param
def validate_output(output: GenerationResult, input_prompt: list[int],
sampling_params: SamplingParams) -> None:
"""Perform several checks on the output of a single prompt"""
@ -226,11 +261,63 @@ def validate_output(output: GenerationResult, input_prompt: list[int],
def validate_outputs(llm: LLM, input_prompts: list[list[int]],
sampling_params: SamplingParams) -> None:
sampling_params: SamplingParams,
monkeypatch: pytest.MonkeyPatch,
check_no_sync: bool) -> None:
"""Generate outputs for a list of prompts and validate the outputs"""
outputs = llm.generate(input_prompts, sampling_params=sampling_params)
num_prompts = len(input_prompts)
outputs = llm.generate(deepcopy(input_prompts),
sampling_params=deepcopy(sampling_params))
if check_no_sync:
del outputs # treat previous .generate as warmup, ignore results
with monkeypatch.context() as patcher:
sample_async_orig = TorchSampler.sample_async
update_requests_orig = TorchSampler.update_requests
_sample_async_hook_called = False
_update_requests_hook_called = False
def _sample_async_hook(*args, **kwargs):
nonlocal _sample_async_hook_called
_sample_async_hook_called = True
with assert_no_cuda_sync():
return sample_async_orig(*args, **kwargs)
def _update_requests_hook(self, state: SampleStateTorch, *args,
**kwargs):
nonlocal _update_requests_hook_called
_update_requests_hook_called = True
# await sampling event outside sync-check (because this syncs)
sampler_event = state.sampler_event
if sampler_event:
sampler_event.synchronize()
with assert_no_cuda_sync():
state.sampler_event = None
try:
return update_requests_orig(self, state, *args,
**kwargs)
finally:
state.sampler_event = sampler_event
# Intercept sampler methods to check that they do not sync (requires
# TLLM_WORKER_USE_SINGLE_PROCESS).
patcher.setattr(TorchSampler, "sample_async", _sample_async_hook)
patcher.setattr(TorchSampler, "update_requests",
_update_requests_hook)
outputs = llm.generate(deepcopy(input_prompts),
sampling_params=deepcopy(sampling_params))
assert _sample_async_hook_called
assert _update_requests_hook_called
num_prompts = len(input_prompts)
assert isinstance(outputs, list)
assert len(
outputs
) == num_prompts, f"expected {num_prompts} outputs, but got {len(outputs)}"
@ -247,73 +334,45 @@ def validate_outputs(llm: LLM, input_prompts: list[list[int]],
@pytest.mark.parametrize("gather_generation_logits", [True, False])
@pytest.mark.parametrize("gather_context_logits", [True, False])
@pytest.mark.parametrize("num_output_beams", [1, 2])
@pytest.mark.parametrize("num_prompts", [1, 2])
@pytest.mark.parametrize("num_prompts", [1, 2, 3])
@pytest.mark.parametrize("stop_token_ids", [[15], None])
@pytest.mark.threadleak(enabled=False)
def test_beam_search_e2e(
gather_context_logits: bool,
gather_generation_logits: bool,
return_log_probs: bool,
num_output_beams: int,
num_prompts: int,
llm,
fixed_params,
input_prompts,
) -> None:
if return_log_probs and num_prompts > 1 and llm.args.sampler_type == "TRTLLMSampler":
pytest.skip(
"Beam search currently does not support return_log_probs with multiple prompts"
)
if return_log_probs and llm.args.sampler_type == "TRTLLMSampler":
pytest.skip(
"Beam search on TRTLLMSampler does not correctly handle log_probs if called multiple times"
)
# create sampling parameters
# additional_model_outputs is used to gather the cache indirection from the model.
sampling_params = SamplingParams(
max_tokens=fixed_params["max_tokens"],
n=num_output_beams,
best_of=fixed_params["max_beam_width"],
use_beam_search=True,
return_context_logits=gather_context_logits,
return_generation_logits=gather_generation_logits,
logprobs=return_log_probs,
end_id=-1,
additional_model_outputs=["cache_indirection"],
)
validate_outputs(llm, input_prompts[:num_prompts], sampling_params)
@pytest.mark.parametrize("return_log_probs", [True, False])
@pytest.mark.parametrize("gather_generation_logits", [True, False])
@pytest.mark.parametrize("gather_context_logits", [True, False])
@pytest.mark.parametrize("num_output_beams", [1, 2])
@pytest.mark.parametrize("num_prompts", [1, 2, 3])
@pytest.mark.parametrize("stop_token_ids", [[15], None])
@pytest.mark.threadleak(enabled=False)
def test_beam_search_e2e_cuda_graph_and_overlap(
gather_context_logits: bool,
gather_generation_logits: bool,
return_log_probs: bool,
num_output_beams: int,
num_prompts: int,
stop_token_ids: list[int] | None,
llm_cuda_graph,
llm: LLM,
fixed_params,
input_prompts,
single_process: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
if return_log_probs and num_prompts > 1 and llm_cuda_graph.args.sampler_type == "TRTLLMSampler":
llm_args = cast(TorchLlmArgs, llm.args)
check_no_sync = single_process # single_process only used for sync check
if check_no_sync and cast(TorchLlmArgs,
llm.args).sampler_type != "TorchSampler":
pytest.skip("Sync check only supported for TorchSampler")
if check_no_sync and (cast(TorchLlmArgs, llm.args).sampler_type
== "TorchSampler") and stop_token_ids is not None:
# FIXME: Fix TorchSampler._are_stop_words
pytest.skip("Stop word handling in TorchSampler syncs")
if return_log_probs and num_prompts > 1 and llm_args.sampler_type == "TRTLLMSampler":
pytest.skip(
"Beam search currently does not support return_log_probs with multiple prompts"
)
if return_log_probs and llm_cuda_graph.args.sampler_type == "TRTLLMSampler":
if return_log_probs and llm_args.sampler_type == "TRTLLMSampler":
pytest.skip(
"Beam search on TRTLLMSampler does not correctly handle log_probs if called multiple times"
)
if stop_token_ids is not None and llm_cuda_graph.args.sampler_type == "TRTLLMSampler":
if stop_token_ids is not None and llm_args.sampler_type == "TRTLLMSampler":
pytest.skip(
"Beam search on TRTLLMSampler does not correctly handle stop_token_ids"
)
# create sampling parameters
# additional_model_outputs is used to gather the cache indirection from the model.
sampling_params = SamplingParams(
@ -328,8 +387,13 @@ def test_beam_search_e2e_cuda_graph_and_overlap(
stop_token_ids=stop_token_ids,
additional_model_outputs=["cache_indirection"],
)
validate_outputs(llm_cuda_graph, input_prompts[:num_prompts],
sampling_params)
validate_outputs(
llm,
input_prompts[:num_prompts],
sampling_params,
check_no_sync=check_no_sync,
monkeypatch=monkeypatch,
)
###########################################################################
@ -440,6 +504,7 @@ def test_beam_search_sampling_batch_basic():
)
# Validate output shapes
assert softmax is not None
expected_tokens_shape = (batch_size, beam_width)
assert next_tokens.shape == expected_tokens_shape, (
f"Expected shape {expected_tokens_shape}, got {next_tokens.shape}")
@ -564,9 +629,11 @@ def create_default_sampler(test_params: GeneralTestParams) -> TorchSampler:
assert max_seq_len > test_params.seq_len, "Max sequence length must be greater than sequence length"
assert max_batch_size > test_params.batch_size, "Max batch size must be greater than batch size"
assert max_batch_size > test_params.seq_slot, "Max batch size must be greater than sequence slot"
assert sampler.store.cache_indirection is not None
assert sampler.store.cache_indirection.shape == (
max_batch_size, max_beam_width,
max_seq_len), "Cache indirection shape mismatch"
assert sampler.store.original_tokens is not None
assert sampler.store.original_tokens.shape == (
max_batch_size, max_beam_width,
max_seq_len), "Original tokens shape mismatch"
@ -579,121 +646,154 @@ def test_create_beam_history():
This test verifies that beam history is correctly reconstructed by following
the cache_indirection backwards to obtain the correct token sequence.
"""
test_params = GeneralTestParams()
request = create_default_request(test_params)
sampler = create_default_sampler(test_params)
# Extract parameters from the test parameters
beam_width = test_params.beam_width
prompt_len = test_params.prompt_len
num_generated_tokens = test_params.num_generated_tokens
seq_slot = test_params.seq_slot
vocab_size = test_params.vocab_size
num_logprobs = test_params.num_logprobs + 1
cache_indirection = sampler.store.cache_indirection
original_tokens = sampler.store.original_tokens
original_logprobs = torch.zeros(
(beam_width, num_generated_tokens, num_logprobs),
dtype=torch.float32,
device=original_tokens.device)
original_logprob_indices = torch.zeros(
(beam_width, num_generated_tokens, num_logprobs),
dtype=torch.int32,
device=original_tokens.device)
original_cum_logprobs = sampler.store.cum_log_probs
@contextmanager
def _uut_provider(
is_warmup: bool) -> Generator[Callable[[], None], None, None]:
test_params = GeneralTestParams()
request = create_default_request(test_params)
sampler = create_default_sampler(test_params)
# Fill the request with some random tokens that will be overwritten by the beam search sampling
# Beam history is created before add_token is called
request.set_generated_tokens(
torch.randint(0,
vocab_size, (beam_width, num_generated_tokens - 1),
dtype=torch.int32).tolist())
# random fill
torch.manual_seed(42)
original_tokens[seq_slot, :beam_width, prompt_len:prompt_len +
num_generated_tokens] = torch.randint(
0,
beam_width, (beam_width, num_generated_tokens),
dtype=torch.int32)
assert original_tokens.sum(
) > 0, "Original tokens must not only contain zeros. Otherwise change the seed."
# Extract parameters from the test parameters
beam_width = test_params.beam_width
prompt_len = test_params.prompt_len
num_generated_tokens = test_params.num_generated_tokens
seq_slot = test_params.seq_slot
vocab_size = test_params.vocab_size
num_logprobs = test_params.num_logprobs + 1
cache_indirection = sampler.store.cache_indirection
assert cache_indirection is not None
original_tokens = sampler.store.original_tokens
assert original_tokens is not None
original_logprobs = torch.zeros(
(beam_width, num_generated_tokens, num_logprobs),
dtype=torch.float32,
device=original_tokens.device)
original_logprob_indices = torch.zeros(
(beam_width, num_generated_tokens, num_logprobs),
dtype=torch.int32,
device=original_tokens.device)
original_cum_logprobs = sampler.store.cum_log_probs
assert original_cum_logprobs is not None
original_logprobs[:beam_width] = torch.randn(
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
dtype=torch.float32)
original_logprob_indices[:beam_width] = torch.randint(
0,
vocab_size,
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
dtype=torch.int32)
assert (original_logprobs != 0).sum(
) > 0, "Original log probs must not only contain zeros. Otherwise change the seed."
assert (original_logprob_indices).sum(
) > 0, "Original log prob indices must not only contain zeros. Otherwise change the seed."
# Fill the request with some random tokens that will be overwritten by the beam search sampling
# Beam history is created before add_token is called
request.set_generated_tokens(
torch.randint(0,
vocab_size, (beam_width, num_generated_tokens - 1),
dtype=torch.int32).tolist())
# random fill
torch.manual_seed(42)
original_tokens[seq_slot, :beam_width, prompt_len:prompt_len +
num_generated_tokens] = torch.randint(
0,
beam_width, (beam_width, num_generated_tokens),
dtype=torch.int32)
assert original_tokens.sum(
) > 0, "Original tokens must not only contain zeros. Otherwise change the seed."
# set the logprobs in the request:
token_logprobs = sampler._convert_logprobs_tensor_to_list(
original_logprob_indices[:beam_width, :num_generated_tokens - 1],
original_logprobs[:beam_width, :num_generated_tokens - 1],
)
request.py_result.set_log_probs(
token_logprobs,
cum_log_probs=torch.zeros_like(
original_cum_logprobs[seq_slot, :beam_width]).tolist())
original_logprobs[:beam_width] = torch.randn(
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
dtype=torch.float32)
original_logprob_indices[:beam_width] = torch.randint(
0,
vocab_size,
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
dtype=torch.int32)
assert (original_logprobs != 0).sum(
) > 0, "Original log probs must not only contain zeros. Otherwise change the seed."
assert (original_logprob_indices).sum(
) > 0, "Original log prob indices must not only contain zeros. Otherwise change the seed."
original_cum_logprobs[seq_slot, :beam_width] = torch.randn(
(beam_width, ), dtype=torch.float32)
assert (original_cum_logprobs != 0).sum(
) > 0, "Original cumulative log probs must not only contain zeros. Otherwise change the seed."
# set the logprobs in the request:
token_logprobs = sampler._convert_logprobs_tensor_to_list(
original_logprob_indices[:beam_width, :num_generated_tokens - 1],
original_logprobs[:beam_width, :num_generated_tokens - 1],
)
request.py_result.set_log_probs(
token_logprobs,
cum_log_probs=torch.zeros_like(
original_cum_logprobs[seq_slot, :beam_width]).tolist())
cache_indirection[seq_slot, :beam_width, prompt_len:prompt_len +
num_generated_tokens] = torch.randint(
0,
beam_width, (beam_width, num_generated_tokens),
dtype=torch.int32)
assert cache_indirection[
seq_slot, :beam_width,
prompt_len:prompt_len + num_generated_tokens].sum(
) > 0, "Deterministic offsets must not only contain zeros. Otherwise change the seed."
original_cum_logprobs[seq_slot, :beam_width] = torch.randn(
(beam_width, ), dtype=torch.float32)
assert (original_cum_logprobs != 0).sum(
) > 0, "Original cumulative log probs must not only contain zeros. Otherwise change the seed."
# set the new log probs and tokens for the beam search sampling
sampler.store.sampled_log_probs[
seq_slot, :beam_width] = original_logprobs[:beam_width,
num_generated_tokens - 1,
0:1]
sampler.store.new_tokens[
0,
seq_slot, :beam_width] = original_logprob_indices[:beam_width,
num_generated_tokens -
1, 0]
# test
beam_history_builder = sampler._prepare_beam_history(
request, finish_reasons=torch.ones((beam_width, ), dtype=torch.int))
torch.cuda.synchronize()
beam_history = beam_history_builder()
cache_indirection[seq_slot, :beam_width, prompt_len:prompt_len +
num_generated_tokens] = torch.randint(
0,
beam_width, (beam_width, num_generated_tokens),
dtype=torch.int32)
assert cache_indirection[
seq_slot, :beam_width,
prompt_len:prompt_len + num_generated_tokens].sum(
) > 0, "Deterministic offsets must not only contain zeros. Otherwise change the seed."
# expected selection:
# Currently beam history only contains the generated tokens, not the prompt tokens.
expected_tokens = torch.zeros(
(sampler.max_beam_width, num_generated_tokens), dtype=torch.int32)
expected_logprobs = torch.zeros(
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
dtype=torch.float32)
for gen_idx in range(num_generated_tokens):
token_idx = prompt_len + gen_idx
expected_tokens[:, gen_idx] = original_tokens[
seq_slot, cache_indirection[seq_slot, :, token_idx], token_idx]
expected_logprobs[:, gen_idx] = original_logprobs[cache_indirection[
seq_slot, :beam_width, token_idx], gen_idx]
# set the new log probs and tokens for the beam search sampling
assert sampler.store.sampled_log_probs is not None
sampler.store.sampled_log_probs[
seq_slot, :beam_width] = original_logprobs[:beam_width,
num_generated_tokens - 1,
0:1]
sampler.store.new_tokens[
0, seq_slot, :
beam_width] = original_logprob_indices[:beam_width,
num_generated_tokens - 1, 0]
torch.testing.assert_close(beam_history.tokens[:beam_width],
expected_tokens[:beam_width])
# test logprobs as well
torch.testing.assert_close(beam_history.logprobs[:beam_width],
expected_logprobs[:beam_width])
torch.testing.assert_close(
beam_history.cum_logprobs[:beam_width],
original_cum_logprobs[seq_slot, :beam_width].to("cpu"))
@dataclass
class UutResult:
beam_history_builder: Callable[[], BeamHistory | None] | None
@dataclass
class UutResultWrapper:
result: UutResult | None = None
res = UutResultWrapper()
# test
def _uut(res=res):
res.result = UutResult(
beam_history_builder=sampler._prepare_beam_history(
request,
finish_reasons=torch.ones((beam_width, ), dtype=torch.int),
), )
yield _uut
torch.cuda.synchronize()
assert res.result is not None
beam_history_builder = res.result.beam_history_builder
assert beam_history_builder is not None
beam_history = beam_history_builder()
assert beam_history is not None
# expected selection:
# Currently beam history only contains the generated tokens, not the prompt tokens.
expected_tokens = torch.zeros(
(sampler.max_beam_width, num_generated_tokens), dtype=torch.int32)
expected_logprobs = torch.zeros(
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
dtype=torch.float32)
for gen_idx in range(num_generated_tokens):
token_idx = prompt_len + gen_idx
expected_tokens[:, gen_idx] = original_tokens[
seq_slot, cache_indirection[seq_slot, :, token_idx], token_idx]
expected_logprobs[:, gen_idx] = original_logprobs[cache_indirection[
seq_slot, :beam_width, token_idx], gen_idx]
torch.testing.assert_close(beam_history.tokens[:beam_width],
expected_tokens[:beam_width])
# test logprobs as well
assert beam_history.logprobs is not None
torch.testing.assert_close(beam_history.logprobs[:beam_width],
expected_logprobs[:beam_width])
assert beam_history.cum_logprobs is not None
torch.testing.assert_close(
beam_history.cum_logprobs[:beam_width],
original_cum_logprobs[seq_slot, :beam_width].to("cpu"))
run_test_with_warmup(_uut_provider, max_sync_s=1)
def test_finish_beams():
@ -702,81 +802,94 @@ def test_finish_beams():
This test verifies that beams are correctly finalized.
"""
test_params = GeneralTestParams()
beam_width = test_params.beam_width
num_generated_tokens = test_params.num_generated_tokens
test_params.seq_len
end_id = test_params.end_id
batch_size = test_params.batch_size
vocab_size = test_params.vocab_size
num_logprobs = 1
request = create_default_request(test_params)
sampler = create_default_sampler(test_params)
store_device = sampler.store.cache_indirection.device
@contextmanager
def _uut_provider(
is_warmup: bool) -> Generator[Callable[[], None], None, None]:
test_params = GeneralTestParams()
beam_width = test_params.beam_width
num_generated_tokens = test_params.num_generated_tokens
end_id = test_params.end_id
batch_size = test_params.batch_size
vocab_size = test_params.vocab_size
num_logprobs = 1
request = create_default_request(test_params)
sampler = create_default_sampler(test_params)
assert sampler.store.cache_indirection is not None
request.set_generated_tokens(
torch.randint(0,
vocab_size, (beam_width, num_generated_tokens),
dtype=torch.int32).tolist())
request.set_generated_tokens(
torch.randint(0,
vocab_size, (beam_width, num_generated_tokens),
dtype=torch.int32).tolist())
torch.manual_seed(42)
# Do not keep end_id tokens in the tensor. This would interfere with the test.
tokens = torch.randint(
0,
end_id, (batch_size, sampler.max_beam_width, num_generated_tokens),
dtype=torch.int32,
device=store_device)
logprobs = torch.randn((batch_size, sampler.max_beam_width,
num_generated_tokens, num_logprobs),
dtype=torch.float32,
device=store_device)
cum_logprobs = logprobs[..., 0].sum(dim=-1)
torch.manual_seed(42)
# Do not keep end_id tokens in the tensor. This would interfere with the test.
tokens = torch.randint(
0,
end_id, (batch_size, sampler.max_beam_width, num_generated_tokens),
dtype=torch.int32)
logprobs = torch.randn((batch_size, sampler.max_beam_width,
num_generated_tokens, num_logprobs),
dtype=torch.float32)
cum_logprobs = logprobs[..., 0].sum(dim=-1)
# assert that the buffers are different from zero. Otherwise the test may pass if the function does not work.
assert tokens.sum(
) > 0, "Tokens must not only contain zeros. Otherwise change the seed."
assert torch.any(logprobs != 0) and torch.any(
cum_logprobs != 0
), "Log probs and cumulative log probs must not only contain zeros. Otherwise change the seed."
# assert that the buffers are different from zero. Otherwise the test may pass if the function does not work.
assert tokens.sum(
) > 0, "Tokens must not only contain zeros. Otherwise change the seed."
assert torch.any(logprobs != 0) and torch.any(
cum_logprobs != 0
), "Log probs and cumulative log probs must not only contain zeros. Otherwise change the seed."
tokens[batch_size - 1, 0, num_generated_tokens //
2:] = BEAM_SEARCH_PAD_TOKEN # simulate early finished beam
tokens[batch_size - 1, 0, num_generated_tokens //
2:] = BEAM_SEARCH_PAD_TOKEN # simulate early finished beam
for batch_idx in range(batch_size):
beam_history = BeamHistory(
tokens=tokens[batch_idx, :beam_width],
logprobs=logprobs[batch_idx, :beam_width],
cum_logprobs=cum_logprobs[batch_idx, :beam_width])
request.py_return_log_probs = False
prompt_len = request.py_prompt_len
if batch_idx < batch_size - 1:
# requests are not finished yet
sampler._finalize_beam(request, beam_history)
final_tokens = torch.tensor(request.get_tokens(),
device=store_device,
dtype=torch.int32)[:, prompt_len:]
torch.testing.assert_close(final_tokens,
tokens[batch_idx, :beam_width])
# Test the case where end_ids are present in the output
else:
sampler._finalize_beam(request, beam_history)
token_history = []
# Given input for beam 0: [ token, token, ..., token, BEAM_SEARCH_PAD_TOKEN, BEAM_SEARCH_PAD_TOKEN, ..., BEAM_SEARCH_PAD_TOKEN]
# Expected output for beam 0: [ token, token, ..., token]
final_tokens_1p = torch.tensor(request.get_tokens()[1:],
device=store_device,
dtype=torch.int32)[:, prompt_len:]
final_tokens_0 = torch.tensor(request.get_tokens()[0],
device=store_device,
dtype=torch.int32)[prompt_len:]
torch.testing.assert_close(final_tokens_1p, tokens[batch_idx,
1:beam_width])
torch.testing.assert_close(final_tokens_0.shape[0],
num_generated_tokens // 2)
torch.testing.assert_close(
final_tokens_0, tokens[batch_idx,
0, :num_generated_tokens // 2])
# test
def _uut():
nonlocal token_history
for batch_idx in range(batch_size):
beam_history = BeamHistory(
tokens=tokens[batch_idx, :beam_width],
logprobs=logprobs[batch_idx, :beam_width],
cum_logprobs=cum_logprobs[batch_idx, :beam_width])
request.py_return_log_probs = False
sampler._finalize_beam(request, beam_history)
token_history.append(deepcopy(request.get_tokens()))
yield _uut
for batch_idx in range(batch_size):
batch_final_tokens = token_history[batch_idx]
if batch_idx < batch_size - 1:
# requests are not finished yet
final_tokens = torch.tensor(batch_final_tokens,
dtype=torch.int32)[:, prompt_len:]
torch.testing.assert_close(final_tokens,
tokens[batch_idx, :beam_width])
# Test the case where end_ids are present in the output
else:
# Given input for beam 0: [ token, token, ..., token, BEAM_SEARCH_PAD_TOKEN, BEAM_SEARCH_PAD_TOKEN, ..., BEAM_SEARCH_PAD_TOKEN]
# Expected output for beam 0: [ token, token, ..., token]
final_tokens_1p = torch.tensor(batch_final_tokens[1:],
dtype=torch.int32)[:,
prompt_len:]
final_tokens_0 = torch.tensor(batch_final_tokens[0],
dtype=torch.int32)[prompt_len:]
torch.testing.assert_close(final_tokens_1p,
tokens[batch_idx, 1:beam_width])
torch.testing.assert_close(final_tokens_0.shape[0],
num_generated_tokens // 2)
torch.testing.assert_close(
final_tokens_0, tokens[batch_idx,
0, :num_generated_tokens // 2])
run_test_with_warmup(_uut_provider, max_sync_s=1)
@force_ampere # Save H100 resource

View File

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, cast
import torch
from transformers.configuration_utils import PretrainedConfig
from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
@ -49,6 +50,7 @@ class DummyConfig(PretrainedConfig):
class DummyModel(torch.nn.Module):
def __init__(self, model_config: ModelConfig):
super().__init__()
assert model_config.pretrained_config is not None
self.dtype = model_config.pretrained_config.torch_dtype
self.model_config = model_config
@ -67,9 +69,11 @@ class DummyModel(torch.nn.Module):
attn_metadata: AttentionMetadata,
return_context_logits: bool = False,
**kwargs,
) -> torch.Tensor:
) -> dict[str, torch.Tensor]:
num_batch_tokens = input_ids.size(0)
assert self.config is not None
assert attn_metadata.seq_lens_cuda is not None
vocab_size = self.config.vocab_size
last_tokens = (
torch.cumsum(
@ -98,18 +102,18 @@ class DummyModel(torch.nn.Module):
num_context_requests = attn_metadata.num_contexts
# each beam has its own attn_metadata.seq_lens_cuda entry
num_generation_requests = (
last_tokens.shape[0] - num_context_requests
) // attn_metadata.beam_width
beam_width = cast(TrtllmAttentionMetadata, attn_metadata).beam_width
num_generation_requests = (last_tokens.shape[0] - num_context_requests) // beam_width
num_requests = num_generation_requests + num_context_requests
# return cache indirection, as additional model output.
# each sequence should only return a 1D cache indirection tensor
assert attn_metadata.cache_indirection is not None
context_cache_indirection = attn_metadata.cache_indirection[:num_context_requests, 0]
generation_cache_indirection = attn_metadata.cache_indirection[
num_context_requests:num_requests
].view(
num_generation_requests * attn_metadata.beam_width,
num_generation_requests * beam_width,
attn_metadata.cache_indirection.shape[-1],
)
return {
@ -130,7 +134,7 @@ class DummyModel(torch.nn.Module):
@register_checkpoint_weight_loader("DUMMY_FORMAT")
class DummyWeightLoader(BaseWeightLoader):
def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]:
def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]: # type: ignore
"""Load weights from your dummy format.
Args:
checkpoint_dir: Directory containing checkpoint files

View File

@ -15,25 +15,14 @@
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from itertools import product
from typing import (
Callable,
ContextManager,
Final,
Generator,
Optional,
Protocol,
Type,
TypeVar,
Union,
cast,
)
from typing import Callable, ContextManager, Final, Generator, Optional, Type, TypeVar, Union, cast
import flashinfer.sampling
import numpy as np
import pytest
import torch
from scipy.stats import power_divergence
from utils.util import assert_no_cuda_sync, force_ampere
from utils.util import UutProvider, assert_no_cuda_sync, force_ampere, run_test_with_warmup
from tensorrt_llm._torch.pyexecutor.llm_request import convert_wordlist
from tensorrt_llm._torch.pyexecutor.sampler import (
@ -363,53 +352,6 @@ class TestStrategySelection:
assert torch_sampler.should_provide_draft_probs(request) == (not is_greedy)
class UutProvider(Protocol):
def __call__(self, is_warmup: bool) -> ContextManager[Callable[[], None]]: ...
def _run_test_with_warmup(
uut_provider: UutProvider,
warmup_sizes_bytes: tuple[int] = (4 * 2**30,),
*,
max_sync_s: Optional[float],
):
"""Run UUT including setup and warmup.
This is mainly used to check that the UUT does not CUDA device sync. Thus,
given that PyTorch's caching memory allocator can device sync when it runs
out of cached GPU memory segments, the warmup allocates some GPU memory.
The warmup also runs the test once. This avoids issues with things like lazy loading
of device code. The UUT provider can use the 'is_warmup' argument to adapt its
behavior to the warmup and final test runs.
If max_sync_s is provided, this helper checks that the UUT does not device sync,
assuming that the sync (CPU) part of the code takes no longer than max_sync_s
seconds to complete.
It is the user's responsibility to ensure that the amount of submitted work
does not exceed the CUDA driver/device queue capacity, which would make
the execution appear synchronous.
"""
with torch.cuda.Stream():
with uut_provider(is_warmup=True) as uut:
bufs = []
for warmup_size in warmup_sizes_bytes:
bufs.append(
torch.ones(warmup_size, device=torch.cuda.current_device(), dtype=torch.int8)
)
del bufs
uut()
with uut_provider(is_warmup=False) as uut:
with (
assert_no_cuda_sync(sync_timeout_s=max_sync_s)
if max_sync_s is not None
else nullcontext()
):
uut()
@force_ampere
@pytest.mark.parametrize(
"draft_len, with_ctx, with_gen",
@ -630,7 +572,7 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool)
torch.testing.assert_close(res.result.req_offsets.to("cpu"), expected_req_offsets)
torch.testing.assert_close(res.result.selected_logits.to("cpu"), expected_logits)
_run_test_with_warmup(_test_runner, max_sync_s=0.3)
run_test_with_warmup(_test_runner, max_sync_s=0.3)
class TestFinishReasons:
@ -852,7 +794,7 @@ class TestFinishReasons:
]
)
_run_test_with_warmup(uut_provider, max_sync_s=0.5)
run_test_with_warmup(uut_provider, max_sync_s=0.5)
@classmethod
def test_are_stop_words_isnt_called_when_no_stop_words(cls, monkeypatch: pytest.MonkeyPatch):
@ -879,13 +821,13 @@ class TestFinishReasons:
],
extra_context=lambda: raising_stop_words_ctx(True),
)
_run_test_with_warmup(uut_provider_with_stop_words, max_sync_s=0.5)
run_test_with_warmup(uut_provider_with_stop_words, max_sync_s=0.5)
uut_provider_without_stop_words = cls.RequestCase.build(
[cls.RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[cls.NOT_FINISHED])],
extra_context=lambda: raising_stop_words_ctx(False),
)
_run_test_with_warmup(uut_provider_without_stop_words, max_sync_s=0.5)
run_test_with_warmup(uut_provider_without_stop_words, max_sync_s=0.5)
class TestBatchedSampling:
@ -1532,7 +1474,7 @@ class TestBatchedSampling:
logit_offset += steps
_run_test_with_warmup(
run_test_with_warmup(
_uut_provider,
max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample
)
@ -2247,7 +2189,7 @@ class TestBatchedSampling:
num_samples=num_samples,
)
_run_test_with_warmup(
run_test_with_warmup(
_uut_provider,
max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample
)
@ -2443,4 +2385,4 @@ class TestBatchedSampling:
torch.testing.assert_close(new_tokens_host[:steps, seq_slot], req_tokens.cpu())
input_offset += steps
_run_test_with_warmup(_uut_provider, max_sync_s=0.2)
run_test_with_warmup(_uut_provider, max_sync_s=0.2)

View File

@ -18,11 +18,11 @@ import math
import os
import time
import unittest
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from difflib import SequenceMatcher
from pathlib import Path
from typing import Any, Generator
from typing import Any, Callable, ContextManager, Generator, Protocol
import psutil
import pynvml
@ -520,6 +520,12 @@ def device_sleep(
time.sleep(spin_s)
class UutProvider(Protocol):
def __call__(self, is_warmup: bool) -> ContextManager[Callable[[], None]]:
...
@contextmanager
def assert_no_cuda_sync(
sync_timeout_s: float = 5, ) -> Generator[None, None, None]:
@ -563,6 +569,47 @@ def assert_no_cuda_sync(
scope_finished_event.synchronize()
def run_test_with_warmup(
uut_provider: UutProvider,
warmup_sizes_bytes: tuple[int] = (4 * 2**30, ),
*,
max_sync_s: float | None,
):
"""Run UUT including setup and warmup.
This is mainly used to check that the UUT does not CUDA device sync. Thus,
given that PyTorch's caching memory allocator can device sync when it runs
out of cached GPU memory segments, the warmup allocates some GPU memory.
The warmup also runs the test once. This avoids issues with things like lazy loading
of device code. The UUT provider can use the 'is_warmup' argument to adapt its
behavior to the warmup and final test runs.
If max_sync_s is provided, this helper checks that the UUT does not device sync,
assuming that the sync (CPU) part of the code takes no longer than max_sync_s
seconds to complete.
It is the user's responsibility to ensure that the amount of submitted work
does not exceed the CUDA driver/device queue capacity, which would make
the execution appear synchronous.
"""
with torch.cuda.Stream():
with uut_provider(is_warmup=True) as uut:
bufs = []
for warmup_size in warmup_sizes_bytes:
bufs.append(
torch.ones(warmup_size,
device=torch.cuda.current_device(),
dtype=torch.int8))
del bufs
uut()
with uut_provider(is_warmup=False) as uut:
with (assert_no_cuda_sync(sync_timeout_s=max_sync_s)
if max_sync_s is not None else nullcontext()):
uut()
_pynvmlInited = False