TensorRT-LLMs/tests/unittest/utils/test_util.py
mpikulski 7c8ba71b49
[TRTLLM-8832][feat] fully async _select_generated_logits with tests (#8628)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
2025-10-27 16:15:32 +01:00

67 lines
1.9 KiB
Python

# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import pytest
import torch
from .util import (DeviceSleepCtl, assert_no_cuda_sync, device_sleep,
force_ampere)
@force_ampere
@pytest.mark.parametrize(
"cancel",
[False, True],
)
def test_device_sleep(cancel: bool):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
sleep_ctl = DeviceSleepCtl()
sleep_time = 0.3
start_event.record()
device_sleep(sleep_time, ctl=sleep_ctl, spin_s=0.01)
end_event.record()
if cancel:
sleep_ctl.cancel()
end_event.synchronize()
# NB: torch.cuda.Event.elapsed_time returns millis
elapsed_time = start_event.elapsed_time(end_event) / 1000
if cancel:
assert elapsed_time < sleep_time
else:
assert elapsed_time >= sleep_time
@force_ampere
@pytest.mark.parametrize(
"uut_syncs",
[False, True],
)
def test_assert_no_cuda_sync(uut_syncs: bool):
def _uut():
if uut_syncs:
torch.cuda.synchronize()
ctx = pytest.raises(AssertionError, match="sync code should return quickly"
) if uut_syncs else nullcontext()
with ctx:
with assert_no_cuda_sync(sync_timeout_s=0.2):
_uut()