# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pickle import sys from unittest.mock import MagicMock import cloudpickle import mpi4py import pytest from tensorrt_llm import mpi_rank from tensorrt_llm._torch.pyexecutor.kv_cache_connector import \ KvCacheConnectorManager from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests cloudpickle.register_pickle_by_value(sys.modules[__name__]) mpi4py.MPI.pickle.__init__( cloudpickle.dumps, cloudpickle.loads, pickle.HIGHEST_PROTOCOL, ) def run_across_mpi(executor, fun, num_ranks): return list(executor.starmap(fun, [() for i in range(num_ranks)])) @pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) # TODO(jthomson04): I don't have the slightest idea why this test is leaking threads. @pytest.mark.threadleak(enabled=False) def test_connector_manager_get_finished_allgather(mpi_pool_executor): def test(): worker = MagicMock() if mpi_rank() == 0: scheduler = MagicMock() scheduler.request_finished.return_value = True else: scheduler = None manager = KvCacheConnectorManager(worker, scheduler=scheduler) req = MagicMock() req.request_id = 42 manager.request_finished(req, []) # To start, make both workers return nothing. worker.get_finished.return_value = ([], []) assert manager.get_finished() == [] assert worker.get_finished.call_count == 1 assert worker.get_finished.call_args[0] == ([42], []) worker.get_finished.reset_mock() # Now, only return the request id on one worker. if mpi_rank() == 0: worker.get_finished.return_value = ([42], []) else: worker.get_finished.return_value = ([], []) # It should still return nothing, since rank 1 is still saving. assert manager.get_finished() == [] assert worker.get_finished.call_count == 1 assert worker.get_finished.call_args[0] == ([], []) # Now, also return it on worker 1. if mpi_rank() == 0: worker.get_finished.return_value = ([], []) else: worker.get_finished.return_value = ([42], []) assert manager.get_finished() == [req] run_across_mpi(mpi_pool_executor, test, 2) @pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) def test_connector_manager_num_matched_tokens(mpi_pool_executor): def test(): worker = MagicMock() if mpi_rank() == 0: scheduler = MagicMock() scheduler.get_num_new_matched_tokens.return_value = (16, True) else: scheduler = None manager = KvCacheConnectorManager(worker, scheduler=scheduler) req = MagicMock() req.request_id = 42 assert manager.get_num_new_matched_tokens(req, 32) == 16 if mpi_rank() == 0: assert scheduler.get_num_new_matched_tokens.call_count == 1 assert scheduler.get_num_new_matched_tokens.call_args[0] == (req, 32) run_across_mpi(mpi_pool_executor, test, 2) @pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True) def test_connector_manager_take_scheduled_requests(mpi_pool_executor): def test(): worker = MagicMock() if mpi_rank() == 0: scheduler = MagicMock() else: scheduler = None manager = KvCacheConnectorManager(worker, scheduler=scheduler) scheduled_requests = ScheduledRequests() req0 = MagicMock() req0.request_id = 0 req1 = MagicMock() req1.request_id = 1 if mpi_rank() == 0: scheduler.get_num_new_matched_tokens.return_value = (16, True) assert manager.get_num_new_matched_tokens(req0, 0) == 16 if mpi_rank() == 0: assert scheduler.get_num_new_matched_tokens.call_count == 1 assert scheduler.get_num_new_matched_tokens.call_args[0] == (req0, 0) scheduler.get_num_new_matched_tokens.reset_mock() scheduler.get_num_new_matched_tokens.return_value = (32, False) assert manager.get_num_new_matched_tokens(req1, 0) == 32 if mpi_rank() == 0: assert scheduler.get_num_new_matched_tokens.call_count == 1 assert scheduler.get_num_new_matched_tokens.call_args[0] == (req1, 0) scheduled_requests.context_requests = [req0, req1] manager.take_scheduled_requests_pending_load(scheduled_requests) assert scheduled_requests.context_requests == [req1] run_across_mpi(mpi_pool_executor, test, 2)