TensorRT-LLMs/tests/hlapi/mpi_test_task.py
Kaiyu Xie 9bd15f1937
TensorRT-LLM v0.10 update
* TensorRT-LLM Release 0.10.0

---------

Co-authored-by: Loki <lokravi@amazon.com>
Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
2024-06-05 20:43:25 +08:00

42 lines
1016 B
Python

#!/usr/bin/env python
from mpi4py.futures import MPICommExecutor
from tensorrt_llm._utils import mpi_comm, mpi_rank, mpi_world_size
from tensorrt_llm.hlapi.mpi_session import MpiCommSession, MPINodeState
class MpiTask:
def __init__(self):
self.executor = MpiCommSession(n_workers=4)
@staticmethod
def task():
if MPINodeState.state is None:
MPINodeState.state = 0
MPINodeState.state += 1
print(f"rank: {mpi_rank()}, state: {MPINodeState.state}")
return (mpi_rank(), MPINodeState.state)
def run(self):
results = self.executor.submit_sync(MpiTask.task)
results = sorted(results, key=lambda x: x[0])
assert results == [(i, 1) for i in range(mpi_world_size())], results
def main():
# The root node
if mpi_rank() == 0:
the_task = MpiTask()
the_task.run()
else: # The worker node
with MPICommExecutor(mpi_comm()) as executor:
pass
if __name__ == '__main__':
main()