TensorRT-LLMs/tests/hlapi/mpi_test_task.py
Kaiyu Xie bf0a5afc92
Update TensorRT-LLM (#1598)
* Update TensorRT-LLM
2024-05-14 16:43:41 +08:00

43 lines
1.0 KiB
Python

#!/usr/bin/env python
from mpi4py import MPI
from mpi4py.futures import MPICommExecutor
from tensorrt_llm._utils import mpi_rank, mpi_world_size
from tensorrt_llm.hlapi.mpi_session import MpiCommSession, MPINodeState
class MpiTask:
def __init__(self):
self.executor = MpiCommSession(n_workers=4)
@staticmethod
def task():
if MPINodeState.state is None:
MPINodeState.state = 0
MPINodeState.state += 1
print(f"rank: {mpi_rank()}, state: {MPINodeState.state}")
return (mpi_rank(), MPINodeState.state)
def run(self):
results = self.executor.submit_sync(MpiTask.task)
results = sorted(results, key=lambda x: x[0])
assert results == [(i, 1) for i in range(mpi_world_size())], results
def main():
# The root node
if mpi_rank() == 0:
the_task = MpiTask()
the_task.run()
else: # The worker node
with MPICommExecutor(MPI.COMM_WORLD) as executor:
pass
if __name__ == '__main__':
main()