TensorRT-LLMs/tests/unittest/llmapi/_run_mpi_comm_task.py
Yan Chunwei ad4226d946
fix: trtllm-bench build trt engine on slurm (#3825)
* add submit_sync to RemoteMpiSessionClient

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

add barrier

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

fix comment

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

disable test

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

* fix

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>

---------

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
2025-04-27 22:26:23 +08:00

34 lines
964 B
Python

import os
import time
from typing import Literal
import click
from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient
from tensorrt_llm.llmapi.utils import print_colored
@click.command()
@click.option("--task_type",
type=click.Choice(["submit", "submit_sync"]),
default="submit")
def main(task_type: Literal["submit", "submit_sync"]):
tasks = [0]
assert os.environ[
'TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set"
client = RemoteMpiCommSessionClient(
os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'])
for task in tasks:
if task_type == "submit":
client.submit(print_colored, f"{task}\n", "green")
elif task_type == "submit_sync":
res = client.submit_sync(print_colored, f"{task}\n", "green")
print(res)
time.sleep(10)
client.shutdown()
if __name__ == "__main__":
main()