fix: only set _mpi_session if world_size is > 1 (#5253)

Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
This commit is contained in:
Aurelien Chartier 2025-06-17 19:21:41 -07:00 committed by GitHub
parent 627062c265
commit e1e5f725fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -43,7 +43,7 @@ from mpi4py.MPI import COMM_WORLD
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.llm import LLM as PyTorchLLM
from tensorrt_llm._utils import global_mpi_rank
from tensorrt_llm._utils import global_mpi_rank, global_mpi_size
from tensorrt_llm.llmapi import LLM
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict
@ -139,9 +139,10 @@ class TritonPythonModel:
f"[trtllm] rank{global_mpi_rank()} is starting trtllm engine with args: {self.llm_engine_args}"
)
mpi_session = MpiCommSession(comm=COMM_WORLD,
n_workers=COMM_WORLD.Get_size())
self.llm_engine_args["_mpi_session"] = mpi_session
if global_mpi_size() > 1:
mpi_session = MpiCommSession(comm=COMM_WORLD,
n_workers=COMM_WORLD.Get_size())
self.llm_engine_args["_mpi_session"] = mpi_session
# Starting the TRT-LLM engine with LLM API and its event thread running the AsyncIO event loop.
self._init_engine()