mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: support packed weights in vanilla moe (#4719)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
6cf1e4d0a9
commit
bf691b3d28
@ -11,6 +11,7 @@ from torch import nn
|
||||
from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo
|
||||
from tensorrt_llm._utils import get_sm_version, logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.quantization.utils import fp4_utils
|
||||
from tensorrt_llm.quantization.utils.fp4_utils import (
|
||||
get_reorder_rows_for_gated_act_gemm_row_indices,
|
||||
get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices,
|
||||
@ -314,6 +315,7 @@ class VanillaMoE(nn.ModuleList):
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_alltoall: bool = False,
|
||||
pack_weights: bool = False,
|
||||
):
|
||||
from ..distributed import AllReduce
|
||||
|
||||
@ -323,6 +325,7 @@ class VanillaMoE(nn.ModuleList):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.weight_loading_mode = weight_loading_mode
|
||||
self.pack_weights = pack_weights
|
||||
|
||||
self.dtype = dtype
|
||||
self.reduce_results = reduce_results
|
||||
@ -360,6 +363,7 @@ class VanillaMoE(nn.ModuleList):
|
||||
self.expert_end = min(
|
||||
self.expert_start + self.expert_size_per_partition,
|
||||
self.num_experts)
|
||||
self.expert_size_per_partition = self.expert_end - self.expert_start
|
||||
|
||||
max_num_tokens = model_config.max_num_tokens
|
||||
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
|
||||
@ -376,10 +380,9 @@ class VanillaMoE(nn.ModuleList):
|
||||
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
|
||||
def create_weights(self):
|
||||
if self._weights_created:
|
||||
return
|
||||
|
||||
def create_experts(self, module_list: nn.ModuleList = None):
|
||||
if module_list is None:
|
||||
module_list = self
|
||||
model_config = copy.copy(self.model_config)
|
||||
model_config.mapping = Mapping(
|
||||
world_size=self.mapping.moe_tp_size,
|
||||
@ -390,7 +393,7 @@ class VanillaMoE(nn.ModuleList):
|
||||
model_config.skip_create_weights_in_init = False
|
||||
for expert_idx in range(self.num_experts):
|
||||
if self.expert_start <= expert_idx < self.expert_end:
|
||||
self[expert_idx] = GatedMLP(
|
||||
module_list[expert_idx] = GatedMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
bias=False,
|
||||
@ -400,24 +403,334 @@ class VanillaMoE(nn.ModuleList):
|
||||
)
|
||||
else:
|
||||
# use identity as placeholder for unused experts
|
||||
self[expert_idx] = nn.Identity()
|
||||
module_list[expert_idx] = nn.Identity()
|
||||
|
||||
def create_weights(self):
|
||||
if self._weights_created:
|
||||
return
|
||||
self._weights_created = True
|
||||
|
||||
if not self.pack_weights:
|
||||
self.create_experts()
|
||||
return
|
||||
|
||||
self.has_any_quant = False
|
||||
self.has_fp8_qdq = False
|
||||
self.has_fp8_block_scales = False
|
||||
self.has_nvfp4 = False
|
||||
gate_up_proj_shape = (
|
||||
self.expert_size_per_partition,
|
||||
self.intermediate_size_per_partition * 2,
|
||||
self.hidden_size,
|
||||
)
|
||||
down_proj_shape = (
|
||||
self.expert_size_per_partition,
|
||||
self.hidden_size,
|
||||
self.intermediate_size_per_partition,
|
||||
)
|
||||
if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
|
||||
exclude_kv_cache=True):
|
||||
self.has_any_quant = True
|
||||
qc = self.quant_config
|
||||
if qc.layer_quant_mode.has_fp8_qdq():
|
||||
self.has_fp8_qdq = True
|
||||
|
||||
self.gate_up_proj_weight = nn.Parameter(
|
||||
torch.empty(
|
||||
gate_up_proj_shape,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.gate_up_proj_weight_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.gate_up_proj_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.gate_up_proj_inv_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.down_proj_weight = nn.Parameter(
|
||||
torch.empty(
|
||||
down_proj_shape,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.down_proj_weight_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.down_proj_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.down_proj_inv_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
elif qc.layer_quant_mode.has_fp8_block_scales():
|
||||
self.has_fp8_block_scales = True
|
||||
|
||||
self.gate_up_proj_weight = nn.Parameter(
|
||||
torch.empty(
|
||||
gate_up_proj_shape,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
gate_up_proj_scale_shape = (
|
||||
self.expert_size_per_partition,
|
||||
math.ceil(self.intermediate_size_per_partition * 2 / 128),
|
||||
math.ceil(self.hidden_size / 128),
|
||||
)
|
||||
self.gate_up_proj_weight_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
gate_up_proj_scale_shape,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Not really used for Gemm now.
|
||||
# Only used to quantize output of FP8 attention.
|
||||
self.gate_up_proj_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.gate_up_proj_inv_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.down_proj_weight = nn.Parameter(
|
||||
torch.empty(
|
||||
down_proj_shape,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
down_proj_scale_shape = (
|
||||
self.expert_size_per_partition,
|
||||
math.ceil(self.hidden_size / 128),
|
||||
math.ceil(self.intermediate_size_per_partition / 128),
|
||||
)
|
||||
self.down_proj_weight_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
down_proj_scale_shape,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Not really used for Gemm now.
|
||||
# Only used to quantize output of FP8 attention.
|
||||
self.down_proj_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.down_proj_inv_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
elif qc.layer_quant_mode.has_nvfp4():
|
||||
self.has_nvfp4 = True
|
||||
self.scaling_vector_size = 16
|
||||
|
||||
assert self.hidden_size % self.scaling_vector_size == 0, f"hidden_size {self.hidden_size} must be divisible by scaling_vector_size {self.scaling_vector_size}"
|
||||
|
||||
# Quantized weights
|
||||
self.gate_up_proj_weight = nn.Parameter(
|
||||
torch.empty(
|
||||
[
|
||||
self.expert_size_per_partition,
|
||||
self.intermediate_size_per_partition * 2,
|
||||
self.hidden_size // 2,
|
||||
],
|
||||
dtype=fp4_utils.float4_e2m1x2,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
|
||||
# Padding is required. See computeSFSize in quantization.h
|
||||
nrows = fp4_utils.pad_up(
|
||||
self.intermediate_size_per_partition * 2, 128)
|
||||
ncols = fp4_utils.pad_up(
|
||||
self.hidden_size // self.scaling_vector_size, 4)
|
||||
self.gate_up_proj_weight_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
[self.expert_size_per_partition, nrows * ncols],
|
||||
dtype=fp4_utils.float4_sf_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# FP32 per-tensor global scaling factor = 448*6/amax_input
|
||||
self.gate_up_proj_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.gate_up_proj_inv_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# (amax_input*amax_weight) / (448*6*448*6)
|
||||
self.gate_up_proj_alpha = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
assert self.intermediate_size_per_partition % self.scaling_vector_size == 0, f"intermediate_size_per_partition {self.intermediate_size_per_partition} must be divisible by scaling_vector_size {self.scaling_vector_size}"
|
||||
|
||||
# Quantized weights
|
||||
self.down_proj_weight = nn.Parameter(
|
||||
torch.empty(
|
||||
[
|
||||
self.expert_size_per_partition,
|
||||
self.hidden_size,
|
||||
self.intermediate_size_per_partition // 2,
|
||||
],
|
||||
dtype=fp4_utils.float4_e2m1x2,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
|
||||
# Padding is required. See computeSFSize in quantization.h
|
||||
nrows = fp4_utils.pad_up(self.hidden_size, 128)
|
||||
ncols = fp4_utils.pad_up(
|
||||
self.intermediate_size_per_partition //
|
||||
self.scaling_vector_size, 4)
|
||||
self.down_proj_weight_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
[self.expert_size_per_partition, nrows * ncols],
|
||||
dtype=fp4_utils.float4_sf_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# FP32 per-tensor global scaling factor = 448*6/amax_input
|
||||
self.down_proj_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.down_proj_inv_input_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# (amax_input*amax_weight) / (448*6*448*6)
|
||||
self.down_proj_alpha = nn.Parameter(
|
||||
torch.empty(
|
||||
self.expert_size_per_partition,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'unsupported quant mode: {qc.quant_mode}')
|
||||
else:
|
||||
self.gate_up_proj_weight = nn.Parameter(
|
||||
torch.empty(gate_up_proj_shape, dtype=self.dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.down_proj_weight = nn.Parameter(
|
||||
torch.empty(down_proj_shape, dtype=self.dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def pack_params(self, experts, module_name: str, weight_name: str):
|
||||
weights = []
|
||||
for expert_idx in range(self.expert_start, self.expert_end):
|
||||
weights.append(
|
||||
getattr(getattr(experts[expert_idx], module_name), weight_name))
|
||||
packed_weight = torch._utils._flatten_dense_tensors(weights)
|
||||
weights_data = torch._utils._unflatten_dense_tensors(
|
||||
packed_weight, weights)
|
||||
for weight, data in zip(weights, weights_data):
|
||||
weight.data = data
|
||||
packed_weight = packed_weight.view(len(weights), *weights_data[0].shape)
|
||||
getattr(self, f"{module_name}_{weight_name}").data = packed_weight
|
||||
|
||||
def load_weights(self, weights: List[Dict]):
|
||||
from ..models.modeling_utils import filter_weights
|
||||
|
||||
assert self._weights_created
|
||||
assert len(weights) == 1
|
||||
weights = weights[0]
|
||||
|
||||
if self.pack_weights:
|
||||
experts = nn.ModuleList([None] * self.num_experts)
|
||||
self.create_experts(experts)
|
||||
experts.to("cuda")
|
||||
else:
|
||||
experts = self
|
||||
|
||||
for expert_idx in range(self.expert_start, self.expert_end):
|
||||
self[expert_idx].gate_up_proj.load_weights([
|
||||
experts[expert_idx].gate_up_proj.load_weights([
|
||||
filter_weights(f"{expert_idx}.w1", weights),
|
||||
filter_weights(f"{expert_idx}.w3", weights),
|
||||
])
|
||||
self[expert_idx].down_proj.load_weights([
|
||||
experts[expert_idx].down_proj.load_weights([
|
||||
filter_weights(f"{expert_idx}.w2", weights),
|
||||
])
|
||||
|
||||
if self.pack_weights:
|
||||
for module_name in ["gate_up_proj", "down_proj"]:
|
||||
for weight_name, _ in getattr(experts[self.expert_start],
|
||||
module_name).named_parameters():
|
||||
self.pack_params(experts, module_name, weight_name)
|
||||
|
||||
def reducescatter_or_allreduce(
|
||||
self,
|
||||
inputs,
|
||||
@ -436,6 +749,31 @@ class VanillaMoE(nn.ModuleList):
|
||||
outputs = self.all_reduce(inputs)
|
||||
return outputs
|
||||
|
||||
def run_experts(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
expanded_inputs: torch.Tensor,
|
||||
expanded_scales: torch.Tensor,
|
||||
sorted_experts: torch.Tensor,
|
||||
batch_indices: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
final_hidden_states = torch.zeros(
|
||||
input.shape,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
for expert_idx in range(self.expert_start, self.expert_end):
|
||||
expert_mask = sorted_experts == expert_idx
|
||||
if not torch.any(expert_mask):
|
||||
continue
|
||||
expanded_input = expanded_inputs[expert_mask]
|
||||
batch_idx = batch_indices[expert_mask]
|
||||
expanded_scale = expanded_scales[expert_mask]
|
||||
|
||||
output = self[expert_idx](expanded_input)
|
||||
final_hidden_states[batch_idx] += output * expanded_scale
|
||||
return final_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -457,20 +795,25 @@ class VanillaMoE(nn.ModuleList):
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
|
||||
final_hidden_states = torch.zeros(x.shape,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
expert_masks = ((token_selected_experts >= self.expert_start)
|
||||
& (token_selected_experts < self.expert_end))
|
||||
local_selected_experts = token_selected_experts[expert_masks]
|
||||
sort_indices = torch.argsort(local_selected_experts)
|
||||
sorted_experts = local_selected_experts[sort_indices]
|
||||
|
||||
for expert_idx in range(self.expert_start, self.expert_end):
|
||||
if not torch.any(token_selected_experts == expert_idx):
|
||||
continue
|
||||
batch_idx, nth_expert = torch.where(
|
||||
token_selected_experts == expert_idx)
|
||||
expert_inputs = x[batch_idx]
|
||||
batch_indices, nth_experts = torch.where(expert_masks)
|
||||
batch_indices = batch_indices[sort_indices]
|
||||
nth_experts = nth_experts[sort_indices]
|
||||
expanded_inputs = x[batch_indices]
|
||||
expanded_scales = token_final_scales[batch_indices, nth_experts, None]
|
||||
|
||||
output = self[expert_idx](expert_inputs)
|
||||
final_hidden_states[batch_idx] += output * token_final_scales[
|
||||
batch_idx, nth_expert, None]
|
||||
final_hidden_states = self.run_experts(
|
||||
x,
|
||||
expanded_inputs,
|
||||
expanded_scales,
|
||||
sorted_experts,
|
||||
batch_indices,
|
||||
)
|
||||
|
||||
final_hidden_states = self.reducescatter_or_allreduce(
|
||||
final_hidden_states,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user