mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
[None][refactor] Unify the usage of MPIDist and TorchDist. (#10380)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
f841b43cde
commit
39cefd6125
@ -10,9 +10,8 @@ import torch
|
||||
import yaml
|
||||
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.distributed import MPIDist, TorchDist
|
||||
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
|
||||
from tensorrt_llm._utils import local_mpi_rank, mpi_disabled, mpi_rank, mpi_world_size
|
||||
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, Runner, mark_ranges
|
||||
|
||||
@ -192,8 +191,7 @@ run_pack = runner.create_run_pack(
|
||||
)
|
||||
if args.enable_autotuner:
|
||||
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
|
||||
dist = TorchDist(mapping=mapping) if mpi_disabled() else MPIDist(mapping=mapping)
|
||||
AutoTuner.get().setup_distributed_state(mapping, dist)
|
||||
AutoTuner.get().setup_distributed_state(mapping)
|
||||
with autotune(cache_path=cache_path):
|
||||
run_pack()
|
||||
else:
|
||||
|
||||
@ -17,15 +17,12 @@ import struct
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
|
||||
try:
|
||||
from cuda.bindings import driver as cuda
|
||||
from cuda.bindings import runtime as cudart
|
||||
except ImportError:
|
||||
from cuda import cuda, cudart
|
||||
|
||||
from ._utils import mpi_comm
|
||||
from .logger import logger
|
||||
from .mapping import Mapping
|
||||
|
||||
@ -107,15 +104,9 @@ class IpcMemory:
|
||||
size += alignment - (size % alignment)
|
||||
return size
|
||||
|
||||
if mpi_disabled():
|
||||
from tensorrt_llm._utils import torch_comm
|
||||
from tensorrt_llm._torch.distributed.communicator import Distributed
|
||||
|
||||
allgather = torch_comm().tp_allgather
|
||||
else:
|
||||
comm = mpi_comm().Split(
|
||||
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
|
||||
)
|
||||
allgather = comm.allgather
|
||||
dist = Distributed.get(mapping)
|
||||
|
||||
# see allocateIpcMemory in cpp/tensorrt_llm/runtime/ipcUtils.cpp for alignment reason
|
||||
# 1 << 21 is 2MB
|
||||
@ -126,7 +117,7 @@ class IpcMemory:
|
||||
_raise_if_error(cudart.cudaMemset(local_ptr, 0, aligned_size)[0])
|
||||
error, local_handle = cudart.cudaIpcGetMemHandle(local_ptr)
|
||||
_raise_if_error(error)
|
||||
handles_reserved = allgather(local_handle.reserved)
|
||||
handles_reserved = dist.tp_allgather(local_handle.reserved)
|
||||
|
||||
handles = []
|
||||
for reserved in handles_reserved:
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from tensorrt_llm._utils import get_free_port as _get_free_port
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
|
||||
from ..utils.logger import ad_logger
|
||||
|
||||
@ -69,10 +69,6 @@ def all_gather_object(object_list, object, group=None):
|
||||
return dist.all_gather_object(object_list, object, group=group)
|
||||
|
||||
|
||||
def get_free_port():
|
||||
return _get_free_port()
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
@ -47,10 +47,10 @@ from tensorrt_llm.llmapi.llm_args import (
|
||||
)
|
||||
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
|
||||
|
||||
from ...._utils import mpi_rank, mpi_world_size
|
||||
from ...._utils import get_free_port, mpi_rank, mpi_world_size
|
||||
from ....bindings.internal.batch_manager import CacheType
|
||||
from ....mapping import Mapping
|
||||
from ...distributed import MPIDist
|
||||
from ...distributed import Distributed
|
||||
from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine
|
||||
from ...pyexecutor.py_executor import PyExecutor
|
||||
from ...pyexecutor.resource_manager import (
|
||||
@ -68,7 +68,7 @@ from ...pyexecutor.scheduler import (
|
||||
SimpleScheduler,
|
||||
)
|
||||
from ..custom_ops.attention_interface import SequenceInfo
|
||||
from ..distributed import common as dist
|
||||
from ..distributed.common import initialize_or_skip
|
||||
from ..llm_args import LlmArgs
|
||||
from ..transform.optimizer import InferenceOptimizer
|
||||
from ..utils.logger import ad_logger
|
||||
@ -880,7 +880,7 @@ def share_target_weights_with_draft(
|
||||
|
||||
|
||||
def create_draft_model_engine_maybe(
|
||||
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, mpi_dist: MPIDist
|
||||
ad_config: LlmArgs, target_engine: ADEngine, dist_mapping: Mapping, dist: Distributed
|
||||
) -> Optional[PyTorchModelEngine]:
|
||||
"""Create a draft model engine for speculative decoding.
|
||||
|
||||
@ -888,7 +888,7 @@ def create_draft_model_engine_maybe(
|
||||
ad_config: The AutoDeploy LLM configuration
|
||||
engine: The target model engine (ADEngine)
|
||||
dist_mapping: The distributed mapping configuration
|
||||
mpi_dist: The MPI distribution object
|
||||
dist: The distribution object
|
||||
|
||||
Returns:
|
||||
PyTorchModelEngine configured as a draft model, or None if not needed
|
||||
@ -925,7 +925,7 @@ def create_draft_model_engine_maybe(
|
||||
llm_args=draft_llm_args,
|
||||
mapping=dist_mapping,
|
||||
attn_runtime_features=attn_runtime_features,
|
||||
dist=mpi_dist,
|
||||
dist=dist,
|
||||
spec_config=draft_spec_config,
|
||||
is_draft_model=True,
|
||||
drafting_loop_wrapper=drafting_loop_wrapper,
|
||||
@ -1004,14 +1004,14 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
world_size = mpi_world_size()
|
||||
rank = mpi_rank()
|
||||
dist_mapping = Mapping(rank=rank, world_size=world_size, tp_size=world_size)
|
||||
mpi_dist = MPIDist(dist_mapping)
|
||||
dist = Distributed.get(dist_mapping)
|
||||
ad_logger.set_rank(rank)
|
||||
torch.cuda.set_device(rank)
|
||||
port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port
|
||||
dist.initialize_or_skip(rank, world_size, port)
|
||||
port = dist.broadcast(get_free_port()) # use MPI broadcast to pick a free port
|
||||
initialize_or_skip(rank, world_size, port)
|
||||
|
||||
# Setup AutoTuner with distributed state for allreduce autotuning
|
||||
AutoTuner.get().setup_distributed_state(dist_mapping, mpi_dist)
|
||||
AutoTuner.get().setup_distributed_state(dist_mapping)
|
||||
|
||||
# some config
|
||||
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
|
||||
@ -1044,7 +1044,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
)
|
||||
|
||||
draft_model_engine = create_draft_model_engine_maybe(
|
||||
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, mpi_dist=mpi_dist
|
||||
ad_config=ad_config, target_engine=engine, dist_mapping=dist_mapping, dist=dist
|
||||
)
|
||||
|
||||
spec_resource_manager = (
|
||||
@ -1171,7 +1171,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
|
||||
scheduler,
|
||||
model_engine=engine,
|
||||
sampler=sampler,
|
||||
dist=mpi_dist,
|
||||
dist=dist,
|
||||
max_num_sequences=max_num_sequences,
|
||||
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
|
||||
max_input_len=ad_config.max_input_len,
|
||||
|
||||
@ -1072,9 +1072,7 @@ class AutoTuner:
|
||||
stream.synchronize()
|
||||
if tuning_config.distributed_tuning_strategy == DistributedTuningStrategy.MERGE:
|
||||
# Currently only AllReduce will use this strategy, and only MPI parallel will enable tuning.
|
||||
# TODO: Unified tp barrier for both MPIDist and TorchDist.
|
||||
if hasattr(self._dist, "tp_comm"):
|
||||
self._dist.tp_comm.barrier()
|
||||
self._dist.tp_barrier()
|
||||
|
||||
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
|
||||
if use_cuda_graph:
|
||||
@ -1495,10 +1493,14 @@ class AutoTuner:
|
||||
else:
|
||||
raise RuntimeError("Unknown error type: {}".format(error))
|
||||
|
||||
def setup_distributed_state(self, mapping: Mapping, dist: Distributed):
|
||||
def setup_distributed_state(self,
|
||||
mapping: Mapping,
|
||||
dist: Optional[Distributed] = ...):
|
||||
"""Setup distributed communication state for autotuning."""
|
||||
self.mapping = mapping
|
||||
self._dist = dist
|
||||
# Create dist only when dist is not provided.
|
||||
# Use the provided dist even if it is None. This is useful for testing.
|
||||
self._dist = Distributed.get(mapping) if dist is ... else dist
|
||||
self._debug_logger(
|
||||
f"[AutoTuner] Whether using distributed tuning: {self._is_distributed()}"
|
||||
)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import copy
|
||||
import math
|
||||
import pickle # nosec B403
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import wraps
|
||||
from enum import IntEnum
|
||||
from functools import lru_cache, wraps
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
@ -32,11 +32,59 @@ except ModuleNotFoundError:
|
||||
from tensorrt_llm import ray_stub as ray
|
||||
|
||||
|
||||
class ReduceOp(IntEnum):
|
||||
SUM = 0
|
||||
PRODUCT = 1
|
||||
MIN = 2
|
||||
MAX = 3
|
||||
BAND = 4
|
||||
BOR = 5
|
||||
BXOR = 6
|
||||
|
||||
|
||||
_reduce_op_to_torch_dict = {
|
||||
ReduceOp.SUM: torch.distributed.ReduceOp.SUM,
|
||||
ReduceOp.PRODUCT: torch.distributed.ReduceOp.PRODUCT,
|
||||
ReduceOp.MIN: torch.distributed.ReduceOp.MIN,
|
||||
ReduceOp.MAX: torch.distributed.ReduceOp.MAX,
|
||||
ReduceOp.BAND: torch.distributed.ReduceOp.BAND,
|
||||
ReduceOp.BOR: torch.distributed.ReduceOp.BOR,
|
||||
ReduceOp.BXOR: torch.distributed.ReduceOp.BXOR,
|
||||
}
|
||||
|
||||
|
||||
def reduce_op_to_torch(op: ReduceOp) -> torch.distributed.ReduceOp:
|
||||
return _reduce_op_to_torch_dict[op]
|
||||
|
||||
|
||||
_reduce_op_to_mpi_dict = {
|
||||
ReduceOp.SUM: MPI.SUM,
|
||||
ReduceOp.PRODUCT: MPI.PROD,
|
||||
ReduceOp.MIN: MPI.MIN,
|
||||
ReduceOp.MAX: MPI.MAX,
|
||||
ReduceOp.BAND: MPI.BAND,
|
||||
ReduceOp.BOR: MPI.BOR,
|
||||
ReduceOp.BXOR: MPI.BXOR,
|
||||
}
|
||||
|
||||
|
||||
def reduce_op_to_mpi(op: ReduceOp) -> MPI.Op:
|
||||
return _reduce_op_to_mpi_dict[op]
|
||||
|
||||
|
||||
class Distributed(ABC):
|
||||
|
||||
def __init__(self, mapping: Mapping):
|
||||
self.mapping = mapping
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def get(mapping: Mapping) -> "Distributed":
|
||||
if mpi_disabled():
|
||||
return TorchDist(mapping)
|
||||
else:
|
||||
return MPIDist(mapping)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self.mapping.rank
|
||||
@ -109,6 +157,14 @@ class Distributed(ABC):
|
||||
def cp_config(self):
|
||||
return self.mapping.cp_config
|
||||
|
||||
@abstractmethod
|
||||
def barrier(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tp_barrier(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def broadcast(self, obj, root=0):
|
||||
pass
|
||||
@ -117,6 +173,10 @@ class Distributed(ABC):
|
||||
def allgather(self, obj, root=0):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allreduce(self, obj, op: ReduceOp = ReduceOp.SUM):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tp_broadcast(self, obj, root=0, **kwargs):
|
||||
pass
|
||||
@ -363,24 +423,9 @@ class MPIDist(Distributed):
|
||||
|
||||
def __init__(self, mapping: Mapping):
|
||||
super().__init__(mapping)
|
||||
self.create_cp_comm()
|
||||
# Repurpose CP ranks to TP for Helix so that the right comms are created.
|
||||
mapping_with_cp = None
|
||||
if self.mapping.has_cp_helix():
|
||||
logger.info(
|
||||
f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
|
||||
mapping_with_cp = copy.deepcopy(self.mapping)
|
||||
self.mapping = self.mapping.repurpose_helix_cp_to_tp()
|
||||
|
||||
self.create_tp_comm()
|
||||
self.create_pp_comm()
|
||||
|
||||
# Restore the original mapping.
|
||||
if mapping_with_cp is not None:
|
||||
logger.info(
|
||||
f"[MPIDist::__init__] Restoring original mapping undoing Helix manipulation."
|
||||
)
|
||||
self.mapping = mapping_with_cp
|
||||
self._cp_comm = None
|
||||
self._tp_comm = None
|
||||
self._pp_comm = None
|
||||
|
||||
def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
|
||||
comm = mpi_comm()
|
||||
@ -392,6 +437,9 @@ class MPIDist(Distributed):
|
||||
def barrier(self):
|
||||
mpi_barrier()
|
||||
|
||||
def tp_barrier(self):
|
||||
self.tp_comm.Barrier()
|
||||
|
||||
def isend(self, buf: np.ndarray, dest, tag=0):
|
||||
# non-blocking send numpy buffer
|
||||
return mpi_isend(buf, dest, tag)
|
||||
@ -413,17 +461,32 @@ class MPIDist(Distributed):
|
||||
def recv_object(self, src, tag=0):
|
||||
return mpi_recv_object(src, tag)
|
||||
|
||||
def create_tp_comm(self):
|
||||
new_group = mpi_comm().group.Incl(self.mapping.tp_group)
|
||||
self.tp_comm = mpi_comm().Create_group(new_group)
|
||||
@property
|
||||
def tp_comm(self):
|
||||
if self._tp_comm is None:
|
||||
mapping = self.mapping
|
||||
if mapping.has_cp_helix():
|
||||
mapping = mapping.repurpose_helix_cp_to_tp()
|
||||
new_group = mpi_comm().group.Incl(mapping.tp_group)
|
||||
self._tp_comm = mpi_comm().Create_group(new_group)
|
||||
return self._tp_comm
|
||||
|
||||
def create_pp_comm(self):
|
||||
new_group = mpi_comm().group.Incl(self.mapping.pp_group)
|
||||
self.pp_comm = mpi_comm().Create_group(new_group)
|
||||
@property
|
||||
def pp_comm(self):
|
||||
if self._pp_comm is None:
|
||||
mapping = self.mapping
|
||||
if mapping.has_cp_helix():
|
||||
mapping = mapping.repurpose_helix_cp_to_tp()
|
||||
new_group = mpi_comm().group.Incl(mapping.pp_group)
|
||||
self._pp_comm = mpi_comm().Create_group(new_group)
|
||||
return self._pp_comm
|
||||
|
||||
def create_cp_comm(self):
|
||||
new_group = mpi_comm().group.Incl(self.mapping.cp_group)
|
||||
self.cp_comm = mpi_comm().Create_group(new_group)
|
||||
@property
|
||||
def cp_comm(self):
|
||||
if self._cp_comm is None:
|
||||
new_group = mpi_comm().group.Incl(self.mapping.cp_group)
|
||||
self._cp_comm = mpi_comm().Create_group(new_group)
|
||||
return self._cp_comm
|
||||
|
||||
def cp_allgather(self, obj):
|
||||
return self.cp_comm.allgather(obj)
|
||||
@ -460,6 +523,10 @@ class MPIDist(Distributed):
|
||||
def pp_broadcast(self, obj, root=0):
|
||||
return self.pp_comm.bcast(obj, root)
|
||||
|
||||
def allreduce(self, obj, op: ReduceOp = ReduceOp.SUM):
|
||||
reduce_op = reduce_op_to_mpi(op)
|
||||
return mpi_comm().allreduce(obj, reduce_op)
|
||||
|
||||
|
||||
class MultiHandleWrapper:
|
||||
"""
|
||||
@ -610,6 +677,10 @@ class TorchDist(Distributed):
|
||||
def barrier(self):
|
||||
dist.barrier()
|
||||
|
||||
@log_op
|
||||
def tp_barrier(self):
|
||||
dist.barrier(group=self.mapping.tp_group_pg)
|
||||
|
||||
@log_op
|
||||
def isend(self, buf: np.ndarray, dest, tag=0):
|
||||
# non-blocking send numpy buffer
|
||||
@ -673,14 +744,16 @@ class TorchDist(Distributed):
|
||||
return MultiHandleWrapper(works)
|
||||
|
||||
@log_op
|
||||
def allreduce(self,
|
||||
obj: int | float | torch.Tensor,
|
||||
op=torch.distributed.ReduceOp.SUM):
|
||||
def allreduce(
|
||||
self,
|
||||
obj: int | float | torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
):
|
||||
is_base_type = isinstance(obj, int) or isinstance(obj, float)
|
||||
if is_base_type:
|
||||
obj = torch.tensor(obj)
|
||||
|
||||
dist.all_reduce(obj, op=op)
|
||||
dist.all_reduce(obj, op=reduce_op_to_torch(op))
|
||||
|
||||
if is_base_type:
|
||||
obj = obj.item()
|
||||
|
||||
@ -10,7 +10,7 @@ from tensorrt_llm.llmapi.llm_args import (BaseSparseAttentionConfig,
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...inputs.multimodal import MultimodalParams
|
||||
from ..distributed import MPIDist
|
||||
from ..distributed import Distributed
|
||||
from ..expert_statistic import ExpertStatistic
|
||||
from ..memory_buffer_utils import get_memory_buffers
|
||||
from ..modules.multi_stream_utils import with_multi_stream
|
||||
@ -75,7 +75,7 @@ class CUDAGraphRunnerConfig:
|
||||
enable_attention_dp: bool
|
||||
batch_size: int
|
||||
mapping: Optional[Mapping]
|
||||
dist: Optional[MPIDist]
|
||||
dist: Optional[Distributed]
|
||||
kv_cache_manager_key: Any
|
||||
sparse_attention_config: Optional[BaseSparseAttentionConfig] = None
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ from ..attention_backend.vanilla import VanillaAttentionMetadata
|
||||
from ..autotuner import AutoTuner, autotune
|
||||
from ..compilation.backend import Backend
|
||||
from ..compilation.utils import capture_piecewise_cuda_graph
|
||||
from ..distributed import MPIDist
|
||||
from ..distributed import Distributed
|
||||
from ..distributed.communicator import init_pp_comm
|
||||
from ..expert_statistic import ExpertStatistic
|
||||
from ..memory_buffer_utils import with_shared_pool
|
||||
@ -134,7 +134,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
llm_args: TorchLlmArgs,
|
||||
mapping: Optional[Mapping] = None,
|
||||
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
|
||||
dist: Optional[MPIDist] = None,
|
||||
dist: Optional[Distributed] = None,
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
is_draft_model: bool = False,
|
||||
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
|
||||
|
||||
@ -13,7 +13,7 @@ from strenum import StrEnum
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
||||
from tensorrt_llm._utils import get_sm_version, mpi_disabled
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy,
|
||||
ContextChunkingPolicy,
|
||||
GuidedDecodingConfig, LoadFormat,
|
||||
@ -27,7 +27,7 @@ from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..attention_backend.interface import AttentionRuntimeFeatures
|
||||
from ..attention_backend.trtllm import TrtllmAttention
|
||||
from ..distributed import MPIDist, TorchDist
|
||||
from ..distributed import Distributed
|
||||
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
|
||||
get_spec_resource_manager)
|
||||
from ..virtual_memory import ExecutorMemoryType, RestoreMode
|
||||
@ -303,10 +303,7 @@ def create_py_executor(
|
||||
"when only processing vision encoder inputs.")
|
||||
|
||||
mapping = _get_mapping(llm_args.parallel_config.to_mapping())
|
||||
if mpi_disabled():
|
||||
dist = TorchDist(mapping=mapping)
|
||||
else:
|
||||
dist = MPIDist(mapping=mapping)
|
||||
dist = Distributed.get(mapping)
|
||||
|
||||
vm_pools = {}
|
||||
enable_sleep = llm_args.enable_sleep
|
||||
|
||||
@ -10,8 +10,7 @@ import torch
|
||||
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||
from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp
|
||||
from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig,
|
||||
PybindMirror)
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
@ -28,11 +27,6 @@ from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
|
||||
get_draft_token_length)
|
||||
from .scheduler import ScheduledRequests
|
||||
|
||||
if ENABLE_MULTI_DEVICE:
|
||||
from mpi4py import MPI
|
||||
|
||||
from tensorrt_llm._utils import mpi_comm
|
||||
|
||||
BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager
|
||||
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
|
||||
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
|
||||
@ -803,12 +797,11 @@ class KVCacheManager(BaseResourceManager):
|
||||
|
||||
if mapping.world_size > 1:
|
||||
# make sure all ranks use same value for maxTokens
|
||||
if mpi_disabled():
|
||||
from tensorrt_llm._utils import torch_comm
|
||||
max_tokens = torch_comm().allreduce(
|
||||
max_tokens, op=torch.distributed.ReduceOp.MIN)
|
||||
else:
|
||||
max_tokens = mpi_comm().allreduce(max_tokens, op=MPI.MIN)
|
||||
dist = Distributed.get(mapping)
|
||||
max_tokens = dist.allreduce(
|
||||
max_tokens,
|
||||
op=ReduceOp.MIN,
|
||||
)
|
||||
|
||||
# get number of blocks
|
||||
blocks_in_primary_pool = int(max_tokens // tokens_per_block)
|
||||
|
||||
@ -29,10 +29,10 @@ import tensorrt_llm as tllm
|
||||
from tensorrt_llm import Mapping
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
MPIDist, TorchDist)
|
||||
Distributed)
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._utils import (get_sm_version, local_mpi_rank, local_mpi_size,
|
||||
mpi_disabled, nvtx_range)
|
||||
nvtx_range)
|
||||
from tensorrt_llm.bindings.internal.runtime import delay_kernel
|
||||
from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -41,7 +41,7 @@ from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
|
||||
|
||||
def profile_allreduce(
|
||||
mapping: Mapping,
|
||||
dist: TorchDist | MPIDist,
|
||||
dist: Distributed,
|
||||
enable_cudagraph: bool = False,
|
||||
inner_loop=200,
|
||||
outer_loop=10,
|
||||
@ -137,14 +137,10 @@ def allreduce_benchmark(
|
||||
cudart.cudaSetDevice(local_rank)
|
||||
|
||||
mapping = Mapping(world_size, rank, gpus_per_node, tp_size=world_size)
|
||||
if mpi_disabled():
|
||||
dist = TorchDist(mapping=mapping)
|
||||
else:
|
||||
dist = MPIDist(mapping=mapping)
|
||||
|
||||
logger.set_rank(mapping.rank)
|
||||
|
||||
AutoTuner.get().setup_distributed_state(mapping, dist)
|
||||
AutoTuner.get().setup_distributed_state(mapping)
|
||||
|
||||
sm_version = get_sm_version()
|
||||
|
||||
|
||||
@ -6,10 +6,11 @@ from _dist_test_utils import get_device_counts
|
||||
from torch.export import export
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.trtllm_dist import is_trtllm_op_available
|
||||
from tensorrt_llm._torch.auto_deploy.distributed import common as dist
|
||||
from tensorrt_llm._torch.auto_deploy.distributed.common import initialize_or_skip
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
|
||||
|
||||
# needed since MPI executor pool leaks a thread (_manager_spawn) on shutdown
|
||||
@ -66,7 +67,7 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str):
|
||||
if not is_trtllm_op_available():
|
||||
pytest.skip("Require trtllm ops to run test_allreduce_fusion.")
|
||||
|
||||
_, _ = dist.initialize_or_skip(port=port)
|
||||
_, _ = initialize_or_skip(port=port)
|
||||
|
||||
# Testing tensors
|
||||
dtype = torch.float16
|
||||
@ -146,7 +147,7 @@ def _test_allreduce_fusion(port: int, ModuleCls, strategy: str):
|
||||
def test_allreduce_fusion(device_count, ModuleCls, strategy):
|
||||
if device_count <= 1:
|
||||
pytest.skip("Require multi GPUs to run test_allreduce_fusion.")
|
||||
port = dist.get_free_port()
|
||||
port = get_free_port()
|
||||
|
||||
n_workers = device_count
|
||||
mpi_pool = MpiPoolSession(n_workers=n_workers)
|
||||
|
||||
@ -17,10 +17,9 @@ from tensorrt_llm._torch.autotuner import (AutoTuner, DistributedTuningStrategy,
|
||||
FakeTensor, OptimizationProfile,
|
||||
StaticDim, TunableRunner,
|
||||
TuningConfig, autotune)
|
||||
from tensorrt_llm._torch.distributed.communicator import MPIDist, TorchDist
|
||||
from tensorrt_llm._torch.distributed import Distributed
|
||||
from tensorrt_llm._torch.utils import (get_power_of_2_num_tokens_buckets,
|
||||
next_positive_power_of_2)
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.bindings.internal.runtime import delay_kernel
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -720,14 +719,11 @@ def _distributed_worker_function(world_size, strategy):
|
||||
rank=rank,
|
||||
tp_size=world_size,
|
||||
pp_size=1)
|
||||
if mpi_disabled():
|
||||
dist = TorchDist(mapping=mapping)
|
||||
else:
|
||||
dist = MPIDist(mapping=mapping)
|
||||
dist = Distributed.get(mapping)
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
tuner.clear_cache()
|
||||
tuner.setup_distributed_state(mapping, dist)
|
||||
tuner.setup_distributed_state(mapping)
|
||||
|
||||
x = torch.randn(16, 32, device='cuda')
|
||||
w = torch.randn(32, 64, device='cuda')
|
||||
|
||||
@ -24,7 +24,6 @@ from utils.util import (check_accuracy, skip_blackwell, skip_blackwell_geforce,
|
||||
skip_pre_hopper)
|
||||
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.distributed import MPIDist, TorchDist
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \
|
||||
CuteDslFusedMoE
|
||||
@ -45,7 +44,7 @@ from tensorrt_llm._torch.modules.fused_moe.quantization import \
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
|
||||
IS_TRITON_KERNELS_AVAILABLE
|
||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||
from tensorrt_llm._utils import get_sm_version, mpi_disabled, mpi_rank
|
||||
from tensorrt_llm._utils import get_sm_version, mpi_rank
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||
|
||||
@ -105,12 +104,7 @@ def test_fused_moe(moe_backend,
|
||||
|
||||
mapping = mapping or Mapping()
|
||||
mapping.rank = mpi_rank()
|
||||
if mpi_disabled():
|
||||
dist = TorchDist(mapping=mapping)
|
||||
else:
|
||||
dist = MPIDist(mapping=mapping)
|
||||
|
||||
AutoTuner.get().setup_distributed_state(mapping, dist)
|
||||
AutoTuner.get().setup_distributed_state(mapping)
|
||||
|
||||
torch.cuda.set_device(mapping.rank)
|
||||
|
||||
|
||||
@ -27,11 +27,9 @@ import tensorrt_llm
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
AllReduceParams, AllReduceStrategy,
|
||||
MoEAllReduce, MoEAllReduceParams,
|
||||
MPIDist, TorchDist)
|
||||
MoEAllReduce, MoEAllReduceParams)
|
||||
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@ -133,12 +131,8 @@ def run_allreduce_op(
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
)
|
||||
if mpi_disabled():
|
||||
dist = TorchDist(mapping=mapping)
|
||||
else:
|
||||
dist = MPIDist(mapping=mapping)
|
||||
|
||||
AutoTuner.get().setup_distributed_state(mapping, dist)
|
||||
AutoTuner.get().setup_distributed_state(mapping)
|
||||
linear = Linear(
|
||||
in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.bindings
|
||||
import tensorrt_llm.bindings.executor as trtllm
|
||||
from tensorrt_llm._torch.distributed import MPIDist
|
||||
from tensorrt_llm._torch.distributed import Distributed
|
||||
from tensorrt_llm._torch.pyexecutor.kv_cache_transceiver import \
|
||||
create_kv_cache_transceiver
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
|
||||
@ -79,7 +79,7 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
|
||||
|
||||
cache_transceiver_config = CacheTransceiverConfig(backend=backend,
|
||||
max_tokens_in_buffer=512)
|
||||
dist = MPIDist(mapping=mapping)
|
||||
dist = Distributed.get(mapping)
|
||||
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
|
||||
mapping, dist, kv_cache_manager_ctx, attention_type,
|
||||
cache_transceiver_config)
|
||||
@ -139,7 +139,7 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
|
||||
def test_cancel_request_in_transmission(attention_type):
|
||||
# Init kv_cache manager and cache transceiver
|
||||
mapping = Mapping(world_size=1, rank=0)
|
||||
dist = MPIDist(mapping=mapping)
|
||||
dist = Distributed.get(mapping)
|
||||
ctx_kv_cache_dtype, gen_kv_cache_dtype = DataType.HALF, DataType.HALF
|
||||
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
|
||||
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user