TensorRT-LLMs/tests/hlapi/test_mpi_session.py
Kaiyu Xie 9bd15f1937
TensorRT-LLM v0.10 update
* TensorRT-LLM Release 0.10.0

---------

Co-authored-by: Loki <lokravi@amazon.com>
Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
2024-06-05 20:43:25 +08:00

37 lines
975 B
Python

import os
import subprocess # nosec B404
from tensorrt_llm.hlapi.mpi_session import MPINodeState
def task0():
if MPINodeState.state is None:
MPINodeState.state = 0
MPINodeState.state += 1
return MPINodeState.state
def test_mpi_session_basic():
from tensorrt_llm.hlapi.mpi_session import MpiPoolSession
n_workers = 4
executor = MpiPoolSession(n_workers)
results = executor.submit_sync(task0)
assert results == [1, 1, 1, 1], results
results = executor.submit_sync(task0)
assert results == [2, 2, 2, 2], results
def test_mpi_session_multi_node():
nworkers = 4
test_case_file = os.path.join(os.path.dirname(__file__), "mpi_test_task.py")
command = f"mpirun --allow-run-as-root -n {nworkers} python {test_case_file}"
subprocess.run(command, shell=True, check=True,
env=os.environ) # nosec B603
if __name__ == '__main__':
test_mpi_session_basic()
test_mpi_session_multi_node()