mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
66 lines
2.4 KiB
Python
66 lines
2.4 KiB
Python
from collections.abc import Callable
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ..distributed import ParallelConfig, TensorParallelMode
|
|
from ..model_config import ModelConfig
|
|
from .linear import Linear, WeightMode, WeightsLoadingConfig
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
|
def __init__(self,
|
|
*,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
bias: bool,
|
|
activation: Callable[[torch.Tensor], torch.Tensor] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
config: Optional[ModelConfig] = None):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.activation = activation
|
|
|
|
config = config or ModelConfig()
|
|
tp_rank = config.mapping.tp_rank
|
|
tp_size = config.mapping.tp_size
|
|
gpus_per_node = config.mapping.gpus_per_node
|
|
self.up_proj = Linear(
|
|
self.hidden_size,
|
|
self.intermediate_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_rank=tp_rank,
|
|
tensor_parallel_size=tp_size,
|
|
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
|
gpus_per_node=gpus_per_node,
|
|
pipeline_parallel_size=config.mapping.pp_size,
|
|
parallel_rank=config.mapping.rank),
|
|
weights_loading_config=WeightsLoadingConfig(
|
|
weight_mode=WeightMode.VANILLA),
|
|
quant_config=config.get_quant_config(),
|
|
skip_create_weights=config.skip_create_weights,
|
|
)
|
|
self.down_proj = Linear(
|
|
self.intermediate_size,
|
|
self.hidden_size,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_rank=tp_rank,
|
|
tensor_parallel_size=tp_size,
|
|
tensor_parallel_mode=TensorParallelMode.ROW,
|
|
gpus_per_node=gpus_per_node,
|
|
pipeline_parallel_size=config.mapping.pp_size,
|
|
parallel_rank=config.mapping.rank),
|
|
quant_config=config.get_quant_config(),
|
|
skip_create_weights=config.skip_create_weights,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.down_proj(self.activation(self.up_proj(x)))
|