mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 16:55:08 +08:00
[TRTLLM-10030][test] ensure that TorchSampler does not sync (#11508)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
This commit is contained in:
parent
d72f8098fe
commit
08c7103fc4
@ -12,21 +12,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import pathlib as _pl
|
||||
from typing import Any
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Generator, cast
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from test_beam_search_util import (BeamSearchTestOutput, DummyConfigLoader,
|
||||
DummyWeightLoader, get_expected_outputs)
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import assert_no_cuda_sync, force_ampere
|
||||
from utils.util import assert_no_cuda_sync, force_ampere, run_test_with_warmup
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm import LLM, SamplingParams, TorchLlmArgs
|
||||
from tensorrt_llm._torch.models.checkpoints import HfCheckpointLoader
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
|
||||
SamplingConfig)
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import BeamHistory, TorchSampler
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import (BeamHistory,
|
||||
SampleStateTorch,
|
||||
TorchSampler)
|
||||
from tensorrt_llm._torch.pyexecutor.sampling_utils import (
|
||||
BEAM_SEARCH_PAD_TOKEN, BeamSearchMetadata, FinishReason,
|
||||
beam_search_sampling_batch)
|
||||
@ -68,43 +75,63 @@ def model_kwargs(fixed_params, sampling_information) -> dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def _build_llm(fixed_params, input_prompts, model_kwargs):
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def with_cuda_graph_and_overlap(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def _build_llm(fixed_params, input_prompts, llm_kwargs):
|
||||
return LLM(
|
||||
**model_kwargs,
|
||||
kv_cache_config=KvCacheConfig(max_tokens=10000),
|
||||
**llm_kwargs,
|
||||
kv_cache_config=KvCacheConfig(
|
||||
max_tokens=10000, # type: ignore
|
||||
),
|
||||
max_batch_size=fixed_params["max_beam_width"] * len(
|
||||
input_prompts
|
||||
), # use small batch size to prevent large buffers from possibly hiding wrong data accesses.
|
||||
max_seq_len=32,
|
||||
max_beam_width=fixed_params["max_beam_width"],
|
||||
disable_overlap_scheduler=True,
|
||||
cuda_graph_config=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm(fixed_params, input_prompts, model_kwargs):
|
||||
llm = _build_llm(fixed_params, input_prompts, model_kwargs)
|
||||
yield llm
|
||||
llm.shutdown()
|
||||
@contextmanager
|
||||
def _single_process_context():
|
||||
os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"]
|
||||
|
||||
|
||||
# NB: It is important that all tests instantiating 'LLM' with
|
||||
# TLLM_WORKER_USE_SINGLE_PROCESS=1 (i.e., single_process=True below)
|
||||
# use this fixture. Otherwise, more than one such 'LLM' object
|
||||
# could be alive at any given point in time and this has been
|
||||
# found to result in corruption of the cache_indirection tensors.
|
||||
@pytest.fixture(scope="module")
|
||||
def llm_cuda_graph(fixed_params, input_prompts, model_kwargs):
|
||||
llm = LLM(
|
||||
**model_kwargs,
|
||||
kv_cache_config=KvCacheConfig(max_tokens=10000),
|
||||
max_batch_size=fixed_params["max_beam_width"] * len(
|
||||
input_prompts
|
||||
), # use small batch size to prevent large buffers from possibly hiding wrong data accesses.
|
||||
max_seq_len=32,
|
||||
max_beam_width=fixed_params["max_beam_width"],
|
||||
disable_overlap_scheduler=False,
|
||||
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
|
||||
enable_padding=True),
|
||||
)
|
||||
yield llm
|
||||
llm.shutdown()
|
||||
def llm(fixed_params, input_prompts, model_kwargs, single_process: bool,
|
||||
with_cuda_graph_and_overlap: bool):
|
||||
gc.collect(
|
||||
2) # force destruction of any other LLM instances (cf. comment above)
|
||||
with _single_process_context() if single_process else nullcontext():
|
||||
llm = _build_llm(
|
||||
fixed_params,
|
||||
input_prompts,
|
||||
llm_kwargs=(
|
||||
(dict(
|
||||
disable_overlap_scheduler=True,
|
||||
cuda_graph_config=None,
|
||||
) if not with_cuda_graph_and_overlap else dict(
|
||||
disable_overlap_scheduler=False,
|
||||
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
|
||||
enable_padding=True),
|
||||
))
|
||||
|
|
||||
deepcopy( # LLM.shutdown resets checkpoint_loader.config_loader
|
||||
model_kwargs)),
|
||||
)
|
||||
with llm:
|
||||
yield llm
|
||||
|
||||
|
||||
def check_generation_logits(beam: CompletionOutput,
|
||||
@ -125,6 +152,7 @@ def check_generation_logits(beam: CompletionOutput,
|
||||
def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams,
|
||||
valid_tokens: int | None) -> None:
|
||||
"""Check if the logprobs have the correct shape"""
|
||||
assert beam.logprobs is not None
|
||||
if sampling_params.logprobs is not None:
|
||||
generated_tokens = valid_tokens if valid_tokens is not None else sampling_params.max_tokens
|
||||
assert len(
|
||||
@ -132,6 +160,7 @@ def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams,
|
||||
) == generated_tokens, f"expected {generated_tokens} logprobs, but got {len(beam.logprobs)}"
|
||||
log_sum = 0.0
|
||||
for logprob_dict in (beam.logprobs):
|
||||
assert isinstance(logprob_dict, dict)
|
||||
for logprob_value in logprob_dict.values():
|
||||
log_sum += logprob_value.logprob
|
||||
assert log_sum == beam.cumulative_logprob, f"expected {beam.cumulative_logprob} logprob, but got {log_sum}"
|
||||
@ -145,6 +174,7 @@ def check_cache_indirection(beam: CompletionOutput,
|
||||
prompt_length: int, beam_idx: int,
|
||||
valid_tokens: int | None) -> None:
|
||||
"""Check if the cache indirection seen by the model is the same as the expected cache indirection"""
|
||||
assert beam.additional_generation_outputs is not None
|
||||
cache_indirection = beam.additional_generation_outputs["cache_indirection"]
|
||||
assert cache_indirection is not None, "cache indirection should not be None"
|
||||
assert cache_indirection.shape[
|
||||
@ -207,6 +237,11 @@ def check_context_logits(output: GenerationResult,
|
||||
assert output.context_logits is None, "context logits should be None"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def single_process(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
def validate_output(output: GenerationResult, input_prompt: list[int],
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Perform several checks on the output of a single prompt"""
|
||||
@ -226,11 +261,63 @@ def validate_output(output: GenerationResult, input_prompt: list[int],
|
||||
|
||||
|
||||
def validate_outputs(llm: LLM, input_prompts: list[list[int]],
|
||||
sampling_params: SamplingParams) -> None:
|
||||
sampling_params: SamplingParams,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
check_no_sync: bool) -> None:
|
||||
"""Generate outputs for a list of prompts and validate the outputs"""
|
||||
outputs = llm.generate(input_prompts, sampling_params=sampling_params)
|
||||
num_prompts = len(input_prompts)
|
||||
|
||||
outputs = llm.generate(deepcopy(input_prompts),
|
||||
sampling_params=deepcopy(sampling_params))
|
||||
|
||||
if check_no_sync:
|
||||
del outputs # treat previous .generate as warmup, ignore results
|
||||
|
||||
with monkeypatch.context() as patcher:
|
||||
sample_async_orig = TorchSampler.sample_async
|
||||
update_requests_orig = TorchSampler.update_requests
|
||||
|
||||
_sample_async_hook_called = False
|
||||
_update_requests_hook_called = False
|
||||
|
||||
def _sample_async_hook(*args, **kwargs):
|
||||
nonlocal _sample_async_hook_called
|
||||
_sample_async_hook_called = True
|
||||
|
||||
with assert_no_cuda_sync():
|
||||
return sample_async_orig(*args, **kwargs)
|
||||
|
||||
def _update_requests_hook(self, state: SampleStateTorch, *args,
|
||||
**kwargs):
|
||||
nonlocal _update_requests_hook_called
|
||||
_update_requests_hook_called = True
|
||||
|
||||
# await sampling event outside sync-check (because this syncs)
|
||||
sampler_event = state.sampler_event
|
||||
if sampler_event:
|
||||
sampler_event.synchronize()
|
||||
|
||||
with assert_no_cuda_sync():
|
||||
state.sampler_event = None
|
||||
try:
|
||||
return update_requests_orig(self, state, *args,
|
||||
**kwargs)
|
||||
finally:
|
||||
state.sampler_event = sampler_event
|
||||
|
||||
# Intercept sampler methods to check that they do not sync (requires
|
||||
# TLLM_WORKER_USE_SINGLE_PROCESS).
|
||||
patcher.setattr(TorchSampler, "sample_async", _sample_async_hook)
|
||||
patcher.setattr(TorchSampler, "update_requests",
|
||||
_update_requests_hook)
|
||||
|
||||
outputs = llm.generate(deepcopy(input_prompts),
|
||||
sampling_params=deepcopy(sampling_params))
|
||||
|
||||
assert _sample_async_hook_called
|
||||
assert _update_requests_hook_called
|
||||
|
||||
num_prompts = len(input_prompts)
|
||||
assert isinstance(outputs, list)
|
||||
assert len(
|
||||
outputs
|
||||
) == num_prompts, f"expected {num_prompts} outputs, but got {len(outputs)}"
|
||||
@ -247,73 +334,45 @@ def validate_outputs(llm: LLM, input_prompts: list[list[int]],
|
||||
@pytest.mark.parametrize("gather_generation_logits", [True, False])
|
||||
@pytest.mark.parametrize("gather_context_logits", [True, False])
|
||||
@pytest.mark.parametrize("num_output_beams", [1, 2])
|
||||
@pytest.mark.parametrize("num_prompts", [1, 2])
|
||||
@pytest.mark.parametrize("num_prompts", [1, 2, 3])
|
||||
@pytest.mark.parametrize("stop_token_ids", [[15], None])
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_beam_search_e2e(
|
||||
gather_context_logits: bool,
|
||||
gather_generation_logits: bool,
|
||||
return_log_probs: bool,
|
||||
num_output_beams: int,
|
||||
num_prompts: int,
|
||||
llm,
|
||||
fixed_params,
|
||||
input_prompts,
|
||||
) -> None:
|
||||
if return_log_probs and num_prompts > 1 and llm.args.sampler_type == "TRTLLMSampler":
|
||||
pytest.skip(
|
||||
"Beam search currently does not support return_log_probs with multiple prompts"
|
||||
)
|
||||
if return_log_probs and llm.args.sampler_type == "TRTLLMSampler":
|
||||
pytest.skip(
|
||||
"Beam search on TRTLLMSampler does not correctly handle log_probs if called multiple times"
|
||||
)
|
||||
|
||||
# create sampling parameters
|
||||
# additional_model_outputs is used to gather the cache indirection from the model.
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=fixed_params["max_tokens"],
|
||||
n=num_output_beams,
|
||||
best_of=fixed_params["max_beam_width"],
|
||||
use_beam_search=True,
|
||||
return_context_logits=gather_context_logits,
|
||||
return_generation_logits=gather_generation_logits,
|
||||
logprobs=return_log_probs,
|
||||
end_id=-1,
|
||||
additional_model_outputs=["cache_indirection"],
|
||||
)
|
||||
validate_outputs(llm, input_prompts[:num_prompts], sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("return_log_probs", [True, False])
|
||||
@pytest.mark.parametrize("gather_generation_logits", [True, False])
|
||||
@pytest.mark.parametrize("gather_context_logits", [True, False])
|
||||
@pytest.mark.parametrize("num_output_beams", [1, 2])
|
||||
@pytest.mark.parametrize("num_prompts", [1, 2, 3])
|
||||
@pytest.mark.parametrize("stop_token_ids", [[15], None])
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_beam_search_e2e_cuda_graph_and_overlap(
|
||||
gather_context_logits: bool,
|
||||
gather_generation_logits: bool,
|
||||
return_log_probs: bool,
|
||||
num_output_beams: int,
|
||||
num_prompts: int,
|
||||
stop_token_ids: list[int] | None,
|
||||
llm_cuda_graph,
|
||||
llm: LLM,
|
||||
fixed_params,
|
||||
input_prompts,
|
||||
single_process: bool,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
if return_log_probs and num_prompts > 1 and llm_cuda_graph.args.sampler_type == "TRTLLMSampler":
|
||||
llm_args = cast(TorchLlmArgs, llm.args)
|
||||
check_no_sync = single_process # single_process only used for sync check
|
||||
if check_no_sync and cast(TorchLlmArgs,
|
||||
llm.args).sampler_type != "TorchSampler":
|
||||
pytest.skip("Sync check only supported for TorchSampler")
|
||||
if check_no_sync and (cast(TorchLlmArgs, llm.args).sampler_type
|
||||
== "TorchSampler") and stop_token_ids is not None:
|
||||
# FIXME: Fix TorchSampler._are_stop_words
|
||||
pytest.skip("Stop word handling in TorchSampler syncs")
|
||||
|
||||
if return_log_probs and num_prompts > 1 and llm_args.sampler_type == "TRTLLMSampler":
|
||||
pytest.skip(
|
||||
"Beam search currently does not support return_log_probs with multiple prompts"
|
||||
)
|
||||
if return_log_probs and llm_cuda_graph.args.sampler_type == "TRTLLMSampler":
|
||||
if return_log_probs and llm_args.sampler_type == "TRTLLMSampler":
|
||||
pytest.skip(
|
||||
"Beam search on TRTLLMSampler does not correctly handle log_probs if called multiple times"
|
||||
)
|
||||
if stop_token_ids is not None and llm_cuda_graph.args.sampler_type == "TRTLLMSampler":
|
||||
if stop_token_ids is not None and llm_args.sampler_type == "TRTLLMSampler":
|
||||
pytest.skip(
|
||||
"Beam search on TRTLLMSampler does not correctly handle stop_token_ids"
|
||||
)
|
||||
|
||||
# create sampling parameters
|
||||
# additional_model_outputs is used to gather the cache indirection from the model.
|
||||
sampling_params = SamplingParams(
|
||||
@ -328,8 +387,13 @@ def test_beam_search_e2e_cuda_graph_and_overlap(
|
||||
stop_token_ids=stop_token_ids,
|
||||
additional_model_outputs=["cache_indirection"],
|
||||
)
|
||||
validate_outputs(llm_cuda_graph, input_prompts[:num_prompts],
|
||||
sampling_params)
|
||||
validate_outputs(
|
||||
llm,
|
||||
input_prompts[:num_prompts],
|
||||
sampling_params,
|
||||
check_no_sync=check_no_sync,
|
||||
monkeypatch=monkeypatch,
|
||||
)
|
||||
|
||||
|
||||
###########################################################################
|
||||
@ -440,6 +504,7 @@ def test_beam_search_sampling_batch_basic():
|
||||
)
|
||||
|
||||
# Validate output shapes
|
||||
assert softmax is not None
|
||||
expected_tokens_shape = (batch_size, beam_width)
|
||||
assert next_tokens.shape == expected_tokens_shape, (
|
||||
f"Expected shape {expected_tokens_shape}, got {next_tokens.shape}")
|
||||
@ -564,9 +629,11 @@ def create_default_sampler(test_params: GeneralTestParams) -> TorchSampler:
|
||||
assert max_seq_len > test_params.seq_len, "Max sequence length must be greater than sequence length"
|
||||
assert max_batch_size > test_params.batch_size, "Max batch size must be greater than batch size"
|
||||
assert max_batch_size > test_params.seq_slot, "Max batch size must be greater than sequence slot"
|
||||
assert sampler.store.cache_indirection is not None
|
||||
assert sampler.store.cache_indirection.shape == (
|
||||
max_batch_size, max_beam_width,
|
||||
max_seq_len), "Cache indirection shape mismatch"
|
||||
assert sampler.store.original_tokens is not None
|
||||
assert sampler.store.original_tokens.shape == (
|
||||
max_batch_size, max_beam_width,
|
||||
max_seq_len), "Original tokens shape mismatch"
|
||||
@ -579,121 +646,154 @@ def test_create_beam_history():
|
||||
This test verifies that beam history is correctly reconstructed by following
|
||||
the cache_indirection backwards to obtain the correct token sequence.
|
||||
"""
|
||||
test_params = GeneralTestParams()
|
||||
request = create_default_request(test_params)
|
||||
sampler = create_default_sampler(test_params)
|
||||
|
||||
# Extract parameters from the test parameters
|
||||
beam_width = test_params.beam_width
|
||||
prompt_len = test_params.prompt_len
|
||||
num_generated_tokens = test_params.num_generated_tokens
|
||||
seq_slot = test_params.seq_slot
|
||||
vocab_size = test_params.vocab_size
|
||||
num_logprobs = test_params.num_logprobs + 1
|
||||
cache_indirection = sampler.store.cache_indirection
|
||||
original_tokens = sampler.store.original_tokens
|
||||
original_logprobs = torch.zeros(
|
||||
(beam_width, num_generated_tokens, num_logprobs),
|
||||
dtype=torch.float32,
|
||||
device=original_tokens.device)
|
||||
original_logprob_indices = torch.zeros(
|
||||
(beam_width, num_generated_tokens, num_logprobs),
|
||||
dtype=torch.int32,
|
||||
device=original_tokens.device)
|
||||
original_cum_logprobs = sampler.store.cum_log_probs
|
||||
@contextmanager
|
||||
def _uut_provider(
|
||||
is_warmup: bool) -> Generator[Callable[[], None], None, None]:
|
||||
test_params = GeneralTestParams()
|
||||
request = create_default_request(test_params)
|
||||
sampler = create_default_sampler(test_params)
|
||||
|
||||
# Fill the request with some random tokens that will be overwritten by the beam search sampling
|
||||
# Beam history is created before add_token is called
|
||||
request.set_generated_tokens(
|
||||
torch.randint(0,
|
||||
vocab_size, (beam_width, num_generated_tokens - 1),
|
||||
dtype=torch.int32).tolist())
|
||||
# random fill
|
||||
torch.manual_seed(42)
|
||||
original_tokens[seq_slot, :beam_width, prompt_len:prompt_len +
|
||||
num_generated_tokens] = torch.randint(
|
||||
0,
|
||||
beam_width, (beam_width, num_generated_tokens),
|
||||
dtype=torch.int32)
|
||||
assert original_tokens.sum(
|
||||
) > 0, "Original tokens must not only contain zeros. Otherwise change the seed."
|
||||
# Extract parameters from the test parameters
|
||||
beam_width = test_params.beam_width
|
||||
prompt_len = test_params.prompt_len
|
||||
num_generated_tokens = test_params.num_generated_tokens
|
||||
seq_slot = test_params.seq_slot
|
||||
vocab_size = test_params.vocab_size
|
||||
num_logprobs = test_params.num_logprobs + 1
|
||||
cache_indirection = sampler.store.cache_indirection
|
||||
assert cache_indirection is not None
|
||||
original_tokens = sampler.store.original_tokens
|
||||
assert original_tokens is not None
|
||||
original_logprobs = torch.zeros(
|
||||
(beam_width, num_generated_tokens, num_logprobs),
|
||||
dtype=torch.float32,
|
||||
device=original_tokens.device)
|
||||
original_logprob_indices = torch.zeros(
|
||||
(beam_width, num_generated_tokens, num_logprobs),
|
||||
dtype=torch.int32,
|
||||
device=original_tokens.device)
|
||||
original_cum_logprobs = sampler.store.cum_log_probs
|
||||
assert original_cum_logprobs is not None
|
||||
|
||||
original_logprobs[:beam_width] = torch.randn(
|
||||
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
|
||||
dtype=torch.float32)
|
||||
original_logprob_indices[:beam_width] = torch.randint(
|
||||
0,
|
||||
vocab_size,
|
||||
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
|
||||
dtype=torch.int32)
|
||||
assert (original_logprobs != 0).sum(
|
||||
) > 0, "Original log probs must not only contain zeros. Otherwise change the seed."
|
||||
assert (original_logprob_indices).sum(
|
||||
) > 0, "Original log prob indices must not only contain zeros. Otherwise change the seed."
|
||||
# Fill the request with some random tokens that will be overwritten by the beam search sampling
|
||||
# Beam history is created before add_token is called
|
||||
request.set_generated_tokens(
|
||||
torch.randint(0,
|
||||
vocab_size, (beam_width, num_generated_tokens - 1),
|
||||
dtype=torch.int32).tolist())
|
||||
# random fill
|
||||
torch.manual_seed(42)
|
||||
original_tokens[seq_slot, :beam_width, prompt_len:prompt_len +
|
||||
num_generated_tokens] = torch.randint(
|
||||
0,
|
||||
beam_width, (beam_width, num_generated_tokens),
|
||||
dtype=torch.int32)
|
||||
assert original_tokens.sum(
|
||||
) > 0, "Original tokens must not only contain zeros. Otherwise change the seed."
|
||||
|
||||
# set the logprobs in the request:
|
||||
token_logprobs = sampler._convert_logprobs_tensor_to_list(
|
||||
original_logprob_indices[:beam_width, :num_generated_tokens - 1],
|
||||
original_logprobs[:beam_width, :num_generated_tokens - 1],
|
||||
)
|
||||
request.py_result.set_log_probs(
|
||||
token_logprobs,
|
||||
cum_log_probs=torch.zeros_like(
|
||||
original_cum_logprobs[seq_slot, :beam_width]).tolist())
|
||||
original_logprobs[:beam_width] = torch.randn(
|
||||
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
|
||||
dtype=torch.float32)
|
||||
original_logprob_indices[:beam_width] = torch.randint(
|
||||
0,
|
||||
vocab_size,
|
||||
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
|
||||
dtype=torch.int32)
|
||||
assert (original_logprobs != 0).sum(
|
||||
) > 0, "Original log probs must not only contain zeros. Otherwise change the seed."
|
||||
assert (original_logprob_indices).sum(
|
||||
) > 0, "Original log prob indices must not only contain zeros. Otherwise change the seed."
|
||||
|
||||
original_cum_logprobs[seq_slot, :beam_width] = torch.randn(
|
||||
(beam_width, ), dtype=torch.float32)
|
||||
assert (original_cum_logprobs != 0).sum(
|
||||
) > 0, "Original cumulative log probs must not only contain zeros. Otherwise change the seed."
|
||||
# set the logprobs in the request:
|
||||
token_logprobs = sampler._convert_logprobs_tensor_to_list(
|
||||
original_logprob_indices[:beam_width, :num_generated_tokens - 1],
|
||||
original_logprobs[:beam_width, :num_generated_tokens - 1],
|
||||
)
|
||||
request.py_result.set_log_probs(
|
||||
token_logprobs,
|
||||
cum_log_probs=torch.zeros_like(
|
||||
original_cum_logprobs[seq_slot, :beam_width]).tolist())
|
||||
|
||||
cache_indirection[seq_slot, :beam_width, prompt_len:prompt_len +
|
||||
num_generated_tokens] = torch.randint(
|
||||
0,
|
||||
beam_width, (beam_width, num_generated_tokens),
|
||||
dtype=torch.int32)
|
||||
assert cache_indirection[
|
||||
seq_slot, :beam_width,
|
||||
prompt_len:prompt_len + num_generated_tokens].sum(
|
||||
) > 0, "Deterministic offsets must not only contain zeros. Otherwise change the seed."
|
||||
original_cum_logprobs[seq_slot, :beam_width] = torch.randn(
|
||||
(beam_width, ), dtype=torch.float32)
|
||||
assert (original_cum_logprobs != 0).sum(
|
||||
) > 0, "Original cumulative log probs must not only contain zeros. Otherwise change the seed."
|
||||
|
||||
# set the new log probs and tokens for the beam search sampling
|
||||
sampler.store.sampled_log_probs[
|
||||
seq_slot, :beam_width] = original_logprobs[:beam_width,
|
||||
num_generated_tokens - 1,
|
||||
0:1]
|
||||
sampler.store.new_tokens[
|
||||
0,
|
||||
seq_slot, :beam_width] = original_logprob_indices[:beam_width,
|
||||
num_generated_tokens -
|
||||
1, 0]
|
||||
# test
|
||||
beam_history_builder = sampler._prepare_beam_history(
|
||||
request, finish_reasons=torch.ones((beam_width, ), dtype=torch.int))
|
||||
torch.cuda.synchronize()
|
||||
beam_history = beam_history_builder()
|
||||
cache_indirection[seq_slot, :beam_width, prompt_len:prompt_len +
|
||||
num_generated_tokens] = torch.randint(
|
||||
0,
|
||||
beam_width, (beam_width, num_generated_tokens),
|
||||
dtype=torch.int32)
|
||||
assert cache_indirection[
|
||||
seq_slot, :beam_width,
|
||||
prompt_len:prompt_len + num_generated_tokens].sum(
|
||||
) > 0, "Deterministic offsets must not only contain zeros. Otherwise change the seed."
|
||||
|
||||
# expected selection:
|
||||
# Currently beam history only contains the generated tokens, not the prompt tokens.
|
||||
expected_tokens = torch.zeros(
|
||||
(sampler.max_beam_width, num_generated_tokens), dtype=torch.int32)
|
||||
expected_logprobs = torch.zeros(
|
||||
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
|
||||
dtype=torch.float32)
|
||||
for gen_idx in range(num_generated_tokens):
|
||||
token_idx = prompt_len + gen_idx
|
||||
expected_tokens[:, gen_idx] = original_tokens[
|
||||
seq_slot, cache_indirection[seq_slot, :, token_idx], token_idx]
|
||||
expected_logprobs[:, gen_idx] = original_logprobs[cache_indirection[
|
||||
seq_slot, :beam_width, token_idx], gen_idx]
|
||||
# set the new log probs and tokens for the beam search sampling
|
||||
assert sampler.store.sampled_log_probs is not None
|
||||
sampler.store.sampled_log_probs[
|
||||
seq_slot, :beam_width] = original_logprobs[:beam_width,
|
||||
num_generated_tokens - 1,
|
||||
0:1]
|
||||
sampler.store.new_tokens[
|
||||
0, seq_slot, :
|
||||
beam_width] = original_logprob_indices[:beam_width,
|
||||
num_generated_tokens - 1, 0]
|
||||
|
||||
torch.testing.assert_close(beam_history.tokens[:beam_width],
|
||||
expected_tokens[:beam_width])
|
||||
# test logprobs as well
|
||||
torch.testing.assert_close(beam_history.logprobs[:beam_width],
|
||||
expected_logprobs[:beam_width])
|
||||
torch.testing.assert_close(
|
||||
beam_history.cum_logprobs[:beam_width],
|
||||
original_cum_logprobs[seq_slot, :beam_width].to("cpu"))
|
||||
@dataclass
|
||||
class UutResult:
|
||||
beam_history_builder: Callable[[], BeamHistory | None] | None
|
||||
|
||||
@dataclass
|
||||
class UutResultWrapper:
|
||||
result: UutResult | None = None
|
||||
|
||||
res = UutResultWrapper()
|
||||
|
||||
# test
|
||||
def _uut(res=res):
|
||||
res.result = UutResult(
|
||||
beam_history_builder=sampler._prepare_beam_history(
|
||||
request,
|
||||
finish_reasons=torch.ones((beam_width, ), dtype=torch.int),
|
||||
), )
|
||||
|
||||
yield _uut
|
||||
|
||||
torch.cuda.synchronize()
|
||||
assert res.result is not None
|
||||
beam_history_builder = res.result.beam_history_builder
|
||||
assert beam_history_builder is not None
|
||||
beam_history = beam_history_builder()
|
||||
assert beam_history is not None
|
||||
|
||||
# expected selection:
|
||||
# Currently beam history only contains the generated tokens, not the prompt tokens.
|
||||
expected_tokens = torch.zeros(
|
||||
(sampler.max_beam_width, num_generated_tokens), dtype=torch.int32)
|
||||
expected_logprobs = torch.zeros(
|
||||
(beam_width, num_generated_tokens, original_logprobs.shape[-1]),
|
||||
dtype=torch.float32)
|
||||
for gen_idx in range(num_generated_tokens):
|
||||
token_idx = prompt_len + gen_idx
|
||||
expected_tokens[:, gen_idx] = original_tokens[
|
||||
seq_slot, cache_indirection[seq_slot, :, token_idx], token_idx]
|
||||
expected_logprobs[:, gen_idx] = original_logprobs[cache_indirection[
|
||||
seq_slot, :beam_width, token_idx], gen_idx]
|
||||
|
||||
torch.testing.assert_close(beam_history.tokens[:beam_width],
|
||||
expected_tokens[:beam_width])
|
||||
# test logprobs as well
|
||||
assert beam_history.logprobs is not None
|
||||
torch.testing.assert_close(beam_history.logprobs[:beam_width],
|
||||
expected_logprobs[:beam_width])
|
||||
assert beam_history.cum_logprobs is not None
|
||||
torch.testing.assert_close(
|
||||
beam_history.cum_logprobs[:beam_width],
|
||||
original_cum_logprobs[seq_slot, :beam_width].to("cpu"))
|
||||
|
||||
run_test_with_warmup(_uut_provider, max_sync_s=1)
|
||||
|
||||
|
||||
def test_finish_beams():
|
||||
@ -702,81 +802,94 @@ def test_finish_beams():
|
||||
This test verifies that beams are correctly finalized.
|
||||
"""
|
||||
|
||||
test_params = GeneralTestParams()
|
||||
beam_width = test_params.beam_width
|
||||
num_generated_tokens = test_params.num_generated_tokens
|
||||
test_params.seq_len
|
||||
end_id = test_params.end_id
|
||||
batch_size = test_params.batch_size
|
||||
vocab_size = test_params.vocab_size
|
||||
num_logprobs = 1
|
||||
request = create_default_request(test_params)
|
||||
sampler = create_default_sampler(test_params)
|
||||
store_device = sampler.store.cache_indirection.device
|
||||
@contextmanager
|
||||
def _uut_provider(
|
||||
is_warmup: bool) -> Generator[Callable[[], None], None, None]:
|
||||
test_params = GeneralTestParams()
|
||||
beam_width = test_params.beam_width
|
||||
num_generated_tokens = test_params.num_generated_tokens
|
||||
end_id = test_params.end_id
|
||||
batch_size = test_params.batch_size
|
||||
vocab_size = test_params.vocab_size
|
||||
num_logprobs = 1
|
||||
request = create_default_request(test_params)
|
||||
sampler = create_default_sampler(test_params)
|
||||
assert sampler.store.cache_indirection is not None
|
||||
|
||||
request.set_generated_tokens(
|
||||
torch.randint(0,
|
||||
vocab_size, (beam_width, num_generated_tokens),
|
||||
dtype=torch.int32).tolist())
|
||||
request.set_generated_tokens(
|
||||
torch.randint(0,
|
||||
vocab_size, (beam_width, num_generated_tokens),
|
||||
dtype=torch.int32).tolist())
|
||||
|
||||
torch.manual_seed(42)
|
||||
# Do not keep end_id tokens in the tensor. This would interfere with the test.
|
||||
tokens = torch.randint(
|
||||
0,
|
||||
end_id, (batch_size, sampler.max_beam_width, num_generated_tokens),
|
||||
dtype=torch.int32,
|
||||
device=store_device)
|
||||
logprobs = torch.randn((batch_size, sampler.max_beam_width,
|
||||
num_generated_tokens, num_logprobs),
|
||||
dtype=torch.float32,
|
||||
device=store_device)
|
||||
cum_logprobs = logprobs[..., 0].sum(dim=-1)
|
||||
torch.manual_seed(42)
|
||||
# Do not keep end_id tokens in the tensor. This would interfere with the test.
|
||||
tokens = torch.randint(
|
||||
0,
|
||||
end_id, (batch_size, sampler.max_beam_width, num_generated_tokens),
|
||||
dtype=torch.int32)
|
||||
logprobs = torch.randn((batch_size, sampler.max_beam_width,
|
||||
num_generated_tokens, num_logprobs),
|
||||
dtype=torch.float32)
|
||||
cum_logprobs = logprobs[..., 0].sum(dim=-1)
|
||||
|
||||
# assert that the buffers are different from zero. Otherwise the test may pass if the function does not work.
|
||||
assert tokens.sum(
|
||||
) > 0, "Tokens must not only contain zeros. Otherwise change the seed."
|
||||
assert torch.any(logprobs != 0) and torch.any(
|
||||
cum_logprobs != 0
|
||||
), "Log probs and cumulative log probs must not only contain zeros. Otherwise change the seed."
|
||||
# assert that the buffers are different from zero. Otherwise the test may pass if the function does not work.
|
||||
assert tokens.sum(
|
||||
) > 0, "Tokens must not only contain zeros. Otherwise change the seed."
|
||||
assert torch.any(logprobs != 0) and torch.any(
|
||||
cum_logprobs != 0
|
||||
), "Log probs and cumulative log probs must not only contain zeros. Otherwise change the seed."
|
||||
|
||||
tokens[batch_size - 1, 0, num_generated_tokens //
|
||||
2:] = BEAM_SEARCH_PAD_TOKEN # simulate early finished beam
|
||||
tokens[batch_size - 1, 0, num_generated_tokens //
|
||||
2:] = BEAM_SEARCH_PAD_TOKEN # simulate early finished beam
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
beam_history = BeamHistory(
|
||||
tokens=tokens[batch_idx, :beam_width],
|
||||
logprobs=logprobs[batch_idx, :beam_width],
|
||||
cum_logprobs=cum_logprobs[batch_idx, :beam_width])
|
||||
request.py_return_log_probs = False
|
||||
prompt_len = request.py_prompt_len
|
||||
|
||||
if batch_idx < batch_size - 1:
|
||||
# requests are not finished yet
|
||||
sampler._finalize_beam(request, beam_history)
|
||||
final_tokens = torch.tensor(request.get_tokens(),
|
||||
device=store_device,
|
||||
dtype=torch.int32)[:, prompt_len:]
|
||||
torch.testing.assert_close(final_tokens,
|
||||
tokens[batch_idx, :beam_width])
|
||||
# Test the case where end_ids are present in the output
|
||||
else:
|
||||
sampler._finalize_beam(request, beam_history)
|
||||
token_history = []
|
||||
|
||||
# Given input for beam 0: [ token, token, ..., token, BEAM_SEARCH_PAD_TOKEN, BEAM_SEARCH_PAD_TOKEN, ..., BEAM_SEARCH_PAD_TOKEN]
|
||||
# Expected output for beam 0: [ token, token, ..., token]
|
||||
final_tokens_1p = torch.tensor(request.get_tokens()[1:],
|
||||
device=store_device,
|
||||
dtype=torch.int32)[:, prompt_len:]
|
||||
final_tokens_0 = torch.tensor(request.get_tokens()[0],
|
||||
device=store_device,
|
||||
dtype=torch.int32)[prompt_len:]
|
||||
torch.testing.assert_close(final_tokens_1p, tokens[batch_idx,
|
||||
1:beam_width])
|
||||
torch.testing.assert_close(final_tokens_0.shape[0],
|
||||
num_generated_tokens // 2)
|
||||
torch.testing.assert_close(
|
||||
final_tokens_0, tokens[batch_idx,
|
||||
0, :num_generated_tokens // 2])
|
||||
# test
|
||||
def _uut():
|
||||
nonlocal token_history
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
beam_history = BeamHistory(
|
||||
tokens=tokens[batch_idx, :beam_width],
|
||||
logprobs=logprobs[batch_idx, :beam_width],
|
||||
cum_logprobs=cum_logprobs[batch_idx, :beam_width])
|
||||
request.py_return_log_probs = False
|
||||
|
||||
sampler._finalize_beam(request, beam_history)
|
||||
|
||||
token_history.append(deepcopy(request.get_tokens()))
|
||||
|
||||
yield _uut
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_final_tokens = token_history[batch_idx]
|
||||
|
||||
if batch_idx < batch_size - 1:
|
||||
# requests are not finished yet
|
||||
final_tokens = torch.tensor(batch_final_tokens,
|
||||
dtype=torch.int32)[:, prompt_len:]
|
||||
torch.testing.assert_close(final_tokens,
|
||||
tokens[batch_idx, :beam_width])
|
||||
# Test the case where end_ids are present in the output
|
||||
else:
|
||||
# Given input for beam 0: [ token, token, ..., token, BEAM_SEARCH_PAD_TOKEN, BEAM_SEARCH_PAD_TOKEN, ..., BEAM_SEARCH_PAD_TOKEN]
|
||||
# Expected output for beam 0: [ token, token, ..., token]
|
||||
final_tokens_1p = torch.tensor(batch_final_tokens[1:],
|
||||
dtype=torch.int32)[:,
|
||||
prompt_len:]
|
||||
final_tokens_0 = torch.tensor(batch_final_tokens[0],
|
||||
dtype=torch.int32)[prompt_len:]
|
||||
torch.testing.assert_close(final_tokens_1p,
|
||||
tokens[batch_idx, 1:beam_width])
|
||||
torch.testing.assert_close(final_tokens_0.shape[0],
|
||||
num_generated_tokens // 2)
|
||||
torch.testing.assert_close(
|
||||
final_tokens_0, tokens[batch_idx,
|
||||
0, :num_generated_tokens // 2])
|
||||
|
||||
run_test_with_warmup(_uut_provider, max_sync_s=1)
|
||||
|
||||
|
||||
@force_ampere # Save H100 resource
|
||||
|
||||
@ -12,11 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata
|
||||
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
|
||||
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
|
||||
@ -49,6 +50,7 @@ class DummyConfig(PretrainedConfig):
|
||||
class DummyModel(torch.nn.Module):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
super().__init__()
|
||||
assert model_config.pretrained_config is not None
|
||||
self.dtype = model_config.pretrained_config.torch_dtype
|
||||
self.model_config = model_config
|
||||
|
||||
@ -67,9 +69,11 @@ class DummyModel(torch.nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
return_context_logits: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
) -> dict[str, torch.Tensor]:
|
||||
num_batch_tokens = input_ids.size(0)
|
||||
|
||||
assert self.config is not None
|
||||
assert attn_metadata.seq_lens_cuda is not None
|
||||
vocab_size = self.config.vocab_size
|
||||
last_tokens = (
|
||||
torch.cumsum(
|
||||
@ -98,18 +102,18 @@ class DummyModel(torch.nn.Module):
|
||||
|
||||
num_context_requests = attn_metadata.num_contexts
|
||||
# each beam has its own attn_metadata.seq_lens_cuda entry
|
||||
num_generation_requests = (
|
||||
last_tokens.shape[0] - num_context_requests
|
||||
) // attn_metadata.beam_width
|
||||
beam_width = cast(TrtllmAttentionMetadata, attn_metadata).beam_width
|
||||
num_generation_requests = (last_tokens.shape[0] - num_context_requests) // beam_width
|
||||
num_requests = num_generation_requests + num_context_requests
|
||||
|
||||
# return cache indirection, as additional model output.
|
||||
# each sequence should only return a 1D cache indirection tensor
|
||||
assert attn_metadata.cache_indirection is not None
|
||||
context_cache_indirection = attn_metadata.cache_indirection[:num_context_requests, 0]
|
||||
generation_cache_indirection = attn_metadata.cache_indirection[
|
||||
num_context_requests:num_requests
|
||||
].view(
|
||||
num_generation_requests * attn_metadata.beam_width,
|
||||
num_generation_requests * beam_width,
|
||||
attn_metadata.cache_indirection.shape[-1],
|
||||
)
|
||||
return {
|
||||
@ -130,7 +134,7 @@ class DummyModel(torch.nn.Module):
|
||||
|
||||
@register_checkpoint_weight_loader("DUMMY_FORMAT")
|
||||
class DummyWeightLoader(BaseWeightLoader):
|
||||
def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]:
|
||||
def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]: # type: ignore
|
||||
"""Load weights from your dummy format.
|
||||
Args:
|
||||
checkpoint_dir: Directory containing checkpoint files
|
||||
|
||||
@ -15,25 +15,14 @@
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
from typing import (
|
||||
Callable,
|
||||
ContextManager,
|
||||
Final,
|
||||
Generator,
|
||||
Optional,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import Callable, ContextManager, Final, Generator, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
import flashinfer.sampling
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from scipy.stats import power_divergence
|
||||
from utils.util import assert_no_cuda_sync, force_ampere
|
||||
from utils.util import UutProvider, assert_no_cuda_sync, force_ampere, run_test_with_warmup
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import convert_wordlist
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import (
|
||||
@ -363,53 +352,6 @@ class TestStrategySelection:
|
||||
assert torch_sampler.should_provide_draft_probs(request) == (not is_greedy)
|
||||
|
||||
|
||||
class UutProvider(Protocol):
|
||||
def __call__(self, is_warmup: bool) -> ContextManager[Callable[[], None]]: ...
|
||||
|
||||
|
||||
def _run_test_with_warmup(
|
||||
uut_provider: UutProvider,
|
||||
warmup_sizes_bytes: tuple[int] = (4 * 2**30,),
|
||||
*,
|
||||
max_sync_s: Optional[float],
|
||||
):
|
||||
"""Run UUT including setup and warmup.
|
||||
|
||||
This is mainly used to check that the UUT does not CUDA device sync. Thus,
|
||||
given that PyTorch's caching memory allocator can device sync when it runs
|
||||
out of cached GPU memory segments, the warmup allocates some GPU memory.
|
||||
|
||||
The warmup also runs the test once. This avoids issues with things like lazy loading
|
||||
of device code. The UUT provider can use the 'is_warmup' argument to adapt its
|
||||
behavior to the warmup and final test runs.
|
||||
|
||||
If max_sync_s is provided, this helper checks that the UUT does not device sync,
|
||||
assuming that the sync (CPU) part of the code takes no longer than max_sync_s
|
||||
seconds to complete.
|
||||
|
||||
It is the user's responsibility to ensure that the amount of submitted work
|
||||
does not exceed the CUDA driver/device queue capacity, which would make
|
||||
the execution appear synchronous.
|
||||
"""
|
||||
with torch.cuda.Stream():
|
||||
with uut_provider(is_warmup=True) as uut:
|
||||
bufs = []
|
||||
for warmup_size in warmup_sizes_bytes:
|
||||
bufs.append(
|
||||
torch.ones(warmup_size, device=torch.cuda.current_device(), dtype=torch.int8)
|
||||
)
|
||||
del bufs
|
||||
uut()
|
||||
|
||||
with uut_provider(is_warmup=False) as uut:
|
||||
with (
|
||||
assert_no_cuda_sync(sync_timeout_s=max_sync_s)
|
||||
if max_sync_s is not None
|
||||
else nullcontext()
|
||||
):
|
||||
uut()
|
||||
|
||||
|
||||
@force_ampere
|
||||
@pytest.mark.parametrize(
|
||||
"draft_len, with_ctx, with_gen",
|
||||
@ -630,7 +572,7 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool)
|
||||
torch.testing.assert_close(res.result.req_offsets.to("cpu"), expected_req_offsets)
|
||||
torch.testing.assert_close(res.result.selected_logits.to("cpu"), expected_logits)
|
||||
|
||||
_run_test_with_warmup(_test_runner, max_sync_s=0.3)
|
||||
run_test_with_warmup(_test_runner, max_sync_s=0.3)
|
||||
|
||||
|
||||
class TestFinishReasons:
|
||||
@ -852,7 +794,7 @@ class TestFinishReasons:
|
||||
]
|
||||
)
|
||||
|
||||
_run_test_with_warmup(uut_provider, max_sync_s=0.5)
|
||||
run_test_with_warmup(uut_provider, max_sync_s=0.5)
|
||||
|
||||
@classmethod
|
||||
def test_are_stop_words_isnt_called_when_no_stop_words(cls, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -879,13 +821,13 @@ class TestFinishReasons:
|
||||
],
|
||||
extra_context=lambda: raising_stop_words_ctx(True),
|
||||
)
|
||||
_run_test_with_warmup(uut_provider_with_stop_words, max_sync_s=0.5)
|
||||
run_test_with_warmup(uut_provider_with_stop_words, max_sync_s=0.5)
|
||||
|
||||
uut_provider_without_stop_words = cls.RequestCase.build(
|
||||
[cls.RequestCase(prompt=[1], new_tokens=[4], finish_reasons=[cls.NOT_FINISHED])],
|
||||
extra_context=lambda: raising_stop_words_ctx(False),
|
||||
)
|
||||
_run_test_with_warmup(uut_provider_without_stop_words, max_sync_s=0.5)
|
||||
run_test_with_warmup(uut_provider_without_stop_words, max_sync_s=0.5)
|
||||
|
||||
|
||||
class TestBatchedSampling:
|
||||
@ -1532,7 +1474,7 @@ class TestBatchedSampling:
|
||||
|
||||
logit_offset += steps
|
||||
|
||||
_run_test_with_warmup(
|
||||
run_test_with_warmup(
|
||||
_uut_provider,
|
||||
max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample
|
||||
)
|
||||
@ -2247,7 +2189,7 @@ class TestBatchedSampling:
|
||||
num_samples=num_samples,
|
||||
)
|
||||
|
||||
_run_test_with_warmup(
|
||||
run_test_with_warmup(
|
||||
_uut_provider,
|
||||
max_sync_s=None, # NB: assert_no_cuda_sync called in TestBatchedSampler._sample
|
||||
)
|
||||
@ -2443,4 +2385,4 @@ class TestBatchedSampling:
|
||||
torch.testing.assert_close(new_tokens_host[:steps, seq_slot], req_tokens.cpu())
|
||||
input_offset += steps
|
||||
|
||||
_run_test_with_warmup(_uut_provider, max_sync_s=0.2)
|
||||
run_test_with_warmup(_uut_provider, max_sync_s=0.2)
|
||||
|
||||
@ -18,11 +18,11 @@ import math
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
from typing import Any, Callable, ContextManager, Generator, Protocol
|
||||
|
||||
import psutil
|
||||
import pynvml
|
||||
@ -520,6 +520,12 @@ def device_sleep(
|
||||
time.sleep(spin_s)
|
||||
|
||||
|
||||
class UutProvider(Protocol):
|
||||
|
||||
def __call__(self, is_warmup: bool) -> ContextManager[Callable[[], None]]:
|
||||
...
|
||||
|
||||
|
||||
@contextmanager
|
||||
def assert_no_cuda_sync(
|
||||
sync_timeout_s: float = 5, ) -> Generator[None, None, None]:
|
||||
@ -563,6 +569,47 @@ def assert_no_cuda_sync(
|
||||
scope_finished_event.synchronize()
|
||||
|
||||
|
||||
def run_test_with_warmup(
|
||||
uut_provider: UutProvider,
|
||||
warmup_sizes_bytes: tuple[int] = (4 * 2**30, ),
|
||||
*,
|
||||
max_sync_s: float | None,
|
||||
):
|
||||
"""Run UUT including setup and warmup.
|
||||
|
||||
This is mainly used to check that the UUT does not CUDA device sync. Thus,
|
||||
given that PyTorch's caching memory allocator can device sync when it runs
|
||||
out of cached GPU memory segments, the warmup allocates some GPU memory.
|
||||
|
||||
The warmup also runs the test once. This avoids issues with things like lazy loading
|
||||
of device code. The UUT provider can use the 'is_warmup' argument to adapt its
|
||||
behavior to the warmup and final test runs.
|
||||
|
||||
If max_sync_s is provided, this helper checks that the UUT does not device sync,
|
||||
assuming that the sync (CPU) part of the code takes no longer than max_sync_s
|
||||
seconds to complete.
|
||||
|
||||
It is the user's responsibility to ensure that the amount of submitted work
|
||||
does not exceed the CUDA driver/device queue capacity, which would make
|
||||
the execution appear synchronous.
|
||||
"""
|
||||
with torch.cuda.Stream():
|
||||
with uut_provider(is_warmup=True) as uut:
|
||||
bufs = []
|
||||
for warmup_size in warmup_sizes_bytes:
|
||||
bufs.append(
|
||||
torch.ones(warmup_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.int8))
|
||||
del bufs
|
||||
uut()
|
||||
|
||||
with uut_provider(is_warmup=False) as uut:
|
||||
with (assert_no_cuda_sync(sync_timeout_s=max_sync_s)
|
||||
if max_sync_s is not None else nullcontext()):
|
||||
uut()
|
||||
|
||||
|
||||
_pynvmlInited = False
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user