mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
Signed-off-by: jthomson04 <jwillthomson19@gmail.com> Signed-off-by: richardhuo-nv <rihuo@nvidia.com> Co-authored-by: jthomson04 <jwillthomson19@gmail.com> Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com> Co-authored-by: Sharan Chetlur <116769508+schetlur-nv@users.noreply.github.com>
171 lines
5.4 KiB
Python
171 lines
5.4 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import pickle
|
|
import sys
|
|
from unittest.mock import MagicMock
|
|
|
|
import cloudpickle
|
|
import mpi4py
|
|
import pytest
|
|
|
|
from tensorrt_llm import mpi_rank
|
|
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import \
|
|
KvCacheConnectorManager
|
|
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
|
|
|
|
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
|
mpi4py.MPI.pickle.__init__(
|
|
cloudpickle.dumps,
|
|
cloudpickle.loads,
|
|
pickle.HIGHEST_PROTOCOL,
|
|
)
|
|
|
|
|
|
def run_across_mpi(executor, fun, num_ranks):
|
|
return list(executor.starmap(fun, [() for i in range(num_ranks)]))
|
|
|
|
|
|
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
|
# TODO(jthomson04): I don't have the slightest idea why this test is leaking threads.
|
|
@pytest.mark.threadleak(enabled=False)
|
|
def test_connector_manager_get_finished_allgather(mpi_pool_executor):
|
|
|
|
def test():
|
|
worker = MagicMock()
|
|
|
|
if mpi_rank() == 0:
|
|
scheduler = MagicMock()
|
|
|
|
scheduler.request_finished.return_value = True
|
|
else:
|
|
scheduler = None
|
|
|
|
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
|
|
|
|
req = MagicMock()
|
|
|
|
req.request_id = 42
|
|
|
|
manager.request_finished(req, [])
|
|
|
|
# To start, make both workers return nothing.
|
|
worker.get_finished.return_value = ([], [])
|
|
|
|
assert manager.get_finished() == []
|
|
|
|
assert worker.get_finished.call_count == 1
|
|
assert worker.get_finished.call_args[0] == ([42], [])
|
|
|
|
worker.get_finished.reset_mock()
|
|
|
|
# Now, only return the request id on one worker.
|
|
if mpi_rank() == 0:
|
|
worker.get_finished.return_value = ([42], [])
|
|
else:
|
|
worker.get_finished.return_value = ([], [])
|
|
|
|
# It should still return nothing, since rank 1 is still saving.
|
|
assert manager.get_finished() == []
|
|
|
|
assert worker.get_finished.call_count == 1
|
|
assert worker.get_finished.call_args[0] == ([], [])
|
|
|
|
# Now, also return it on worker 1.
|
|
if mpi_rank() == 0:
|
|
worker.get_finished.return_value = ([], [])
|
|
else:
|
|
worker.get_finished.return_value = ([42], [])
|
|
|
|
assert manager.get_finished() == [req]
|
|
|
|
run_across_mpi(mpi_pool_executor, test, 2)
|
|
|
|
|
|
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
|
def test_connector_manager_num_matched_tokens(mpi_pool_executor):
|
|
|
|
def test():
|
|
worker = MagicMock()
|
|
|
|
if mpi_rank() == 0:
|
|
scheduler = MagicMock()
|
|
scheduler.get_num_new_matched_tokens.return_value = (16, True)
|
|
else:
|
|
scheduler = None
|
|
|
|
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
|
|
|
|
req = MagicMock()
|
|
|
|
req.request_id = 42
|
|
|
|
assert manager.get_num_new_matched_tokens(req, 32) == 16
|
|
|
|
if mpi_rank() == 0:
|
|
assert scheduler.get_num_new_matched_tokens.call_count == 1
|
|
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req,
|
|
32)
|
|
|
|
run_across_mpi(mpi_pool_executor, test, 2)
|
|
|
|
|
|
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
|
def test_connector_manager_take_scheduled_requests(mpi_pool_executor):
|
|
|
|
def test():
|
|
worker = MagicMock()
|
|
|
|
if mpi_rank() == 0:
|
|
scheduler = MagicMock()
|
|
else:
|
|
scheduler = None
|
|
|
|
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
|
|
|
|
scheduled_requests = ScheduledRequests()
|
|
|
|
req0 = MagicMock()
|
|
req0.request_id = 0
|
|
|
|
req1 = MagicMock()
|
|
req1.request_id = 1
|
|
|
|
if mpi_rank() == 0:
|
|
scheduler.get_num_new_matched_tokens.return_value = (16, True)
|
|
|
|
assert manager.get_num_new_matched_tokens(req0, 0) == 16
|
|
if mpi_rank() == 0:
|
|
assert scheduler.get_num_new_matched_tokens.call_count == 1
|
|
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req0,
|
|
0)
|
|
|
|
scheduler.get_num_new_matched_tokens.reset_mock()
|
|
scheduler.get_num_new_matched_tokens.return_value = (32, False)
|
|
|
|
assert manager.get_num_new_matched_tokens(req1, 0) == 32
|
|
if mpi_rank() == 0:
|
|
assert scheduler.get_num_new_matched_tokens.call_count == 1
|
|
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req1,
|
|
0)
|
|
|
|
scheduled_requests.context_requests = [req0, req1]
|
|
|
|
manager.take_scheduled_requests_pending_load(scheduled_requests)
|
|
|
|
assert scheduled_requests.context_requests == [req1]
|
|
|
|
run_across_mpi(mpi_pool_executor, test, 2)
|