TensorRT-LLMs/tests/unittest/_torch/test_connector.py
Richard Huo ce580ce4f5
[None][feat] KV Cache Connector API (#7228)
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
Signed-off-by: richardhuo-nv <rihuo@nvidia.com>
Co-authored-by: jthomson04 <jwillthomson19@gmail.com>
Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com>
Co-authored-by: Sharan Chetlur <116769508+schetlur-nv@users.noreply.github.com>
2025-08-28 23:09:27 -04:00

171 lines
5.4 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import sys
from unittest.mock import MagicMock
import cloudpickle
import mpi4py
import pytest
from tensorrt_llm import mpi_rank
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import \
KvCacheConnectorManager
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
cloudpickle.register_pickle_by_value(sys.modules[__name__])
mpi4py.MPI.pickle.__init__(
cloudpickle.dumps,
cloudpickle.loads,
pickle.HIGHEST_PROTOCOL,
)
def run_across_mpi(executor, fun, num_ranks):
return list(executor.starmap(fun, [() for i in range(num_ranks)]))
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
# TODO(jthomson04): I don't have the slightest idea why this test is leaking threads.
@pytest.mark.threadleak(enabled=False)
def test_connector_manager_get_finished_allgather(mpi_pool_executor):
def test():
worker = MagicMock()
if mpi_rank() == 0:
scheduler = MagicMock()
scheduler.request_finished.return_value = True
else:
scheduler = None
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
req = MagicMock()
req.request_id = 42
manager.request_finished(req, [])
# To start, make both workers return nothing.
worker.get_finished.return_value = ([], [])
assert manager.get_finished() == []
assert worker.get_finished.call_count == 1
assert worker.get_finished.call_args[0] == ([42], [])
worker.get_finished.reset_mock()
# Now, only return the request id on one worker.
if mpi_rank() == 0:
worker.get_finished.return_value = ([42], [])
else:
worker.get_finished.return_value = ([], [])
# It should still return nothing, since rank 1 is still saving.
assert manager.get_finished() == []
assert worker.get_finished.call_count == 1
assert worker.get_finished.call_args[0] == ([], [])
# Now, also return it on worker 1.
if mpi_rank() == 0:
worker.get_finished.return_value = ([], [])
else:
worker.get_finished.return_value = ([42], [])
assert manager.get_finished() == [req]
run_across_mpi(mpi_pool_executor, test, 2)
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
def test_connector_manager_num_matched_tokens(mpi_pool_executor):
def test():
worker = MagicMock()
if mpi_rank() == 0:
scheduler = MagicMock()
scheduler.get_num_new_matched_tokens.return_value = (16, True)
else:
scheduler = None
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
req = MagicMock()
req.request_id = 42
assert manager.get_num_new_matched_tokens(req, 32) == 16
if mpi_rank() == 0:
assert scheduler.get_num_new_matched_tokens.call_count == 1
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req,
32)
run_across_mpi(mpi_pool_executor, test, 2)
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
def test_connector_manager_take_scheduled_requests(mpi_pool_executor):
def test():
worker = MagicMock()
if mpi_rank() == 0:
scheduler = MagicMock()
else:
scheduler = None
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
scheduled_requests = ScheduledRequests()
req0 = MagicMock()
req0.request_id = 0
req1 = MagicMock()
req1.request_id = 1
if mpi_rank() == 0:
scheduler.get_num_new_matched_tokens.return_value = (16, True)
assert manager.get_num_new_matched_tokens(req0, 0) == 16
if mpi_rank() == 0:
assert scheduler.get_num_new_matched_tokens.call_count == 1
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req0,
0)
scheduler.get_num_new_matched_tokens.reset_mock()
scheduler.get_num_new_matched_tokens.return_value = (32, False)
assert manager.get_num_new_matched_tokens(req1, 0) == 32
if mpi_rank() == 0:
assert scheduler.get_num_new_matched_tokens.call_count == 1
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req1,
0)
scheduled_requests.context_requests = [req0, req1]
manager.take_scheduled_requests_pending_load(scheduled_requests)
assert scheduled_requests.context_requests == [req1]
run_across_mpi(mpi_pool_executor, test, 2)