TensorRT-LLMs/tensorrt_llm/_torch/modules/mlp.py
Kaiyu Xie 3aa6b11d13
Update TensorRT-LLM (#2936)
* Update TensorRT-LLM

---------

Co-authored-by: changcui <cuichang147@gmail.com>
2025-03-18 21:25:19 +08:00

66 lines
2.4 KiB
Python

from collections.abc import Callable
from typing import Optional
import torch
from torch import nn
from ..distributed import ParallelConfig, TensorParallelMode
from ..model_config import ModelConfig
from .linear import Linear, WeightMode, WeightsLoadingConfig
class MLP(nn.Module):
def __init__(self,
*,
hidden_size: int,
intermediate_size: int,
bias: bool,
activation: Callable[[torch.Tensor], torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.activation = activation
config = config or ModelConfig()
tp_rank = config.mapping.tp_rank
tp_size = config.mapping.tp_size
gpus_per_node = config.mapping.gpus_per_node
self.up_proj = Linear(
self.hidden_size,
self.intermediate_size,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank),
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.VANILLA),
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights,
)
self.down_proj = Linear(
self.intermediate_size,
self.hidden_size,
bias=bias,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_rank=tp_rank,
tensor_parallel_size=tp_size,
tensor_parallel_mode=TensorParallelMode.ROW,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=config.mapping.pp_size,
parallel_rank=config.mapping.rank),
quant_config=config.get_quant_config(),
skip_create_weights=config.skip_create_weights,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.activation(self.up_proj(x)))