[None][refactor] Unify the usage of MPIDist and TorchDist. (#10380)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2026-01-14 14:05:47 +08:00 committed by GitHub
parent f841b43cde
commit 39cefd6125
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 162 additions and 131 deletions

View File

@ -10,9 +10,8 @@ import torch
import yaml
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.distributed import MPIDist, TorchDist
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
from tensorrt_llm._utils import local_mpi_rank, mpi_disabled, mpi_rank, mpi_world_size
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
from tensorrt_llm.logger import logger
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, Runner, mark_ranges
@ -192,8 +191,7 @@ run_pack = runner.create_run_pack(
)
if args.enable_autotuner:
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
dist = TorchDist(mapping=mapping) if mpi_disabled() else MPIDist(mapping=mapping)
AutoTuner.get().setup_distributed_state(mapping, dist)
AutoTuner.get().setup_distributed_state(mapping)
with autotune(cache_path=cache_path):
run_pack()
else:

View File

@ -17,15 +17,12 @@ import struct
import sys
from typing import List, Tuple
from tensorrt_llm._utils import mpi_disabled
try:
from cuda.bindings import driver as cuda
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cuda, cudart
from ._utils import mpi_comm
from .logger import logger
from .mapping import Mapping
@ -107,15 +104,9 @@ class IpcMemory:
size += alignment - (size % alignment)
return size
if mpi_disabled():
from tensorrt_llm._utils import torch_comm
from tensorrt_llm._torch.distributed.communicator import Distributed
allgather = torch_comm().tp_allgather
else:
comm = mpi_comm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
allgather = comm.allgather
dist = Distributed.get(mapping)
# see allocateIpcMemory in cpp/tensorrt_llm/runtime/ipcUtils.cpp for alignment reason
# 1 << 21 is 2MB
@ -126,7 +117,7 @@ class IpcMemory:
_raise_if_error(cudart.cudaMemset(local_ptr, 0, aligned_size)[0])
error, local_handle = cudart.cudaIpcGetMemHandle(local_ptr)
_raise_if_error(error)
handles_reserved = allgather(local_handle.reserved)
handles_reserved = dist.tp_allgather(local_handle.reserved)
handles = []
for reserved in handles_reserved:

View File

@ -9,7 +9,7 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tensorrt_llm._utils import get_free_port as _get_free_port
from tensorrt_llm._utils import get_free_port
from ..utils.logger import ad_logger
@ -69,10 +69,6 @@ def all_gather_object(object_list, object, group=None):
return dist.all_gather_object(object_list, object, group=group)
def get_free_port():
return _get_free_port()
def get_world_size() -> int:
return dist.get_world_size()

View File

@ -47,10 +47,10 @@ from tensorrt_llm.llmapi.llm_args import (
)
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
from ...._utils import mpi_rank, mpi_world_size
from ...._utils import get_free_port, mpi_rank, mpi_world_size
from ....bindings.internal.batch_manager import CacheType
from ....mapping import Mapping
from ...distributed import MPIDist
from ...distributed import Distributed
from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine
from ...pyexecutor.py_executor import PyExecutor
from ...pyexecutor.resource_manager import (
@ -68,7 +68,7 @@ from ...pyexecutor.scheduler import (
SimpleScheduler,
)
from ..custom_ops.attention_interface import SequenceInfo
from ..distributed import common as dist
from ..distributed.common import initialize_or_skip
from ..llm_args import LlmArgs
from ..transform.optimizer import InferenceOptimizer
from ..utils.logger import ad_logger
@ -880,7 +880,7 @@ def share_target_weights_with_draft(
def create_draft_model_engine_maybe(
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, mpi_dist: MPIDist
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, dist: Distributed
) -> Optional[PyTorchModelEngine]:
"""Create a draft model engine for speculative decoding.
@ -888,7 +888,7 @@ def create_draft_model_engine_maybe(
ad_config: The AutoDeploy LLM configuration
engine: The target model engine (ADEngine)
dist_mapping: The distributed mapping configuration
mpi_dist: The MPI distribution object
dist: The distribution object
Returns:
PyTorchModelEngine configured as a draft model, or None if not needed
@ -925,7 +925,7 @@ def create_draft_model_engine_maybe(
llm_args=draft_llm_args,
mapping=dist_mapping,
attn_runtime_features=attn_runtime_features,
dist=mpi_dist,
dist=dist,
spec_config=draft_spec_config,
is_draft_model=True,
drafting_loop_wrapper=drafting_loop_wrapper,
@ -1004,14 +1004,14 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
world_size = mpi_world_size()
rank = mpi_rank()
dist_mapping = Mapping(rank=rank, world_size=world_size, tp_size=world_size)
mpi_dist = MPIDist(dist_mapping)
dist = Distributed.get(dist_mapping)
ad_logger.set_rank(rank)
torch.cuda.set_device(rank)
port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port
dist.initialize_or_skip(rank, world_size, port)
port = dist.broadcast(get_free_port()) # use MPI broadcast to pick a free port
initialize_or_skip(rank, world_size, port)
# Setup AutoTuner with distributed state for allreduce autotuning
AutoTuner.get().setup_distributed_state(dist_mapping, mpi_dist)
AutoTuner.get().setup_distributed_state(dist_mapping)
# some config
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
@ -1044,7 +1044,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
)
draft_model_engine = create_draft_model_engine_maybe(
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, dist=dist
)
spec_resource_manager = (
@ -1171,7 +1171,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
scheduler,
model_engine=engine,
sampler=sampler,
dist=mpi_dist,
dist=dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
max_input_len=ad_config.max_input_len,

View File

@ -1072,9 +1072,7 @@ class AutoTuner:
stream.synchronize()
if tuning_config.distributed_tuning_strategy == DistributedTuningStrategy.MERGE:
# Currently only AllReduce will use this strategy, and only MPI parallel will enable tuning.
# TODO: Unified tp barrier for both MPIDist and TorchDist.
if hasattr(self._dist, "tp_comm"):
self._dist.tp_comm.barrier()
self._dist.tp_barrier()
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
if use_cuda_graph:
@ -1495,10 +1493,14 @@ class AutoTuner:
else:
raise RuntimeError("Unknown error type: {}".format(error))
def setup_distributed_state(self, mapping: Mapping, dist: Distributed):
def setup_distributed_state(self,
mapping: Mapping,
dist: Optional[Distributed] = ...):
"""Setup distributed communication state for autotuning."""
self.mapping = mapping
self._dist = dist
# Create dist only when dist is not provided.
# Use the provided dist even if it is None. This is useful for testing.
self._dist = Distributed.get(mapping) if dist is ... else dist
self._debug_logger(
f"[AutoTuner] Whether using distributed tuning: {self._is_distributed()}"
)

View File

@ -1,8 +1,8 @@
import copy
import math
import pickle # nosec B403
from abc import ABC, abstractmethod
from functools import wraps
from enum import IntEnum
from functools import lru_cache, wraps
from typing import List, Optional
import numpy as np
@ -32,11 +32,59 @@ except ModuleNotFoundError:
from tensorrt_llm import ray_stub as ray
class ReduceOp(IntEnum):
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
BAND = 4
BOR = 5
BXOR = 6
_reduce_op_to_torch_dict = {
ReduceOp.SUM: torch.distributed.ReduceOp.SUM,
ReduceOp.PRODUCT: torch.distributed.ReduceOp.PRODUCT,
ReduceOp.MIN: torch.distributed.ReduceOp.MIN,
ReduceOp.MAX: torch.distributed.ReduceOp.MAX,
ReduceOp.BAND: torch.distributed.ReduceOp.BAND,
ReduceOp.BOR: torch.distributed.ReduceOp.BOR,
ReduceOp.BXOR: torch.distributed.ReduceOp.BXOR,
}
def reduce_op_to_torch(op: ReduceOp) -> torch.distributed.ReduceOp:
return _reduce_op_to_torch_dict[op]
_reduce_op_to_mpi_dict = {
ReduceOp.SUM: MPI.SUM,
ReduceOp.PRODUCT: MPI.PROD,
ReduceOp.MIN: MPI.MIN,
ReduceOp.MAX: MPI.MAX,
ReduceOp.BAND: MPI.BAND,
ReduceOp.BOR: MPI.BOR,
ReduceOp.BXOR: MPI.BXOR,
}
def reduce_op_to_mpi(op: ReduceOp) -> MPI.Op:
return _reduce_op_to_mpi_dict[op]
class Distributed(ABC):
def __init__(self, mapping: Mapping):
self.mapping = mapping
@staticmethod
@lru_cache(maxsize=None)
def get(mapping: Mapping) -> "Distributed":
if mpi_disabled():
return TorchDist(mapping)
else:
return MPIDist(mapping)
@property
def rank(self):
return self.mapping.rank
@ -109,6 +157,14 @@ class Distributed(ABC):
def cp_config(self):
return self.mapping.cp_config
@abstractmethod
def barrier(self):
pass
@abstractmethod
def tp_barrier(self):
pass
@abstractmethod
def broadcast(self, obj, root=0):
pass
@ -117,6 +173,10 @@ class Distributed(ABC):
def allgather(self, obj, root=0):
pass
@abstractmethod
def allreduce(self, obj, op: ReduceOp = ReduceOp.SUM):
pass
@abstractmethod
def tp_broadcast(self, obj, root=0, **kwargs):
pass
@ -363,24 +423,9 @@ class MPIDist(Distributed):
def __init__(self, mapping: Mapping):
super().__init__(mapping)
self.create_cp_comm()
# Repurpose CP ranks to TP for Helix so that the right comms are created.
mapping_with_cp = None
if self.mapping.has_cp_helix():
logger.info(
f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
mapping_with_cp = copy.deepcopy(self.mapping)
self.mapping = self.mapping.repurpose_helix_cp_to_tp()
self.create_tp_comm()
self.create_pp_comm()
# Restore the original mapping.
if mapping_with_cp is not None:
logger.info(
f"[MPIDist::__init__] Restoring original mapping undoing Helix manipulation."
)
self.mapping = mapping_with_cp
self._cp_comm = None
self._tp_comm = None
self._pp_comm = None
def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
comm = mpi_comm()
@ -392,6 +437,9 @@ class MPIDist(Distributed):
def barrier(self):
mpi_barrier()
def tp_barrier(self):
self.tp_comm.Barrier()
def isend(self, buf: np.ndarray, dest, tag=0):
# non-blocking send numpy buffer
return mpi_isend(buf, dest, tag)
@ -413,17 +461,32 @@ class MPIDist(Distributed):
def recv_object(self, src, tag=0):
return mpi_recv_object(src, tag)
def create_tp_comm(self):
new_group = mpi_comm().group.Incl(self.mapping.tp_group)
self.tp_comm = mpi_comm().Create_group(new_group)
@property
def tp_comm(self):
if self._tp_comm is None:
mapping = self.mapping
if mapping.has_cp_helix():
mapping = mapping.repurpose_helix_cp_to_tp()
new_group = mpi_comm().group.Incl(mapping.tp_group)
self._tp_comm = mpi_comm().Create_group(new_group)
return self._tp_comm
def create_pp_comm(self):
new_group = mpi_comm().group.Incl(self.mapping.pp_group)
self.pp_comm = mpi_comm().Create_group(new_group)
@property
def pp_comm(self):
if self._pp_comm is None:
mapping = self.mapping
if mapping.has_cp_helix():
mapping = mapping.repurpose_helix_cp_to_tp()
new_group = mpi_comm().group.Incl(mapping.pp_group)
self._pp_comm = mpi_comm().Create_group(new_group)
return self._pp_comm
def create_cp_comm(self):
new_group = mpi_comm().group.Incl(self.mapping.cp_group)
self.cp_comm = mpi_comm().Create_group(new_group)
@property
def cp_comm(self):
if self._cp_comm is None:
new_group = mpi_comm().group.Incl(self.mapping.cp_group)
self._cp_comm = mpi_comm().Create_group(new_group)
return self._cp_comm
def cp_allgather(self, obj):
return self.cp_comm.allgather(obj)
@ -460,6 +523,10 @@ class MPIDist(Distributed):
def pp_broadcast(self, obj, root=0):
return self.pp_comm.bcast(obj, root)
def allreduce(self, obj, op: ReduceOp = ReduceOp.SUM):
reduce_op = reduce_op_to_mpi(op)
return mpi_comm().allreduce(obj, reduce_op)
class MultiHandleWrapper:
"""
@ -610,6 +677,10 @@ class TorchDist(Distributed):
def barrier(self):
dist.barrier()
@log_op
def tp_barrier(self):
dist.barrier(group=self.mapping.tp_group_pg)
@log_op
def isend(self, buf: np.ndarray, dest, tag=0):
# non-blocking send numpy buffer
@ -673,14 +744,16 @@ class TorchDist(Distributed):
return MultiHandleWrapper(works)
@log_op
def allreduce(self,
obj: int | float | torch.Tensor,
op=torch.distributed.ReduceOp.SUM):
def allreduce(
self,
obj: int | float | torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
):
is_base_type = isinstance(obj, int) or isinstance(obj, float)
if is_base_type:
obj = torch.tensor(obj)
dist.all_reduce(obj, op=op)
dist.all_reduce(obj, op=reduce_op_to_torch(op))
if is_base_type:
obj = obj.item()

View File

@ -10,7 +10,7 @@ from tensorrt_llm.llmapi.llm_args import (BaseSparseAttentionConfig,
from tensorrt_llm.mapping import Mapping
from ...inputs.multimodal import MultimodalParams
from ..distributed import MPIDist
from ..distributed import Distributed
from ..expert_statistic import ExpertStatistic
from ..memory_buffer_utils import get_memory_buffers
from ..modules.multi_stream_utils import with_multi_stream
@ -75,7 +75,7 @@ class CUDAGraphRunnerConfig:
enable_attention_dp: bool
batch_size: int
mapping: Optional[Mapping]
dist: Optional[MPIDist]
dist: Optional[Distributed]
kv_cache_manager_key: Any
sparse_attention_config: Optional[BaseSparseAttentionConfig] = None

View File

@ -35,7 +35,7 @@ from ..attention_backend.vanilla import VanillaAttentionMetadata
from ..autotuner import AutoTuner, autotune
from ..compilation.backend import Backend
from ..compilation.utils import capture_piecewise_cuda_graph
from ..distributed import MPIDist
from ..distributed import Distributed
from ..distributed.communicator import init_pp_comm
from ..expert_statistic import ExpertStatistic
from ..memory_buffer_utils import with_shared_pool
@ -134,7 +134,7 @@ class PyTorchModelEngine(ModelEngine):
llm_args: TorchLlmArgs,
mapping: Optional[Mapping] = None,
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
dist: Optional[MPIDist] = None,
dist: Optional[Distributed] = None,
spec_config: Optional["DecodingBaseConfig"] = None,
is_draft_model: bool = False,
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],

View File

@ -13,7 +13,7 @@ from strenum import StrEnum
import tensorrt_llm
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._utils import get_sm_version, mpi_disabled
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy,
ContextChunkingPolicy,
GuidedDecodingConfig, LoadFormat,
@ -27,7 +27,7 @@ from tensorrt_llm.quantization import QuantAlgo
from ..attention_backend.interface import AttentionRuntimeFeatures
from ..attention_backend.trtllm import TrtllmAttention
from ..distributed import MPIDist, TorchDist
from ..distributed import Distributed
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
get_spec_resource_manager)
from ..virtual_memory import ExecutorMemoryType, RestoreMode
@ -303,10 +303,7 @@ def create_py_executor(
"when only processing vision encoder inputs.")
mapping = _get_mapping(llm_args.parallel_config.to_mapping())
if mpi_disabled():
dist = TorchDist(mapping=mapping)
else:
dist = MPIDist(mapping=mapping)
dist = Distributed.get(mapping)
vm_pools = {}
enable_sleep = llm_args.enable_sleep

View File

@ -10,8 +10,7 @@ import torch
import tensorrt_llm
import tensorrt_llm.bindings
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp
from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig,
PybindMirror)
from tensorrt_llm.lora_helper import LoraConfig
@ -28,11 +27,6 @@ from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
get_draft_token_length)
from .scheduler import ScheduledRequests
if ENABLE_MULTI_DEVICE:
from mpi4py import MPI
from tensorrt_llm._utils import mpi_comm
BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
@ -803,12 +797,11 @@ class KVCacheManager(BaseResourceManager):
if mapping.world_size > 1:
# make sure all ranks use same value for maxTokens
if mpi_disabled():
from tensorrt_llm._utils import torch_comm
max_tokens = torch_comm().allreduce(
max_tokens, op=torch.distributed.ReduceOp.MIN)
else:
max_tokens = mpi_comm().allreduce(max_tokens, op=MPI.MIN)
dist = Distributed.get(mapping)
max_tokens = dist.allreduce(
max_tokens,
op=ReduceOp.MIN,
)
# get number of blocks
blocks_in_primary_pool = int(max_tokens // tokens_per_block)

View File

@ -29,10 +29,10 @@ import tensorrt_llm as tllm
from tensorrt_llm import Mapping
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
MPIDist, TorchDist)
Distributed)
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._utils import (get_sm_version, local_mpi_rank, local_mpi_size,
mpi_disabled, nvtx_range)
nvtx_range)
from tensorrt_llm.bindings.internal.runtime import delay_kernel
from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy
from tensorrt_llm.logger import logger
@ -41,7 +41,7 @@ from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
def profile_allreduce(
mapping: Mapping,
dist: TorchDist | MPIDist,
dist: Distributed,
enable_cudagraph: bool = False,
inner_loop=200,
outer_loop=10,
@ -137,14 +137,10 @@ def allreduce_benchmark(
cudart.cudaSetDevice(local_rank)
mapping = Mapping(world_size, rank, gpus_per_node, tp_size=world_size)
if mpi_disabled():
dist = TorchDist(mapping=mapping)
else:
dist = MPIDist(mapping=mapping)
logger.set_rank(mapping.rank)
AutoTuner.get().setup_distributed_state(mapping, dist)
AutoTuner.get().setup_distributed_state(mapping)
sm_version = get_sm_version()

View File

@ -6,10 +6,11 @@ from _dist_test_utils import get_device_counts
from torch.export import export
from tensorrt_llm._torch.auto_deploy.custom_ops.trtllm_dist import is_trtllm_op_available
from tensorrt_llm._torch.auto_deploy.distributed import common as dist
from tensorrt_llm._torch.auto_deploy.distributed.common import initialize_or_skip
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
from tensorrt_llm._utils import get_free_port
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
# needed since MPI executor pool leaks a thread (_manager_spawn) on shutdown
@ -66,7 +67,7 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str):
if not is_trtllm_op_available():
pytest.skip("Require trtllm ops to run test_allreduce_fusion.")
_, _ = dist.initialize_or_skip(port=port)
_, _ = initialize_or_skip(port=port)
# Testing tensors
dtype = torch.float16
@ -146,7 +147,7 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str):
def test_allreduce_fusion(device_count, ModuleCls, strategy):
if device_count <= 1:
pytest.skip("Require multi GPUs to run test_allreduce_fusion.")
port = dist.get_free_port()
port = get_free_port()
n_workers = device_count
mpi_pool = MpiPoolSession(n_workers=n_workers)

View File

@ -17,10 +17,9 @@ from tensorrt_llm._torch.autotuner import (AutoTuner, DistributedTuningStrategy,
FakeTensor, OptimizationProfile,
StaticDim, TunableRunner,
TuningConfig, autotune)
from tensorrt_llm._torch.distributed.communicator import MPIDist, TorchDist
from tensorrt_llm._torch.distributed import Distributed
from tensorrt_llm._torch.utils import (get_power_of_2_num_tokens_buckets,
next_positive_power_of_2)
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.bindings.internal.runtime import delay_kernel
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
@ -720,14 +719,11 @@ def _distributed_worker_function(world_size, strategy):
rank=rank,
tp_size=world_size,
pp_size=1)
if mpi_disabled():
dist = TorchDist(mapping=mapping)
else:
dist = MPIDist(mapping=mapping)
dist = Distributed.get(mapping)
tuner = AutoTuner.get()
tuner.clear_cache()
tuner.setup_distributed_state(mapping, dist)
tuner.setup_distributed_state(mapping)
x = torch.randn(16, 32, device='cuda')
w = torch.randn(32, 64, device='cuda')

View File

@ -24,7 +24,6 @@ from utils.util import (check_accuracy, skip_blackwell, skip_blackwell_geforce,
skip_pre_hopper)
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.distributed import MPIDist, TorchDist
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \
CuteDslFusedMoE
@ -45,7 +44,7 @@ from tensorrt_llm._torch.modules.fused_moe.quantization import \
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm._utils import get_sm_version, mpi_disabled, mpi_rank
from tensorrt_llm._utils import get_sm_version, mpi_rank
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
@ -105,12 +104,7 @@ def test_fused_moe(moe_backend,
mapping = mapping or Mapping()
mapping.rank = mpi_rank()
if mpi_disabled():
dist = TorchDist(mapping=mapping)
else:
dist = MPIDist(mapping=mapping)
AutoTuner.get().setup_distributed_state(mapping, dist)
AutoTuner.get().setup_distributed_state(mapping)
torch.cuda.set_device(mapping.rank)

View File

@ -27,11 +27,9 @@ import tensorrt_llm
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, AllReduceStrategy,
MoEAllReduce, MoEAllReduceParams,
MPIDist, TorchDist)
MoEAllReduce, MoEAllReduceParams)
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.mapping import Mapping
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
@ -133,12 +131,8 @@ def run_allreduce_op(
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
)
if mpi_disabled():
dist = TorchDist(mapping=mapping)
else:
dist = MPIDist(mapping=mapping)
AutoTuner.get().setup_distributed_state(mapping, dist)
AutoTuner.get().setup_distributed_state(mapping)
linear = Linear(
in_features=hidden_size,
out_features=hidden_size,

View File

@ -6,7 +6,7 @@ import torch
import tensorrt_llm
import tensorrt_llm.bindings
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._torch.distributed import MPIDist
from tensorrt_llm._torch.distributed import Distributed
from tensorrt_llm._torch.pyexecutor.kv_cache_transceiver import \
create_kv_cache_transceiver
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
@ -79,7 +79,7 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
cache_transceiver_config = CacheTransceiverConfig(backend=backend,
max_tokens_in_buffer=512)
dist = MPIDist(mapping=mapping)
dist = Distributed.get(mapping)
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_ctx, attention_type,
cache_transceiver_config)
@ -139,7 +139,7 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
def test_cancel_request_in_transmission(attention_type):
# Init kv_cache manager and cache transceiver
mapping = Mapping(world_size=1, rank=0)
dist = MPIDist(mapping=mapping)
dist = Distributed.get(mapping)
ctx_kv_cache_dtype, gen_kv_cache_dtype = DataType.HALF, DataType.HALF
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)