refactor: remove ParallelConfig in tensorrt_llm._torch.distributed module (#3370)

* remove tensorrt_llm._torch.distributed.ParallelConfig

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

* fix ci

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

* fix ci

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

* clean

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

* fix embedding test

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

* fix

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

* fix comments

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

* polish

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

* fix ci

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

* rebase

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>

---------

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
QI JUN 2025-04-12 06:34:20 +08:00 committed by GitHub
parent cf9ceea890
commit d167cbd5bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 406 additions and 575 deletions

View File

@ -7,7 +7,6 @@ from transformers import OPTConfig
from transformers.activations import ACT2FN
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
DecoderModelForCausalLM,
@ -16,7 +15,7 @@ from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from tensorrt_llm._torch.modules.embedding import Embedding
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
class LayerNorm(nn.LayerNorm):
@ -70,10 +69,8 @@ class OPTDecoderLayer(DecoderLayer):
config.ffn_dim,
bias=config.enable_bias,
dtype=config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN),
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=model_config.get_quant_config(),
)
self.fc2 = Linear(
@ -81,10 +78,8 @@ class OPTDecoderLayer(DecoderLayer):
config.hidden_size,
bias=config.enable_bias,
dtype=config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
tensor_parallel_mode=TensorParallelMode.ROW),
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=model_config.get_quant_config(),
)
self.final_layer_norm = LayerNorm(
@ -150,10 +145,8 @@ class OPTModel(DecoderModel):
config.vocab_size,
config.word_embed_proj_dim,
dtype=config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN))
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN)
self.embed_positions = nn.Embedding(config.max_position_embeddings + 2,
config.hidden_size,
dtype=config.torch_dtype)

View File

@ -9,7 +9,6 @@ from tensorrt_llm.functional import AttentionMaskType
from tensorrt_llm.models.modeling_utils import QuantConfig
from ..distributed import allgather
from ..modules.linear import ParallelConfig
from .flashinfer import FlashInferAttentionMetadata, PlanParams
from .interface import AttentionBackend, AttentionMask, PredefinedAttentionMask
from .vanilla import VanillaAttention
@ -440,12 +439,10 @@ class StarAttention(AttentionBackend[StarAttentionMetadata]):
out_tmp = output
lse = lse.unsqueeze(-1) / np.log2(np.e) # [b * s, nheads, 1]
if metadata.mapping.cp_size != 1:
parallel_cfg = ParallelConfig(
tensor_parallel_size=metadata.mapping.cp_size,
tensor_parallel_rank=metadata.mapping.cp_rank,
pipeline_parallel_size=metadata.mapping.pp_size)
output_tensor = allgather(output, parallel_cfg, gather_dim=0)
lse_tensor = allgather(lse, parallel_cfg, gather_dim=0)
output_tensor = allgather(output,
metadata.mapping,
gather_dim=0)
lse_tensor = allgather(lse, metadata.mapping, gather_dim=0)
output_tensor = output_tensor.to(torch.float32)
else:
lse_tensor = lse

View File

@ -4,18 +4,19 @@ from .common import ReduceOp, get_rank_world_size, is_ompi
# use trtllm distributed ops to improve TP performance if possible
try:
from ....mapping import Mapping
from ...distributed import AllReduce, allgather
from ...modules.linear import AllReduceFusionOp, AllReduceParams, ParallelConfig
from ...modules.linear import AllReduceFusionOp, AllReduceParams
def trtllm_allgather(tensor, dim):
rank, world_size = get_rank_world_size()
p_config = ParallelConfig(tensor_parallel_size=world_size, tensor_parallel_rank=rank)
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
return allgather(tensor, p_config, gather_dim=dim)
def trtllm_allreduce(tensor, op, all_reduce_params=None):
rank, world_size = get_rank_world_size()
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
p_config = ParallelConfig(tensor_parallel_size=world_size, tensor_parallel_rank=rank)
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
torch_op = AllReduce(p_config)
return torch_op(tensor, all_reduce_params=all_reduce_params)

View File

@ -1,8 +1,6 @@
import atexit
import enum
import os
import threading
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
@ -17,37 +15,6 @@ from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
_thread_local = threading.local()
class TensorParallelMode(str, enum.Enum):
COLUMN = 'column'
ROW = 'row'
@classmethod
def split_dim(cls, mode):
return 1 if mode == cls.ROW else 0
@dataclass(kw_only=True)
class ParallelConfig:
tensor_parallel_size: int = 1
tensor_parallel_rank: int = 0
gpus_per_node: int = 8
tensor_parallel_mode: Optional[TensorParallelMode] = None
gather_output: bool = False
# pipeline parallel parameter in case we have multiple parallel groups
# default to TP-only mode if not specified for backward compatibility
# TODO Remove redundant fields. Keep only parallel_rank, tp_size, pp_size in constructor
# and infer tp_rank, pp_rank, etc. automatically.
pipeline_parallel_size: int = 1
parallel_rank: Optional[int] = None
def __post_init__(self):
self.parallel_size = self.tensor_parallel_size * self.pipeline_parallel_size
if self.pipeline_parallel_size > 1:
assert self.parallel_rank is not None, "parallel_rank must be specified for PP mode"
else:
self.parallel_rank = self.tensor_parallel_rank
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
if not hasattr(_thread_local, 'allreduce_workspaces'):
_thread_local.allreduce_workspaces = {}
@ -77,7 +44,7 @@ def get_deepseek_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
def allreduce(
input: torch.Tensor,
workspace: Optional[torch.LongTensor],
parallel_config: ParallelConfig,
mapping: Mapping,
strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
config: AllReduceConfig = AllReduceConfig(0),
all_reduce_params: Optional[AllReduceParams] = None
@ -98,7 +65,7 @@ def allreduce(
Args:
input (Tensor): The input tensor.
parallel_config (ParallelConfig): The parallel config.
mapping (Mapping): The parallel mapping.
strategy (AllReduceStrategy): NCCL delegates all-reduce to NCCL while ONESHOT and TWOSHOT are custom latency-optimal algorithms.
AUTO chooses amongst the three based on a message-size heuristic.
config (AllReduceConfig): The config for custom allreduce kernels.
@ -106,19 +73,10 @@ def allreduce(
Returns:
The reduced tensor and an optional intermediate tensor if fused.
'''
if parallel_config.tensor_parallel_size == 1 or (
all_reduce_params is not None
and all_reduce_params.enable_allreduce == False):
if mapping.tp_size == 1 or (all_reduce_params is not None and
all_reduce_params.enable_allreduce == False):
return input
mapping = Mapping(
world_size=parallel_config.parallel_size,
tp_size=parallel_config.tensor_parallel_size,
pp_size=parallel_config.pipeline_parallel_size,
rank=parallel_config.parallel_rank,
gpus_per_node=parallel_config.gpus_per_node,
)
if all_reduce_params is None:
all_reduce_params = AllReduceParams()
is_fused = all_reduce_params.fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM or \
@ -164,7 +122,7 @@ def userbuffers_allreduce_finalize(
def allgather(input: torch.Tensor,
parallel_config: ParallelConfig,
mapping: Mapping,
gather_dim: int = -1) -> torch.Tensor:
'''
Add an operation that performs a collective all-gather.
@ -184,22 +142,14 @@ def allgather(input: torch.Tensor,
Args:
input (Tensor): The input tensor.
parallel_config (ParallelConfig): The parallel config.
mapping (Mapping): The parallel mapping.
gather_dim (int): Gather along given dimension. By default -1.
Returns:
The gathered tensor.
'''
if parallel_config.tensor_parallel_size == 1:
if mapping.tp_size == 1:
return input
mapping = Mapping(
world_size=parallel_config.parallel_size,
tp_size=parallel_config.tensor_parallel_size,
pp_size=parallel_config.pipeline_parallel_size,
rank=parallel_config.parallel_rank,
gpus_per_node=parallel_config.gpus_per_node,
)
output = torch.ops.trtllm.allgather(
input,
mapping.tp_group,
@ -211,26 +161,17 @@ def allgather(input: torch.Tensor,
output = torch.movedim(output, 0, gather_dim)
input_shape = input.size()
output = output.reshape(input_shape[:gather_dim] +
(parallel_config.tensor_parallel_size *
input_shape[gather_dim], ) +
(mapping.tp_size * input_shape[gather_dim], ) +
input_shape[gather_dim + 1:])
return output
def reducescatter(input: torch.Tensor,
parallel_config: ParallelConfig,
mapping: Mapping,
scatter_dim: int = -1) -> torch.Tensor:
if parallel_config.tensor_parallel_size == 1:
if mapping.tp_size == 1:
return input
mapping = Mapping(
world_size=parallel_config.parallel_size,
tp_size=parallel_config.tensor_parallel_size,
pp_size=parallel_config.pipeline_parallel_size,
rank=parallel_config.parallel_rank,
gpus_per_node=parallel_config.gpus_per_node,
)
output = torch.ops.trtllm.reducescatter(
input,
mapping.tp_group,
@ -242,8 +183,7 @@ def reducescatter(input: torch.Tensor,
output = torch.movedim(output, 0, scatter_dim)
input_shape = input.size()
output = output.reshape(input_shape[:scatter_dim] +
(input_shape[scatter_dim] //
parallel_config.tensor_parallel_size, ) +
(input_shape[scatter_dim] // mapping.tp_size, ) +
input_shape[scatter_dim + 1:])
return output
@ -251,25 +191,14 @@ def reducescatter(input: torch.Tensor,
class AllReduce(nn.Module):
def __init__(self,
parallel_config: ParallelConfig,
mapping: Mapping,
strategy: AllReduceStrategy = AllReduceStrategy.AUTO):
super().__init__()
self.parallel_config = parallel_config
self.tp_size = self.parallel_config.tensor_parallel_size
self.rank = self.parallel_config.parallel_rank
self.gpus_per_node = self.parallel_config.gpus_per_node
self.mapping = mapping
self.workspace = None
self.strategy = strategy
if self.tp_size > 1:
mapping = Mapping(
world_size=self.parallel_config.parallel_size,
tp_size=self.tp_size,
pp_size=self.parallel_config.pipeline_parallel_size,
rank=self.rank,
gpus_per_node=self.gpus_per_node,
)
if self.mapping.tp_size > 1:
if self.strategy != AllReduceStrategy.UB:
self.workspace = get_allreduce_workspace(mapping)
@ -281,7 +210,7 @@ class AllReduce(nn.Module):
) -> torch.Tensor:
output = allreduce(input,
self.workspace,
self.parallel_config,
self.mapping,
all_reduce_params=all_reduce_params,
strategy=self.strategy)
return output
@ -289,22 +218,11 @@ class AllReduce(nn.Module):
class DeepseekAllReduce(nn.Module):
def __init__(self, parallel_config: ParallelConfig):
def __init__(self, mapping: Mapping):
super().__init__()
self.parallel_config = parallel_config
self.tp_size = self.parallel_config.tensor_parallel_size
self.tp_rank = self.parallel_config.tensor_parallel_rank
self.gpus_per_node = self.parallel_config.gpus_per_node
self.rank = self.parallel_config.parallel_rank
self.mapping = mapping
self.workspace = None
if self.tp_size > 1:
mapping = Mapping(
world_size=self.parallel_config.parallel_size,
tp_size=self.tp_size,
pp_size=self.parallel_config.pipeline_parallel_size,
rank=self.rank,
gpus_per_node=self.gpus_per_node,
)
if self.mapping.tp_size > 1:
self.workspace = get_deepseek_allreduce_workspace(mapping)
def forward(
@ -331,8 +249,8 @@ class DeepseekAllReduce(nn.Module):
input=hidden_states,
workspace=self.workspace,
reduce_fusion_inputs=reduce_fusion_inputs,
rank=self.parallel_config.tensor_parallel_rank,
nranks=self.parallel_config.tensor_parallel_size,
rank=self.mapping.tp_rank,
nranks=self.mapping.tp_size,
eps=eps,
fusion_op=fusion_op,
)

View File

@ -15,7 +15,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
DeepseekAllReduce, ParallelConfig, allgather)
DeepseekAllReduce, allgather)
from ..model_config import ModelConfig
from ..models.modeling_utils import MissingLayer, ModelConfig, support_pp
from ..modules.attention import MLA
@ -306,13 +306,8 @@ class Deepseekv3MoE(nn.Module):
overridden_tp_size=shared_tp_size,
is_expert=True)
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
gpus_per_node=model_config.mapping.gpus_per_node,
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank)
self.all_reduce = AllReduce(self.parallel_config)
self.mapping = model_config.mapping
self.all_reduce = AllReduce(self.mapping)
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
@ -321,14 +316,14 @@ class Deepseekv3MoE(nn.Module):
def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, min_latency_mode):
if self.use_dp and self.parallel_config.tensor_parallel_size > 1:
if self.use_dp and self.mapping.tp_size > 1:
max_num_token = max(all_rank_num_tokens)
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_token - hidden_states.shape[0]))
if disable_fp4_allgather():
hidden_states = allgather(hidden_states,
self.parallel_config,
self.mapping,
gather_dim=0)
router_logits = self.gate(hidden_states)
@ -383,7 +378,7 @@ class Deepseekv3MoE(nn.Module):
assert shared_output.size() == routed_output.size(
), f'unmatched tensor shape'
final_hidden_states = shared_output + routed_output
if not self.use_dp and self.parallel_config.tensor_parallel_size > 1:
if not self.use_dp and self.mapping.tp_size > 1:
final_hidden_states = self.all_reduce(
final_hidden_states, all_reduce_params=final_all_reduce_params)
@ -470,14 +465,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
gpus_per_node=model_config.mapping.gpus_per_node,
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank)
self.mapping = model_config.mapping
self.layer_idx = layer_idx
self.all_reduce = AllReduce(self.parallel_config)
self.all_reduce = AllReduce(self.mapping)
self.next_layer_layernorm: RMSNorm = None
self.deepseek_allreduce_disabled = os.environ.get(
@ -486,7 +476,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
self.deepseek_allreduce_disabled = True
if not self.deepseek_allreduce_disabled:
self.deepseek_allreduce = DeepseekAllReduce(self.parallel_config)
self.deepseek_allreduce = DeepseekAllReduce(self.mapping)
def forward(
self,
@ -511,9 +501,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.PRE_MOE_FUSION or self.fusion_config.
PRE_MLP_FUSION or self.parallel_config.tensor_parallel_size == 1
or self.enable_attention_dp)),
self.fusion_config.PRE_MOE_FUSION
or self.fusion_config.PRE_MLP_FUSION
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
**kwargs,
)
@ -586,9 +576,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states_fp4,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
POST_MLP_FUSION or self.parallel_config.tensor_parallel_size
== 1 or self.enable_attention_dp)),
self.fusion_config.POST_MOE_FUSION
or self.fusion_config.POST_MLP_FUSION
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
min_latency_mode=min_latency_mode,
)
else:
@ -596,9 +586,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
POST_MLP_FUSION or self.parallel_config.tensor_parallel_size
== 1 or self.enable_attention_dp)),
self.fusion_config.POST_MOE_FUSION
or self.fusion_config.POST_MLP_FUSION
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
min_latency_mode=min_latency_mode,
)
@ -729,8 +719,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.PRE_MOE_FUSION or self.parallel_config.
tensor_parallel_size == 1 or self.enable_attention_dp)),
self.fusion_config.PRE_MOE_FUSION or self.mapping.tp_size == 1
or self.enable_attention_dp)),
**kwargs,
)
@ -762,8 +752,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
hidden_states,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION or self.parallel_config.
tensor_parallel_size == 1 or self.enable_attention_dp)),
self.fusion_config.POST_MOE_FUSION or self.mapping.tp_size == 1
or self.enable_attention_dp)),
)
if self.fusion_config.POST_MOE_FUSION:

View File

@ -6,8 +6,7 @@ from torch import nn
from transformers import Llama4Config, LlamaConfig
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, DeepseekAllReduce,
ParallelConfig, TensorParallelMode)
AllReduceParams, DeepseekAllReduce)
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
@ -19,7 +18,8 @@ from ..modules.embedding import Embedding
from ..modules.fused_moe import (FusedMoE, Llama4RenormalizeMoeRoutingMethod,
MoEWeightLoadingMode)
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, WeightMode, WeightsLoadingConfig
from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
WeightsLoadingConfig)
from ..modules.rms_norm import RMSNorm
from ..modules.rotary_embedding import RotaryEmbedding
from ..speculative import Eagle3SpecMetadata, SpecMetadata
@ -132,11 +132,8 @@ class Llama4MoE(nn.Module):
dtype=config.torch_dtype,
quant_config=None)
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
gpus_per_node=model_config.mapping.gpus_per_node)
self.all_reduce = AllReduce(self.parallel_config)
self.mapping = model_config.mapping
self.all_reduce = AllReduce(self.mapping)
self.moe_event = [torch.cuda.Event(), torch.cuda.Event()]
self.aux_stream = aux_stream
@ -177,7 +174,7 @@ class Llama4MoE(nn.Module):
assert shared_output.size() == routed_output.size(
), f'unmatched tensor shape'
final_hidden_states = shared_output + routed_output
if self.parallel_config.tensor_parallel_size > 1:
if self.mapping.tp_size > 1:
final_hidden_states = self.all_reduce(
final_hidden_states, all_reduce_params=final_all_reduce_params)
@ -268,14 +265,11 @@ class LlamaDecoderLayer(DecoderLayer):
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
gpus_per_node=model_config.mapping.gpus_per_node)
self.all_reduce = AllReduce(self.parallel_config)
self.mapping = model_config.mapping
self.all_reduce = AllReduce(self.mapping)
self.next_layer_layernorm: RMSNorm = None
self.deepseek_allreduce = DeepseekAllReduce(self.parallel_config)
self.deepseek_allreduce = DeepseekAllReduce(self.mapping)
def forward(
self,
@ -307,9 +301,9 @@ class LlamaDecoderLayer(DecoderLayer):
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.PRE_MOE_FUSION
or self.parallel_config.tensor_parallel_size == 1)),
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.fusion_config.PRE_MOE_FUSION
or self.mapping.tp_size == 1)),
**kwargs,
)
@ -332,9 +326,8 @@ class LlamaDecoderLayer(DecoderLayer):
hidden_states,
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
self.fusion_config.POST_MOE_FUSION
or self.fusion_config.POST_MLP_FUSION
or self.parallel_config.tensor_parallel_size == 1)),
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
POST_MLP_FUSION or self.mapping.tp_size == 1)),
min_latency_mode=min_latency_mode,
)
if spec_metadata is not None:
@ -388,8 +381,6 @@ class Eagle3LlamaAttention(LlamaAttention):
config = model_config.pretrained_config
tp_size = model_config.mapping.tp_size
tp_rank = model_config.mapping.tp_rank
gpus_per_node = model_config.mapping.gpus_per_node
# Override the QKV projection. The number of input features
# is twice as big for EAGLE3 draft models.
@ -398,13 +389,8 @@ class Eagle3LlamaAttention(LlamaAttention):
tp_size * self.q_size + 2 * tp_size * self.kv_size,
bias=config.attention_bias,
dtype=config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=tp_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank),
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.FUSED_QKV_LINEAR),
quant_config=model_config.get_quant_config(),
@ -486,15 +472,9 @@ class LlamaModel(DecoderModel):
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank,
gather_output=True,
gpus_per_node=model_config.mapping.gpus_per_node,
),
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(

View File

@ -8,7 +8,6 @@ from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import ParallelConfig
from ..model_config import ModelConfig
from ..models.modeling_utils import ModelConfig
from ..modules.attention import Attention
@ -133,11 +132,7 @@ class MixtralDecoderLayer(DecoderLayer):
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
# TODO: add pipeline parallel config
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
gpus_per_node=model_config.mapping.gpus_per_node)
self.mapping = model_config.mapping
self.layer_idx = layer_idx
def forward(

View File

@ -26,11 +26,11 @@ from tqdm import tqdm
from ...logger import logger
from ..attention_backend.interface import AttentionMetadata
from ..distributed import ParallelConfig, TensorParallelMode
from ..model_config import ModelConfig
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding, LMHead
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import TensorParallelMode
from ..modules.logits_procesor import LogitsProcessor
from .modeling_llama import LlamaAttention
from .modeling_utils import duplicate_kv_weight, register_auto_model
@ -224,13 +224,9 @@ class MllamaForCausalLM(nn.Module):
text_config.vocab_size,
text_config.hidden_size,
dtype=text_config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=config.mapping.tp_rank,
tensor_parallel_size=config.mapping.tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
gpus_per_node=config.mapping.gpus_per_node,
),
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
# use embedding weights in lm_head if tie word embedding is enabled
if text_config.tie_word_embeddings:

View File

@ -4,8 +4,6 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
@ -15,6 +13,7 @@ from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode
from ..modules.rms_norm import RMSNorm
from ..modules.rotary_embedding import RotaryEmbedding
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
@ -55,15 +54,9 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
config.hidden_size,
bias=False,
dtype=config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank,
gather_output=True,
gpus_per_node=model_config.mapping.gpus_per_node,
),
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
quant_config=model_config.get_quant_config(),
skip_create_weights=model_config.skip_create_weights,
)

View File

@ -8,13 +8,12 @@ from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import ParallelConfig, TensorParallelMode
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear
from ..modules.linear import Linear, TensorParallelMode
from ..modules.rms_norm import RMSNorm
from ..modules.rotary_embedding import RotaryEmbedding
from ..pipeline_interface import PipelineInterface
@ -147,13 +146,9 @@ class QwenModel(DecoderModel):
config.pretrained_config.vocab_size,
config.pretrained_config.hidden_size,
dtype=config.pretrained_config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=config.mapping.tp_rank,
tensor_parallel_size=config.mapping.tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
gpus_per_node=config.mapping.gpus_per_node,
),
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList([
QwenDecoderLayer(

View File

@ -10,14 +10,13 @@ from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import ParallelConfig, TensorParallelMode
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import DefaultMoeRoutingMethod, FusedMoE
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear
from ..modules.linear import Linear, TensorParallelMode
from ..modules.rms_norm import RMSNorm
from ..modules.rotary_embedding import RotaryEmbedding
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
@ -175,11 +174,6 @@ class QwenMoeDecoderLayer(DecoderLayer):
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
# TODO: add pipeline parallel config
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
gpus_per_node=model_config.mapping.gpus_per_node)
self.layer_idx = layer_idx
def forward(
@ -224,13 +218,9 @@ class QwenMoeModel(DecoderModel):
config.pretrained_config.vocab_size,
config.pretrained_config.hidden_size,
dtype=config.pretrained_config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=config.mapping.tp_rank,
tensor_parallel_size=config.mapping.tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
gpus_per_node=config.mapping.gpus_per_node,
),
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList([
QwenMoeDecoderLayer(

View File

@ -11,14 +11,15 @@ from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_any_only
from tqdm import tqdm
from tensorrt_llm.mapping import Mapping
from ...logger import logger
from ..attention_backend import AttentionMetadata
from ..distributed import ParallelConfig, TensorParallelMode
from ..model_config import ModelConfig, TConfig
from ..modules.attention import Attention
from ..modules.embedding import Embedding, LMHead
from ..modules.fused_moe import FusedMoE
from ..modules.linear import Linear, WeightMode
from ..modules.linear import Linear, TensorParallelMode, WeightMode
from ..modules.logits_procesor import LogitsProcessor
from ..modules.rms_norm import RMSNorm
from ..pipeline_interface import PipelineInterface
@ -356,14 +357,13 @@ class DecoderModelForCausalLM(nn.Module,
vocab_size,
hidden_size,
dtype=config.pretrained_config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=0,
tensor_parallel_size=1,
tensor_parallel_mode=None,
gather_output=False,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank,
mapping=Mapping(
world_size=1,
tp_size=1,
rank=0,
),
tensor_parallel_mode=None,
gather_output=False,
)
else:
# TODO(zhenhuanc): Currently lm_head Linear will not accept QuantConfig
@ -372,15 +372,9 @@ class DecoderModelForCausalLM(nn.Module,
vocab_size,
hidden_size,
dtype=config.pretrained_config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=config.mapping.tp_size,
tensor_parallel_rank=config.mapping.tp_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
gpus_per_node=config.mapping.gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank,
),
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
# use embedding weights in lm_head if tie word embedding is enabled

View File

@ -4,15 +4,17 @@ from typing import Optional
import torch
from torch import nn
from tensorrt_llm.mapping import Mapping
from ..attention_backend import (AttentionInputType, AttentionMetadata,
TrtllmAttention)
from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask)
from ..attention_backend.utils import create_attention
from ..distributed import AllReduceParams, ParallelConfig, TensorParallelMode
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from .linear import Linear, WeightMode, WeightsLoadingConfig
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
from .rms_norm import RMSNorm
from .rotary_embedding import RotaryEmbedding
@ -80,12 +82,18 @@ class Attention(nn.Module):
# tensor parallel
config = config or ModelConfig()
tp_size = config.mapping.tp_size
tp_rank = config.mapping.tp_rank
gpus_per_node = config.mapping.gpus_per_node
pp_size = config.mapping.pp_size
if config.mapping.enable_attention_dp:
tp_size = 1
tp_rank = 0
mapping = Mapping(
world_size=tp_size * pp_size,
tp_size=tp_size,
pp_size=pp_size,
rank=config.mapping.rank,
gpus_per_node=config.mapping.gpus_per_node,
enable_attention_dp=config.mapping.enable_attention_dp,
)
assert self.num_heads % tp_size == 0
self.num_heads = self.num_heads // tp_size
self.num_key_value_heads = (self.num_key_value_heads + tp_size -
@ -103,13 +111,8 @@ class Attention(nn.Module):
tp_size * self.q_size + 2 * tp_size * self.kv_size,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=tp_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank),
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.FUSED_QKV_LINEAR),
quant_config=config.get_quant_config(),
@ -120,13 +123,8 @@ class Attention(nn.Module):
self.hidden_size,
bias=self.dense_bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=tp_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank),
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights,
)
@ -310,27 +308,17 @@ class MLA(nn.Module):
# tensor parallel
config = config or ModelConfig()
tp_size = config.mapping.tp_size
tp_rank = config.mapping.tp_rank
gpus_per_node = config.mapping.gpus_per_node
pp_size = config.mapping.pp_size
if config.mapping.enable_attention_dp:
tp_size = 1
tp_rank = 0
row_parallel_config = ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.ROW,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank,
)
col_parallel_config = ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank,
mapping = Mapping(
world_size=tp_size * pp_size,
tp_size=tp_size,
pp_size=pp_size,
rank=config.mapping.rank,
gpus_per_node=config.mapping.gpus_per_node,
enable_attention_dp=config.mapping.enable_attention_dp,
)
assert self.num_heads % tp_size == 0
@ -362,7 +350,8 @@ class MLA(nn.Module):
(self.qk_nope_head_dim + self.qk_rope_head_dim),
bias=bias,
dtype=dtype,
parallel_config=col_parallel_config,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights)
else:
@ -381,7 +370,8 @@ class MLA(nn.Module):
(self.qk_nope_head_dim + self.qk_rope_head_dim),
bias=bias,
dtype=dtype,
parallel_config=col_parallel_config,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights)
self.q_b_proj = self.q_proj
@ -400,7 +390,8 @@ class MLA(nn.Module):
(self.qk_nope_head_dim + self.v_head_dim),
bias=bias,
dtype=dtype,
parallel_config=col_parallel_config,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights)
# This parameter will view into self.kv_b_proj.weight after loading weights.
@ -455,7 +446,8 @@ class MLA(nn.Module):
self.hidden_size,
bias=self.dense_bias,
dtype=dtype,
parallel_config=row_parallel_config,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config,
skip_create_weights=config.skip_create_weights,
)

View File

@ -6,9 +6,10 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter
from tensorrt_llm.functional import AllReduceParams
from tensorrt_llm.mapping import Mapping
from ..distributed import ParallelConfig, TensorParallelMode, allgather
from .linear import Linear
from ..distributed import allgather
from .linear import Linear, TensorParallelMode
class LMHead(Linear):
@ -18,7 +19,7 @@ class LMHead(Linear):
num_embeddings (int): vocabulary size.
embedding_dim (int): size of hidden state.
dtype (Optional[torch.dtype]): type of the parameters.
parallel_config (Optional[ParallelConfig]): parallelism configuration.
mapping (Optional[Mapping]): parallelism configuration.
If not provided, the embedding is not parallelized.
"""
@ -27,17 +28,19 @@ class LMHead(Linear):
num_embeddings: int,
embedding_dim: int,
dtype: torch.dtype = None,
parallel_config: Optional[ParallelConfig] = None,
mapping: Optional[Mapping] = None,
tensor_parallel_mode: Optional[TensorParallelMode] = None,
gather_output: bool = False,
):
local_in_features = embedding_dim
local_out_features = num_embeddings
parallel_config = parallel_config or ParallelConfig()
tp_size = parallel_config.tensor_parallel_size
mapping = mapping or Mapping()
tp_size = mapping.tp_size
if parallel_config.tensor_parallel_mode == TensorParallelMode.ROW:
if tensor_parallel_mode == TensorParallelMode.ROW:
local_in_features = math.ceil(embedding_dim / tp_size)
self.padding_size = tp_size * local_in_features - embedding_dim
elif parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
elif tensor_parallel_mode == TensorParallelMode.COLUMN:
local_out_features = math.ceil(num_embeddings / tp_size)
self.padding_size = tp_size * local_out_features - num_embeddings
@ -46,10 +49,12 @@ class LMHead(Linear):
local_out_features * tp_size,
bias=False,
dtype=dtype,
parallel_config=parallel_config,
mapping=mapping,
tensor_parallel_mode=tensor_parallel_mode,
gather_output=gather_output,
)
if parallel_config.tensor_parallel_mode == TensorParallelMode.ROW:
if tensor_parallel_mode == TensorParallelMode.ROW:
if self.tp_rank == self.tp_size - 1:
local_in_features -= self.padding_size
self.in_features = local_in_features
@ -61,7 +66,7 @@ class LMHead(Linear):
@property
def vocab_size_padded(self) -> int:
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
if self.tp_mode == TensorParallelMode.COLUMN:
return self.out_features * self.tp_size
else:
return self.out_features
@ -73,8 +78,7 @@ class LMHead(Linear):
all_reduce_params: Optional[AllReduceParams] = None
) -> torch.Tensor:
output = super().forward(input, all_reduce_params=all_reduce_params)
if (self.tp_mode == TensorParallelMode.COLUMN
and self.parallel_config.gather_output
if (self.tp_mode == TensorParallelMode.COLUMN and self.gather_output
and self.padding_size > 0):
output = output[..., :-self.padding_size]
@ -82,7 +86,7 @@ class LMHead(Linear):
def load_weights(self, weights: List[Dict]):
original_weight = None
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
if self.tp_mode == TensorParallelMode.COLUMN:
if self.tp_rank == self.tp_size - 1 and self.padding_size > 0:
original_weight = self.weight.data.zero_()
self.weight.data = self.weight[:-self.padding_size, :]
@ -113,7 +117,7 @@ class Embedding(LMHead):
num_embeddings (int): vocabulary size.
embedding_dim (int): size of hidden state.
dtype (Optional[torch.dtype]): type of the parameters.
parallel_config (Optional[ParallelConfig]): parallelism configuration.
mapping (Optional[Mapping]): parallelism configuration.
If not provided, the embedding is not parallelized.
"""
@ -122,13 +126,17 @@ class Embedding(LMHead):
num_embeddings: int,
embedding_dim: int,
dtype: Optional[torch.dtype] = None,
parallel_config: Optional[ParallelConfig] = None,
mapping: Optional[Mapping] = None,
tensor_parallel_mode: Optional[TensorParallelMode] = None,
gather_output: bool = False,
):
super().__init__(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
dtype=dtype,
parallel_config=parallel_config,
mapping=mapping,
tensor_parallel_mode=tensor_parallel_mode,
gather_output=gather_output,
)
if self.tp_size > 1:
slice_width = math.ceil(num_embeddings / self.tp_size)
@ -138,7 +146,7 @@ class Embedding(LMHead):
def forward(self, input):
if self.tp_size > 1:
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
if self.tp_mode == TensorParallelMode.COLUMN:
# Build the mask.
input, input_mask = get_masked_input_and_mask(
input,
@ -149,15 +157,15 @@ class Embedding(LMHead):
output = F.embedding(input, self.weight)
# Mask the output embedding.
if self.tp_size > 1:
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
if self.tp_mode == TensorParallelMode.COLUMN:
output.masked_fill_(input_mask, 0)
# Reduce across all the model parallel GPUs.
output = self.all_reduce(output)
elif self.parallel_config.tensor_parallel_mode == TensorParallelMode.ROW:
if self.parallel_config.gather_output:
elif self.tp_mode == TensorParallelMode.ROW:
if self.gather_output:
if self.tp_rank == self.tp_size - 1 and self.padding_size > 0:
output = F.pad(output, (0, self.padding_size))
output = allgather(output, self.parallel_config)
output = allgather(output, self.mapping)
if self.padding_size > 0:
output = output[..., :-self.padding_size]

View File

@ -10,7 +10,7 @@ from ..distributed import allgather, reducescatter
from ..model_config import ModelConfig
from ..utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
reswizzle_sf)
from .linear import ParallelConfig, TensorParallelMode, load_weight_shard
from .linear import TensorParallelMode, load_weight_shard
# The declarations aligns with moe_kernels.h
# pack inputs into int64, e.g. 4 x bf16 input values
@ -270,15 +270,10 @@ class FusedMoE(nn.Module):
self.use_dp = model_config.mapping.enable_attention_dp
# All ranks participate in allreduce regardless of EP/TP combination
self.parallel_config = ParallelConfig(
tensor_parallel_rank=model_config.mapping.tp_rank,
tensor_parallel_size=model_config.mapping.tp_size,
gpus_per_node=model_config.mapping.gpus_per_node,
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank)
self.parallel_size = self.parallel_config.tensor_parallel_size
self.mapping = model_config.mapping
self.parallel_size = self.mapping.tp_size
self.all_reduce = AllReduce(self.parallel_config)
self.all_reduce = AllReduce(self.mapping)
self.intermediate_size_per_partition = intermediate_size // self.tp_size
@ -510,7 +505,7 @@ class FusedMoE(nn.Module):
flatten_outputs = allgather(
torch.cat(flatten_inputs),
self.parallel_config,
self.mapping,
gather_dim=0,
).view(self.parallel_size, -1)
@ -532,9 +527,7 @@ class FusedMoE(nn.Module):
outputs = inputs
if self.parallel_size > 1:
if self.use_dp:
outputs = reducescatter(inputs,
self.parallel_config,
scatter_dim=0)
outputs = reducescatter(inputs, self.mapping, scatter_dim=0)
elif self.reduce_results:
outputs = self.all_reduce(inputs)
return outputs
@ -595,7 +588,7 @@ class FusedMoE(nn.Module):
):
x_sf, token_selected_experts, token_final_scales = self.all_gather(
[x_sf, token_selected_experts, token_final_scales])
x = allgather(x, self.parallel_config, gather_dim=0)
x = allgather(x, self.mapping, gather_dim=0)
token_selected_experts = token_selected_experts.flatten(
0, 1).contiguous()
token_final_scales = token_final_scales.flatten(0, 1).contiguous()
@ -701,7 +694,7 @@ class FusedMoE(nn.Module):
self.event_dict[EventType.MoeChunkingOverlap].wait()
outputs = torch.cat(outputs_list)
if self.use_dp:
rank = self.parallel_config.tensor_parallel_rank
rank = self.mapping.tp_rank
outputs = outputs[:all_rank_num_tokens[rank]]
return outputs

View File

@ -5,13 +5,14 @@ import torch
import torch.nn.functional as F
from torch import nn
from tensorrt_llm._torch.peft.lora.layer import LoraLayer, LoraModuleType
from tensorrt_llm.mapping import Mapping
from ..custom_ops import IS_FLASHINFER_AVAIABLE
from ..distributed import AllReduceParams, ParallelConfig, TensorParallelMode
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import Fp4QuantizedTensor
from .linear import Linear, WeightMode, WeightsLoadingConfig
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
def swiglu(x):
@ -44,30 +45,31 @@ class GatedMLP(nn.Module):
self.activation = activation
config = config or ModelConfig()
self.mapping = config.mapping
if overridden_tp_size is not None:
assert config.mapping.tp_size % overridden_tp_size == 0
tp_rank = config.mapping.tp_rank % overridden_tp_size
tp_size = overridden_tp_size
# "Misuse" pp_size here to perform all-reduce within smaller groups
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
else:
tp_rank = config.mapping.tp_rank
tp_size = config.mapping.tp_size
pp_size = config.mapping.pp_size
gpus_per_node = config.mapping.gpus_per_node
mapping = Mapping(
world_size=tp_size * pp_size,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
tp_size=tp_size,
pp_size=pp_size,
)
self.gate_up_proj = Linear(
self.hidden_size,
self.intermediate_size * 2,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=pp_size,
parallel_rank=config.mapping.rank),
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.FUSED_GATE_UP_LINEAR),
quant_config=config.get_quant_config(),
@ -79,13 +81,8 @@ class GatedMLP(nn.Module):
self.hidden_size,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.ROW,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=pp_size,
parallel_rank=config.mapping.rank),
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
is_expert=is_expert,
skip_create_weights=config.skip_create_weights,

View File

@ -10,9 +10,9 @@ from torch.nn.parameter import Parameter
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
from tensorrt_llm.mapping import Mapping
from ...models.modeling_utils import QuantConfig
from ..distributed import ParallelConfig, TensorParallelMode
from ..utils import Fp4QuantizedTensor
E2M1_MAX = 6.0
@ -33,6 +33,15 @@ class WeightsLoadingConfig:
ignore_tensor_parallel: bool = False
class TensorParallelMode(str, enum.Enum):
COLUMN = 'column'
ROW = 'row'
@classmethod
def split_dim(cls, mode):
return 1 if mode == cls.ROW else 0
def load_weight_shard(
weight,
tensor_parallel_size: int = 1,
@ -135,7 +144,9 @@ class Linear(nn.Module):
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_config: Optional[ParallelConfig] = None,
mapping: Optional[Mapping] = None,
tensor_parallel_mode: Optional[TensorParallelMode] = None,
gather_output: bool = False,
quant_config: Optional[QuantConfig] = None,
weights_loading_config: Optional[WeightsLoadingConfig] = None,
is_expert: bool = False,
@ -147,38 +158,37 @@ class Linear(nn.Module):
super().__init__()
self.has_bias = bias
self.dtype = dtype
self.parallel_config = parallel_config or ParallelConfig()
self.mapping = mapping or Mapping()
# could be modified later
self.quant_config = quant_config
self.weights_loading_config = weights_loading_config or WeightsLoadingConfig(
)
self.tp_size = self.parallel_config.tensor_parallel_size
self.tp_rank = self.parallel_config.tensor_parallel_rank
self.tp_mode = self.parallel_config.tensor_parallel_mode
self.tp_size = self.mapping.tp_size
self.tp_rank = self.mapping.tp_rank
self.tp_mode = tensor_parallel_mode
self.gather_output = gather_output
local_in_features = in_features
local_out_features = out_features
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.ROW:
if self.tp_mode == TensorParallelMode.ROW:
assert in_features % self.tp_size == 0, (
f'in_features {in_features} must be divisible by tp_size {self.tp_size}'
)
local_in_features = in_features // self.tp_size
elif self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
elif self.tp_mode == TensorParallelMode.COLUMN:
assert out_features % self.tp_size == 0, (
f'out_features {out_features} must be divisible by tp_size {self.tp_size}'
)
local_out_features = out_features // self.tp_size
else:
assert self.parallel_config.tensor_parallel_mode is None, (
'unsupported tensor parallel mode: {self.parallel_config.tensor_parallel_mode}'
)
assert self.tp_mode is None, (
'unsupported tensor parallel mode: {self.tp_mode}')
self.in_features = local_in_features
self.out_features = local_out_features
self.all_reduce = AllReduce(
self.parallel_config) if not is_expert else None
self.all_reduce = AllReduce(self.mapping) if not is_expert else None
self._weights_created = False
self.is_expert = is_expert
self.use_custom_cublas_mm = use_custom_cublas_mm
@ -385,8 +395,8 @@ class Linear(nn.Module):
output = self.apply_linear(input, self.weight, bias)
elif self.tp_mode == TensorParallelMode.COLUMN:
output = self.apply_linear(input, self.weight, self.bias)
if self.parallel_config.gather_output:
output = allgather(output, self.parallel_config)
if self.gather_output:
output = allgather(output, self.mapping)
else:
output = self.apply_linear(input, self.weight, self.bias)

View File

@ -5,9 +5,8 @@ import torch
from torch import nn
from ..attention_backend import AttentionMetadata
from ..distributed import ParallelConfig, TensorParallelMode
from ..model_config import ModelConfig
from .linear import Linear
from .linear import Linear, TensorParallelMode
@dataclass
@ -166,9 +165,9 @@ class Mamba2(nn.Module):
super().__init__()
config = config or ModelConfig()
self.mapping = config.mapping
tp_rank = config.mapping.tp_rank
tp_size = config.mapping.tp_size
gpus_per_node = config.mapping.gpus_per_node
d_inner = d_model * expand
nheads = d_inner // head_dim
@ -205,11 +204,8 @@ class Mamba2(nn.Module):
d_in_proj,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node),
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=config.get_quant_config(),
)
@ -219,11 +215,8 @@ class Mamba2(nn.Module):
conv_dim,
bias=conv_bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node),
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights,
)
@ -258,11 +251,8 @@ class Mamba2(nn.Module):
d_model,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.ROW,
gpus_per_node=gpus_per_node),
mapping=self.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config())
def forward(

View File

@ -4,10 +4,9 @@ from typing import Optional
import torch
from torch import nn
from ..distributed import ParallelConfig, TensorParallelMode
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from .linear import Linear, WeightMode, WeightsLoadingConfig
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
class MLP(nn.Module):
@ -29,40 +28,25 @@ class MLP(nn.Module):
self.activation = activation
config = config or ModelConfig()
tp_rank = config.mapping.tp_rank
tp_size = config.mapping.tp_size
gpus_per_node = config.mapping.gpus_per_node
self.up_proj = Linear(
self.hidden_size,
self.intermediate_size,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank),
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.VANILLA),
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights)
self.up_proj = Linear(self.hidden_size,
self.intermediate_size,
bias=bias,
dtype=dtype,
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.VANILLA),
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights)
self.down_proj = Linear(
self.intermediate_size,
self.hidden_size,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.ROW,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank),
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights)
self.down_proj = Linear(self.intermediate_size,
self.hidden_size,
bias=bias,
dtype=dtype,
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights)
self.up_lora = LoraLayer([LoraModuleType.MLP_H_TO_4H],
[self.intermediate_size])

View File

@ -191,9 +191,9 @@ class Mapping(object):
self.attn_cp_size = attn_cp_size
self.auto_parallel = auto_parallel
self.world_size = world_size
self.enable_attention_dp = enable_attention_dp
self.rank = rank
self.gpus_per_node = gpus_per_node
self.enable_attention_dp = enable_attention_dp
self.pp_groups = []
self.cp_groups = []
self.tp_groups = []
@ -262,10 +262,13 @@ class Mapping(object):
@rank.setter
def rank(self, rank: int):
if not isinstance(rank, int) or rank < 0 or rank >= self.world_size:
raise ValueError(
f"Rank should be an integer between 0 and {self.world_size-1}, but got {rank}."
)
# TODO(qijun): skip check for enable_attention_dp temporarily, will support attention_dp_size
if not self.enable_attention_dp:
if not isinstance(rank,
int) or rank < 0 and rank >= self.world_size:
raise ValueError(
f"Rank should be an integer between 0 and {self.world_size-1}, but got {rank}."
)
self._rank = rank
@property

View File

@ -28,9 +28,9 @@ from torch import nn
import tensorrt_llm
from tensorrt_llm._torch.compilation.backend import Backend
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm.mapping import Mapping
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
@ -77,11 +77,10 @@ def row_linear_residual_norm_fusion_forward(
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
),
tensor_parallel_mode=TensorParallelMode.ROW,
mapping=Mapping(world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank),
).cuda()
norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda()

View File

@ -26,10 +26,10 @@ from utils.util import skip_pre_blackwell
import tensorrt_llm
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, DeepseekAllReduce,
ParallelConfig, TensorParallelMode)
from tensorrt_llm._torch.modules.linear import Linear
AllReduceParams, DeepseekAllReduce)
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm.mapping import Mapping
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
@ -87,16 +87,16 @@ def row_linear_residual_norm_fusion_forward(
norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda()
allreduce = AllReduce(parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
allreduce = AllReduce(mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
), ).cuda()
deepseek_allreduce = DeepseekAllReduce(parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
deepseek_allreduce = DeepseekAllReduce(mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
), ).cuda()
scale = torch.tensor(1.0, dtype=torch.float32).cuda()
@ -268,10 +268,10 @@ def moe_residual_norm_fusion_forward(
norm_weight = torch.randn((hidden_size, ), dtype=dtype, device="cuda")
# Initialize DeepseekAllReduce and AllReduce
deepseek_allreduce = DeepseekAllReduce(parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
deepseek_allreduce = DeepseekAllReduce(mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
)).cuda()
# Initialize RMSNorm
@ -283,11 +283,12 @@ def moe_residual_norm_fusion_forward(
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.ROW,
).cuda()
l0.load_weights([dict(weight=l0_weight)])
token_input_chunked = torch.chunk(token_input.clone(),

View File

@ -10,8 +10,9 @@ from mpi4py.futures import MPIPoolExecutor
from torch import nn
import tensorrt_llm
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
from tensorrt_llm._torch.modules.embedding import Embedding, LMHead
from tensorrt_llm._torch.modules.linear import TensorParallelMode
from tensorrt_llm.mapping import Mapping
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
@ -44,10 +45,12 @@ def column_embedding_forward(x, vocab_size, hidden_size, dtype,
num_embeddings=vocab_size,
embedding_dim=hidden_size,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN),
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.COLUMN,
)
embedding.load_weights([dict(weight=weight)])
embedding.cuda()
@ -79,11 +82,13 @@ def row_embedding_forward(x, vocab_size, hidden_size, dtype,
num_embeddings=vocab_size,
embedding_dim=hidden_size,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
gather_output=True),
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.ROW,
gather_output=True,
)
embedding.load_weights([dict(weight=weight)])
embedding.cuda()
@ -115,11 +120,13 @@ def column_lm_head_forward(x, vocab_size, hidden_size, dtype,
num_embeddings=vocab_size,
embedding_dim=hidden_size,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True),
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
lm_head.load_weights([dict(weight=weight)])
lm_head.cuda()
@ -152,10 +159,12 @@ def row_lm_head_forward(x, vocab_size, hidden_size, dtype, tensor_parallel_size,
num_embeddings=vocab_size,
embedding_dim=hidden_size,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW),
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.ROW,
)
lm_head.load_weights([dict(weight=weight)])
lm_head.cuda()

View File

@ -10,9 +10,9 @@ from mpi4py.futures import MPIPoolExecutor
from torch import nn
import tensorrt_llm
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
from tensorrt_llm.mapping import Mapping
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
@ -52,23 +52,27 @@ def mlp_forward(x, hidden_size, dtype, tensor_parallel_size,
out_features=4 * hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.COLUMN,
)
l0.load_weights([dict(weight=weights[0])])
l0.cuda()
l1 = Linear(in_features=4 * hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
))
l1 = Linear(
in_features=4 * hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.ROW,
)
l1.load_weights([dict(weight=weights[1])])
l1.cuda()
@ -102,15 +106,19 @@ def column_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
tensor_parallel_rank, weights):
x = x.cuda()
l0 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True))
l0 = Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
l0.load_weights([dict(weight=weights[0])])
l0.cuda()
@ -137,15 +145,18 @@ def row_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
tensor_parallel_rank, weights):
x = x.cuda()
l0 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
))
l0 = Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.ROW,
)
l0.load_weights([dict(weight=weights[0])])
l0.cuda()
@ -183,15 +194,18 @@ def row_linear_norm_fusion_forward(x, hidden_size, dtype, tensor_parallel_size,
eps=eps,
)
l0 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
))
l0 = Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
mapping=Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=tensor_parallel_rank,
),
tensor_parallel_mode=TensorParallelMode.ROW,
)
l0.load_weights([dict(weight=weights[0])])
l0.cuda()

View File

@ -16,10 +16,10 @@ import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.compilation.backend import Backend
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, AllReduceStrategy,
ParallelConfig, TensorParallelMode,
userbuffers_allreduce_finalize)
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization import QuantAlgo
@ -123,11 +123,12 @@ def run_single_rank_ar_rms_norm(tensor_parallel_size, a, b, c, gamma):
ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype)
hidden = torch.matmul(a_local, b_local, out=ub0_tensor)
parallel_config = ParallelConfig(
tensor_parallel_rank=rank,
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_mode=TensorParallelMode.COLUMN)
ar = AllReduce(parallel_config, strategy=AllReduceStrategy.UB)
mapping = Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=rank,
)
ar = AllReduce(mapping, strategy=AllReduceStrategy.UB)
ar_params = AllReduceParams(
strategy=AllReduceStrategy.UB,
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
@ -214,11 +215,12 @@ def run_single_rank_ar_rms_norm_fp8(tensor_parallel_size, a, b, c, gamma,
ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype)
hidden = torch.matmul(a_local, b_local, out=ub0_tensor)
parallel_config = ParallelConfig(
tensor_parallel_rank=rank,
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_mode=TensorParallelMode.COLUMN)
ar = AllReduce(parallel_config, strategy=AllReduceStrategy.UB)
mapping = Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=rank,
)
ar = AllReduce(mapping, strategy=AllReduceStrategy.UB)
ar_params = AllReduceParams(
strategy=AllReduceStrategy.UB,
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
@ -318,55 +320,45 @@ class UBTestModel(nn.Module):
quant_config.layer_quant_mode
self.rank = rank
self.tp_size = tp_size
mapping = Mapping(
world_size=tp_size,
tp_size=tp_size,
rank=rank,
)
self.l0 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=rank,
tensor_parallel_mode=TensorParallelMode.ROW,
),
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config).cuda()
self.l1 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
),
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config).cuda()
self.l2 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=rank,
tensor_parallel_mode=TensorParallelMode.ROW,
),
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config).cuda()
self.l3 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
),
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config).cuda()
self.l4 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=rank,
tensor_parallel_mode=TensorParallelMode.ROW,
),
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config).cuda()
self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps,
dtype=dtype).cuda()
@ -608,11 +600,12 @@ def run_single_rank_ar_rms_norm_fp4(tensor_parallel_size, a, b, c, gamma):
ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype)
hidden = torch.matmul(a_local, b_local, out=ub0_tensor)
parallel_config = ParallelConfig(
tensor_parallel_rank=rank,
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_mode=TensorParallelMode.COLUMN)
ar = AllReduce(parallel_config, strategy=AllReduceStrategy.UB)
mapping = Mapping(
world_size=tensor_parallel_size,
tp_size=tensor_parallel_size,
rank=rank,
)
ar = AllReduce(mapping, strategy=AllReduceStrategy.UB)
ar_params = AllReduceParams(
strategy=AllReduceStrategy.UB,
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
@ -694,18 +687,14 @@ class UBMMAddModel(nn.Module):
self.rank = rank
self.hidden_size = hidden_size
self.dtype = dtype
self.ar_0 = AllReduce(
ParallelConfig(tensor_parallel_size=tp_size,
tensor_parallel_rank=rank,
tensor_parallel_mode=TensorParallelMode.ROW)).cuda()
self.ar_1 = AllReduce(
ParallelConfig(tensor_parallel_size=tp_size,
tensor_parallel_rank=rank,
tensor_parallel_mode=TensorParallelMode.ROW)).cuda()
self.ar_2 = AllReduce(
ParallelConfig(tensor_parallel_size=tp_size,
tensor_parallel_rank=rank,
tensor_parallel_mode=TensorParallelMode.ROW)).cuda()
mapping = Mapping(
world_size=tp_size,
tp_size=tp_size,
rank=rank,
)
self.ar_0 = AllReduce(mapping).cuda()
self.ar_1 = AllReduce(mapping).cuda()
self.ar_2 = AllReduce(mapping).cuda()
self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps,
dtype=dtype).cuda()
self.norm1 = RMSNorm(hidden_size=hidden_size, eps=eps,