[None][feat] Support nvfp4 for gptoss (#8956)

Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
This commit is contained in:
dongfengy 2026-01-04 21:57:44 +08:00 committed by GitHub
parent a4dcc6a711
commit afc533193d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 961 additions and 255 deletions

View File

@ -34,15 +34,17 @@ using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch::Tensor> const& routing_logits,
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
torch::optional<torch::Tensor> const& hidden_states_scale, torch::Tensor const& gemm1_weights,
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner,
int64_t const moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
torch::optional<torch::Tensor> const& topk_ids)
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
int64_t const local_expert_offset, int64_t const local_num_experts,
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
{
TORCH_CHECK(dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, "dtype can only be e4m3 or e2m1.");
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE");
@ -161,8 +163,13 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
args.gemm1_weights = gemm1_weights.data_ptr();
args.gemm1_weights_scale = gemm1_weights_scale.data_ptr();
args.gemm1_bias = gemm1_bias.has_value() ? gemm1_bias.value().data_ptr<float>() : nullptr;
args.gemm1_alpha = gemm1_alpha.has_value() ? gemm1_alpha.value().data_ptr<float>() : nullptr;
args.gemm1_beta = gemm1_beta.has_value() ? gemm1_beta.value().data_ptr<float>() : nullptr;
args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr<float>() : nullptr;
args.gemm2_weights = gemm2_weights.data_ptr();
args.gemm2_weights_scale = gemm2_weights_scale.data_ptr();
args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr<float>() : nullptr;
args.num_tokens = hidden_states.sizes()[0];
args.num_experts = num_experts;
if (dtype == btg::Dtype::E4m3)
@ -313,6 +320,38 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
TORCH_CHECK(intermediate_size % 16 == 0, "the second dimension of weights must be a multiple of 16.");
TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size, "gemm1_weights_scale has incorrect dim 1.");
if (gemm1_bias.has_value())
{
TORCH_CHECK(gemm1_bias.value().scalar_type() == at::ScalarType::Float, "gemm1_bias must be float, got %s.",
c10::toString(gemm1_bias.value().scalar_type()));
TORCH_CHECK(gemm1_bias.value().dim() == 2, "gemm1_bias must be 2D.");
TORCH_CHECK(gemm1_bias.value().sizes()[0] == local_num_experts, "gemm1_bias has incorrect dim 0.");
TORCH_CHECK(gemm1_bias.value().sizes()[1] == 2 * intermediate_size, "gemm1_bias has incorrect dim 1.");
}
if (gemm1_alpha.has_value())
{
TORCH_CHECK(gemm1_alpha.value().scalar_type() == at::ScalarType::Float, "gemm1_alpha must be float, got %s.",
c10::toString(gemm1_alpha.value().scalar_type()));
TORCH_CHECK(gemm1_alpha.value().dim() == 1, "gemm1_alpha must be 1D.");
TORCH_CHECK(gemm1_alpha.value().sizes()[0] == local_num_experts, "gemm1_alpha has incorrect dim 0.");
}
if (gemm1_beta.has_value())
{
TORCH_CHECK(gemm1_beta.value().scalar_type() == at::ScalarType::Float, "gemm1_beta must be float, got %s.",
c10::toString(gemm1_beta.value().scalar_type()));
TORCH_CHECK(gemm1_beta.value().dim() == 1, "gemm1_beta must be 1D.");
TORCH_CHECK(gemm1_beta.value().sizes()[0] == local_num_experts, "gemm1_beta has incorrect dim 0.");
}
if (gemm1_clamp_limit.has_value())
{
TORCH_CHECK(gemm1_clamp_limit.value().scalar_type() == at::ScalarType::Float,
"gemm1_clamp_limit must be float, got %s.", c10::toString(gemm1_clamp_limit.value().scalar_type()));
TORCH_CHECK(gemm1_clamp_limit.value().dim() == 1, "gemm1_clamp_limit must be 1D.");
TORCH_CHECK(
gemm1_clamp_limit.value().sizes()[0] == local_num_experts, "gemm1_clamp_limit has incorrect dim 0.");
}
TORCH_CHECK(gemm2_weights.scalar_type() == FLOAT4_E2M1X2, "gemm2_weights must be byte.");
TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D.");
@ -322,6 +361,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
TORCH_CHECK(gemm2_weights_scale.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm2_weights_scale must be fp8.");
if (gemm2_bias.has_value())
{
TORCH_CHECK(gemm2_bias.value().scalar_type() == at::ScalarType::Float, "gemm2_bias must be float, got %s.",
c10::toString(gemm2_bias.value().scalar_type()));
TORCH_CHECK(gemm2_bias.value().dim() == 2, "gemm2_bias must be 2D.");
TORCH_CHECK(gemm2_bias.value().sizes()[0] == local_num_experts, "gemm2_bias has incorrect dim 0.");
TORCH_CHECK(gemm2_bias.value().sizes()[1] == args.hidden_size, "gemm2_bias has incorrect dim 1.");
}
TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts, "gemm2_weights_scale has incorrect dim 0.");
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size, "gemm2_weights_scale has incorrect dim 1.");
@ -440,14 +488,17 @@ public:
[[nodiscard]] std::vector<torch::Tensor> run(torch::optional<torch::Tensor> const& routing_logits,
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights,
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor,
int64_t const routing_method_type, bool const do_finalize, std::vector<int64_t> moeConfigIndex,
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
int64_t const local_expert_offset, int64_t const local_num_experts,
std::optional<double> const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
std::vector<int64_t> moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
torch::optional<torch::Tensor> const& topk_ids)
{
// moeConfigIndex corresponds to pair (tileN, config)
auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]);
@ -468,10 +519,11 @@ public:
}
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale,
gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config, topk_weights, topk_ids);
gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config,
topk_weights, topk_ids);
}
private:
@ -553,11 +605,11 @@ public:
}
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states,
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, gemm2_weights,
gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts,
top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config,
topk_weights, topk_ids);
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, gemm2_weights, gemm2_weights_scale, std::nullopt, output1_scales_scalar,
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids);
}
private:

View File

@ -176,8 +176,13 @@ class FP4BlockScaleMoEInputs:
hidden_states_scale: torch.Tensor
gemm1_weights: torch.Tensor
gemm1_weights_scale: torch.Tensor
gemm1_bias: torch.Tensor
gemm1_alpha: torch.Tensor
gemm1_beta: torch.Tensor
gemm1_clamp_limit: torch.Tensor
gemm2_weights: torch.Tensor
gemm2_weights_scale: torch.Tensor
gemm2_bias: torch.Tensor
output1_scale_scalar: torch.Tensor
output1_scale_gate_scalar: torch.Tensor
output2_scale_scalar: torch.Tensor
@ -235,14 +240,15 @@ class FP4BlockScaleMoERunner(TunableRunner):
return kernel_runner.run_moe(
args.routing_logits, args.routing_bias, args.hidden_states,
args.hidden_states_scale, args.gemm1_weights,
args.gemm1_weights_scale, args.gemm2_weights,
args.gemm2_weights_scale, args.output1_scale_scalar,
args.output1_scale_gate_scalar, args.output2_scale_scalar,
self.num_experts, self.top_k, self.n_group, self.topk_group,
self.intermediate_size, self.local_expert_offset,
self.local_num_experts, self.routed_scaling_factor,
self.routing_method_type, self.do_finalize, tactic,
args.topk_weights, args.topk_ids)
args.gemm1_weights_scale, args.gemm1_bias, args.gemm1_alpha,
args.gemm1_beta, args.gemm1_clamp_limit, args.gemm2_weights,
args.gemm2_weights_scale, args.gemm2_bias,
args.output1_scale_scalar, args.output1_scale_gate_scalar,
args.output2_scale_scalar, self.num_experts, self.top_k,
self.n_group, self.topk_group, self.intermediate_size,
self.local_expert_offset, self.local_num_experts,
self.routed_scaling_factor, self.routing_method_type,
self.do_finalize, tactic, args.topk_weights, args.topk_ids)
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile,
@ -317,8 +323,8 @@ class FP4BlockScaleMoERunner(TunableRunner):
ROUTER_LOGITS_IDX = 0
CONSTRAINED_RL_DIM = 0
TOPK_WEIGHTS_IDX = 11
TOPK_IDS_IDX = 12
TOPK_WEIGHTS_IDX = 16
TOPK_IDS_IDX = 17
constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX,
CONSTRAINED_RL_DIM,
@ -359,8 +365,13 @@ def fp4_block_scale_moe_runner(
hidden_states_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm1_bias: torch.Tensor,
gemm1_alpha: torch.Tensor,
gemm1_beta: torch.Tensor,
gemm1_clamp_limit: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
gemm2_bias: torch.Tensor,
output1_scale_scalar: torch.Tensor,
output1_scale_gate_scalar: torch.Tensor,
output2_scale_scalar: torch.Tensor,
@ -416,8 +427,13 @@ def fp4_block_scale_moe_runner(
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm1_bias,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
gemm2_weights,
gemm2_weights_scale,
gemm2_bias,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,
@ -474,8 +490,13 @@ def _(routing_logits,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm1_bias,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
gemm2_weights,
gemm2_weights_scale,
gemm2_bias,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,

View File

@ -34,8 +34,7 @@ from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from ..utils import Fp4QuantizedTensor
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import (DecoderModel, duplicate_kv_weight, filter_weights,
register_auto_model)
from .modeling_utils import DecoderModel, filter_weights, register_auto_model
# Use TinyGEMM when the number of tokens is not larger than this threshold
MIN_LATENCY_TINYGEMM_NUM_TOKENS = 128
@ -639,6 +638,15 @@ class GptOssForCausalLM(SpecDecOneEngineForCausalLM[Transformer, GptOssConfig]):
quant_config = self.model_config.quant_config
if quant_config.exclude_modules:
if quant_config.quant_algo == "NVFP4":
quant_config.exclude_modules = [
'block.*.attn.qkv',
'block.*.attn.out',
'block.*.mlp.gate',
'embedding',
'unembedding',
]
for i, module in enumerate(quant_config.exclude_modules):
names = module.split(".")
if names[-1] in params_map_reverse:
@ -653,13 +661,10 @@ class GptOssForCausalLM(SpecDecOneEngineForCausalLM[Transformer, GptOssConfig]):
module.create_weights()
def load_weights(self, weights: Dict):
is_ori_model = True
for k, v in weights.items():
if 'q_proj' in k:
is_ori_model = False
is_nvfp4 = self.model_config.quant_config.quant_mode.has_nvfp4()
if is_ori_model:
self.load_ori_weights(weights)
if is_nvfp4:
self.load_nvfp4_weights(weights)
else:
self.load_hf_weights(weights)
@ -811,176 +816,107 @@ class GptOssForCausalLM(SpecDecOneEngineForCausalLM[Transformer, GptOssConfig]):
if p is not None:
p.data.copy_(module_weights[n][:])
def load_ori_weights(self, weights: Dict):
head_dim = self.config.head_dim
num_q_head = self.config.num_attention_heads
num_kv_head = self.config.num_key_value_heads
def load_nvfp4_weights(self, weights: Dict):
num_expert = self.config.num_local_experts
enable_attention_dp = self.model_config.mapping.enable_attention_dp
tp_size = self.model_config.mapping.tp_size
for name, module in tqdm(list(self.named_modules()),
desc="Loading weights"):
if len(module._parameters) <= 0 or name.startswith("draft_model"):
continue
names = name.split(".")
module_weights = {}
if names[-1] in self.params_map:
names[-1] = self.params_map[names[-1]]
for k, v in self.hf_params_map.items():
name = name.replace(k, v)
names = name.split('.')
if names[-1] == "backend" and isinstance(module, MoE):
# Backend is under experts module (ConfigurableMoE wrapper)
name = '.'.join(names[:-1])
# Drop the first "model" prefix
if names[0] == 'model':
name = '.'.join(names[1:])
else:
name = '.'.join(names)
module_weights = filter_weights(name, weights)
if isinstance(module, MoE):
# [num_experts, intermediate_size * 2, hidden_size]
gate_up_proj = filter_weights(name.replace("experts", "mlp1"),
weights)
# [num_experts, intermediate_size, hidden_size]
down_proj = filter_weights(name.replace("experts", "mlp2"),
weights)
try:
# Official MXFP4 ckpt.
gate_up_weight = gate_up_proj['weight.blocks'].flatten(
-2, -1)
gate, up = gate_up_weight[:, ::2, :], gate_up_weight[:, 1::
2, :]
gate_up_weight = torch.cat([gate, up], dim=-2)
gate_up_bias = gate_up_proj['bias']
gate, up = gate_up_bias[:, ::2], gate_up_bias[:, 1::2]
gate_up_bias = torch.cat([gate, up], dim=-1)
moe_weights = {
'gate_up_proj': [
gate_up_weight[i, :, :].transpose(0, 1)
for i in range(num_expert)
],
'down_proj': [
down_proj['weight.blocks'].flatten(
-2, -1)[i, :, :].transpose(0, 1)
for i in range(num_expert)
],
'gate_up_proj.bias':
[gate_up_bias[i, :] for i in range(num_expert)],
'down_proj.bias':
[down_proj['bias'][i, :] for i in range(num_expert)]
}
except:
# For BF16 ckpt.
moe_weights = {
'gate_up_proj': [
gate_up_proj['weight'][i, :, :].transpose(0, 1).to(
self.model.dtype) for i in range(num_expert)
],
'down_proj': [
down_proj['weight'][i, :, :].transpose(0, 1).to(
self.model.dtype) for i in range(num_expert)
],
'gate_up_proj.bias':
[gate_up_proj['bias'][i, :] for i in range(num_expert)],
'down_proj.bias':
[down_proj['bias'][i, :] for i in range(num_expert)]
}
# Only for Official MXFP4 ckpt.
if 'weight.scales' in gate_up_proj:
gate_up_weight_scale = gate_up_proj['weight.scales']
gate, up = gate_up_weight_scale[:, ::
2, :], gate_up_weight_scale[:,
1::
2, :]
gate_up_weight_scale = torch.cat([gate, up], dim=-2)
assert getattr(module, "quant_config", None) is not None and \
module.quant_config.quant_mode.has_nvfp4()
gate_up = module_weights.get('gate_up_proj', None)
down = module_weights.get('down_proj', None)
gate_up_bias = module_weights.get('gate_up_proj_bias', None)
down_bias = module_weights.get('down_proj_bias', None)
def deinterleave(tensor):
g, u = tensor[..., ::2], tensor[..., 1::2]
return torch.cat([g, u], dim=-1)
gate_up = deinterleave(gate_up)
gate_up_bias = deinterleave(gate_up_bias)
# Only fp32 bias is supported for NVFP4 MoE.
if gate_up_bias.dtype != torch.float32:
gate_up_bias = gate_up_bias.to(torch.float32)
if down_bias.dtype != torch.float32:
down_bias = down_bias.to(torch.float32)
moe_weights = {}
if gate_up is not None:
moe_weights['gate_up_proj'] = [
gate_up[i, :, :] for i in range(num_expert)
]
if down is not None:
moe_weights['down_proj'] = [
down[i, :, :] for i in range(num_expert)
]
if gate_up_bias is not None:
moe_weights['gate_up_proj.bias'] = [
gate_up_bias[i, :] for i in range(num_expert)
]
if down_bias is not None:
moe_weights['down_proj.bias'] = [
down_bias[i, :] for i in range(num_expert)
]
# Per-expert block scales (transpose to expected layout)
if 'gate_up_proj_weight_scale' in module_weights:
gu_ws = module_weights['gate_up_proj_weight_scale']
gu_ws = deinterleave(gu_ws)
moe_weights['gate_up_proj_weight_scale'] = [
gate_up_weight_scale[i, :, :].transpose(0, 1)
for i in range(num_expert)
gu_ws[i, :, :] for i in range(num_expert)
]
if self.model_config.quant_config.quant_algo == 'W4A16_MXFP4':
for i in range(num_expert):
moe_weights[f"{i}.w1.weight_scale_inv"] = gate[
i, :, :]
moe_weights[f"{i}.w3.weight_scale_inv"] = up[
i, :, :]
if 'weight.scales' in down_proj:
if 'down_proj_weight_scale' in module_weights:
dp_ws = module_weights['down_proj_weight_scale']
moe_weights['down_proj_weight_scale'] = [
down_proj['weight.scales'][i, :, :].transpose(0, 1)
for i in range(num_expert)
dp_ws[i, :, :] for i in range(num_expert)
]
if self.model_config.quant_config.quant_algo == 'W4A16_MXFP4':
for i in range(num_expert):
moe_weights[f"{i}.w2.weight_scale_inv"] = down_proj[
'weight.scales'][i, :, :]
# Module-level globals for NVFP4 loaders
for src_key in [
'gate_up_proj_weight_scale_2',
'down_proj_weight_scale_2',
'gate_up_proj_input_scale',
'down_proj_input_scale',
]:
if src_key in module_weights:
moe_weights[src_key] = module_weights[src_key]
module.load_weights(weights=[moe_weights])
elif hasattr(module, "load_weights"):
# Load Attention module weights.
if 'qkv' in name:
q_weight = module_weights['weight'][:head_dim *
num_q_head, :]
k_weight = module_weights['weight'][head_dim *
num_q_head:head_dim *
(num_q_head +
num_kv_head), :]
v_weight = module_weights['weight'][-head_dim *
num_kv_head:, :]
q_bias = module_weights['bias'][:head_dim * num_q_head]
k_bias = module_weights['bias'][head_dim *
num_q_head:head_dim *
(num_q_head + num_kv_head)]
v_bias = module_weights['bias'][-head_dim * num_kv_head:]
# Handle KV weight duplication for GQA
tensors_need_duplication = ['weight', 'bias']
if module.quant_config.quant_mode.has_mxfp4():
tensors_need_duplication.append('weight_scale')
# Duplicate KV weights if needed
tensor_parallel_size = tp_size if not enable_attention_dp else 1
k_weight_dict = {'weight': k_weight, 'bias': k_bias}
v_weight_dict = {'weight': v_weight, 'bias': v_bias}
if 'weight_scale' in module_weights:
k_weight_dict['weight_scale'] = module_weights[
'weight_scale'][head_dim * num_q_head:head_dim *
(num_q_head + num_kv_head), :]
v_weight_dict['weight_scale'] = module_weights[
'weight_scale'][-head_dim * num_kv_head:, :]
k_weight_dict = {
k: (duplicate_kv_weight(
weight=v,
num_kv_heads=num_kv_head,
tensor_parallel_size=tensor_parallel_size)
if k in tensors_need_duplication else v)
for k, v in k_weight_dict.items()
}
v_weight_dict = {
k: (duplicate_kv_weight(
weight=v,
num_kv_heads=num_kv_head,
tensor_parallel_size=tensor_parallel_size)
if k in tensors_need_duplication else v)
for k, v in v_weight_dict.items()
}
qkv_weights = [{
'weight': q_weight,
'bias': q_bias
}, k_weight_dict, v_weight_dict]
module.load_weights(weights=qkv_weights)
# For qkv_proj
q_weight_bias = filter_weights(
name.replace('qkv_proj', 'q_proj'), weights)
k_weight_bias = filter_weights(
name.replace('qkv_proj', 'k_proj'), weights)
v_weight_bias = filter_weights(
name.replace('qkv_proj', 'v_proj'), weights)
module.load_weights(
weights=[q_weight_bias, k_weight_bias, v_weight_bias])
else:
# Dense & gate & sinks
# For o_proj, sinks.
module.load_weights(weights=[module_weights])
else:
# Load LN weights.
if names[-1].endswith("layernorm") and names[-3] == "block":
# skip loading weights for the fused norms
# Load four LN weights (attn.norm, mlp.norm, input_layernorm, post_attention_layernorm).
if 'next_layer_layernorm' in name:
continue
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n.replace(
"weight", "scale")][:])
p.data.copy_(module_weights[n][:])

View File

@ -20,10 +20,10 @@ from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode
# isort: off
from .quantization import (
DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEMethod,
UnquantizedFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
W4A16MXFP4TRTLLMGenFusedMoEMethod)
DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod,
NVFP4TRTLLMGenFusedMoEMethod, UnquantizedFusedMoEMethod,
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod)
# isort: on
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
@ -214,7 +214,7 @@ class TRTLLMGenFusedMoE(MoE):
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None:
assert self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports mxfp4 quantization with bias, swiglu_alpha, swiglu_beta and swiglu_limit."
assert self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants."
def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
@ -222,7 +222,9 @@ class TRTLLMGenFusedMoE(MoE):
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
return DeepSeekFP8BlockScalesFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_nvfp4():
return NVFP4TRTLLMGenFusedMoEMethod()
return NVFP4TRTLLMGenFusedMoEMethod(
) if self.swiglu_alpha is not None else NVFP4TRTLLMGenFusedMoEBaseMethod(
)
elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4():
return W4A16MXFP4TRTLLMGenFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8():
@ -324,6 +326,11 @@ class TRTLLMGenFusedMoE(MoE):
self,
'fc31_act_scale') and self.fc31_act_scale is not None:
x = x * self.fc31_act_scale
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
if pad_size > 0:
x = torch.nn.functional.pad(x, (0, pad_size))
x_row = x.shape[0]
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size, False,
@ -446,6 +453,8 @@ class TRTLLMGenFusedMoE(MoE):
topk_ids=token_selected_experts,
)
elif self.has_nvfp4:
intermediate_size_per_partition_padded = self.w3_w1_weight.shape[
-2] // 2
outputs = torch.ops.trtllm.fp4_block_scale_moe_runner(
router_logits,
@ -454,8 +463,13 @@ class TRTLLMGenFusedMoE(MoE):
x_sf.view(torch.float8_e4m3fn),
self.w3_w1_weight,
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
self.w3_w1_bias if self.bias else None,
self.swiglu_alpha,
self.swiglu_beta,
self.swiglu_limit,
self.w2_weight,
self.w2_weight_scale.view(torch.float8_e4m3fn),
self.w2_bias if self.bias else None,
self.fc31_scale_c.data,
self.fc31_alpha.data,
self.fc2_alpha.data,
@ -463,7 +477,7 @@ class TRTLLMGenFusedMoE(MoE):
top_k,
n_group,
topk_group,
self.intermediate_size_per_partition,
intermediate_size_per_partition_padded,
self.slot_start,
self.expert_size_per_partition,
routed_scaling_factor,
@ -478,6 +492,11 @@ class TRTLLMGenFusedMoE(MoE):
return outputs
else:
final_hidden_states = outputs[0]
# Slice output if it was padded
if final_hidden_states.shape[1] > self.hidden_size:
final_hidden_states = final_hidden_states[:, :self.
hidden_size].contiguous(
)
elif self.has_w4a16_mxfp4:
assert x.dtype == torch.bfloat16

View File

@ -218,13 +218,11 @@ class FusedMoEMethodBase(ABC):
# bias
if module.bias:
# The shape might be padded so we use weight shape[:2]
if w3_w1_bias_shape is None:
w3_w1_bias_shape = (
module.expert_size_per_partition,
module.expand_intermediate_size_per_partition)
w3_w1_bias_shape = w3_w1_weight_shape[:2]
if w2_bias_shape is None:
w2_bias_shape = (module.expert_size_per_partition,
module.hidden_size)
w2_bias_shape = w2_weight_shape[:2]
bias_dtype = bias_dtype or module.dtype
w3_w1_bias = nn.Parameter(torch.empty(w3_w1_bias_shape,
dtype=bias_dtype),
@ -1731,7 +1729,8 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
weight_vec_size,
block_scales_dtype,
block_scales_vec_size,
scaling_vector_size=16):
scaling_vector_size=16,
bias_dtype: Optional[torch.dtype] = None):
module.scaling_vector_size = scaling_vector_size
@ -1780,7 +1779,8 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
w3_w1_weight_shape=w3_w1_weight_shape,
w2_weight_shape=w2_weight_shape,
w3_w1_bias_shape=w3_w1_bias_shape,
w2_bias_shape=w2_bias_shape)
w2_bias_shape=w2_bias_shape,
bias_dtype=bias_dtype)
self.setup_quant_scales(module)
@ -2300,7 +2300,7 @@ class NVFP4CuteDslFusedMoEMethod(NVFP4CutlassFusedMoEMethod):
dst_w3_w1_weight_scale.copy_(w3_w1_weight_scale_interleaved)
class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
class NVFP4TRTLLMGenFusedMoEBaseMethod(NVFP4FusedMoEMethod):
weight_dtype = float4_sf_dtype
block_scales_dtype = torch.float8_e4m3fn
@ -2308,12 +2308,18 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
# This assumes the same input shape always results in the same permute indices
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
def create_weights(self, module: torch.nn.Module):
def create_weights(self,
module: torch.nn.Module,
bias_dtype: Optional[torch.dtype] = None):
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
block_scales_vec_size = 1
super().create_weights(module, self.weight_dtype, weight_vec_size,
self.block_scales_dtype, block_scales_vec_size)
super().create_weights(module,
self.weight_dtype,
weight_vec_size,
self.block_scales_dtype,
block_scales_vec_size,
bias_dtype=bias_dtype)
fc31_scale_c = nn.Parameter(torch.ones(module.expert_size_per_partition,
dtype=torch.float32),
@ -2565,7 +2571,340 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
})
class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethod):
class NVFP4TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEBaseMethod):
weight_alignment = 32
input_hidden_alignment = 32
def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int,
block_scales_vec_size: int):
def round_up(x, alignment):
return (x + alignment - 1) // alignment * alignment
# Compute padded sizes
intermediate_size_per_partition_padded = round_up(
module.intermediate_size_per_partition, self.weight_alignment)
w3_w1_hidden_size_padded = round_up(module.hidden_size,
self.input_hidden_alignment)
w2_hidden_size_padded = round_up(module.hidden_size,
self.weight_alignment)
# Divide by 16 because we use int64 to pack 16 fp4 values
w3_w1_weight_shape = (module.expert_size_per_partition,
intermediate_size_per_partition_padded *
module.intermediate_size_expand_ratio,
w3_w1_hidden_size_padded // weight_vec_size)
w2_weight_shape = (module.expert_size_per_partition,
w2_hidden_size_padded,
intermediate_size_per_partition_padded //
weight_vec_size)
w3_w1_weight_scale_shape = (module.expert_size_per_partition,
intermediate_size_per_partition_padded *
module.intermediate_size_expand_ratio,
w3_w1_hidden_size_padded //
module.scaling_vector_size //
block_scales_vec_size)
w2_weight_scale_shape = (module.expert_size_per_partition,
w2_hidden_size_padded,
intermediate_size_per_partition_padded //
module.scaling_vector_size //
block_scales_vec_size)
if module.bias:
w3_w1_bias_shape = (module.expert_size_per_partition,
intermediate_size_per_partition_padded *
module.intermediate_size_expand_ratio)
w2_bias_shape = (module.expert_size_per_partition,
w2_hidden_size_padded)
else:
w3_w1_bias_shape = None
w2_bias_shape = None
return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape,
w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape)
def create_weights(self, module: torch.nn.Module):
# Here we only enable padding for hidden_size > 1024 since there are small unit tests that expect no padding.
if module.hidden_size > 1024 and module.hidden_size % 256 != 0:
self.weight_alignment = 256
# For now let's keep input alignment same as weight alignment. There are practical reasons that this might be a different value.
# See the comment in MXFP4WeightTRTLLMGenFusedMoEMethod for more details.
self.input_hidden_alignment = 256
super().create_weights(module, bias_dtype=torch.float32)
def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = tuple()
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight: torch.Tensor,
w3_weight: torch.Tensor,
dst_w3_w1_weight: torch.Tensor):
device = torch.device(f"cuda:{torch.cuda.current_device()}")
dst_on_gpu = dst_w3_w1_weight.device.type == "cuda"
dst_w3_w1_weight_gpu = dst_w3_w1_weight if dst_on_gpu else dst_w3_w1_weight.cuda(
)
alignment = _get_weight_alignment(self.weight_alignment,
module.scaling_vector_size,
module.tp_size, w1_weight.shape[0])
if len(w1_weight.shape) == 2:
# Pad weights
# We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8.
assert w1_weight.dtype == torch.uint8
w1_weight = maybe_pad_for_mxfp4(w1_weight,
self.input_hidden_alignment // 2,
alignment)
assert w3_weight.dtype == torch.uint8
w3_weight = maybe_pad_for_mxfp4(w3_weight,
self.input_hidden_alignment // 2,
alignment)
else:
# Pad bias, TRTLLM backend expects float32 bias.
assert len(w1_weight.shape) == 1
assert len(w3_weight.shape) == 1
w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).float()
w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).float()
w1_weight_shard = load_weight_shard(w1_weight,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
w3_weight_shard = load_weight_shard(w3_weight,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
# FIXME: this depends on the kernel internals
epilogue_tile_m = 128
# Keep weights in device buffer
dst_w3_weight, dst_w1_weight = dst_w3_w1_weight_gpu.chunk(2, dim=0)
dst_w3_weight.copy_(w3_weight_shard.view(dst_w3_weight.dtype))
dst_w1_weight.copy_(w1_weight_shard.view(dst_w1_weight.dtype))
# Get permute indices
permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices(
dst_w3_w1_weight_gpu, self._cache_permute_indices, epilogue_tile_m)
# Shuffle the weight according to permute indices
processed_w31_weight_shard = torch.ops.trtllm.shuffle_matrix(
dst_w3_w1_weight_gpu,
permute_indices.to(dst_w3_w1_weight_gpu.device))
# Copy the result into device buffer
dst_w3_w1_weight_gpu.copy_(processed_w31_weight_shard.view(
dst_w3_w1_weight_gpu.dtype),
non_blocking=dst_on_gpu)
if not dst_on_gpu:
dst_w3_w1_weight.copy_(dst_w3_w1_weight_gpu)
def load_expert_w2_weight(self, module: torch.nn.Module,
w2_weight: torch.Tensor,
dst_w2_weight: torch.Tensor):
device = torch.device(f"cuda:{torch.cuda.current_device()}")
dst_on_gpu = dst_w2_weight.device.type == "cuda"
dst_w2_weight_gpu = dst_w2_weight if dst_on_gpu else dst_w2_weight.cuda(
)
shard_w2_weight_dim = 2 * w2_weight.shape[1] if len(
w2_weight.shape) == 2 else w2_weight.shape[0]
alignment = _get_weight_alignment(self.weight_alignment,
module.scaling_vector_size,
module.tp_size, shard_w2_weight_dim)
if len(w2_weight.shape) == 2:
assert w2_weight.dtype == torch.uint8
w2_weight = maybe_pad_for_mxfp4(w2_weight, alignment // 2,
self.weight_alignment)
else:
assert len(w2_weight.shape) == 1
w2_weight = maybe_pad_for_mxfp4(w2_weight, self.weight_alignment)
# Divide bias by tp_size as we shard along the hidden dimension.
# The bias is applied at each TP rank before the final accumulation.
w2_weight /= module.tp_size
w2_weight_shard = load_weight_shard(w2_weight,
module.tp_size,
module.tp_rank,
TensorParallelMode.ROW,
device=device)
# FIXME: this depends on the kernel internals
epilogue_tile_m = 128
# Keep weights in device buffer
dst_w2_weight_gpu.copy_(w2_weight_shard.view(dst_w2_weight_gpu.dtype),
non_blocking=dst_on_gpu)
# Get permuted indices
permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices(
dst_w2_weight_gpu, self._cache_permute_indices, epilogue_tile_m)
# Shuffle the weight according to permute indices
processed_w2_weight = torch.ops.trtllm.shuffle_matrix(
dst_w2_weight_gpu, permute_indices.to(dst_w2_weight_gpu.device))
# Copy the result into device buffer
dst_w2_weight_gpu.copy_(processed_w2_weight.view(
dst_w2_weight_gpu.dtype),
non_blocking=dst_on_gpu)
if not dst_on_gpu:
dst_w2_weight.copy_(dst_w2_weight_gpu)
def load_expert_w3_w1_weight_scale_nvfp4(
self,
module: torch.nn.Module,
w1_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
dst_w3_w1_weight_scale: torch.Tensor,
num_elts_per_sf: int = 16):
device = torch.device(f"cuda:{torch.cuda.current_device()}")
dst_on_gpu = dst_w3_w1_weight_scale.device.type == "cuda"
dst_w3_w1_weight_scale_gpu = dst_w3_w1_weight_scale if dst_on_gpu else dst_w3_w1_weight_scale.cuda(
)
alignment = _get_weight_alignment(self.weight_alignment,
module.scaling_vector_size,
module.tp_size,
w3_weight_scale.shape[0])
w1_weight_scale = maybe_pad_for_mxfp4(
w1_weight_scale,
self.input_hidden_alignment // module.scaling_vector_size,
alignment)
w3_weight_scale = maybe_pad_for_mxfp4(
w3_weight_scale,
self.input_hidden_alignment // module.scaling_vector_size,
alignment)
w1_weight_scale = load_weight_shard(w1_weight_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
w3_weight_scale = load_weight_shard(w3_weight_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
# Keep weights in device buffer
dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale_gpu.chunk(
2, dim=0)
dst_w3_weight_scale.copy_(
w3_weight_scale.view(dst_w3_weight_scale.dtype))
dst_w1_weight_scale.copy_(
w1_weight_scale.view(dst_w1_weight_scale.dtype))
orig_shape = dst_w3_w1_weight_scale_gpu.shape
# trtllm-gen specific block scales preprocessing logics
epilogue_tile_m = 128 # FIXME
# Get permute indices
permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices(
dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype),
self._cache_permute_indices,
epilogue_tile_m,
num_elts_per_sf=num_elts_per_sf)
# Shuffle the weight according to permute indices
w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix(
dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype), permute_indices)
# Assert should only be removed during debugging
assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed"
# Interleave the weight.
processed_w3_w1_weight_scale = torch.ops.trtllm.block_scale_interleave(
w3_w1_weight_scale.view(float4_sf_dtype).reshape(orig_shape))
# Copy the result into device buffer
dst_w3_w1_weight_scale_gpu.copy_(
processed_w3_w1_weight_scale.view(
self.block_scales_dtype).reshape(orig_shape))
if not dst_on_gpu:
dst_w3_w1_weight_scale.copy_(dst_w3_w1_weight_scale_gpu)
def load_expert_w2_weight_scale_nvfp4(self,
module: torch.nn.Module,
w2_weight_scale: torch.Tensor,
dst_w2_weight_scale: torch.Tensor,
num_elts_per_sf: int = 16):
device = torch.device(f"cuda:{torch.cuda.current_device()}")
dst_on_gpu = dst_w2_weight_scale.device.type == "cuda"
dst_w2_weight_scale_gpu = dst_w2_weight_scale if dst_on_gpu else dst_w2_weight_scale.cuda(
)
alignment = _get_weight_alignment(self.weight_alignment,
module.scaling_vector_size,
module.tp_size,
w2_weight_scale.shape[-1])
w2_weight_scale = maybe_pad_for_mxfp4(
w2_weight_scale, alignment // module.scaling_vector_size,
self.weight_alignment)
w2_weight_scale = load_weight_shard(w2_weight_scale,
module.tp_size,
module.tp_rank,
TensorParallelMode.ROW,
device=device)
# Keep weights in device buffer
dst_w2_weight_scale_gpu.copy_(
w2_weight_scale.view(dst_w2_weight_scale_gpu.dtype))
orig_shape = dst_w2_weight_scale_gpu.shape
# trtllm-gen specific block scales preprocessing logics
epilogue_tile_m = 128 # FIXME: read from kernel
# Assert should only be removed during debugging
assert dst_w2_weight_scale_gpu.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed"
# Get permute indices
permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices(
dst_w2_weight_scale_gpu.view(float4_sf_dtype),
self._cache_permute_indices,
epilogue_tile_m,
num_elts_per_sf=num_elts_per_sf)
# Shuffle the weight according to permute indices
w_shuffled = torch.ops.trtllm.shuffle_matrix(
dst_w2_weight_scale_gpu.view(dtype=float4_sf_dtype),
permute_indices)
# Interleave the weight.
processed_w2_weight_scale = torch.ops.trtllm.block_scale_interleave(
w_shuffled)
# Copy the result into device buffer
dst_w2_weight_scale_gpu.copy_(
processed_w2_weight_scale.view(
self.block_scales_dtype).reshape(orig_shape))
if not dst_on_gpu:
dst_w2_weight_scale.copy_(dst_w2_weight_scale_gpu)
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
super().load_quant_scales(module, weights)
# Normalize biases to account for the global scale factors,
# matching the kernel's expectation (similar to test_moe.py logic).
if module.w3_w1_bias is not None:
# gemm1_bias * gemm1_scales_global * hidden_states_scale_global
module.w3_w1_bias.data.div_((module.fc31_alpha.data).view(-1, 1))
if module.w2_bias is not None:
# gemm2_bias * c_global_sf * gemm2_scales_global
module.w2_bias.data.div_((module.fc2_alpha.data).view(-1, 1))
if module.swiglu_beta is not None:
module.swiglu_beta.data.div_((module.fc31_alpha.data))
if module.swiglu_limit is not None:
module.swiglu_limit.data.div_((module.fc31_alpha.data))
class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEBaseMethod):
def create_weights(self, module: torch.nn.Module):
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4

View File

@ -270,6 +270,13 @@ GPT-OSS/20B-MXFP4:
- quant_algo: W4A16_MXFP4
kv_cache_quant_algo: FP8
accuracy: 85.0
GPT-OSS/20B-NVFP4:
- accuracy: 85.0
- quant_algo: NVFP4
accuracy: 85.0
- quant_algo: NVFP4
kv_cache_quant_algo: FP8
accuracy: 85.0
LGAI-EXAONE/EXAONE-4.0-32B:
- accuracy: 88.36
ByteDance-Seed/Seed-OSS-36B-Instruct:

View File

@ -4270,6 +4270,44 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
task.evaluate(llm,
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_blackwell
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [
(2, 1, 1, False, True, True),
(2, 1, 2, False, True, True),
(2, 1, 2, True, True, True),
],
ids=["tp2", "ep2", "dp2"])
def test_w4_2gpus_nvfp4(self, tp_size, pp_size, ep_size, attention_dp,
cuda_graph, overlap_scheduler, mocker):
pytest.skip("Models not uploaded to CI")
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
dtype="auto")
llm = LLM("./nvfp4ckpt",
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
max_seq_len=8192,
**pytorch_config,
enable_attention_dp=attention_dp,
moe_config=MoeConfig(backend="TRTLLM"))
with llm:
model_name = "GPT-OSS/20B-NVFP4"
task = GSM8K(model_name)
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
{"scores_filter": "exact_match,flexible-extract"})
task.evaluate(llm,
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize(
"kv_cache_dtype",

View File

@ -11,6 +11,7 @@ import cloudpickle
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from _torch.helpers import (calc_woq_tolerence, per_block_cast_to_fp8,
per_block_cast_to_fp8_e8m0,
per_token_cast_to_fp8_e8m0)
@ -1398,6 +1399,53 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
and moe_backend in ["TRTLLM", "CUTLASS"] else "0"
})
run_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion)
@skip_pre_blackwell
@pytest.mark.parametrize("hidden_size, intermediate_size", [(2880, 2880)])
@pytest.mark.parametrize("swiglu_alpha", [1, 0.1], ids=lambda v: f"alpha{v}")
@pytest.mark.parametrize("swiglu_beta", [0, 1], ids=lambda v: f"beta{v}")
@pytest.mark.parametrize("swiglu_limit", [float("inf"), 1],
ids=lambda v: f"limit{v}")
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
ids=lambda x: ""
if x == 0 else "enable_configurable_moe")
def test_fused_moe_nvfp4_gptoss_style(hidden_size, intermediate_size,
swiglu_alpha, swiglu_beta, swiglu_limit,
enable_configurable_moe, mocker):
mocker.patch.dict(os.environ, {
"ENABLE_CONFIGURABLE_MOE":
"1" if enable_configurable_moe == 1 else "0"
})
run_fused_moe_nvfp4(dtype=torch.bfloat16,
moe_backend="TRTLLM",
finalize_fusion=False,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_experts=32,
top_k=4,
seq_len=256,
gptoss_style=True,
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit)
def run_fused_moe_nvfp4(dtype,
moe_backend,
finalize_fusion,
hidden_size=512,
intermediate_size=512,
num_experts=8,
top_k=2,
seq_len=4,
gptoss_style=False,
swiglu_alpha=None,
swiglu_beta=None,
swiglu_limit=None):
if moe_backend == "TRTLLM":
if dtype == torch.float16:
pytest.skip("TRTLLM NVFP4 MoE backend does not support float16 yet")
@ -1424,11 +1472,11 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
with torch.device(f"cuda:{mapping.rank}"):
SCALING_VECTOR_SIZE = 16
SEQ_LEN = 4
HIDDEN_SIZE = 512
INTERMEDIATE_SIZE = 512
NUM_EXPERTS = 8
TOP_K = 2
SEQ_LEN = seq_len
HIDDEN_SIZE = hidden_size
INTERMEDIATE_SIZE = intermediate_size
NUM_EXPERTS = num_experts
TOP_K = top_k
routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
@ -1455,24 +1503,38 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
device="cuda") * 0.05
w3_sf_global = (448 * 6) / w3_weight.abs().max().float()
if gptoss_style:
w1_bias = torch.randn(INTERMEDIATE_SIZE,
device='cuda',
dtype=torch.float)
w2_bias = torch.randn(HIDDEN_SIZE,
device='cuda',
dtype=torch.float)
w3_bias = torch.randn(INTERMEDIATE_SIZE,
device='cuda',
dtype=torch.float)
weights[f"{expert_id}.w1.bias"] = w1_bias
weights[f"{expert_id}.w2.bias"] = w2_bias
weights[f"{expert_id}.w3.bias"] = w3_bias
w3_w1_global = min(
w1_sf_global,
w3_sf_global) # w3 global and w1 global must be the same
w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize(
w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
w1_weight_nvfp4, w1_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize(
w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False, False)
w1_sf_block_unswizzled = w1_sf_block_unswizzled.view(
INTERMEDIATE_SIZE, -1)
w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize(
w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False)
w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w2_sf_block.cpu().view(HIDDEN_SIZE, -1))
w2_weight_nvfp4, w2_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize(
w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False, False)
w2_sf_block_unswizzled = w2_sf_block_unswizzled.view(
HIDDEN_SIZE, -1)
w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize(
w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
w3_weight_nvfp4, w3_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize(
w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False, False)
w3_sf_block_unswizzled = w3_sf_block_unswizzled.view(
INTERMEDIATE_SIZE, -1)
w1_input_scale = x_sf_global.cuda()
w2_input_scale = x_sf_global.cuda()
@ -1497,6 +1559,23 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global
swiglu_alpha_tensor = None
swiglu_beta_tensor = None
swiglu_limit_tensor = None
if gptoss_style:
swiglu_alpha_tensor = torch.full((NUM_EXPERTS, ),
swiglu_alpha,
device='cuda',
dtype=torch.float)
swiglu_beta_tensor = torch.full((NUM_EXPERTS, ),
swiglu_beta,
device='cuda',
dtype=torch.float)
swiglu_limit_tensor = torch.full((NUM_EXPERTS, ),
swiglu_limit,
device='cuda',
dtype=torch.float)
quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4)
# Create pretrained_config with necessary parameters
@ -1514,6 +1593,10 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
quant_config=quant_config,
moe_backend=moe_backend,
moe_disable_finalize_fusion=not finalize_fusion),
bias=gptoss_style,
swiglu_alpha=swiglu_alpha_tensor,
swiglu_beta=swiglu_beta_tensor,
swiglu_limit=swiglu_limit_tensor,
)
fused_moe.load_weights([weights])
fused_moe.post_load_weights()
@ -1526,7 +1609,11 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
model_config=ModelConfig(quant_config=quant_config))
model_config=ModelConfig(quant_config=quant_config),
bias=gptoss_style,
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit)
ref_fused_moe.load_weights([weights])
ref_fused_moe.cuda()
@ -1534,11 +1621,33 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
with torch.inference_mode():
ref_output = ref_fused_moe.forward(x, router_logits)
with torch.inference_mode(), autotune():
fused_moe.forward(x, router_logits)
if not gptoss_style:
with torch.inference_mode(), autotune():
fused_moe.forward(x, router_logits)
else:
# We skip autotune for gptoss style to reduce memory usage since the input shape is already quite large.
with torch.inference_mode():
fused_moe.forward(x, router_logits)
output = fused_moe.forward(x, router_logits)
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15)
if gptoss_style:
rtol = 0.1
atol = 0.1
percent = 0.95
else:
rtol = 1e-2
atol = 0.15
percent = None
if gptoss_style:
check_accuracy(output,
ref_output,
rtol=rtol,
atol=atol,
percent=percent)
else:
torch.testing.assert_close(output, ref_output, rtol=rtol, atol=atol)
if not test_all_kernels:
return
@ -1551,10 +1660,17 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
for tactic in all_tactics:
with AutoTuner.get().replay(tactic), torch.inference_mode():
output = fused_moe.forward(x, router_logits)
torch.testing.assert_close(output,
ref_output,
rtol=1e-2,
atol=0.15)
if gptoss_style:
check_accuracy(output,
ref_output,
rtol=rtol,
atol=atol,
percent=percent)
else:
torch.testing.assert_close(output,
ref_output,
rtol=rtol,
atol=atol)
@skip_pre_blackwell
@ -2690,7 +2806,10 @@ class RefGatedMLPFusedMoE(nn.Module):
dtype: Optional[torch.dtype] = None,
model_config: ModelConfig = ModelConfig(),
use_cute_dsl_blockscaling_mm: bool = False,
bias=False):
bias=False,
swiglu_alpha: Optional[float] = None,
swiglu_beta: Optional[float] = None,
swiglu_limit: Optional[float] = None):
super().__init__()
self.num_experts = num_experts
self.routing_method = routing_method
@ -2701,6 +2820,19 @@ class RefGatedMLPFusedMoE(nn.Module):
self.dtype = dtype
self.quant_config = model_config.quant_config
def custom_swiglu(x):
gate, value = x.chunk(2, dim=-1)
if swiglu_limit is not None and swiglu_limit != float("inf"):
gate = gate.clamp(max=swiglu_limit)
value = value.clamp(min=-swiglu_limit, max=swiglu_limit)
alpha = swiglu_alpha if swiglu_alpha is not None else 1.0
gate_act = gate * torch.sigmoid(gate * alpha)
beta = swiglu_beta if swiglu_beta is not None else 0.0
return gate_act * (value + beta)
self.experts = nn.ModuleList([
GatedMLP(
hidden_size=self.hidden_size,
@ -2709,6 +2841,8 @@ class RefGatedMLPFusedMoE(nn.Module):
dtype=self.dtype,
config=model_config,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
activation=custom_swiglu
if swiglu_alpha is not None else F.silu,
) for _ in range(self.num_experts)
])

View File

@ -547,11 +547,23 @@ def run_moe_reference_fp4(args):
args.gemm2_weights, args.gemm2_scales, 1 / args.gemm2_scales_global,
sf_vec_size).cuda()
args_dequant = moe_args_dequant(
args.num_tokens, args.num_experts, args.hidden_size,
args.intermediate_size, args.top_k, args.padding, hidden_states_dequant,
args.expert_logits, gemm1_weights_dequant, gemm2_weights_dequant,
args.permute_info, args.use_routing_scales_on_input)
args_dequant = moe_args_dequant(args.num_tokens,
args.num_experts,
args.hidden_size,
args.intermediate_size,
args.top_k,
args.padding,
hidden_states_dequant,
args.expert_logits,
gemm1_weights_dequant,
gemm2_weights_dequant,
args.permute_info,
args.use_routing_scales_on_input,
gemm1_bias=args.gemm1_bias,
gemm1_alpha=args.gemm1_alpha,
gemm1_beta=args.gemm1_beta,
gemm1_clamp_limit=args.gemm1_clamp_limit,
gemm2_bias=args.gemm2_bias)
return run_moe_dequant(args_dequant, "fp4"), args_dequant
@ -1157,6 +1169,44 @@ class TestMoeFp4:
use_autotune=False,
use_topk_as_input=use_topk_as_input)
@pytest.mark.parametrize("num_tokens", [1])
@pytest.mark.parametrize("hidden_size", [512])
@pytest.mark.parametrize("intermediate_size", [512])
@pytest.mark.parametrize(
"routing_info",
[
pytest.param(
{
"num_experts": 128,
"top_k": 4,
"n_groups": None,
"top_k_groups": None,
"routed_scaling": None,
"has_routing_bias": False,
"routing_method_type": RoutingMethodType.Renormalize
},
id="RoutingGPTOSS")
],
)
@pytest.mark.parametrize("swiglu_alpha", [1, 0.1],
ids=lambda v: f"alpha{v}")
@pytest.mark.parametrize("swiglu_beta", [0, 1], ids=lambda v: f"beta{v}")
@pytest.mark.parametrize("swiglu_limit", [float("inf"), 1],
ids=lambda v: f"limit{v}")
def test_gptoss_style_nvfp4(self, num_tokens, hidden_size,
intermediate_size, routing_info, swiglu_alpha,
swiglu_beta, swiglu_limit):
self.run_moe_fp4_test(num_tokens,
hidden_size,
intermediate_size,
routing_info,
use_autotune=False,
gptoss_style=True,
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit)
@pytest.mark.parametrize("num_tokens", [1])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [1024])
@ -1219,9 +1269,17 @@ class TestMoeFp4:
use_autotune=True,
use_topk_as_input=True)
def run_moe_fp4_test(self, num_tokens: int, hidden_size: int,
intermediate_size: int, routing_info: dict,
use_autotune: bool, use_topk_as_input: bool) -> None:
def run_moe_fp4_test(self,
num_tokens: int,
hidden_size: int,
intermediate_size: int,
routing_info: dict,
use_autotune: bool,
use_topk_as_input: bool = False,
gptoss_style: bool = False,
swiglu_alpha: float = None,
swiglu_beta: float = None,
swiglu_limit: float = None) -> None:
torch.random.manual_seed(0)
@ -1289,6 +1347,39 @@ class TestMoeFp4:
device='cuda',
dtype=torch.bfloat16)
gemm1_bias = None
gemm2_bias = None
swiglu_alpha_tensor = None
swiglu_beta_tensor = None
swiglu_limit_tensor = None
if gptoss_style:
gemm1_bias = 50 * torch.randn(num_experts,
2 * intermediate_size,
device='cuda',
dtype=torch.float)
gemm2_bias = 50 * torch.randn(
num_experts, hidden_size, device='cuda', dtype=torch.float)
# waived due to missing kernel support for bias in nvfp4
#gemm1_bias[:] = 0
#gemm2_bias[:] = 0
swiglu_alpha_tensor = torch.full((num_experts, ),
swiglu_alpha,
device='cuda',
dtype=torch.float)
swiglu_beta_tensor = torch.full((num_experts, ),
swiglu_beta,
device='cuda',
dtype=torch.float)
swiglu_limit_tensor = torch.full((num_experts, ),
swiglu_limit,
device='cuda',
dtype=torch.float)
use_ue8m0 = False
# Quantize hidden states. Produces scales for activations in 128x4 layout for ref impl.
hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, hidden_states_scale_global = quant_fp4(
@ -1343,14 +1434,29 @@ class TestMoeFp4:
permute_info, scores = routing_reference_renormalize_naive(
expert_logits, top_k, padding)
args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size,
top_k, padding, hidden_states_fp4_bytes,
args = moe_args(num_tokens,
num_experts,
hidden_size,
intermediate_size,
top_k,
padding,
hidden_states_fp4_bytes,
hidden_states_scale_fp4_bytes,
hidden_states_scale_global, scores,
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes,
gemm1_scales_global, gemm2_weights_fp4_bytes,
gemm2_scales_fp4_bytes, gemm2_scales_global,
permute_info, False)
hidden_states_scale_global,
scores,
gemm1_weights_fp4_bytes,
gemm1_scales_fp4_bytes,
gemm1_scales_global,
gemm2_weights_fp4_bytes,
gemm2_scales_fp4_bytes,
gemm2_scales_global,
permute_info,
False,
gemm1_bias=gemm1_bias,
gemm1_alpha=swiglu_alpha_tensor,
gemm1_beta=swiglu_beta_tensor,
gemm1_clamp_limit=swiglu_limit_tensor,
gemm2_bias=gemm2_bias)
#
# Run the reference implementations
#
@ -1364,12 +1470,17 @@ class TestMoeFp4:
# Reorder rows of W1 and scales for fused gated activation
gemm1_weights_fp4_interleaved = []
gemm1_scales_fp4_interleaved = []
gemm1_bias_interleaved = []
for i in range(num_experts):
gemm1_weights_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
gemm1_scales_fp4_interleaved.append(
reorder_rows_for_gated_act_gemm(
gemm1_scales_linear_fp4[i].clone()))
if gemm1_bias is not None:
gemm1_bias_interleaved.append(
reorder_rows_for_gated_act_gemm(
gemm1_bias[i].clone().reshape(-1, 1)))
# Stack weights and scales for all experts
gemm1_weights_fp4_interleaved = torch.stack(
@ -1384,8 +1495,10 @@ class TestMoeFp4:
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_fp4_shuffled = []
gemm1_scales_fp4_shuffled = []
gemm1_bias_shuffled = []
gemm2_weights_fp4_shuffled = []
gemm2_scales_fp4_shuffled = []
gemm2_bias_shuffled = []
for i in range(num_experts):
gemm1_weights_fp4_shuffled.append(
shuffle_matrix_a(
@ -1395,6 +1508,10 @@ class TestMoeFp4:
shuffle_matrix_sf_a(
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
epilogue_tile_m))
if gemm1_bias is not None:
gemm1_bias_shuffled.append(
shuffle_matrix_a(gemm1_bias_interleaved[i],
epilogue_tile_m))
gemm2_weights_fp4_shuffled.append(
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
@ -1403,6 +1520,10 @@ class TestMoeFp4:
shuffle_matrix_sf_a(
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m))
if gemm2_bias is not None:
gemm2_bias_shuffled.append(
shuffle_matrix_a(gemm2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m))
# Stack weights for all experts
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
@ -1415,10 +1536,35 @@ class TestMoeFp4:
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
intermediate_size // 16)
if gemm1_bias is not None:
gemm1_bias_shuffled = torch.stack(gemm1_bias_shuffled).reshape(
num_experts, -1)
else:
gemm1_bias_shuffled = None
if gemm2_bias is not None:
gemm2_bias_shuffled = torch.stack(gemm2_bias_shuffled).reshape(
num_experts, -1)
else:
gemm2_bias_shuffled = None
#
# Run the TRT-LLM kernel
#
if gptoss_style:
# NOTE: correct the beta and clamp to account for the global scale factor
# Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h
# for more details
swiglu_beta_tensor = swiglu_beta_tensor * args.gemm1_scales_global * args.hidden_states_scale_global
swiglu_limit_tensor = swiglu_limit_tensor * args.gemm1_scales_global * args.hidden_states_scale_global
# Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h
# for more details
gemm1_bias_shuffled = gemm1_bias_shuffled * args.gemm1_scales_global[:,
None] * args.hidden_states_scale_global
gemm2_bias_shuffled = gemm2_bias_shuffled * args_dequant.c_global_sf * args.gemm2_scales_global[:,
None]
# c_global_sf: fc2_input_scale
scale_c_fc1 = args_dequant.c_global_sf * (
1.0 / args.gemm1_scales_global) * (1.0 /
@ -1449,8 +1595,13 @@ class TestMoeFp4:
hidden_states_scale_linear_fp4,
gemm1_weights_fp4_shuffled,
gemm1_scales_fp4_shuffled,
gemm1_bias_shuffled, # Bias
swiglu_alpha_tensor, # Alpha
swiglu_beta_tensor, # Beta
swiglu_limit_tensor, # Limit
gemm2_weights_fp4_shuffled,
gemm2_scales_fp4_shuffled,
gemm2_bias_shuffled, # Bias
scale_c_fc1,
scale_gate_fc1,
scale_c_fc2,
@ -1469,11 +1620,20 @@ class TestMoeFp4:
torch.cuda.synchronize()
output_dequant_actual = output[0].to(torch.float)
if gptoss_style:
atol = 0.2
rtol = 0.2
percent = 0.85
else:
atol = 0.1
rtol = 0.85
percent = 0.925
check_accuracy(output_dequant_reference,
output_dequant_actual,
atol=0.1,
rtol=0.85,
percent=0.925)
atol=atol,
rtol=rtol,
percent=percent)
def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int,
intermediate_size: int, routing_info: dict,