feat: support packed weights in vanilla moe (#4719)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2025-05-29 06:24:24 +08:00 committed by GitHub
parent 6cf1e4d0a9
commit bf691b3d28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,6 +11,7 @@ from torch import nn
from tensorrt_llm._mnnvl_utils import MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import get_sm_version, logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.quantization.utils import fp4_utils
from tensorrt_llm.quantization.utils.fp4_utils import (
get_reorder_rows_for_gated_act_gemm_row_indices,
get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices,
@ -314,6 +315,7 @@ class VanillaMoE(nn.ModuleList):
VANILLA,
apply_router_weight_on_input: bool = False,
enable_alltoall: bool = False,
pack_weights: bool = False,
):
from ..distributed import AllReduce
@ -323,6 +325,7 @@ class VanillaMoE(nn.ModuleList):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.weight_loading_mode = weight_loading_mode
self.pack_weights = pack_weights
self.dtype = dtype
self.reduce_results = reduce_results
@ -360,6 +363,7 @@ class VanillaMoE(nn.ModuleList):
self.expert_end = min(
self.expert_start + self.expert_size_per_partition,
self.num_experts)
self.expert_size_per_partition = self.expert_end - self.expert_start
max_num_tokens = model_config.max_num_tokens
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
@ -376,10 +380,9 @@ class VanillaMoE(nn.ModuleList):
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input
def create_weights(self):
if self._weights_created:
return
def create_experts(self, module_list: nn.ModuleList = None):
if module_list is None:
module_list = self
model_config = copy.copy(self.model_config)
model_config.mapping = Mapping(
world_size=self.mapping.moe_tp_size,
@ -390,7 +393,7 @@ class VanillaMoE(nn.ModuleList):
model_config.skip_create_weights_in_init = False
for expert_idx in range(self.num_experts):
if self.expert_start <= expert_idx < self.expert_end:
self[expert_idx] = GatedMLP(
module_list[expert_idx] = GatedMLP(
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
bias=False,
@ -400,24 +403,334 @@ class VanillaMoE(nn.ModuleList):
)
else:
# use identity as placeholder for unused experts
self[expert_idx] = nn.Identity()
module_list[expert_idx] = nn.Identity()
def create_weights(self):
if self._weights_created:
return
self._weights_created = True
if not self.pack_weights:
self.create_experts()
return
self.has_any_quant = False
self.has_fp8_qdq = False
self.has_fp8_block_scales = False
self.has_nvfp4 = False
gate_up_proj_shape = (
self.expert_size_per_partition,
self.intermediate_size_per_partition * 2,
self.hidden_size,
)
down_proj_shape = (
self.expert_size_per_partition,
self.hidden_size,
self.intermediate_size_per_partition,
)
if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
self.has_any_quant = True
qc = self.quant_config
if qc.layer_quant_mode.has_fp8_qdq():
self.has_fp8_qdq = True
self.gate_up_proj_weight = nn.Parameter(
torch.empty(
gate_up_proj_shape,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
self.gate_up_proj_weight_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.gate_up_proj_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.gate_up_proj_inv_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.down_proj_weight = nn.Parameter(
torch.empty(
down_proj_shape,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
self.down_proj_weight_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.down_proj_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.down_proj_inv_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
elif qc.layer_quant_mode.has_fp8_block_scales():
self.has_fp8_block_scales = True
self.gate_up_proj_weight = nn.Parameter(
torch.empty(
gate_up_proj_shape,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
gate_up_proj_scale_shape = (
self.expert_size_per_partition,
math.ceil(self.intermediate_size_per_partition * 2 / 128),
math.ceil(self.hidden_size / 128),
)
self.gate_up_proj_weight_scale = nn.Parameter(
torch.empty(
gate_up_proj_scale_shape,
dtype=torch.float32,
),
requires_grad=False,
)
# Not really used for Gemm now.
# Only used to quantize output of FP8 attention.
self.gate_up_proj_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.gate_up_proj_inv_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.down_proj_weight = nn.Parameter(
torch.empty(
down_proj_shape,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
down_proj_scale_shape = (
self.expert_size_per_partition,
math.ceil(self.hidden_size / 128),
math.ceil(self.intermediate_size_per_partition / 128),
)
self.down_proj_weight_scale = nn.Parameter(
torch.empty(
down_proj_scale_shape,
dtype=torch.float32,
),
requires_grad=False,
)
# Not really used for Gemm now.
# Only used to quantize output of FP8 attention.
self.down_proj_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.down_proj_inv_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
elif qc.layer_quant_mode.has_nvfp4():
self.has_nvfp4 = True
self.scaling_vector_size = 16
assert self.hidden_size % self.scaling_vector_size == 0, f"hidden_size {self.hidden_size} must be divisible by scaling_vector_size {self.scaling_vector_size}"
# Quantized weights
self.gate_up_proj_weight = nn.Parameter(
torch.empty(
[
self.expert_size_per_partition,
self.intermediate_size_per_partition * 2,
self.hidden_size // 2,
],
dtype=fp4_utils.float4_e2m1x2,
),
requires_grad=False,
)
# FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
# Padding is required. See computeSFSize in quantization.h
nrows = fp4_utils.pad_up(
self.intermediate_size_per_partition * 2, 128)
ncols = fp4_utils.pad_up(
self.hidden_size // self.scaling_vector_size, 4)
self.gate_up_proj_weight_scale = nn.Parameter(
torch.empty(
[self.expert_size_per_partition, nrows * ncols],
dtype=fp4_utils.float4_sf_dtype,
),
requires_grad=False,
)
# FP32 per-tensor global scaling factor = 448*6/amax_input
self.gate_up_proj_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.gate_up_proj_inv_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
# (amax_input*amax_weight) / (448*6*448*6)
self.gate_up_proj_alpha = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
assert self.intermediate_size_per_partition % self.scaling_vector_size == 0, f"intermediate_size_per_partition {self.intermediate_size_per_partition} must be divisible by scaling_vector_size {self.scaling_vector_size}"
# Quantized weights
self.down_proj_weight = nn.Parameter(
torch.empty(
[
self.expert_size_per_partition,
self.hidden_size,
self.intermediate_size_per_partition // 2,
],
dtype=fp4_utils.float4_e2m1x2,
),
requires_grad=False,
)
# FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
# Padding is required. See computeSFSize in quantization.h
nrows = fp4_utils.pad_up(self.hidden_size, 128)
ncols = fp4_utils.pad_up(
self.intermediate_size_per_partition //
self.scaling_vector_size, 4)
self.down_proj_weight_scale = nn.Parameter(
torch.empty(
[self.expert_size_per_partition, nrows * ncols],
dtype=fp4_utils.float4_sf_dtype,
),
requires_grad=False,
)
# FP32 per-tensor global scaling factor = 448*6/amax_input
self.down_proj_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
self.down_proj_inv_input_scale = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
# (amax_input*amax_weight) / (448*6*448*6)
self.down_proj_alpha = nn.Parameter(
torch.empty(
self.expert_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
else:
raise ValueError(f'unsupported quant mode: {qc.quant_mode}')
else:
self.gate_up_proj_weight = nn.Parameter(
torch.empty(gate_up_proj_shape, dtype=self.dtype),
requires_grad=False,
)
self.down_proj_weight = nn.Parameter(
torch.empty(down_proj_shape, dtype=self.dtype),
requires_grad=False,
)
def pack_params(self, experts, module_name: str, weight_name: str):
weights = []
for expert_idx in range(self.expert_start, self.expert_end):
weights.append(
getattr(getattr(experts[expert_idx], module_name), weight_name))
packed_weight = torch._utils._flatten_dense_tensors(weights)
weights_data = torch._utils._unflatten_dense_tensors(
packed_weight, weights)
for weight, data in zip(weights, weights_data):
weight.data = data
packed_weight = packed_weight.view(len(weights), *weights_data[0].shape)
getattr(self, f"{module_name}_{weight_name}").data = packed_weight
def load_weights(self, weights: List[Dict]):
from ..models.modeling_utils import filter_weights
assert self._weights_created
assert len(weights) == 1
weights = weights[0]
if self.pack_weights:
experts = nn.ModuleList([None] * self.num_experts)
self.create_experts(experts)
experts.to("cuda")
else:
experts = self
for expert_idx in range(self.expert_start, self.expert_end):
self[expert_idx].gate_up_proj.load_weights([
experts[expert_idx].gate_up_proj.load_weights([
filter_weights(f"{expert_idx}.w1", weights),
filter_weights(f"{expert_idx}.w3", weights),
])
self[expert_idx].down_proj.load_weights([
experts[expert_idx].down_proj.load_weights([
filter_weights(f"{expert_idx}.w2", weights),
])
if self.pack_weights:
for module_name in ["gate_up_proj", "down_proj"]:
for weight_name, _ in getattr(experts[self.expert_start],
module_name).named_parameters():
self.pack_params(experts, module_name, weight_name)
def reducescatter_or_allreduce(
self,
inputs,
@ -436,6 +749,31 @@ class VanillaMoE(nn.ModuleList):
outputs = self.all_reduce(inputs)
return outputs
def run_experts(
self,
input: torch.Tensor,
expanded_inputs: torch.Tensor,
expanded_scales: torch.Tensor,
sorted_experts: torch.Tensor,
batch_indices: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros(
input.shape,
dtype=input.dtype,
device=input.device,
)
for expert_idx in range(self.expert_start, self.expert_end):
expert_mask = sorted_experts == expert_idx
if not torch.any(expert_mask):
continue
expanded_input = expanded_inputs[expert_mask]
batch_idx = batch_indices[expert_mask]
expanded_scale = expanded_scales[expert_mask]
output = self[expert_idx](expanded_input)
final_hidden_states[batch_idx] += output * expanded_scale
return final_hidden_states
def forward(
self,
x: torch.Tensor,
@ -457,20 +795,25 @@ class VanillaMoE(nn.ModuleList):
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
final_hidden_states = torch.zeros(x.shape,
dtype=x.dtype,
device=x.device)
expert_masks = ((token_selected_experts >= self.expert_start)
& (token_selected_experts < self.expert_end))
local_selected_experts = token_selected_experts[expert_masks]
sort_indices = torch.argsort(local_selected_experts)
sorted_experts = local_selected_experts[sort_indices]
for expert_idx in range(self.expert_start, self.expert_end):
if not torch.any(token_selected_experts == expert_idx):
continue
batch_idx, nth_expert = torch.where(
token_selected_experts == expert_idx)
expert_inputs = x[batch_idx]
batch_indices, nth_experts = torch.where(expert_masks)
batch_indices = batch_indices[sort_indices]
nth_experts = nth_experts[sort_indices]
expanded_inputs = x[batch_indices]
expanded_scales = token_final_scales[batch_indices, nth_experts, None]
output = self[expert_idx](expert_inputs)
final_hidden_states[batch_idx] += output * token_final_scales[
batch_idx, nth_expert, None]
final_hidden_states = self.run_experts(
x,
expanded_inputs,
expanded_scales,
sorted_experts,
batch_indices,
)
final_hidden_states = self.reducescatter_or_allreduce(
final_hidden_states,