mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
refactor: remove ParallelConfig in tensorrt_llm._torch.distributed module (#3370)
* remove tensorrt_llm._torch.distributed.ParallelConfig Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> * fix ci Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> * fix ci Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> * clean Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> * fix embedding test Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> * fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> * fix comments Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> * polish Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> * fix ci Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> * rebase Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --------- Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Co-authored-by: hlu1 <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
parent
cf9ceea890
commit
d167cbd5bb
@ -7,7 +7,6 @@ from transformers import OPTConfig
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from tensorrt_llm._torch.attention_backend import AttentionMetadata
|
||||
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
|
||||
DecoderModelForCausalLM,
|
||||
@ -16,7 +15,7 @@ from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
|
||||
from tensorrt_llm._torch.modules.attention import Attention
|
||||
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
|
||||
from tensorrt_llm._torch.modules.embedding import Embedding
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
@ -70,10 +69,8 @@ class OPTDecoderLayer(DecoderLayer):
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
dtype=config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN),
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=model_config.get_quant_config(),
|
||||
)
|
||||
self.fc2 = Linear(
|
||||
@ -81,10 +78,8 @@ class OPTDecoderLayer(DecoderLayer):
|
||||
config.hidden_size,
|
||||
bias=config.enable_bias,
|
||||
dtype=config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW),
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=model_config.get_quant_config(),
|
||||
)
|
||||
self.final_layer_norm = LayerNorm(
|
||||
@ -150,10 +145,8 @@ class OPTModel(DecoderModel):
|
||||
config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
dtype=config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN))
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN)
|
||||
self.embed_positions = nn.Embedding(config.max_position_embeddings + 2,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
@ -9,7 +9,6 @@ from tensorrt_llm.functional import AttentionMaskType
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from ..distributed import allgather
|
||||
from ..modules.linear import ParallelConfig
|
||||
from .flashinfer import FlashInferAttentionMetadata, PlanParams
|
||||
from .interface import AttentionBackend, AttentionMask, PredefinedAttentionMask
|
||||
from .vanilla import VanillaAttention
|
||||
@ -440,12 +439,10 @@ class StarAttention(AttentionBackend[StarAttentionMetadata]):
|
||||
out_tmp = output
|
||||
lse = lse.unsqueeze(-1) / np.log2(np.e) # [b * s, nheads, 1]
|
||||
if metadata.mapping.cp_size != 1:
|
||||
parallel_cfg = ParallelConfig(
|
||||
tensor_parallel_size=metadata.mapping.cp_size,
|
||||
tensor_parallel_rank=metadata.mapping.cp_rank,
|
||||
pipeline_parallel_size=metadata.mapping.pp_size)
|
||||
output_tensor = allgather(output, parallel_cfg, gather_dim=0)
|
||||
lse_tensor = allgather(lse, parallel_cfg, gather_dim=0)
|
||||
output_tensor = allgather(output,
|
||||
metadata.mapping,
|
||||
gather_dim=0)
|
||||
lse_tensor = allgather(lse, metadata.mapping, gather_dim=0)
|
||||
output_tensor = output_tensor.to(torch.float32)
|
||||
else:
|
||||
lse_tensor = lse
|
||||
|
||||
@ -4,18 +4,19 @@ from .common import ReduceOp, get_rank_world_size, is_ompi
|
||||
|
||||
# use trtllm distributed ops to improve TP performance if possible
|
||||
try:
|
||||
from ....mapping import Mapping
|
||||
from ...distributed import AllReduce, allgather
|
||||
from ...modules.linear import AllReduceFusionOp, AllReduceParams, ParallelConfig
|
||||
from ...modules.linear import AllReduceFusionOp, AllReduceParams
|
||||
|
||||
def trtllm_allgather(tensor, dim):
|
||||
rank, world_size = get_rank_world_size()
|
||||
p_config = ParallelConfig(tensor_parallel_size=world_size, tensor_parallel_rank=rank)
|
||||
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
|
||||
return allgather(tensor, p_config, gather_dim=dim)
|
||||
|
||||
def trtllm_allreduce(tensor, op, all_reduce_params=None):
|
||||
rank, world_size = get_rank_world_size()
|
||||
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
|
||||
p_config = ParallelConfig(tensor_parallel_size=world_size, tensor_parallel_rank=rank)
|
||||
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
|
||||
torch_op = AllReduce(p_config)
|
||||
return torch_op(tensor, all_reduce_params=all_reduce_params)
|
||||
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import atexit
|
||||
import enum
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -17,37 +15,6 @@ from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
class TensorParallelMode(str, enum.Enum):
|
||||
COLUMN = 'column'
|
||||
ROW = 'row'
|
||||
|
||||
@classmethod
|
||||
def split_dim(cls, mode):
|
||||
return 1 if mode == cls.ROW else 0
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ParallelConfig:
|
||||
tensor_parallel_size: int = 1
|
||||
tensor_parallel_rank: int = 0
|
||||
gpus_per_node: int = 8
|
||||
tensor_parallel_mode: Optional[TensorParallelMode] = None
|
||||
gather_output: bool = False
|
||||
# pipeline parallel parameter in case we have multiple parallel groups
|
||||
# default to TP-only mode if not specified for backward compatibility
|
||||
# TODO Remove redundant fields. Keep only parallel_rank, tp_size, pp_size in constructor
|
||||
# and infer tp_rank, pp_rank, etc. automatically.
|
||||
pipeline_parallel_size: int = 1
|
||||
parallel_rank: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.parallel_size = self.tensor_parallel_size * self.pipeline_parallel_size
|
||||
if self.pipeline_parallel_size > 1:
|
||||
assert self.parallel_rank is not None, "parallel_rank must be specified for PP mode"
|
||||
else:
|
||||
self.parallel_rank = self.tensor_parallel_rank
|
||||
|
||||
|
||||
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
|
||||
if not hasattr(_thread_local, 'allreduce_workspaces'):
|
||||
_thread_local.allreduce_workspaces = {}
|
||||
@ -77,7 +44,7 @@ def get_deepseek_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
|
||||
def allreduce(
|
||||
input: torch.Tensor,
|
||||
workspace: Optional[torch.LongTensor],
|
||||
parallel_config: ParallelConfig,
|
||||
mapping: Mapping,
|
||||
strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
|
||||
config: AllReduceConfig = AllReduceConfig(0),
|
||||
all_reduce_params: Optional[AllReduceParams] = None
|
||||
@ -98,7 +65,7 @@ def allreduce(
|
||||
|
||||
Args:
|
||||
input (Tensor): The input tensor.
|
||||
parallel_config (ParallelConfig): The parallel config.
|
||||
mapping (Mapping): The parallel mapping.
|
||||
strategy (AllReduceStrategy): NCCL delegates all-reduce to NCCL while ONESHOT and TWOSHOT are custom latency-optimal algorithms.
|
||||
AUTO chooses amongst the three based on a message-size heuristic.
|
||||
config (AllReduceConfig): The config for custom allreduce kernels.
|
||||
@ -106,19 +73,10 @@ def allreduce(
|
||||
Returns:
|
||||
The reduced tensor and an optional intermediate tensor if fused.
|
||||
'''
|
||||
if parallel_config.tensor_parallel_size == 1 or (
|
||||
all_reduce_params is not None
|
||||
and all_reduce_params.enable_allreduce == False):
|
||||
if mapping.tp_size == 1 or (all_reduce_params is not None and
|
||||
all_reduce_params.enable_allreduce == False):
|
||||
return input
|
||||
|
||||
mapping = Mapping(
|
||||
world_size=parallel_config.parallel_size,
|
||||
tp_size=parallel_config.tensor_parallel_size,
|
||||
pp_size=parallel_config.pipeline_parallel_size,
|
||||
rank=parallel_config.parallel_rank,
|
||||
gpus_per_node=parallel_config.gpus_per_node,
|
||||
)
|
||||
|
||||
if all_reduce_params is None:
|
||||
all_reduce_params = AllReduceParams()
|
||||
is_fused = all_reduce_params.fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM or \
|
||||
@ -164,7 +122,7 @@ def userbuffers_allreduce_finalize(
|
||||
|
||||
|
||||
def allgather(input: torch.Tensor,
|
||||
parallel_config: ParallelConfig,
|
||||
mapping: Mapping,
|
||||
gather_dim: int = -1) -> torch.Tensor:
|
||||
'''
|
||||
Add an operation that performs a collective all-gather.
|
||||
@ -184,22 +142,14 @@ def allgather(input: torch.Tensor,
|
||||
|
||||
Args:
|
||||
input (Tensor): The input tensor.
|
||||
parallel_config (ParallelConfig): The parallel config.
|
||||
mapping (Mapping): The parallel mapping.
|
||||
gather_dim (int): Gather along given dimension. By default -1.
|
||||
Returns:
|
||||
The gathered tensor.
|
||||
'''
|
||||
if parallel_config.tensor_parallel_size == 1:
|
||||
if mapping.tp_size == 1:
|
||||
return input
|
||||
|
||||
mapping = Mapping(
|
||||
world_size=parallel_config.parallel_size,
|
||||
tp_size=parallel_config.tensor_parallel_size,
|
||||
pp_size=parallel_config.pipeline_parallel_size,
|
||||
rank=parallel_config.parallel_rank,
|
||||
gpus_per_node=parallel_config.gpus_per_node,
|
||||
)
|
||||
|
||||
output = torch.ops.trtllm.allgather(
|
||||
input,
|
||||
mapping.tp_group,
|
||||
@ -211,26 +161,17 @@ def allgather(input: torch.Tensor,
|
||||
output = torch.movedim(output, 0, gather_dim)
|
||||
input_shape = input.size()
|
||||
output = output.reshape(input_shape[:gather_dim] +
|
||||
(parallel_config.tensor_parallel_size *
|
||||
input_shape[gather_dim], ) +
|
||||
(mapping.tp_size * input_shape[gather_dim], ) +
|
||||
input_shape[gather_dim + 1:])
|
||||
return output
|
||||
|
||||
|
||||
def reducescatter(input: torch.Tensor,
|
||||
parallel_config: ParallelConfig,
|
||||
mapping: Mapping,
|
||||
scatter_dim: int = -1) -> torch.Tensor:
|
||||
if parallel_config.tensor_parallel_size == 1:
|
||||
if mapping.tp_size == 1:
|
||||
return input
|
||||
|
||||
mapping = Mapping(
|
||||
world_size=parallel_config.parallel_size,
|
||||
tp_size=parallel_config.tensor_parallel_size,
|
||||
pp_size=parallel_config.pipeline_parallel_size,
|
||||
rank=parallel_config.parallel_rank,
|
||||
gpus_per_node=parallel_config.gpus_per_node,
|
||||
)
|
||||
|
||||
output = torch.ops.trtllm.reducescatter(
|
||||
input,
|
||||
mapping.tp_group,
|
||||
@ -242,8 +183,7 @@ def reducescatter(input: torch.Tensor,
|
||||
output = torch.movedim(output, 0, scatter_dim)
|
||||
input_shape = input.size()
|
||||
output = output.reshape(input_shape[:scatter_dim] +
|
||||
(input_shape[scatter_dim] //
|
||||
parallel_config.tensor_parallel_size, ) +
|
||||
(input_shape[scatter_dim] // mapping.tp_size, ) +
|
||||
input_shape[scatter_dim + 1:])
|
||||
return output
|
||||
|
||||
@ -251,25 +191,14 @@ def reducescatter(input: torch.Tensor,
|
||||
class AllReduce(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
parallel_config: ParallelConfig,
|
||||
mapping: Mapping,
|
||||
strategy: AllReduceStrategy = AllReduceStrategy.AUTO):
|
||||
super().__init__()
|
||||
|
||||
self.parallel_config = parallel_config
|
||||
self.tp_size = self.parallel_config.tensor_parallel_size
|
||||
self.rank = self.parallel_config.parallel_rank
|
||||
self.gpus_per_node = self.parallel_config.gpus_per_node
|
||||
|
||||
self.mapping = mapping
|
||||
self.workspace = None
|
||||
self.strategy = strategy
|
||||
if self.tp_size > 1:
|
||||
mapping = Mapping(
|
||||
world_size=self.parallel_config.parallel_size,
|
||||
tp_size=self.tp_size,
|
||||
pp_size=self.parallel_config.pipeline_parallel_size,
|
||||
rank=self.rank,
|
||||
gpus_per_node=self.gpus_per_node,
|
||||
)
|
||||
if self.mapping.tp_size > 1:
|
||||
if self.strategy != AllReduceStrategy.UB:
|
||||
self.workspace = get_allreduce_workspace(mapping)
|
||||
|
||||
@ -281,7 +210,7 @@ class AllReduce(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
output = allreduce(input,
|
||||
self.workspace,
|
||||
self.parallel_config,
|
||||
self.mapping,
|
||||
all_reduce_params=all_reduce_params,
|
||||
strategy=self.strategy)
|
||||
return output
|
||||
@ -289,22 +218,11 @@ class AllReduce(nn.Module):
|
||||
|
||||
class DeepseekAllReduce(nn.Module):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
def __init__(self, mapping: Mapping):
|
||||
super().__init__()
|
||||
self.parallel_config = parallel_config
|
||||
self.tp_size = self.parallel_config.tensor_parallel_size
|
||||
self.tp_rank = self.parallel_config.tensor_parallel_rank
|
||||
self.gpus_per_node = self.parallel_config.gpus_per_node
|
||||
self.rank = self.parallel_config.parallel_rank
|
||||
self.mapping = mapping
|
||||
self.workspace = None
|
||||
if self.tp_size > 1:
|
||||
mapping = Mapping(
|
||||
world_size=self.parallel_config.parallel_size,
|
||||
tp_size=self.tp_size,
|
||||
pp_size=self.parallel_config.pipeline_parallel_size,
|
||||
rank=self.rank,
|
||||
gpus_per_node=self.gpus_per_node,
|
||||
)
|
||||
if self.mapping.tp_size > 1:
|
||||
self.workspace = get_deepseek_allreduce_workspace(mapping)
|
||||
|
||||
def forward(
|
||||
@ -331,8 +249,8 @@ class DeepseekAllReduce(nn.Module):
|
||||
input=hidden_states,
|
||||
workspace=self.workspace,
|
||||
reduce_fusion_inputs=reduce_fusion_inputs,
|
||||
rank=self.parallel_config.tensor_parallel_rank,
|
||||
nranks=self.parallel_config.tensor_parallel_size,
|
||||
rank=self.mapping.tp_rank,
|
||||
nranks=self.mapping.tp_size,
|
||||
eps=eps,
|
||||
fusion_op=fusion_op,
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
|
||||
DeepseekAllReduce, ParallelConfig, allgather)
|
||||
DeepseekAllReduce, allgather)
|
||||
from ..model_config import ModelConfig
|
||||
from ..models.modeling_utils import MissingLayer, ModelConfig, support_pp
|
||||
from ..modules.attention import MLA
|
||||
@ -306,13 +306,8 @@ class Deepseekv3MoE(nn.Module):
|
||||
overridden_tp_size=shared_tp_size,
|
||||
is_expert=True)
|
||||
|
||||
self.parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
gpus_per_node=model_config.mapping.gpus_per_node,
|
||||
pipeline_parallel_size=model_config.mapping.pp_size,
|
||||
parallel_rank=model_config.mapping.rank)
|
||||
self.all_reduce = AllReduce(self.parallel_config)
|
||||
self.mapping = model_config.mapping
|
||||
self.all_reduce = AllReduce(self.mapping)
|
||||
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
|
||||
self.event_dict = {
|
||||
key: torch.cuda.Event()
|
||||
@ -321,14 +316,14 @@ class Deepseekv3MoE(nn.Module):
|
||||
|
||||
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
||||
all_rank_num_tokens, min_latency_mode):
|
||||
if self.use_dp and self.parallel_config.tensor_parallel_size > 1:
|
||||
if self.use_dp and self.mapping.tp_size > 1:
|
||||
max_num_token = max(all_rank_num_tokens)
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_num_token - hidden_states.shape[0]))
|
||||
if disable_fp4_allgather():
|
||||
hidden_states = allgather(hidden_states,
|
||||
self.parallel_config,
|
||||
self.mapping,
|
||||
gather_dim=0)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
@ -383,7 +378,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
assert shared_output.size() == routed_output.size(
|
||||
), f'unmatched tensor shape'
|
||||
final_hidden_states = shared_output + routed_output
|
||||
if not self.use_dp and self.parallel_config.tensor_parallel_size > 1:
|
||||
if not self.use_dp and self.mapping.tp_size > 1:
|
||||
final_hidden_states = self.all_reduce(
|
||||
final_hidden_states, all_reduce_params=final_all_reduce_params)
|
||||
|
||||
@ -470,14 +465,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
gpus_per_node=model_config.mapping.gpus_per_node,
|
||||
pipeline_parallel_size=model_config.mapping.pp_size,
|
||||
parallel_rank=model_config.mapping.rank)
|
||||
self.mapping = model_config.mapping
|
||||
self.layer_idx = layer_idx
|
||||
self.all_reduce = AllReduce(self.parallel_config)
|
||||
self.all_reduce = AllReduce(self.mapping)
|
||||
self.next_layer_layernorm: RMSNorm = None
|
||||
|
||||
self.deepseek_allreduce_disabled = os.environ.get(
|
||||
@ -486,7 +476,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
self.deepseek_allreduce_disabled = True
|
||||
|
||||
if not self.deepseek_allreduce_disabled:
|
||||
self.deepseek_allreduce = DeepseekAllReduce(self.parallel_config)
|
||||
self.deepseek_allreduce = DeepseekAllReduce(self.mapping)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -511,9 +501,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.PRE_MOE_FUSION or self.fusion_config.
|
||||
PRE_MLP_FUSION or self.parallel_config.tensor_parallel_size == 1
|
||||
or self.enable_attention_dp)),
|
||||
self.fusion_config.PRE_MOE_FUSION
|
||||
or self.fusion_config.PRE_MLP_FUSION
|
||||
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -586,9 +576,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states_fp4,
|
||||
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
|
||||
POST_MLP_FUSION or self.parallel_config.tensor_parallel_size
|
||||
== 1 or self.enable_attention_dp)),
|
||||
self.fusion_config.POST_MOE_FUSION
|
||||
or self.fusion_config.POST_MLP_FUSION
|
||||
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
|
||||
min_latency_mode=min_latency_mode,
|
||||
)
|
||||
else:
|
||||
@ -596,9 +586,9 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states,
|
||||
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
|
||||
POST_MLP_FUSION or self.parallel_config.tensor_parallel_size
|
||||
== 1 or self.enable_attention_dp)),
|
||||
self.fusion_config.POST_MOE_FUSION
|
||||
or self.fusion_config.POST_MLP_FUSION
|
||||
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
|
||||
min_latency_mode=min_latency_mode,
|
||||
)
|
||||
|
||||
@ -729,8 +719,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.PRE_MOE_FUSION or self.parallel_config.
|
||||
tensor_parallel_size == 1 or self.enable_attention_dp)),
|
||||
self.fusion_config.PRE_MOE_FUSION or self.mapping.tp_size == 1
|
||||
or self.enable_attention_dp)),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -762,8 +752,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
hidden_states,
|
||||
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MOE_FUSION or self.parallel_config.
|
||||
tensor_parallel_size == 1 or self.enable_attention_dp)),
|
||||
self.fusion_config.POST_MOE_FUSION or self.mapping.tp_size == 1
|
||||
or self.enable_attention_dp)),
|
||||
)
|
||||
|
||||
if self.fusion_config.POST_MOE_FUSION:
|
||||
|
||||
@ -6,8 +6,7 @@ from torch import nn
|
||||
from transformers import Llama4Config, LlamaConfig
|
||||
|
||||
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
AllReduceParams, DeepseekAllReduce,
|
||||
ParallelConfig, TensorParallelMode)
|
||||
AllReduceParams, DeepseekAllReduce)
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -19,7 +18,8 @@ from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import (FusedMoE, Llama4RenormalizeMoeRoutingMethod,
|
||||
MoEWeightLoadingMode)
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear, WeightMode, WeightsLoadingConfig
|
||||
from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
|
||||
WeightsLoadingConfig)
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..modules.rotary_embedding import RotaryEmbedding
|
||||
from ..speculative import Eagle3SpecMetadata, SpecMetadata
|
||||
@ -132,11 +132,8 @@ class Llama4MoE(nn.Module):
|
||||
dtype=config.torch_dtype,
|
||||
quant_config=None)
|
||||
|
||||
self.parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
gpus_per_node=model_config.mapping.gpus_per_node)
|
||||
self.all_reduce = AllReduce(self.parallel_config)
|
||||
self.mapping = model_config.mapping
|
||||
self.all_reduce = AllReduce(self.mapping)
|
||||
self.moe_event = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
self.aux_stream = aux_stream
|
||||
|
||||
@ -177,7 +174,7 @@ class Llama4MoE(nn.Module):
|
||||
assert shared_output.size() == routed_output.size(
|
||||
), f'unmatched tensor shape'
|
||||
final_hidden_states = shared_output + routed_output
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
if self.mapping.tp_size > 1:
|
||||
final_hidden_states = self.all_reduce(
|
||||
final_hidden_states, all_reduce_params=final_all_reduce_params)
|
||||
|
||||
@ -268,14 +265,11 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
gpus_per_node=model_config.mapping.gpus_per_node)
|
||||
self.all_reduce = AllReduce(self.parallel_config)
|
||||
self.mapping = model_config.mapping
|
||||
self.all_reduce = AllReduce(self.mapping)
|
||||
self.next_layer_layernorm: RMSNorm = None
|
||||
|
||||
self.deepseek_allreduce = DeepseekAllReduce(self.parallel_config)
|
||||
self.deepseek_allreduce = DeepseekAllReduce(self.mapping)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -307,9 +301,9 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.PRE_MOE_FUSION
|
||||
or self.parallel_config.tensor_parallel_size == 1)),
|
||||
all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=not (self.fusion_config.PRE_MOE_FUSION
|
||||
or self.mapping.tp_size == 1)),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -332,9 +326,8 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
hidden_states,
|
||||
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
|
||||
final_all_reduce_params=AllReduceParams(enable_allreduce=not (
|
||||
self.fusion_config.POST_MOE_FUSION
|
||||
or self.fusion_config.POST_MLP_FUSION
|
||||
or self.parallel_config.tensor_parallel_size == 1)),
|
||||
self.fusion_config.POST_MOE_FUSION or self.fusion_config.
|
||||
POST_MLP_FUSION or self.mapping.tp_size == 1)),
|
||||
min_latency_mode=min_latency_mode,
|
||||
)
|
||||
if spec_metadata is not None:
|
||||
@ -388,8 +381,6 @@ class Eagle3LlamaAttention(LlamaAttention):
|
||||
config = model_config.pretrained_config
|
||||
|
||||
tp_size = model_config.mapping.tp_size
|
||||
tp_rank = model_config.mapping.tp_rank
|
||||
gpus_per_node = model_config.mapping.gpus_per_node
|
||||
|
||||
# Override the QKV projection. The number of input features
|
||||
# is twice as big for EAGLE3 draft models.
|
||||
@ -398,13 +389,8 @@ class Eagle3LlamaAttention(LlamaAttention):
|
||||
tp_size * self.q_size + 2 * tp_size * self.kv_size,
|
||||
bias=config.attention_bias,
|
||||
dtype=config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=model_config.mapping.pp_size,
|
||||
parallel_rank=model_config.mapping.rank),
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
weights_loading_config=WeightsLoadingConfig(
|
||||
weight_mode=WeightMode.FUSED_QKV_LINEAR),
|
||||
quant_config=model_config.get_quant_config(),
|
||||
@ -486,15 +472,9 @@ class LlamaModel(DecoderModel):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
pipeline_parallel_size=model_config.mapping.pp_size,
|
||||
parallel_rank=model_config.mapping.rank,
|
||||
gather_output=True,
|
||||
gpus_per_node=model_config.mapping.gpus_per_node,
|
||||
),
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
|
||||
@ -8,7 +8,6 @@ from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..distributed import ParallelConfig
|
||||
from ..model_config import ModelConfig
|
||||
from ..models.modeling_utils import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
@ -133,11 +132,7 @@ class MixtralDecoderLayer(DecoderLayer):
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
# TODO: add pipeline parallel config
|
||||
self.parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
gpus_per_node=model_config.mapping.gpus_per_node)
|
||||
self.mapping = model_config.mapping
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
|
||||
@ -26,11 +26,11 @@ from tqdm import tqdm
|
||||
|
||||
from ...logger import logger
|
||||
from ..attention_backend.interface import AttentionMetadata
|
||||
from ..distributed import ParallelConfig, TensorParallelMode
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding, LMHead
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.logits_procesor import LogitsProcessor
|
||||
from .modeling_llama import LlamaAttention
|
||||
from .modeling_utils import duplicate_kv_weight, register_auto_model
|
||||
@ -224,13 +224,9 @@ class MllamaForCausalLM(nn.Module):
|
||||
text_config.vocab_size,
|
||||
text_config.hidden_size,
|
||||
dtype=text_config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=config.mapping.tp_rank,
|
||||
tensor_parallel_size=config.mapping.tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
gpus_per_node=config.mapping.gpus_per_node,
|
||||
),
|
||||
mapping=config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
# use embedding weights in lm_head if tie word embedding is enabled
|
||||
if text_config.tie_word_embeddings:
|
||||
|
||||
@ -4,8 +4,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -15,6 +13,7 @@ from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..modules.rotary_embedding import RotaryEmbedding
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
@ -55,15 +54,9 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
pipeline_parallel_size=model_config.mapping.pp_size,
|
||||
parallel_rank=model_config.mapping.rank,
|
||||
gather_output=True,
|
||||
gpus_per_node=model_config.mapping.gpus_per_node,
|
||||
),
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
quant_config=model_config.get_quant_config(),
|
||||
skip_create_weights=model_config.skip_create_weights,
|
||||
)
|
||||
|
||||
@ -8,13 +8,12 @@ from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..distributed import ParallelConfig, TensorParallelMode
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..modules.rotary_embedding import RotaryEmbedding
|
||||
from ..pipeline_interface import PipelineInterface
|
||||
@ -147,13 +146,9 @@ class QwenModel(DecoderModel):
|
||||
config.pretrained_config.vocab_size,
|
||||
config.pretrained_config.hidden_size,
|
||||
dtype=config.pretrained_config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=config.mapping.tp_rank,
|
||||
tensor_parallel_size=config.mapping.tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
gpus_per_node=config.mapping.gpus_per_node,
|
||||
),
|
||||
mapping=config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
QwenDecoderLayer(
|
||||
|
||||
@ -10,14 +10,13 @@ from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..distributed import ParallelConfig, TensorParallelMode
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import DefaultMoeRoutingMethod, FusedMoE
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..modules.rotary_embedding import RotaryEmbedding
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
@ -175,11 +174,6 @@ class QwenMoeDecoderLayer(DecoderLayer):
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
# TODO: add pipeline parallel config
|
||||
self.parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
gpus_per_node=model_config.mapping.gpus_per_node)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
@ -224,13 +218,9 @@ class QwenMoeModel(DecoderModel):
|
||||
config.pretrained_config.vocab_size,
|
||||
config.pretrained_config.hidden_size,
|
||||
dtype=config.pretrained_config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=config.mapping.tp_rank,
|
||||
tensor_parallel_size=config.mapping.tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
gpus_per_node=config.mapping.gpus_per_node,
|
||||
),
|
||||
mapping=config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
QwenMoeDecoderLayer(
|
||||
|
||||
@ -11,14 +11,15 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_any_only
|
||||
from tqdm import tqdm
|
||||
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...logger import logger
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..distributed import ParallelConfig, TensorParallelMode
|
||||
from ..model_config import ModelConfig, TConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.embedding import Embedding, LMHead
|
||||
from ..modules.fused_moe import FusedMoE
|
||||
from ..modules.linear import Linear, WeightMode
|
||||
from ..modules.linear import Linear, TensorParallelMode, WeightMode
|
||||
from ..modules.logits_procesor import LogitsProcessor
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..pipeline_interface import PipelineInterface
|
||||
@ -356,14 +357,13 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
dtype=config.pretrained_config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=0,
|
||||
tensor_parallel_size=1,
|
||||
tensor_parallel_mode=None,
|
||||
gather_output=False,
|
||||
pipeline_parallel_size=config.mapping.pp_size,
|
||||
parallel_rank=config.mapping.rank,
|
||||
mapping=Mapping(
|
||||
world_size=1,
|
||||
tp_size=1,
|
||||
rank=0,
|
||||
),
|
||||
tensor_parallel_mode=None,
|
||||
gather_output=False,
|
||||
)
|
||||
else:
|
||||
# TODO(zhenhuanc): Currently lm_head Linear will not accept QuantConfig
|
||||
@ -372,15 +372,9 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
dtype=config.pretrained_config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=config.mapping.tp_size,
|
||||
tensor_parallel_rank=config.mapping.tp_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
gpus_per_node=config.mapping.gpus_per_node,
|
||||
pipeline_parallel_size=config.mapping.pp_size,
|
||||
parallel_rank=config.mapping.rank,
|
||||
),
|
||||
mapping=config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
|
||||
# use embedding weights in lm_head if tie word embedding is enabled
|
||||
|
||||
@ -4,15 +4,17 @@ from typing import Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..attention_backend import (AttentionInputType, AttentionMetadata,
|
||||
TrtllmAttention)
|
||||
from ..attention_backend.interface import (PositionalEmbeddingParams,
|
||||
PredefinedAttentionMask)
|
||||
from ..attention_backend.utils import create_attention
|
||||
from ..distributed import AllReduceParams, ParallelConfig, TensorParallelMode
|
||||
from ..distributed import AllReduceParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from .linear import Linear, WeightMode, WeightsLoadingConfig
|
||||
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
||||
from .rms_norm import RMSNorm
|
||||
from .rotary_embedding import RotaryEmbedding
|
||||
|
||||
@ -80,12 +82,18 @@ class Attention(nn.Module):
|
||||
# tensor parallel
|
||||
config = config or ModelConfig()
|
||||
tp_size = config.mapping.tp_size
|
||||
tp_rank = config.mapping.tp_rank
|
||||
gpus_per_node = config.mapping.gpus_per_node
|
||||
pp_size = config.mapping.pp_size
|
||||
if config.mapping.enable_attention_dp:
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
|
||||
mapping = Mapping(
|
||||
world_size=tp_size * pp_size,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
rank=config.mapping.rank,
|
||||
gpus_per_node=config.mapping.gpus_per_node,
|
||||
enable_attention_dp=config.mapping.enable_attention_dp,
|
||||
)
|
||||
assert self.num_heads % tp_size == 0
|
||||
self.num_heads = self.num_heads // tp_size
|
||||
self.num_key_value_heads = (self.num_key_value_heads + tp_size -
|
||||
@ -103,13 +111,8 @@ class Attention(nn.Module):
|
||||
tp_size * self.q_size + 2 * tp_size * self.kv_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=config.mapping.pp_size,
|
||||
parallel_rank=config.mapping.rank),
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
weights_loading_config=WeightsLoadingConfig(
|
||||
weight_mode=WeightMode.FUSED_QKV_LINEAR),
|
||||
quant_config=config.get_quant_config(),
|
||||
@ -120,13 +123,8 @@ class Attention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=self.dense_bias,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=config.mapping.pp_size,
|
||||
parallel_rank=config.mapping.rank),
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights=config.skip_create_weights,
|
||||
)
|
||||
@ -310,27 +308,17 @@ class MLA(nn.Module):
|
||||
# tensor parallel
|
||||
config = config or ModelConfig()
|
||||
tp_size = config.mapping.tp_size
|
||||
tp_rank = config.mapping.tp_rank
|
||||
gpus_per_node = config.mapping.gpus_per_node
|
||||
pp_size = config.mapping.pp_size
|
||||
if config.mapping.enable_attention_dp:
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
|
||||
row_parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=config.mapping.pp_size,
|
||||
parallel_rank=config.mapping.rank,
|
||||
)
|
||||
col_parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=config.mapping.pp_size,
|
||||
parallel_rank=config.mapping.rank,
|
||||
mapping = Mapping(
|
||||
world_size=tp_size * pp_size,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
rank=config.mapping.rank,
|
||||
gpus_per_node=config.mapping.gpus_per_node,
|
||||
enable_attention_dp=config.mapping.enable_attention_dp,
|
||||
)
|
||||
|
||||
assert self.num_heads % tp_size == 0
|
||||
@ -362,7 +350,8 @@ class MLA(nn.Module):
|
||||
(self.qk_nope_head_dim + self.qk_rope_head_dim),
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=col_parallel_config,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=quant_config,
|
||||
skip_create_weights=config.skip_create_weights)
|
||||
else:
|
||||
@ -381,7 +370,8 @@ class MLA(nn.Module):
|
||||
(self.qk_nope_head_dim + self.qk_rope_head_dim),
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=col_parallel_config,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=quant_config,
|
||||
skip_create_weights=config.skip_create_weights)
|
||||
self.q_b_proj = self.q_proj
|
||||
@ -400,7 +390,8 @@ class MLA(nn.Module):
|
||||
(self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=col_parallel_config,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=quant_config,
|
||||
skip_create_weights=config.skip_create_weights)
|
||||
# This parameter will view into self.kv_b_proj.weight after loading weights.
|
||||
@ -455,7 +446,8 @@ class MLA(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=self.dense_bias,
|
||||
dtype=dtype,
|
||||
parallel_config=row_parallel_config,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config,
|
||||
skip_create_weights=config.skip_create_weights,
|
||||
)
|
||||
|
||||
@ -6,9 +6,10 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from tensorrt_llm.functional import AllReduceParams
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..distributed import ParallelConfig, TensorParallelMode, allgather
|
||||
from .linear import Linear
|
||||
from ..distributed import allgather
|
||||
from .linear import Linear, TensorParallelMode
|
||||
|
||||
|
||||
class LMHead(Linear):
|
||||
@ -18,7 +19,7 @@ class LMHead(Linear):
|
||||
num_embeddings (int): vocabulary size.
|
||||
embedding_dim (int): size of hidden state.
|
||||
dtype (Optional[torch.dtype]): type of the parameters.
|
||||
parallel_config (Optional[ParallelConfig]): parallelism configuration.
|
||||
mapping (Optional[Mapping]): parallelism configuration.
|
||||
If not provided, the embedding is not parallelized.
|
||||
"""
|
||||
|
||||
@ -27,17 +28,19 @@ class LMHead(Linear):
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
dtype: torch.dtype = None,
|
||||
parallel_config: Optional[ParallelConfig] = None,
|
||||
mapping: Optional[Mapping] = None,
|
||||
tensor_parallel_mode: Optional[TensorParallelMode] = None,
|
||||
gather_output: bool = False,
|
||||
):
|
||||
local_in_features = embedding_dim
|
||||
local_out_features = num_embeddings
|
||||
parallel_config = parallel_config or ParallelConfig()
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
mapping = mapping or Mapping()
|
||||
tp_size = mapping.tp_size
|
||||
|
||||
if parallel_config.tensor_parallel_mode == TensorParallelMode.ROW:
|
||||
if tensor_parallel_mode == TensorParallelMode.ROW:
|
||||
local_in_features = math.ceil(embedding_dim / tp_size)
|
||||
self.padding_size = tp_size * local_in_features - embedding_dim
|
||||
elif parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
|
||||
elif tensor_parallel_mode == TensorParallelMode.COLUMN:
|
||||
local_out_features = math.ceil(num_embeddings / tp_size)
|
||||
self.padding_size = tp_size * local_out_features - num_embeddings
|
||||
|
||||
@ -46,10 +49,12 @@ class LMHead(Linear):
|
||||
local_out_features * tp_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=parallel_config,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=tensor_parallel_mode,
|
||||
gather_output=gather_output,
|
||||
)
|
||||
|
||||
if parallel_config.tensor_parallel_mode == TensorParallelMode.ROW:
|
||||
if tensor_parallel_mode == TensorParallelMode.ROW:
|
||||
if self.tp_rank == self.tp_size - 1:
|
||||
local_in_features -= self.padding_size
|
||||
self.in_features = local_in_features
|
||||
@ -61,7 +66,7 @@ class LMHead(Linear):
|
||||
|
||||
@property
|
||||
def vocab_size_padded(self) -> int:
|
||||
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
|
||||
if self.tp_mode == TensorParallelMode.COLUMN:
|
||||
return self.out_features * self.tp_size
|
||||
else:
|
||||
return self.out_features
|
||||
@ -73,8 +78,7 @@ class LMHead(Linear):
|
||||
all_reduce_params: Optional[AllReduceParams] = None
|
||||
) -> torch.Tensor:
|
||||
output = super().forward(input, all_reduce_params=all_reduce_params)
|
||||
if (self.tp_mode == TensorParallelMode.COLUMN
|
||||
and self.parallel_config.gather_output
|
||||
if (self.tp_mode == TensorParallelMode.COLUMN and self.gather_output
|
||||
and self.padding_size > 0):
|
||||
output = output[..., :-self.padding_size]
|
||||
|
||||
@ -82,7 +86,7 @@ class LMHead(Linear):
|
||||
|
||||
def load_weights(self, weights: List[Dict]):
|
||||
original_weight = None
|
||||
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
|
||||
if self.tp_mode == TensorParallelMode.COLUMN:
|
||||
if self.tp_rank == self.tp_size - 1 and self.padding_size > 0:
|
||||
original_weight = self.weight.data.zero_()
|
||||
self.weight.data = self.weight[:-self.padding_size, :]
|
||||
@ -113,7 +117,7 @@ class Embedding(LMHead):
|
||||
num_embeddings (int): vocabulary size.
|
||||
embedding_dim (int): size of hidden state.
|
||||
dtype (Optional[torch.dtype]): type of the parameters.
|
||||
parallel_config (Optional[ParallelConfig]): parallelism configuration.
|
||||
mapping (Optional[Mapping]): parallelism configuration.
|
||||
If not provided, the embedding is not parallelized.
|
||||
"""
|
||||
|
||||
@ -122,13 +126,17 @@ class Embedding(LMHead):
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
parallel_config: Optional[ParallelConfig] = None,
|
||||
mapping: Optional[Mapping] = None,
|
||||
tensor_parallel_mode: Optional[TensorParallelMode] = None,
|
||||
gather_output: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
embedding_dim=embedding_dim,
|
||||
num_embeddings=num_embeddings,
|
||||
dtype=dtype,
|
||||
parallel_config=parallel_config,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=tensor_parallel_mode,
|
||||
gather_output=gather_output,
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
slice_width = math.ceil(num_embeddings / self.tp_size)
|
||||
@ -138,7 +146,7 @@ class Embedding(LMHead):
|
||||
|
||||
def forward(self, input):
|
||||
if self.tp_size > 1:
|
||||
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
|
||||
if self.tp_mode == TensorParallelMode.COLUMN:
|
||||
# Build the mask.
|
||||
input, input_mask = get_masked_input_and_mask(
|
||||
input,
|
||||
@ -149,15 +157,15 @@ class Embedding(LMHead):
|
||||
output = F.embedding(input, self.weight)
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
|
||||
if self.tp_mode == TensorParallelMode.COLUMN:
|
||||
output.masked_fill_(input_mask, 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = self.all_reduce(output)
|
||||
elif self.parallel_config.tensor_parallel_mode == TensorParallelMode.ROW:
|
||||
if self.parallel_config.gather_output:
|
||||
elif self.tp_mode == TensorParallelMode.ROW:
|
||||
if self.gather_output:
|
||||
if self.tp_rank == self.tp_size - 1 and self.padding_size > 0:
|
||||
output = F.pad(output, (0, self.padding_size))
|
||||
output = allgather(output, self.parallel_config)
|
||||
output = allgather(output, self.mapping)
|
||||
if self.padding_size > 0:
|
||||
output = output[..., :-self.padding_size]
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from ..distributed import allgather, reducescatter
|
||||
from ..model_config import ModelConfig
|
||||
from ..utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather,
|
||||
reswizzle_sf)
|
||||
from .linear import ParallelConfig, TensorParallelMode, load_weight_shard
|
||||
from .linear import TensorParallelMode, load_weight_shard
|
||||
|
||||
# The declarations aligns with moe_kernels.h
|
||||
# pack inputs into int64, e.g. 4 x bf16 input values
|
||||
@ -270,15 +270,10 @@ class FusedMoE(nn.Module):
|
||||
self.use_dp = model_config.mapping.enable_attention_dp
|
||||
|
||||
# All ranks participate in allreduce regardless of EP/TP combination
|
||||
self.parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=model_config.mapping.tp_rank,
|
||||
tensor_parallel_size=model_config.mapping.tp_size,
|
||||
gpus_per_node=model_config.mapping.gpus_per_node,
|
||||
pipeline_parallel_size=model_config.mapping.pp_size,
|
||||
parallel_rank=model_config.mapping.rank)
|
||||
self.parallel_size = self.parallel_config.tensor_parallel_size
|
||||
self.mapping = model_config.mapping
|
||||
self.parallel_size = self.mapping.tp_size
|
||||
|
||||
self.all_reduce = AllReduce(self.parallel_config)
|
||||
self.all_reduce = AllReduce(self.mapping)
|
||||
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
|
||||
@ -510,7 +505,7 @@ class FusedMoE(nn.Module):
|
||||
|
||||
flatten_outputs = allgather(
|
||||
torch.cat(flatten_inputs),
|
||||
self.parallel_config,
|
||||
self.mapping,
|
||||
gather_dim=0,
|
||||
).view(self.parallel_size, -1)
|
||||
|
||||
@ -532,9 +527,7 @@ class FusedMoE(nn.Module):
|
||||
outputs = inputs
|
||||
if self.parallel_size > 1:
|
||||
if self.use_dp:
|
||||
outputs = reducescatter(inputs,
|
||||
self.parallel_config,
|
||||
scatter_dim=0)
|
||||
outputs = reducescatter(inputs, self.mapping, scatter_dim=0)
|
||||
elif self.reduce_results:
|
||||
outputs = self.all_reduce(inputs)
|
||||
return outputs
|
||||
@ -595,7 +588,7 @@ class FusedMoE(nn.Module):
|
||||
):
|
||||
x_sf, token_selected_experts, token_final_scales = self.all_gather(
|
||||
[x_sf, token_selected_experts, token_final_scales])
|
||||
x = allgather(x, self.parallel_config, gather_dim=0)
|
||||
x = allgather(x, self.mapping, gather_dim=0)
|
||||
token_selected_experts = token_selected_experts.flatten(
|
||||
0, 1).contiguous()
|
||||
token_final_scales = token_final_scales.flatten(0, 1).contiguous()
|
||||
@ -701,7 +694,7 @@ class FusedMoE(nn.Module):
|
||||
self.event_dict[EventType.MoeChunkingOverlap].wait()
|
||||
outputs = torch.cat(outputs_list)
|
||||
if self.use_dp:
|
||||
rank = self.parallel_config.tensor_parallel_rank
|
||||
rank = self.mapping.tp_rank
|
||||
outputs = outputs[:all_rank_num_tokens[rank]]
|
||||
return outputs
|
||||
|
||||
|
||||
@ -5,13 +5,14 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..custom_ops import IS_FLASHINFER_AVAIABLE
|
||||
from ..distributed import AllReduceParams, ParallelConfig, TensorParallelMode
|
||||
from ..distributed import AllReduceParams
|
||||
from ..model_config import ModelConfig
|
||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
from .linear import Linear, WeightMode, WeightsLoadingConfig
|
||||
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
||||
|
||||
|
||||
def swiglu(x):
|
||||
@ -44,30 +45,31 @@ class GatedMLP(nn.Module):
|
||||
self.activation = activation
|
||||
|
||||
config = config or ModelConfig()
|
||||
self.mapping = config.mapping
|
||||
if overridden_tp_size is not None:
|
||||
assert config.mapping.tp_size % overridden_tp_size == 0
|
||||
tp_rank = config.mapping.tp_rank % overridden_tp_size
|
||||
tp_size = overridden_tp_size
|
||||
# "Misuse" pp_size here to perform all-reduce within smaller groups
|
||||
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
|
||||
else:
|
||||
tp_rank = config.mapping.tp_rank
|
||||
tp_size = config.mapping.tp_size
|
||||
pp_size = config.mapping.pp_size
|
||||
gpus_per_node = config.mapping.gpus_per_node
|
||||
|
||||
mapping = Mapping(
|
||||
world_size=tp_size * pp_size,
|
||||
rank=self.mapping.rank,
|
||||
gpus_per_node=self.mapping.gpus_per_node,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
)
|
||||
|
||||
self.gate_up_proj = Linear(
|
||||
self.hidden_size,
|
||||
self.intermediate_size * 2,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=pp_size,
|
||||
parallel_rank=config.mapping.rank),
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
weights_loading_config=WeightsLoadingConfig(
|
||||
weight_mode=WeightMode.FUSED_GATE_UP_LINEAR),
|
||||
quant_config=config.get_quant_config(),
|
||||
@ -79,13 +81,8 @@ class GatedMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=pp_size,
|
||||
parallel_rank=config.mapping.rank),
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=config.get_quant_config(),
|
||||
is_expert=is_expert,
|
||||
skip_create_weights=config.skip_create_weights,
|
||||
|
||||
@ -10,9 +10,9 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...models.modeling_utils import QuantConfig
|
||||
from ..distributed import ParallelConfig, TensorParallelMode
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
|
||||
E2M1_MAX = 6.0
|
||||
@ -33,6 +33,15 @@ class WeightsLoadingConfig:
|
||||
ignore_tensor_parallel: bool = False
|
||||
|
||||
|
||||
class TensorParallelMode(str, enum.Enum):
|
||||
COLUMN = 'column'
|
||||
ROW = 'row'
|
||||
|
||||
@classmethod
|
||||
def split_dim(cls, mode):
|
||||
return 1 if mode == cls.ROW else 0
|
||||
|
||||
|
||||
def load_weight_shard(
|
||||
weight,
|
||||
tensor_parallel_size: int = 1,
|
||||
@ -135,7 +144,9 @@ class Linear(nn.Module):
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
parallel_config: Optional[ParallelConfig] = None,
|
||||
mapping: Optional[Mapping] = None,
|
||||
tensor_parallel_mode: Optional[TensorParallelMode] = None,
|
||||
gather_output: bool = False,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
weights_loading_config: Optional[WeightsLoadingConfig] = None,
|
||||
is_expert: bool = False,
|
||||
@ -147,38 +158,37 @@ class Linear(nn.Module):
|
||||
super().__init__()
|
||||
self.has_bias = bias
|
||||
self.dtype = dtype
|
||||
self.parallel_config = parallel_config or ParallelConfig()
|
||||
self.mapping = mapping or Mapping()
|
||||
# could be modified later
|
||||
self.quant_config = quant_config
|
||||
self.weights_loading_config = weights_loading_config or WeightsLoadingConfig(
|
||||
)
|
||||
self.tp_size = self.parallel_config.tensor_parallel_size
|
||||
self.tp_rank = self.parallel_config.tensor_parallel_rank
|
||||
self.tp_mode = self.parallel_config.tensor_parallel_mode
|
||||
self.tp_size = self.mapping.tp_size
|
||||
self.tp_rank = self.mapping.tp_rank
|
||||
self.tp_mode = tensor_parallel_mode
|
||||
self.gather_output = gather_output
|
||||
|
||||
local_in_features = in_features
|
||||
local_out_features = out_features
|
||||
|
||||
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.ROW:
|
||||
if self.tp_mode == TensorParallelMode.ROW:
|
||||
assert in_features % self.tp_size == 0, (
|
||||
f'in_features {in_features} must be divisible by tp_size {self.tp_size}'
|
||||
)
|
||||
local_in_features = in_features // self.tp_size
|
||||
elif self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
|
||||
elif self.tp_mode == TensorParallelMode.COLUMN:
|
||||
assert out_features % self.tp_size == 0, (
|
||||
f'out_features {out_features} must be divisible by tp_size {self.tp_size}'
|
||||
)
|
||||
local_out_features = out_features // self.tp_size
|
||||
else:
|
||||
assert self.parallel_config.tensor_parallel_mode is None, (
|
||||
'unsupported tensor parallel mode: {self.parallel_config.tensor_parallel_mode}'
|
||||
)
|
||||
assert self.tp_mode is None, (
|
||||
'unsupported tensor parallel mode: {self.tp_mode}')
|
||||
|
||||
self.in_features = local_in_features
|
||||
self.out_features = local_out_features
|
||||
|
||||
self.all_reduce = AllReduce(
|
||||
self.parallel_config) if not is_expert else None
|
||||
self.all_reduce = AllReduce(self.mapping) if not is_expert else None
|
||||
self._weights_created = False
|
||||
self.is_expert = is_expert
|
||||
self.use_custom_cublas_mm = use_custom_cublas_mm
|
||||
@ -385,8 +395,8 @@ class Linear(nn.Module):
|
||||
output = self.apply_linear(input, self.weight, bias)
|
||||
elif self.tp_mode == TensorParallelMode.COLUMN:
|
||||
output = self.apply_linear(input, self.weight, self.bias)
|
||||
if self.parallel_config.gather_output:
|
||||
output = allgather(output, self.parallel_config)
|
||||
if self.gather_output:
|
||||
output = allgather(output, self.mapping)
|
||||
else:
|
||||
output = self.apply_linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
@ -5,9 +5,8 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..distributed import ParallelConfig, TensorParallelMode
|
||||
from ..model_config import ModelConfig
|
||||
from .linear import Linear
|
||||
from .linear import Linear, TensorParallelMode
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -166,9 +165,9 @@ class Mamba2(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = config or ModelConfig()
|
||||
self.mapping = config.mapping
|
||||
tp_rank = config.mapping.tp_rank
|
||||
tp_size = config.mapping.tp_size
|
||||
gpus_per_node = config.mapping.gpus_per_node
|
||||
|
||||
d_inner = d_model * expand
|
||||
nheads = d_inner // head_dim
|
||||
@ -205,11 +204,8 @@ class Mamba2(nn.Module):
|
||||
d_in_proj,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gpus_per_node=gpus_per_node),
|
||||
mapping=self.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=config.get_quant_config(),
|
||||
)
|
||||
|
||||
@ -219,11 +215,8 @@ class Mamba2(nn.Module):
|
||||
conv_dim,
|
||||
bias=conv_bias,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gpus_per_node=gpus_per_node),
|
||||
mapping=self.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights=config.skip_create_weights,
|
||||
)
|
||||
@ -258,11 +251,8 @@ class Mamba2(nn.Module):
|
||||
d_model,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
gpus_per_node=gpus_per_node),
|
||||
mapping=self.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=config.get_quant_config())
|
||||
|
||||
def forward(
|
||||
|
||||
@ -4,10 +4,9 @@ from typing import Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..distributed import ParallelConfig, TensorParallelMode
|
||||
from ..model_config import ModelConfig
|
||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from .linear import Linear, WeightMode, WeightsLoadingConfig
|
||||
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@ -29,40 +28,25 @@ class MLP(nn.Module):
|
||||
self.activation = activation
|
||||
|
||||
config = config or ModelConfig()
|
||||
tp_rank = config.mapping.tp_rank
|
||||
tp_size = config.mapping.tp_size
|
||||
gpus_per_node = config.mapping.gpus_per_node
|
||||
self.up_proj = Linear(
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=config.mapping.pp_size,
|
||||
parallel_rank=config.mapping.rank),
|
||||
weights_loading_config=WeightsLoadingConfig(
|
||||
weight_mode=WeightMode.VANILLA),
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights=config.skip_create_weights)
|
||||
self.up_proj = Linear(self.hidden_size,
|
||||
self.intermediate_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
mapping=config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
weights_loading_config=WeightsLoadingConfig(
|
||||
weight_mode=WeightMode.VANILLA),
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights=config.skip_create_weights)
|
||||
|
||||
self.down_proj = Linear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=config.mapping.pp_size,
|
||||
parallel_rank=config.mapping.rank),
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights=config.skip_create_weights)
|
||||
self.down_proj = Linear(self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
mapping=config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights=config.skip_create_weights)
|
||||
|
||||
self.up_lora = LoraLayer([LoraModuleType.MLP_H_TO_4H],
|
||||
[self.intermediate_size])
|
||||
|
||||
@ -191,9 +191,9 @@ class Mapping(object):
|
||||
self.attn_cp_size = attn_cp_size
|
||||
self.auto_parallel = auto_parallel
|
||||
self.world_size = world_size
|
||||
self.enable_attention_dp = enable_attention_dp
|
||||
self.rank = rank
|
||||
self.gpus_per_node = gpus_per_node
|
||||
self.enable_attention_dp = enable_attention_dp
|
||||
self.pp_groups = []
|
||||
self.cp_groups = []
|
||||
self.tp_groups = []
|
||||
@ -262,10 +262,13 @@ class Mapping(object):
|
||||
|
||||
@rank.setter
|
||||
def rank(self, rank: int):
|
||||
if not isinstance(rank, int) or rank < 0 or rank >= self.world_size:
|
||||
raise ValueError(
|
||||
f"Rank should be an integer between 0 and {self.world_size-1}, but got {rank}."
|
||||
)
|
||||
# TODO(qijun): skip check for enable_attention_dp temporarily, will support attention_dp_size
|
||||
if not self.enable_attention_dp:
|
||||
if not isinstance(rank,
|
||||
int) or rank < 0 and rank >= self.world_size:
|
||||
raise ValueError(
|
||||
f"Rank should be an integer between 0 and {self.world_size-1}, but got {rank}."
|
||||
)
|
||||
self._rank = rank
|
||||
|
||||
@property
|
||||
|
||||
@ -28,9 +28,9 @@ from torch import nn
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.compilation.backend import Backend
|
||||
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
||||
MPI.pickle.__init__(
|
||||
@ -77,11 +77,10 @@ def row_linear_residual_norm_fusion_forward(
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
mapping=Mapping(world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank),
|
||||
).cuda()
|
||||
norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda()
|
||||
|
||||
|
||||
@ -26,10 +26,10 @@ from utils.util import skip_pre_blackwell
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
AllReduceParams, DeepseekAllReduce,
|
||||
ParallelConfig, TensorParallelMode)
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
AllReduceParams, DeepseekAllReduce)
|
||||
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
||||
MPI.pickle.__init__(
|
||||
@ -87,16 +87,16 @@ def row_linear_residual_norm_fusion_forward(
|
||||
|
||||
norm = RMSNorm(hidden_size=hidden_size, eps=eps, dtype=dtype).cuda()
|
||||
|
||||
allreduce = AllReduce(parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
allreduce = AllReduce(mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
), ).cuda()
|
||||
|
||||
deepseek_allreduce = DeepseekAllReduce(parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
deepseek_allreduce = DeepseekAllReduce(mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
), ).cuda()
|
||||
|
||||
scale = torch.tensor(1.0, dtype=torch.float32).cuda()
|
||||
@ -268,10 +268,10 @@ def moe_residual_norm_fusion_forward(
|
||||
norm_weight = torch.randn((hidden_size, ), dtype=dtype, device="cuda")
|
||||
|
||||
# Initialize DeepseekAllReduce and AllReduce
|
||||
deepseek_allreduce = DeepseekAllReduce(parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
deepseek_allreduce = DeepseekAllReduce(mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
)).cuda()
|
||||
|
||||
# Initialize RMSNorm
|
||||
@ -283,11 +283,12 @@ def moe_residual_norm_fusion_forward(
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
).cuda()
|
||||
l0.load_weights([dict(weight=l0_weight)])
|
||||
token_input_chunked = torch.chunk(token_input.clone(),
|
||||
|
||||
@ -10,8 +10,9 @@ from mpi4py.futures import MPIPoolExecutor
|
||||
from torch import nn
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.embedding import Embedding, LMHead
|
||||
from tensorrt_llm._torch.modules.linear import TensorParallelMode
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
||||
MPI.pickle.__init__(
|
||||
@ -44,10 +45,12 @@ def column_embedding_forward(x, vocab_size, hidden_size, dtype,
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN),
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
)
|
||||
embedding.load_weights([dict(weight=weight)])
|
||||
embedding.cuda()
|
||||
@ -79,11 +82,13 @@ def row_embedding_forward(x, vocab_size, hidden_size, dtype,
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
gather_output=True),
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
gather_output=True,
|
||||
)
|
||||
embedding.load_weights([dict(weight=weight)])
|
||||
embedding.cuda()
|
||||
@ -115,11 +120,13 @@ def column_lm_head_forward(x, vocab_size, hidden_size, dtype,
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True),
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
lm_head.load_weights([dict(weight=weight)])
|
||||
lm_head.cuda()
|
||||
@ -152,10 +159,12 @@ def row_lm_head_forward(x, vocab_size, hidden_size, dtype, tensor_parallel_size,
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW),
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
)
|
||||
lm_head.load_weights([dict(weight=weight)])
|
||||
lm_head.cuda()
|
||||
|
||||
@ -10,9 +10,9 @@ from mpi4py.futures import MPIPoolExecutor
|
||||
from torch import nn
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
||||
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
||||
MPI.pickle.__init__(
|
||||
@ -52,23 +52,27 @@ def mlp_forward(x, hidden_size, dtype, tensor_parallel_size,
|
||||
out_features=4 * hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
)
|
||||
l0.load_weights([dict(weight=weights[0])])
|
||||
l0.cuda()
|
||||
l1 = Linear(in_features=4 * hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
))
|
||||
l1 = Linear(
|
||||
in_features=4 * hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
)
|
||||
l1.load_weights([dict(weight=weights[1])])
|
||||
l1.cuda()
|
||||
|
||||
@ -102,15 +106,19 @@ def column_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
|
||||
tensor_parallel_rank, weights):
|
||||
|
||||
x = x.cuda()
|
||||
l0 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True))
|
||||
l0 = Linear(
|
||||
in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
l0.load_weights([dict(weight=weights[0])])
|
||||
l0.cuda()
|
||||
|
||||
@ -137,15 +145,18 @@ def row_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
|
||||
tensor_parallel_rank, weights):
|
||||
|
||||
x = x.cuda()
|
||||
l0 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
))
|
||||
l0 = Linear(
|
||||
in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
)
|
||||
l0.load_weights([dict(weight=weights[0])])
|
||||
l0.cuda()
|
||||
|
||||
@ -183,15 +194,18 @@ def row_linear_norm_fusion_forward(x, hidden_size, dtype, tensor_parallel_size,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
l0 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_rank=tensor_parallel_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
))
|
||||
l0 = Linear(
|
||||
in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
mapping=Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=tensor_parallel_rank,
|
||||
),
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
)
|
||||
l0.load_weights([dict(weight=weights[0])])
|
||||
l0.cuda()
|
||||
|
||||
|
||||
@ -16,10 +16,10 @@ import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._torch.compilation.backend import Backend
|
||||
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
AllReduceParams, AllReduceStrategy,
|
||||
ParallelConfig, TensorParallelMode,
|
||||
userbuffers_allreduce_finalize)
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
@ -123,11 +123,12 @@ def run_single_rank_ar_rms_norm(tensor_parallel_size, a, b, c, gamma):
|
||||
|
||||
ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype)
|
||||
hidden = torch.matmul(a_local, b_local, out=ub0_tensor)
|
||||
parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN)
|
||||
ar = AllReduce(parallel_config, strategy=AllReduceStrategy.UB)
|
||||
mapping = Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=rank,
|
||||
)
|
||||
ar = AllReduce(mapping, strategy=AllReduceStrategy.UB)
|
||||
ar_params = AllReduceParams(
|
||||
strategy=AllReduceStrategy.UB,
|
||||
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
|
||||
@ -214,11 +215,12 @@ def run_single_rank_ar_rms_norm_fp8(tensor_parallel_size, a, b, c, gamma,
|
||||
|
||||
ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype)
|
||||
hidden = torch.matmul(a_local, b_local, out=ub0_tensor)
|
||||
parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN)
|
||||
ar = AllReduce(parallel_config, strategy=AllReduceStrategy.UB)
|
||||
mapping = Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=rank,
|
||||
)
|
||||
ar = AllReduce(mapping, strategy=AllReduceStrategy.UB)
|
||||
ar_params = AllReduceParams(
|
||||
strategy=AllReduceStrategy.UB,
|
||||
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
|
||||
@ -318,55 +320,45 @@ class UBTestModel(nn.Module):
|
||||
quant_config.layer_quant_mode
|
||||
self.rank = rank
|
||||
self.tp_size = tp_size
|
||||
mapping = Mapping(
|
||||
world_size=tp_size,
|
||||
tp_size=tp_size,
|
||||
rank=rank,
|
||||
)
|
||||
self.l0 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
),
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config).cuda()
|
||||
self.l1 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
),
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=quant_config).cuda()
|
||||
self.l2 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
),
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config).cuda()
|
||||
self.l3 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
),
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
quant_config=quant_config).cuda()
|
||||
self.l4 = Linear(in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
),
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=quant_config).cuda()
|
||||
self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps,
|
||||
dtype=dtype).cuda()
|
||||
@ -608,11 +600,12 @@ def run_single_rank_ar_rms_norm_fp4(tensor_parallel_size, a, b, c, gamma):
|
||||
|
||||
ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype)
|
||||
hidden = torch.matmul(a_local, b_local, out=ub0_tensor)
|
||||
parallel_config = ParallelConfig(
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN)
|
||||
ar = AllReduce(parallel_config, strategy=AllReduceStrategy.UB)
|
||||
mapping = Mapping(
|
||||
world_size=tensor_parallel_size,
|
||||
tp_size=tensor_parallel_size,
|
||||
rank=rank,
|
||||
)
|
||||
ar = AllReduce(mapping, strategy=AllReduceStrategy.UB)
|
||||
ar_params = AllReduceParams(
|
||||
strategy=AllReduceStrategy.UB,
|
||||
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
|
||||
@ -694,18 +687,14 @@ class UBMMAddModel(nn.Module):
|
||||
self.rank = rank
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.ar_0 = AllReduce(
|
||||
ParallelConfig(tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW)).cuda()
|
||||
self.ar_1 = AllReduce(
|
||||
ParallelConfig(tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW)).cuda()
|
||||
self.ar_2 = AllReduce(
|
||||
ParallelConfig(tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=rank,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW)).cuda()
|
||||
mapping = Mapping(
|
||||
world_size=tp_size,
|
||||
tp_size=tp_size,
|
||||
rank=rank,
|
||||
)
|
||||
self.ar_0 = AllReduce(mapping).cuda()
|
||||
self.ar_1 = AllReduce(mapping).cuda()
|
||||
self.ar_2 = AllReduce(mapping).cuda()
|
||||
self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps,
|
||||
dtype=dtype).cuda()
|
||||
self.norm1 = RMSNorm(hidden_size=hidden_size, eps=eps,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user