mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
import os
|
|
import subprocess # nosec B404
|
|
|
|
import pytest
|
|
|
|
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
|
from tensorrt_llm.llmapi.mpi_session import MPINodeState
|
|
|
|
|
|
def task0():
|
|
if MPINodeState.state is None:
|
|
MPINodeState.state = 0
|
|
MPINodeState.state += 1
|
|
return MPINodeState.state
|
|
|
|
|
|
@pytest.mark.skipif(not ENABLE_MULTI_DEVICE, reason="multi-device required")
|
|
def test_mpi_session_basic():
|
|
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
|
|
|
|
n_workers = 4
|
|
executor = MpiPoolSession(n_workers)
|
|
results = executor.submit_sync(task0)
|
|
assert results == [1, 1, 1, 1], results
|
|
|
|
results = executor.submit_sync(task0)
|
|
assert results == [2, 2, 2, 2], results
|
|
|
|
|
|
@pytest.mark.skipif(not ENABLE_MULTI_DEVICE, reason="multi-device required")
|
|
def test_mpi_session_multi_node():
|
|
nworkers = 4
|
|
test_case_file = os.path.join(os.path.dirname(__file__), "mpi_test_task.py")
|
|
command = f"mpirun --allow-run-as-root -n {nworkers} python {test_case_file}"
|
|
subprocess.run(command, shell=True, check=True,
|
|
env=os.environ) # nosec B603
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_mpi_session_basic()
|
|
test_mpi_session_multi_node()
|