TensorRT-LLMs/tensorrt_llm/_torch/modules/mlp.py
Wanli Jiang 421eb9e39c
[None][feat] Optimize NemotronH model with elementwise and nvfp4 fusion (#11273)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
2026-02-12 09:25:31 -05:00

160 lines
5.3 KiB
Python

from collections.abc import Callable
from typing import Optional
import torch
from torch import nn
from tensorrt_llm.mapping import Mapping
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from ..utils import Fp4QuantizedTensor, relu2
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
class MLP(nn.Module):
def __init__(
self,
*,
hidden_size: int,
intermediate_size: int,
bias: bool,
activation: Callable[[torch.Tensor], torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None,
layer_idx: Optional[int] = None,
reduce_output: bool = True,
overridden_tp_size: Optional[int] = None,
):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.activation = activation
config = config or ModelConfig()
self.mapping = config.mapping
if overridden_tp_size is not None:
assert config.mapping.tp_size % overridden_tp_size == 0
tp_size = overridden_tp_size
# "Misuse" pp_size here to perform all-reduce within smaller groups
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
mapping = Mapping(
world_size=tp_size * pp_size,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
tp_size=tp_size,
pp_size=pp_size,
)
else:
mapping = config.mapping
self.up_lora = LoraLayer(
[LoraModuleType.MLP_H_TO_4H],
[self.intermediate_size // config.mapping.tp_size])
self.up_proj = Linear(
self.hidden_size,
self.intermediate_size,
bias=bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.VANILLA),
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.up_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
self.down_proj = Linear(
self.intermediate_size,
self.hidden_size,
bias=bias,
dtype=dtype,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
reduce_output=reduce_output,
)
self._use_fused_relu2_quant = False
def create_weights(self):
self.up_proj.create_weights()
self.down_proj.create_weights()
has_nvfp4 = hasattr(self.down_proj,
'has_nvfp4') and self.down_proj.has_nvfp4
has_kernel = hasattr(torch.ops.trtllm, 'fused_relu2_quantize')
has_scale = hasattr(self.down_proj, 'input_scale')
is_relu2 = self.activation is relu2
self._use_fused_relu2_quant = has_nvfp4 and has_kernel and has_scale and is_relu2
def forward(
self,
x: torch.Tensor,
lora_params: Optional[dict] = None,
) -> torch.Tensor:
if lora_params is not None:
return self.forward_lora(x, lora_params=lora_params)
x_up = self.up_proj(x)
if self._use_fused_relu2_quant:
x_act = self._fused_relu2_quant(x_up)
else:
x_act = self.activation(x_up)
x_down = self.down_proj(x_act)
return x_down
def _fused_relu2_quant(self, x: torch.Tensor) -> Fp4QuantizedTensor:
x_flat = x.view(-1, x.shape[-1])
if not x_flat.is_contiguous():
x_flat = x_flat.contiguous()
if x_flat.dtype not in (torch.float16, torch.bfloat16):
x_flat = x_flat.to(torch.bfloat16)
fp4_tensor, sf_tensor = torch.ops.trtllm.fused_relu2_quantize(
x_flat, self.down_proj.input_scale, 16)
return Fp4QuantizedTensor(
fp4_tensor=fp4_tensor,
scaling_factor=sf_tensor,
is_sf_swizzled=True,
)
def forward_lora(
self,
x: torch.Tensor,
lora_params: Optional[dict] = None,
) -> torch.Tensor:
assert lora_params is not None
x_up = self.up_proj(x)
assert self.layer_idx is not None, "layer_idx is required for lora"
x_up_lora = self.up_lora(x, lora_params, self.layer_idx)
if x_up_lora is not None:
x_up = x_up + x_up_lora
x_act = self.activation(x_up)
x_down = self.down_proj(x_act,
lora_params=lora_params,
layer_idx=self.layer_idx)
return x_down