mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
fix: only set _mpi_session if world_size is > 1 (#5253)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
This commit is contained in:
parent
627062c265
commit
e1e5f725fc
@ -43,7 +43,7 @@ from mpi4py.MPI import COMM_WORLD
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch.llm import LLM as PyTorchLLM
|
||||
from tensorrt_llm._utils import global_mpi_rank
|
||||
from tensorrt_llm._utils import global_mpi_rank, global_mpi_size
|
||||
from tensorrt_llm.llmapi import LLM
|
||||
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict
|
||||
|
||||
@ -139,9 +139,10 @@ class TritonPythonModel:
|
||||
f"[trtllm] rank{global_mpi_rank()} is starting trtllm engine with args: {self.llm_engine_args}"
|
||||
)
|
||||
|
||||
mpi_session = MpiCommSession(comm=COMM_WORLD,
|
||||
n_workers=COMM_WORLD.Get_size())
|
||||
self.llm_engine_args["_mpi_session"] = mpi_session
|
||||
if global_mpi_size() > 1:
|
||||
mpi_session = MpiCommSession(comm=COMM_WORLD,
|
||||
n_workers=COMM_WORLD.Get_size())
|
||||
self.llm_engine_args["_mpi_session"] = mpi_session
|
||||
|
||||
# Starting the TRT-LLM engine with LLM API and its event thread running the AsyncIO event loop.
|
||||
self._init_engine()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user