mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Support nvfp4 for gptoss (#8956)
Signed-off-by: Dongfeng Yu <dongfengy@nvidia.com>
This commit is contained in:
parent
a4dcc6a711
commit
afc533193d
@ -34,15 +34,17 @@ using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
|
||||
std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch::Tensor> const& routing_logits,
|
||||
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
|
||||
torch::optional<torch::Tensor> const& hidden_states_scale, torch::Tensor const& gemm1_weights,
|
||||
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
|
||||
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
|
||||
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
|
||||
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
|
||||
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
|
||||
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
|
||||
int64_t const routing_method_type, bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner,
|
||||
int64_t const moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
|
||||
torch::optional<torch::Tensor> const& topk_ids)
|
||||
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
|
||||
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
|
||||
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
|
||||
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
|
||||
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
|
||||
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
|
||||
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
|
||||
int64_t const local_expert_offset, int64_t const local_num_experts,
|
||||
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type,
|
||||
bool const do_finalize, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t const moeConfigIndex,
|
||||
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
|
||||
{
|
||||
TORCH_CHECK(dtype == btg::Dtype::E4m3 || dtype == btg::Dtype::E2m1, "dtype can only be e4m3 or e2m1.");
|
||||
TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by FP4 block scale MOE");
|
||||
@ -161,8 +163,13 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
|
||||
|
||||
args.gemm1_weights = gemm1_weights.data_ptr();
|
||||
args.gemm1_weights_scale = gemm1_weights_scale.data_ptr();
|
||||
args.gemm1_bias = gemm1_bias.has_value() ? gemm1_bias.value().data_ptr<float>() : nullptr;
|
||||
args.gemm1_alpha = gemm1_alpha.has_value() ? gemm1_alpha.value().data_ptr<float>() : nullptr;
|
||||
args.gemm1_beta = gemm1_beta.has_value() ? gemm1_beta.value().data_ptr<float>() : nullptr;
|
||||
args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr<float>() : nullptr;
|
||||
args.gemm2_weights = gemm2_weights.data_ptr();
|
||||
args.gemm2_weights_scale = gemm2_weights_scale.data_ptr();
|
||||
args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr<float>() : nullptr;
|
||||
args.num_tokens = hidden_states.sizes()[0];
|
||||
args.num_experts = num_experts;
|
||||
if (dtype == btg::Dtype::E4m3)
|
||||
@ -313,6 +320,38 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
|
||||
TORCH_CHECK(intermediate_size % 16 == 0, "the second dimension of weights must be a multiple of 16.");
|
||||
TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size, "gemm1_weights_scale has incorrect dim 1.");
|
||||
|
||||
if (gemm1_bias.has_value())
|
||||
{
|
||||
TORCH_CHECK(gemm1_bias.value().scalar_type() == at::ScalarType::Float, "gemm1_bias must be float, got %s.",
|
||||
c10::toString(gemm1_bias.value().scalar_type()));
|
||||
TORCH_CHECK(gemm1_bias.value().dim() == 2, "gemm1_bias must be 2D.");
|
||||
TORCH_CHECK(gemm1_bias.value().sizes()[0] == local_num_experts, "gemm1_bias has incorrect dim 0.");
|
||||
TORCH_CHECK(gemm1_bias.value().sizes()[1] == 2 * intermediate_size, "gemm1_bias has incorrect dim 1.");
|
||||
}
|
||||
|
||||
if (gemm1_alpha.has_value())
|
||||
{
|
||||
TORCH_CHECK(gemm1_alpha.value().scalar_type() == at::ScalarType::Float, "gemm1_alpha must be float, got %s.",
|
||||
c10::toString(gemm1_alpha.value().scalar_type()));
|
||||
TORCH_CHECK(gemm1_alpha.value().dim() == 1, "gemm1_alpha must be 1D.");
|
||||
TORCH_CHECK(gemm1_alpha.value().sizes()[0] == local_num_experts, "gemm1_alpha has incorrect dim 0.");
|
||||
}
|
||||
if (gemm1_beta.has_value())
|
||||
{
|
||||
TORCH_CHECK(gemm1_beta.value().scalar_type() == at::ScalarType::Float, "gemm1_beta must be float, got %s.",
|
||||
c10::toString(gemm1_beta.value().scalar_type()));
|
||||
TORCH_CHECK(gemm1_beta.value().dim() == 1, "gemm1_beta must be 1D.");
|
||||
TORCH_CHECK(gemm1_beta.value().sizes()[0] == local_num_experts, "gemm1_beta has incorrect dim 0.");
|
||||
}
|
||||
if (gemm1_clamp_limit.has_value())
|
||||
{
|
||||
TORCH_CHECK(gemm1_clamp_limit.value().scalar_type() == at::ScalarType::Float,
|
||||
"gemm1_clamp_limit must be float, got %s.", c10::toString(gemm1_clamp_limit.value().scalar_type()));
|
||||
TORCH_CHECK(gemm1_clamp_limit.value().dim() == 1, "gemm1_clamp_limit must be 1D.");
|
||||
TORCH_CHECK(
|
||||
gemm1_clamp_limit.value().sizes()[0] == local_num_experts, "gemm1_clamp_limit has incorrect dim 0.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(gemm2_weights.scalar_type() == FLOAT4_E2M1X2, "gemm2_weights must be byte.");
|
||||
|
||||
TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D.");
|
||||
@ -322,6 +361,15 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
|
||||
|
||||
TORCH_CHECK(gemm2_weights_scale.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm2_weights_scale must be fp8.");
|
||||
|
||||
if (gemm2_bias.has_value())
|
||||
{
|
||||
TORCH_CHECK(gemm2_bias.value().scalar_type() == at::ScalarType::Float, "gemm2_bias must be float, got %s.",
|
||||
c10::toString(gemm2_bias.value().scalar_type()));
|
||||
TORCH_CHECK(gemm2_bias.value().dim() == 2, "gemm2_bias must be 2D.");
|
||||
TORCH_CHECK(gemm2_bias.value().sizes()[0] == local_num_experts, "gemm2_bias has incorrect dim 0.");
|
||||
TORCH_CHECK(gemm2_bias.value().sizes()[1] == args.hidden_size, "gemm2_bias has incorrect dim 1.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
|
||||
TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts, "gemm2_weights_scale has incorrect dim 0.");
|
||||
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size, "gemm2_weights_scale has incorrect dim 1.");
|
||||
@ -440,14 +488,17 @@ public:
|
||||
[[nodiscard]] std::vector<torch::Tensor> run(torch::optional<torch::Tensor> const& routing_logits,
|
||||
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
|
||||
torch::Tensor const& hidden_states_scale, torch::Tensor const& gemm1_weights,
|
||||
torch::Tensor const& gemm1_weights_scale, torch::Tensor const& gemm2_weights,
|
||||
torch::Tensor const& gemm2_weights_scale, torch::Tensor const& output1_scales_scalar,
|
||||
torch::Tensor const& output1_scales_gate_scalar, torch::Tensor const& output2_scales_scalar,
|
||||
int64_t const num_experts, int64_t const top_k, std::optional<int64_t> const n_group,
|
||||
std::optional<int64_t> const topk_group, int64_t const intermediate_size, int64_t const local_expert_offset,
|
||||
int64_t const local_num_experts, std::optional<double> const routed_scaling_factor,
|
||||
int64_t const routing_method_type, bool const do_finalize, std::vector<int64_t> moeConfigIndex,
|
||||
torch::optional<torch::Tensor> const& topk_weights, torch::optional<torch::Tensor> const& topk_ids)
|
||||
torch::Tensor const& gemm1_weights_scale, std::optional<torch::Tensor> const& gemm1_bias,
|
||||
std::optional<torch::Tensor> const& gemm1_alpha, std::optional<torch::Tensor> const& gemm1_beta,
|
||||
std::optional<torch::Tensor> const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights,
|
||||
torch::Tensor const& gemm2_weights_scale, std::optional<torch::Tensor> const& gemm2_bias,
|
||||
torch::Tensor const& output1_scales_scalar, torch::Tensor const& output1_scales_gate_scalar,
|
||||
torch::Tensor const& output2_scales_scalar, int64_t const num_experts, int64_t const top_k,
|
||||
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
|
||||
int64_t const local_expert_offset, int64_t const local_num_experts,
|
||||
std::optional<double> const routed_scaling_factor, int64_t const routing_method_type, bool const do_finalize,
|
||||
std::vector<int64_t> moeConfigIndex, torch::optional<torch::Tensor> const& topk_weights,
|
||||
torch::optional<torch::Tensor> const& topk_ids)
|
||||
{
|
||||
// moeConfigIndex corresponds to pair (tileN, config)
|
||||
auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]);
|
||||
@ -468,10 +519,11 @@ public:
|
||||
}
|
||||
|
||||
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale,
|
||||
gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output1_scales_scalar,
|
||||
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
|
||||
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
|
||||
routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config, topk_weights, topk_ids);
|
||||
gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
|
||||
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar,
|
||||
num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
|
||||
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeElt, *mRunners[tileN], config,
|
||||
topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -553,11 +605,11 @@ public:
|
||||
}
|
||||
|
||||
return run_fp4_block_scale_moe_runner(routing_logits, routing_bias, hidden_states,
|
||||
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, gemm2_weights,
|
||||
gemm2_weights_scale, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts,
|
||||
top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
|
||||
routed_scaling_factor, tileN, routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config,
|
||||
topk_weights, topk_ids);
|
||||
std::nullopt /*hidden_states_scale*/, gemm1_weights, gemm1_weights_scale, std::nullopt, std::nullopt,
|
||||
std::nullopt, std::nullopt, gemm2_weights, gemm2_weights_scale, std::nullopt, output1_scales_scalar,
|
||||
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
|
||||
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
|
||||
routing_method_type, do_finalize, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -176,8 +176,13 @@ class FP4BlockScaleMoEInputs:
|
||||
hidden_states_scale: torch.Tensor
|
||||
gemm1_weights: torch.Tensor
|
||||
gemm1_weights_scale: torch.Tensor
|
||||
gemm1_bias: torch.Tensor
|
||||
gemm1_alpha: torch.Tensor
|
||||
gemm1_beta: torch.Tensor
|
||||
gemm1_clamp_limit: torch.Tensor
|
||||
gemm2_weights: torch.Tensor
|
||||
gemm2_weights_scale: torch.Tensor
|
||||
gemm2_bias: torch.Tensor
|
||||
output1_scale_scalar: torch.Tensor
|
||||
output1_scale_gate_scalar: torch.Tensor
|
||||
output2_scale_scalar: torch.Tensor
|
||||
@ -235,14 +240,15 @@ class FP4BlockScaleMoERunner(TunableRunner):
|
||||
return kernel_runner.run_moe(
|
||||
args.routing_logits, args.routing_bias, args.hidden_states,
|
||||
args.hidden_states_scale, args.gemm1_weights,
|
||||
args.gemm1_weights_scale, args.gemm2_weights,
|
||||
args.gemm2_weights_scale, args.output1_scale_scalar,
|
||||
args.output1_scale_gate_scalar, args.output2_scale_scalar,
|
||||
self.num_experts, self.top_k, self.n_group, self.topk_group,
|
||||
self.intermediate_size, self.local_expert_offset,
|
||||
self.local_num_experts, self.routed_scaling_factor,
|
||||
self.routing_method_type, self.do_finalize, tactic,
|
||||
args.topk_weights, args.topk_ids)
|
||||
args.gemm1_weights_scale, args.gemm1_bias, args.gemm1_alpha,
|
||||
args.gemm1_beta, args.gemm1_clamp_limit, args.gemm2_weights,
|
||||
args.gemm2_weights_scale, args.gemm2_bias,
|
||||
args.output1_scale_scalar, args.output1_scale_gate_scalar,
|
||||
args.output2_scale_scalar, self.num_experts, self.top_k,
|
||||
self.n_group, self.topk_group, self.intermediate_size,
|
||||
self.local_expert_offset, self.local_num_experts,
|
||||
self.routed_scaling_factor, self.routing_method_type,
|
||||
self.do_finalize, tactic, args.topk_weights, args.topk_ids)
|
||||
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
@ -317,8 +323,8 @@ class FP4BlockScaleMoERunner(TunableRunner):
|
||||
|
||||
ROUTER_LOGITS_IDX = 0
|
||||
CONSTRAINED_RL_DIM = 0
|
||||
TOPK_WEIGHTS_IDX = 11
|
||||
TOPK_IDS_IDX = 12
|
||||
TOPK_WEIGHTS_IDX = 16
|
||||
TOPK_IDS_IDX = 17
|
||||
|
||||
constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX,
|
||||
CONSTRAINED_RL_DIM,
|
||||
@ -359,8 +365,13 @@ def fp4_block_scale_moe_runner(
|
||||
hidden_states_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
gemm1_bias: torch.Tensor,
|
||||
gemm1_alpha: torch.Tensor,
|
||||
gemm1_beta: torch.Tensor,
|
||||
gemm1_clamp_limit: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
gemm2_bias: torch.Tensor,
|
||||
output1_scale_scalar: torch.Tensor,
|
||||
output1_scale_gate_scalar: torch.Tensor,
|
||||
output2_scale_scalar: torch.Tensor,
|
||||
@ -416,8 +427,13 @@ def fp4_block_scale_moe_runner(
|
||||
hidden_states_scale,
|
||||
gemm1_weights,
|
||||
gemm1_weights_scale,
|
||||
gemm1_bias,
|
||||
gemm1_alpha,
|
||||
gemm1_beta,
|
||||
gemm1_clamp_limit,
|
||||
gemm2_weights,
|
||||
gemm2_weights_scale,
|
||||
gemm2_bias,
|
||||
output1_scale_scalar,
|
||||
output1_scale_gate_scalar,
|
||||
output2_scale_scalar,
|
||||
@ -474,8 +490,13 @@ def _(routing_logits,
|
||||
hidden_states_scale,
|
||||
gemm1_weights,
|
||||
gemm1_weights_scale,
|
||||
gemm1_bias,
|
||||
gemm1_alpha,
|
||||
gemm1_beta,
|
||||
gemm1_clamp_limit,
|
||||
gemm2_weights,
|
||||
gemm2_weights_scale,
|
||||
gemm2_bias,
|
||||
output1_scale_scalar,
|
||||
output1_scale_gate_scalar,
|
||||
output2_scale_scalar,
|
||||
|
||||
@ -34,8 +34,7 @@ from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import (DecoderModel, duplicate_kv_weight, filter_weights,
|
||||
register_auto_model)
|
||||
from .modeling_utils import DecoderModel, filter_weights, register_auto_model
|
||||
|
||||
# Use TinyGEMM when the number of tokens is not larger than this threshold
|
||||
MIN_LATENCY_TINYGEMM_NUM_TOKENS = 128
|
||||
@ -639,6 +638,15 @@ class GptOssForCausalLM(SpecDecOneEngineForCausalLM[Transformer, GptOssConfig]):
|
||||
|
||||
quant_config = self.model_config.quant_config
|
||||
if quant_config.exclude_modules:
|
||||
if quant_config.quant_algo == "NVFP4":
|
||||
quant_config.exclude_modules = [
|
||||
'block.*.attn.qkv',
|
||||
'block.*.attn.out',
|
||||
'block.*.mlp.gate',
|
||||
'embedding',
|
||||
'unembedding',
|
||||
]
|
||||
|
||||
for i, module in enumerate(quant_config.exclude_modules):
|
||||
names = module.split(".")
|
||||
if names[-1] in params_map_reverse:
|
||||
@ -653,13 +661,10 @@ class GptOssForCausalLM(SpecDecOneEngineForCausalLM[Transformer, GptOssConfig]):
|
||||
module.create_weights()
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
is_ori_model = True
|
||||
for k, v in weights.items():
|
||||
if 'q_proj' in k:
|
||||
is_ori_model = False
|
||||
is_nvfp4 = self.model_config.quant_config.quant_mode.has_nvfp4()
|
||||
|
||||
if is_ori_model:
|
||||
self.load_ori_weights(weights)
|
||||
if is_nvfp4:
|
||||
self.load_nvfp4_weights(weights)
|
||||
else:
|
||||
self.load_hf_weights(weights)
|
||||
|
||||
@ -811,176 +816,107 @@ class GptOssForCausalLM(SpecDecOneEngineForCausalLM[Transformer, GptOssConfig]):
|
||||
if p is not None:
|
||||
p.data.copy_(module_weights[n][:])
|
||||
|
||||
def load_ori_weights(self, weights: Dict):
|
||||
head_dim = self.config.head_dim
|
||||
num_q_head = self.config.num_attention_heads
|
||||
num_kv_head = self.config.num_key_value_heads
|
||||
def load_nvfp4_weights(self, weights: Dict):
|
||||
num_expert = self.config.num_local_experts
|
||||
enable_attention_dp = self.model_config.mapping.enable_attention_dp
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
|
||||
for name, module in tqdm(list(self.named_modules()),
|
||||
desc="Loading weights"):
|
||||
if len(module._parameters) <= 0 or name.startswith("draft_model"):
|
||||
continue
|
||||
names = name.split(".")
|
||||
|
||||
module_weights = {}
|
||||
if names[-1] in self.params_map:
|
||||
names[-1] = self.params_map[names[-1]]
|
||||
for k, v in self.hf_params_map.items():
|
||||
name = name.replace(k, v)
|
||||
|
||||
names = name.split('.')
|
||||
if names[-1] == "backend" and isinstance(module, MoE):
|
||||
# Backend is under experts module (ConfigurableMoE wrapper)
|
||||
name = '.'.join(names[:-1])
|
||||
|
||||
# Drop the first "model" prefix
|
||||
if names[0] == 'model':
|
||||
name = '.'.join(names[1:])
|
||||
else:
|
||||
name = '.'.join(names)
|
||||
module_weights = filter_weights(name, weights)
|
||||
|
||||
if isinstance(module, MoE):
|
||||
# [num_experts, intermediate_size * 2, hidden_size]
|
||||
gate_up_proj = filter_weights(name.replace("experts", "mlp1"),
|
||||
weights)
|
||||
# [num_experts, intermediate_size, hidden_size]
|
||||
down_proj = filter_weights(name.replace("experts", "mlp2"),
|
||||
weights)
|
||||
try:
|
||||
# Official MXFP4 ckpt.
|
||||
gate_up_weight = gate_up_proj['weight.blocks'].flatten(
|
||||
-2, -1)
|
||||
gate, up = gate_up_weight[:, ::2, :], gate_up_weight[:, 1::
|
||||
2, :]
|
||||
gate_up_weight = torch.cat([gate, up], dim=-2)
|
||||
gate_up_bias = gate_up_proj['bias']
|
||||
gate, up = gate_up_bias[:, ::2], gate_up_bias[:, 1::2]
|
||||
gate_up_bias = torch.cat([gate, up], dim=-1)
|
||||
moe_weights = {
|
||||
'gate_up_proj': [
|
||||
gate_up_weight[i, :, :].transpose(0, 1)
|
||||
for i in range(num_expert)
|
||||
],
|
||||
'down_proj': [
|
||||
down_proj['weight.blocks'].flatten(
|
||||
-2, -1)[i, :, :].transpose(0, 1)
|
||||
for i in range(num_expert)
|
||||
],
|
||||
'gate_up_proj.bias':
|
||||
[gate_up_bias[i, :] for i in range(num_expert)],
|
||||
'down_proj.bias':
|
||||
[down_proj['bias'][i, :] for i in range(num_expert)]
|
||||
}
|
||||
except:
|
||||
# For BF16 ckpt.
|
||||
moe_weights = {
|
||||
'gate_up_proj': [
|
||||
gate_up_proj['weight'][i, :, :].transpose(0, 1).to(
|
||||
self.model.dtype) for i in range(num_expert)
|
||||
],
|
||||
'down_proj': [
|
||||
down_proj['weight'][i, :, :].transpose(0, 1).to(
|
||||
self.model.dtype) for i in range(num_expert)
|
||||
],
|
||||
'gate_up_proj.bias':
|
||||
[gate_up_proj['bias'][i, :] for i in range(num_expert)],
|
||||
'down_proj.bias':
|
||||
[down_proj['bias'][i, :] for i in range(num_expert)]
|
||||
}
|
||||
# Only for Official MXFP4 ckpt.
|
||||
if 'weight.scales' in gate_up_proj:
|
||||
gate_up_weight_scale = gate_up_proj['weight.scales']
|
||||
gate, up = gate_up_weight_scale[:, ::
|
||||
2, :], gate_up_weight_scale[:,
|
||||
1::
|
||||
2, :]
|
||||
gate_up_weight_scale = torch.cat([gate, up], dim=-2)
|
||||
assert getattr(module, "quant_config", None) is not None and \
|
||||
module.quant_config.quant_mode.has_nvfp4()
|
||||
gate_up = module_weights.get('gate_up_proj', None)
|
||||
down = module_weights.get('down_proj', None)
|
||||
gate_up_bias = module_weights.get('gate_up_proj_bias', None)
|
||||
down_bias = module_weights.get('down_proj_bias', None)
|
||||
|
||||
def deinterleave(tensor):
|
||||
g, u = tensor[..., ::2], tensor[..., 1::2]
|
||||
return torch.cat([g, u], dim=-1)
|
||||
|
||||
gate_up = deinterleave(gate_up)
|
||||
gate_up_bias = deinterleave(gate_up_bias)
|
||||
|
||||
# Only fp32 bias is supported for NVFP4 MoE.
|
||||
if gate_up_bias.dtype != torch.float32:
|
||||
gate_up_bias = gate_up_bias.to(torch.float32)
|
||||
if down_bias.dtype != torch.float32:
|
||||
down_bias = down_bias.to(torch.float32)
|
||||
|
||||
moe_weights = {}
|
||||
if gate_up is not None:
|
||||
moe_weights['gate_up_proj'] = [
|
||||
gate_up[i, :, :] for i in range(num_expert)
|
||||
]
|
||||
if down is not None:
|
||||
moe_weights['down_proj'] = [
|
||||
down[i, :, :] for i in range(num_expert)
|
||||
]
|
||||
if gate_up_bias is not None:
|
||||
moe_weights['gate_up_proj.bias'] = [
|
||||
gate_up_bias[i, :] for i in range(num_expert)
|
||||
]
|
||||
if down_bias is not None:
|
||||
moe_weights['down_proj.bias'] = [
|
||||
down_bias[i, :] for i in range(num_expert)
|
||||
]
|
||||
|
||||
# Per-expert block scales (transpose to expected layout)
|
||||
if 'gate_up_proj_weight_scale' in module_weights:
|
||||
gu_ws = module_weights['gate_up_proj_weight_scale']
|
||||
gu_ws = deinterleave(gu_ws)
|
||||
moe_weights['gate_up_proj_weight_scale'] = [
|
||||
gate_up_weight_scale[i, :, :].transpose(0, 1)
|
||||
for i in range(num_expert)
|
||||
gu_ws[i, :, :] for i in range(num_expert)
|
||||
]
|
||||
|
||||
if self.model_config.quant_config.quant_algo == 'W4A16_MXFP4':
|
||||
for i in range(num_expert):
|
||||
moe_weights[f"{i}.w1.weight_scale_inv"] = gate[
|
||||
i, :, :]
|
||||
moe_weights[f"{i}.w3.weight_scale_inv"] = up[
|
||||
i, :, :]
|
||||
|
||||
if 'weight.scales' in down_proj:
|
||||
if 'down_proj_weight_scale' in module_weights:
|
||||
dp_ws = module_weights['down_proj_weight_scale']
|
||||
moe_weights['down_proj_weight_scale'] = [
|
||||
down_proj['weight.scales'][i, :, :].transpose(0, 1)
|
||||
for i in range(num_expert)
|
||||
dp_ws[i, :, :] for i in range(num_expert)
|
||||
]
|
||||
|
||||
if self.model_config.quant_config.quant_algo == 'W4A16_MXFP4':
|
||||
for i in range(num_expert):
|
||||
moe_weights[f"{i}.w2.weight_scale_inv"] = down_proj[
|
||||
'weight.scales'][i, :, :]
|
||||
# Module-level globals for NVFP4 loaders
|
||||
for src_key in [
|
||||
'gate_up_proj_weight_scale_2',
|
||||
'down_proj_weight_scale_2',
|
||||
'gate_up_proj_input_scale',
|
||||
'down_proj_input_scale',
|
||||
]:
|
||||
if src_key in module_weights:
|
||||
moe_weights[src_key] = module_weights[src_key]
|
||||
|
||||
module.load_weights(weights=[moe_weights])
|
||||
elif hasattr(module, "load_weights"):
|
||||
# Load Attention module weights.
|
||||
if 'qkv' in name:
|
||||
q_weight = module_weights['weight'][:head_dim *
|
||||
num_q_head, :]
|
||||
k_weight = module_weights['weight'][head_dim *
|
||||
num_q_head:head_dim *
|
||||
(num_q_head +
|
||||
num_kv_head), :]
|
||||
v_weight = module_weights['weight'][-head_dim *
|
||||
num_kv_head:, :]
|
||||
q_bias = module_weights['bias'][:head_dim * num_q_head]
|
||||
k_bias = module_weights['bias'][head_dim *
|
||||
num_q_head:head_dim *
|
||||
(num_q_head + num_kv_head)]
|
||||
v_bias = module_weights['bias'][-head_dim * num_kv_head:]
|
||||
|
||||
# Handle KV weight duplication for GQA
|
||||
tensors_need_duplication = ['weight', 'bias']
|
||||
if module.quant_config.quant_mode.has_mxfp4():
|
||||
tensors_need_duplication.append('weight_scale')
|
||||
|
||||
# Duplicate KV weights if needed
|
||||
tensor_parallel_size = tp_size if not enable_attention_dp else 1
|
||||
|
||||
k_weight_dict = {'weight': k_weight, 'bias': k_bias}
|
||||
v_weight_dict = {'weight': v_weight, 'bias': v_bias}
|
||||
|
||||
if 'weight_scale' in module_weights:
|
||||
k_weight_dict['weight_scale'] = module_weights[
|
||||
'weight_scale'][head_dim * num_q_head:head_dim *
|
||||
(num_q_head + num_kv_head), :]
|
||||
v_weight_dict['weight_scale'] = module_weights[
|
||||
'weight_scale'][-head_dim * num_kv_head:, :]
|
||||
|
||||
k_weight_dict = {
|
||||
k: (duplicate_kv_weight(
|
||||
weight=v,
|
||||
num_kv_heads=num_kv_head,
|
||||
tensor_parallel_size=tensor_parallel_size)
|
||||
if k in tensors_need_duplication else v)
|
||||
for k, v in k_weight_dict.items()
|
||||
}
|
||||
|
||||
v_weight_dict = {
|
||||
k: (duplicate_kv_weight(
|
||||
weight=v,
|
||||
num_kv_heads=num_kv_head,
|
||||
tensor_parallel_size=tensor_parallel_size)
|
||||
if k in tensors_need_duplication else v)
|
||||
for k, v in v_weight_dict.items()
|
||||
}
|
||||
|
||||
qkv_weights = [{
|
||||
'weight': q_weight,
|
||||
'bias': q_bias
|
||||
}, k_weight_dict, v_weight_dict]
|
||||
module.load_weights(weights=qkv_weights)
|
||||
# For qkv_proj
|
||||
q_weight_bias = filter_weights(
|
||||
name.replace('qkv_proj', 'q_proj'), weights)
|
||||
k_weight_bias = filter_weights(
|
||||
name.replace('qkv_proj', 'k_proj'), weights)
|
||||
v_weight_bias = filter_weights(
|
||||
name.replace('qkv_proj', 'v_proj'), weights)
|
||||
module.load_weights(
|
||||
weights=[q_weight_bias, k_weight_bias, v_weight_bias])
|
||||
else:
|
||||
# Dense & gate & sinks
|
||||
# For o_proj, sinks.
|
||||
module.load_weights(weights=[module_weights])
|
||||
else:
|
||||
# Load LN weights.
|
||||
if names[-1].endswith("layernorm") and names[-3] == "block":
|
||||
# skip loading weights for the fused norms
|
||||
# Load four LN weights (attn.norm, mlp.norm, input_layernorm, post_attention_layernorm).
|
||||
if 'next_layer_layernorm' in name:
|
||||
continue
|
||||
|
||||
for n, p in module._parameters.items():
|
||||
if p is not None:
|
||||
p.data.copy_(module_weights[n.replace(
|
||||
"weight", "scale")][:])
|
||||
p.data.copy_(module_weights[n][:])
|
||||
|
||||
@ -20,10 +20,10 @@ from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode
|
||||
|
||||
# isort: off
|
||||
from .quantization import (
|
||||
DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEMethod,
|
||||
UnquantizedFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
|
||||
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
|
||||
W4A16MXFP4TRTLLMGenFusedMoEMethod)
|
||||
DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod,
|
||||
NVFP4TRTLLMGenFusedMoEMethod, UnquantizedFusedMoEMethod,
|
||||
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
|
||||
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod)
|
||||
# isort: on
|
||||
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
|
||||
|
||||
@ -214,7 +214,7 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
|
||||
|
||||
if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None:
|
||||
assert self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports mxfp4 quantization with bias, swiglu_alpha, swiglu_beta and swiglu_limit."
|
||||
assert self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants."
|
||||
|
||||
def _get_quant_method(self):
|
||||
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
|
||||
@ -222,7 +222,9 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
|
||||
return DeepSeekFP8BlockScalesFusedMoEMethod()
|
||||
elif self.quant_config.layer_quant_mode.has_nvfp4():
|
||||
return NVFP4TRTLLMGenFusedMoEMethod()
|
||||
return NVFP4TRTLLMGenFusedMoEMethod(
|
||||
) if self.swiglu_alpha is not None else NVFP4TRTLLMGenFusedMoEBaseMethod(
|
||||
)
|
||||
elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4():
|
||||
return W4A16MXFP4TRTLLMGenFusedMoEMethod()
|
||||
elif self.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8():
|
||||
@ -324,6 +326,11 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
self,
|
||||
'fc31_act_scale') and self.fc31_act_scale is not None:
|
||||
x = x * self.fc31_act_scale
|
||||
|
||||
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
|
||||
if pad_size > 0:
|
||||
x = torch.nn.functional.pad(x, (0, pad_size))
|
||||
|
||||
x_row = x.shape[0]
|
||||
x, x_sf = torch.ops.trtllm.fp4_quantize(
|
||||
x, self.fc31_input_scale, self.scaling_vector_size, False,
|
||||
@ -446,6 +453,8 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
topk_ids=token_selected_experts,
|
||||
)
|
||||
elif self.has_nvfp4:
|
||||
intermediate_size_per_partition_padded = self.w3_w1_weight.shape[
|
||||
-2] // 2
|
||||
|
||||
outputs = torch.ops.trtllm.fp4_block_scale_moe_runner(
|
||||
router_logits,
|
||||
@ -454,8 +463,13 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
x_sf.view(torch.float8_e4m3fn),
|
||||
self.w3_w1_weight,
|
||||
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
|
||||
self.w3_w1_bias if self.bias else None,
|
||||
self.swiglu_alpha,
|
||||
self.swiglu_beta,
|
||||
self.swiglu_limit,
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale.view(torch.float8_e4m3fn),
|
||||
self.w2_bias if self.bias else None,
|
||||
self.fc31_scale_c.data,
|
||||
self.fc31_alpha.data,
|
||||
self.fc2_alpha.data,
|
||||
@ -463,7 +477,7 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
self.intermediate_size_per_partition,
|
||||
intermediate_size_per_partition_padded,
|
||||
self.slot_start,
|
||||
self.expert_size_per_partition,
|
||||
routed_scaling_factor,
|
||||
@ -478,6 +492,11 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
return outputs
|
||||
else:
|
||||
final_hidden_states = outputs[0]
|
||||
# Slice output if it was padded
|
||||
if final_hidden_states.shape[1] > self.hidden_size:
|
||||
final_hidden_states = final_hidden_states[:, :self.
|
||||
hidden_size].contiguous(
|
||||
)
|
||||
elif self.has_w4a16_mxfp4:
|
||||
assert x.dtype == torch.bfloat16
|
||||
|
||||
|
||||
@ -218,13 +218,11 @@ class FusedMoEMethodBase(ABC):
|
||||
|
||||
# bias
|
||||
if module.bias:
|
||||
# The shape might be padded so we use weight shape[:2]
|
||||
if w3_w1_bias_shape is None:
|
||||
w3_w1_bias_shape = (
|
||||
module.expert_size_per_partition,
|
||||
module.expand_intermediate_size_per_partition)
|
||||
w3_w1_bias_shape = w3_w1_weight_shape[:2]
|
||||
if w2_bias_shape is None:
|
||||
w2_bias_shape = (module.expert_size_per_partition,
|
||||
module.hidden_size)
|
||||
w2_bias_shape = w2_weight_shape[:2]
|
||||
bias_dtype = bias_dtype or module.dtype
|
||||
w3_w1_bias = nn.Parameter(torch.empty(w3_w1_bias_shape,
|
||||
dtype=bias_dtype),
|
||||
@ -1731,7 +1729,8 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
|
||||
weight_vec_size,
|
||||
block_scales_dtype,
|
||||
block_scales_vec_size,
|
||||
scaling_vector_size=16):
|
||||
scaling_vector_size=16,
|
||||
bias_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
module.scaling_vector_size = scaling_vector_size
|
||||
|
||||
@ -1780,7 +1779,8 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
|
||||
w3_w1_weight_shape=w3_w1_weight_shape,
|
||||
w2_weight_shape=w2_weight_shape,
|
||||
w3_w1_bias_shape=w3_w1_bias_shape,
|
||||
w2_bias_shape=w2_bias_shape)
|
||||
w2_bias_shape=w2_bias_shape,
|
||||
bias_dtype=bias_dtype)
|
||||
|
||||
self.setup_quant_scales(module)
|
||||
|
||||
@ -2300,7 +2300,7 @@ class NVFP4CuteDslFusedMoEMethod(NVFP4CutlassFusedMoEMethod):
|
||||
dst_w3_w1_weight_scale.copy_(w3_w1_weight_scale_interleaved)
|
||||
|
||||
|
||||
class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
class NVFP4TRTLLMGenFusedMoEBaseMethod(NVFP4FusedMoEMethod):
|
||||
weight_dtype = float4_sf_dtype
|
||||
block_scales_dtype = torch.float8_e4m3fn
|
||||
|
||||
@ -2308,12 +2308,18 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
# This assumes the same input shape always results in the same permute indices
|
||||
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
|
||||
|
||||
def create_weights(self, module: torch.nn.Module):
|
||||
def create_weights(self,
|
||||
module: torch.nn.Module,
|
||||
bias_dtype: Optional[torch.dtype] = None):
|
||||
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
|
||||
block_scales_vec_size = 1
|
||||
|
||||
super().create_weights(module, self.weight_dtype, weight_vec_size,
|
||||
self.block_scales_dtype, block_scales_vec_size)
|
||||
super().create_weights(module,
|
||||
self.weight_dtype,
|
||||
weight_vec_size,
|
||||
self.block_scales_dtype,
|
||||
block_scales_vec_size,
|
||||
bias_dtype=bias_dtype)
|
||||
|
||||
fc31_scale_c = nn.Parameter(torch.ones(module.expert_size_per_partition,
|
||||
dtype=torch.float32),
|
||||
@ -2565,7 +2571,340 @@ class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
})
|
||||
|
||||
|
||||
class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethod):
|
||||
class NVFP4TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEBaseMethod):
|
||||
weight_alignment = 32
|
||||
input_hidden_alignment = 32
|
||||
|
||||
def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int,
|
||||
block_scales_vec_size: int):
|
||||
|
||||
def round_up(x, alignment):
|
||||
return (x + alignment - 1) // alignment * alignment
|
||||
|
||||
# Compute padded sizes
|
||||
intermediate_size_per_partition_padded = round_up(
|
||||
module.intermediate_size_per_partition, self.weight_alignment)
|
||||
w3_w1_hidden_size_padded = round_up(module.hidden_size,
|
||||
self.input_hidden_alignment)
|
||||
w2_hidden_size_padded = round_up(module.hidden_size,
|
||||
self.weight_alignment)
|
||||
|
||||
# Divide by 16 because we use int64 to pack 16 fp4 values
|
||||
w3_w1_weight_shape = (module.expert_size_per_partition,
|
||||
intermediate_size_per_partition_padded *
|
||||
module.intermediate_size_expand_ratio,
|
||||
w3_w1_hidden_size_padded // weight_vec_size)
|
||||
w2_weight_shape = (module.expert_size_per_partition,
|
||||
w2_hidden_size_padded,
|
||||
intermediate_size_per_partition_padded //
|
||||
weight_vec_size)
|
||||
|
||||
w3_w1_weight_scale_shape = (module.expert_size_per_partition,
|
||||
intermediate_size_per_partition_padded *
|
||||
module.intermediate_size_expand_ratio,
|
||||
w3_w1_hidden_size_padded //
|
||||
module.scaling_vector_size //
|
||||
block_scales_vec_size)
|
||||
w2_weight_scale_shape = (module.expert_size_per_partition,
|
||||
w2_hidden_size_padded,
|
||||
intermediate_size_per_partition_padded //
|
||||
module.scaling_vector_size //
|
||||
block_scales_vec_size)
|
||||
|
||||
if module.bias:
|
||||
w3_w1_bias_shape = (module.expert_size_per_partition,
|
||||
intermediate_size_per_partition_padded *
|
||||
module.intermediate_size_expand_ratio)
|
||||
w2_bias_shape = (module.expert_size_per_partition,
|
||||
w2_hidden_size_padded)
|
||||
else:
|
||||
w3_w1_bias_shape = None
|
||||
w2_bias_shape = None
|
||||
|
||||
return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape,
|
||||
w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape)
|
||||
|
||||
def create_weights(self, module: torch.nn.Module):
|
||||
# Here we only enable padding for hidden_size > 1024 since there are small unit tests that expect no padding.
|
||||
if module.hidden_size > 1024 and module.hidden_size % 256 != 0:
|
||||
self.weight_alignment = 256
|
||||
# For now let's keep input alignment same as weight alignment. There are practical reasons that this might be a different value.
|
||||
# See the comment in MXFP4WeightTRTLLMGenFusedMoEMethod for more details.
|
||||
self.input_hidden_alignment = 256
|
||||
|
||||
super().create_weights(module, bias_dtype=torch.float32)
|
||||
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
module.quant_scales = tuple()
|
||||
|
||||
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
|
||||
w1_weight: torch.Tensor,
|
||||
w3_weight: torch.Tensor,
|
||||
dst_w3_w1_weight: torch.Tensor):
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
dst_on_gpu = dst_w3_w1_weight.device.type == "cuda"
|
||||
dst_w3_w1_weight_gpu = dst_w3_w1_weight if dst_on_gpu else dst_w3_w1_weight.cuda(
|
||||
)
|
||||
|
||||
alignment = _get_weight_alignment(self.weight_alignment,
|
||||
module.scaling_vector_size,
|
||||
module.tp_size, w1_weight.shape[0])
|
||||
if len(w1_weight.shape) == 2:
|
||||
# Pad weights
|
||||
# We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8.
|
||||
assert w1_weight.dtype == torch.uint8
|
||||
w1_weight = maybe_pad_for_mxfp4(w1_weight,
|
||||
self.input_hidden_alignment // 2,
|
||||
alignment)
|
||||
assert w3_weight.dtype == torch.uint8
|
||||
w3_weight = maybe_pad_for_mxfp4(w3_weight,
|
||||
self.input_hidden_alignment // 2,
|
||||
alignment)
|
||||
else:
|
||||
# Pad bias, TRTLLM backend expects float32 bias.
|
||||
assert len(w1_weight.shape) == 1
|
||||
assert len(w3_weight.shape) == 1
|
||||
w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).float()
|
||||
w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).float()
|
||||
|
||||
w1_weight_shard = load_weight_shard(w1_weight,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
w3_weight_shard = load_weight_shard(w3_weight,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
|
||||
# FIXME: this depends on the kernel internals
|
||||
epilogue_tile_m = 128
|
||||
|
||||
# Keep weights in device buffer
|
||||
dst_w3_weight, dst_w1_weight = dst_w3_w1_weight_gpu.chunk(2, dim=0)
|
||||
dst_w3_weight.copy_(w3_weight_shard.view(dst_w3_weight.dtype))
|
||||
dst_w1_weight.copy_(w1_weight_shard.view(dst_w1_weight.dtype))
|
||||
|
||||
# Get permute indices
|
||||
permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices(
|
||||
dst_w3_w1_weight_gpu, self._cache_permute_indices, epilogue_tile_m)
|
||||
|
||||
# Shuffle the weight according to permute indices
|
||||
processed_w31_weight_shard = torch.ops.trtllm.shuffle_matrix(
|
||||
dst_w3_w1_weight_gpu,
|
||||
permute_indices.to(dst_w3_w1_weight_gpu.device))
|
||||
|
||||
# Copy the result into device buffer
|
||||
dst_w3_w1_weight_gpu.copy_(processed_w31_weight_shard.view(
|
||||
dst_w3_w1_weight_gpu.dtype),
|
||||
non_blocking=dst_on_gpu)
|
||||
if not dst_on_gpu:
|
||||
dst_w3_w1_weight.copy_(dst_w3_w1_weight_gpu)
|
||||
|
||||
def load_expert_w2_weight(self, module: torch.nn.Module,
|
||||
w2_weight: torch.Tensor,
|
||||
dst_w2_weight: torch.Tensor):
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
dst_on_gpu = dst_w2_weight.device.type == "cuda"
|
||||
dst_w2_weight_gpu = dst_w2_weight if dst_on_gpu else dst_w2_weight.cuda(
|
||||
)
|
||||
|
||||
shard_w2_weight_dim = 2 * w2_weight.shape[1] if len(
|
||||
w2_weight.shape) == 2 else w2_weight.shape[0]
|
||||
alignment = _get_weight_alignment(self.weight_alignment,
|
||||
module.scaling_vector_size,
|
||||
module.tp_size, shard_w2_weight_dim)
|
||||
if len(w2_weight.shape) == 2:
|
||||
assert w2_weight.dtype == torch.uint8
|
||||
w2_weight = maybe_pad_for_mxfp4(w2_weight, alignment // 2,
|
||||
self.weight_alignment)
|
||||
else:
|
||||
assert len(w2_weight.shape) == 1
|
||||
w2_weight = maybe_pad_for_mxfp4(w2_weight, self.weight_alignment)
|
||||
|
||||
# Divide bias by tp_size as we shard along the hidden dimension.
|
||||
# The bias is applied at each TP rank before the final accumulation.
|
||||
w2_weight /= module.tp_size
|
||||
|
||||
w2_weight_shard = load_weight_shard(w2_weight,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.ROW,
|
||||
device=device)
|
||||
|
||||
# FIXME: this depends on the kernel internals
|
||||
epilogue_tile_m = 128
|
||||
|
||||
# Keep weights in device buffer
|
||||
dst_w2_weight_gpu.copy_(w2_weight_shard.view(dst_w2_weight_gpu.dtype),
|
||||
non_blocking=dst_on_gpu)
|
||||
# Get permuted indices
|
||||
permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices(
|
||||
dst_w2_weight_gpu, self._cache_permute_indices, epilogue_tile_m)
|
||||
|
||||
# Shuffle the weight according to permute indices
|
||||
processed_w2_weight = torch.ops.trtllm.shuffle_matrix(
|
||||
dst_w2_weight_gpu, permute_indices.to(dst_w2_weight_gpu.device))
|
||||
|
||||
# Copy the result into device buffer
|
||||
dst_w2_weight_gpu.copy_(processed_w2_weight.view(
|
||||
dst_w2_weight_gpu.dtype),
|
||||
non_blocking=dst_on_gpu)
|
||||
|
||||
if not dst_on_gpu:
|
||||
dst_w2_weight.copy_(dst_w2_weight_gpu)
|
||||
|
||||
def load_expert_w3_w1_weight_scale_nvfp4(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
w1_weight_scale: torch.Tensor,
|
||||
w3_weight_scale: torch.Tensor,
|
||||
dst_w3_w1_weight_scale: torch.Tensor,
|
||||
num_elts_per_sf: int = 16):
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
dst_on_gpu = dst_w3_w1_weight_scale.device.type == "cuda"
|
||||
dst_w3_w1_weight_scale_gpu = dst_w3_w1_weight_scale if dst_on_gpu else dst_w3_w1_weight_scale.cuda(
|
||||
)
|
||||
|
||||
alignment = _get_weight_alignment(self.weight_alignment,
|
||||
module.scaling_vector_size,
|
||||
module.tp_size,
|
||||
w3_weight_scale.shape[0])
|
||||
w1_weight_scale = maybe_pad_for_mxfp4(
|
||||
w1_weight_scale,
|
||||
self.input_hidden_alignment // module.scaling_vector_size,
|
||||
alignment)
|
||||
w3_weight_scale = maybe_pad_for_mxfp4(
|
||||
w3_weight_scale,
|
||||
self.input_hidden_alignment // module.scaling_vector_size,
|
||||
alignment)
|
||||
|
||||
w1_weight_scale = load_weight_shard(w1_weight_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
w3_weight_scale = load_weight_shard(w3_weight_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device)
|
||||
# Keep weights in device buffer
|
||||
dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale_gpu.chunk(
|
||||
2, dim=0)
|
||||
dst_w3_weight_scale.copy_(
|
||||
w3_weight_scale.view(dst_w3_weight_scale.dtype))
|
||||
dst_w1_weight_scale.copy_(
|
||||
w1_weight_scale.view(dst_w1_weight_scale.dtype))
|
||||
|
||||
orig_shape = dst_w3_w1_weight_scale_gpu.shape
|
||||
|
||||
# trtllm-gen specific block scales preprocessing logics
|
||||
epilogue_tile_m = 128 # FIXME
|
||||
|
||||
# Get permute indices
|
||||
permute_indices = trtllmgen_maybe_get_cached_w3_w1_permute_indices(
|
||||
dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype),
|
||||
self._cache_permute_indices,
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=num_elts_per_sf)
|
||||
|
||||
# Shuffle the weight according to permute indices
|
||||
w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix(
|
||||
dst_w3_w1_weight_scale_gpu.view(float4_sf_dtype), permute_indices)
|
||||
|
||||
# Assert should only be removed during debugging
|
||||
assert w3_w1_weight_scale.is_cuda, "w3_w1_weight_scale.is_cuda should be true or suffer from slow speed"
|
||||
# Interleave the weight.
|
||||
processed_w3_w1_weight_scale = torch.ops.trtllm.block_scale_interleave(
|
||||
w3_w1_weight_scale.view(float4_sf_dtype).reshape(orig_shape))
|
||||
# Copy the result into device buffer
|
||||
dst_w3_w1_weight_scale_gpu.copy_(
|
||||
processed_w3_w1_weight_scale.view(
|
||||
self.block_scales_dtype).reshape(orig_shape))
|
||||
|
||||
if not dst_on_gpu:
|
||||
dst_w3_w1_weight_scale.copy_(dst_w3_w1_weight_scale_gpu)
|
||||
|
||||
def load_expert_w2_weight_scale_nvfp4(self,
|
||||
module: torch.nn.Module,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
dst_w2_weight_scale: torch.Tensor,
|
||||
num_elts_per_sf: int = 16):
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
dst_on_gpu = dst_w2_weight_scale.device.type == "cuda"
|
||||
dst_w2_weight_scale_gpu = dst_w2_weight_scale if dst_on_gpu else dst_w2_weight_scale.cuda(
|
||||
)
|
||||
|
||||
alignment = _get_weight_alignment(self.weight_alignment,
|
||||
module.scaling_vector_size,
|
||||
module.tp_size,
|
||||
w2_weight_scale.shape[-1])
|
||||
w2_weight_scale = maybe_pad_for_mxfp4(
|
||||
w2_weight_scale, alignment // module.scaling_vector_size,
|
||||
self.weight_alignment)
|
||||
|
||||
w2_weight_scale = load_weight_shard(w2_weight_scale,
|
||||
module.tp_size,
|
||||
module.tp_rank,
|
||||
TensorParallelMode.ROW,
|
||||
device=device)
|
||||
# Keep weights in device buffer
|
||||
dst_w2_weight_scale_gpu.copy_(
|
||||
w2_weight_scale.view(dst_w2_weight_scale_gpu.dtype))
|
||||
|
||||
orig_shape = dst_w2_weight_scale_gpu.shape
|
||||
|
||||
# trtllm-gen specific block scales preprocessing logics
|
||||
epilogue_tile_m = 128 # FIXME: read from kernel
|
||||
|
||||
# Assert should only be removed during debugging
|
||||
assert dst_w2_weight_scale_gpu.is_cuda, "dst_w2_weight_scale.is_cuda should be true or suffer from slow speed"
|
||||
|
||||
# Get permute indices
|
||||
permute_indices = trtllmgen_maybe_get_cached_w2_permute_indices(
|
||||
dst_w2_weight_scale_gpu.view(float4_sf_dtype),
|
||||
self._cache_permute_indices,
|
||||
epilogue_tile_m,
|
||||
num_elts_per_sf=num_elts_per_sf)
|
||||
|
||||
# Shuffle the weight according to permute indices
|
||||
w_shuffled = torch.ops.trtllm.shuffle_matrix(
|
||||
dst_w2_weight_scale_gpu.view(dtype=float4_sf_dtype),
|
||||
permute_indices)
|
||||
# Interleave the weight.
|
||||
processed_w2_weight_scale = torch.ops.trtllm.block_scale_interleave(
|
||||
w_shuffled)
|
||||
# Copy the result into device buffer
|
||||
dst_w2_weight_scale_gpu.copy_(
|
||||
processed_w2_weight_scale.view(
|
||||
self.block_scales_dtype).reshape(orig_shape))
|
||||
|
||||
if not dst_on_gpu:
|
||||
dst_w2_weight_scale.copy_(dst_w2_weight_scale_gpu)
|
||||
|
||||
def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
|
||||
super().load_quant_scales(module, weights)
|
||||
|
||||
# Normalize biases to account for the global scale factors,
|
||||
# matching the kernel's expectation (similar to test_moe.py logic).
|
||||
if module.w3_w1_bias is not None:
|
||||
# gemm1_bias * gemm1_scales_global * hidden_states_scale_global
|
||||
module.w3_w1_bias.data.div_((module.fc31_alpha.data).view(-1, 1))
|
||||
|
||||
if module.w2_bias is not None:
|
||||
# gemm2_bias * c_global_sf * gemm2_scales_global
|
||||
module.w2_bias.data.div_((module.fc2_alpha.data).view(-1, 1))
|
||||
|
||||
if module.swiglu_beta is not None:
|
||||
module.swiglu_beta.data.div_((module.fc31_alpha.data))
|
||||
|
||||
if module.swiglu_limit is not None:
|
||||
module.swiglu_limit.data.div_((module.fc31_alpha.data))
|
||||
|
||||
|
||||
class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEBaseMethod):
|
||||
|
||||
def create_weights(self, module: torch.nn.Module):
|
||||
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
|
||||
|
||||
@ -270,6 +270,13 @@ GPT-OSS/20B-MXFP4:
|
||||
- quant_algo: W4A16_MXFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 85.0
|
||||
GPT-OSS/20B-NVFP4:
|
||||
- accuracy: 85.0
|
||||
- quant_algo: NVFP4
|
||||
accuracy: 85.0
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 85.0
|
||||
LGAI-EXAONE/EXAONE-4.0-32B:
|
||||
- accuracy: 88.36
|
||||
ByteDance-Seed/Seed-OSS-36B-Instruct:
|
||||
|
||||
@ -4270,6 +4270,44 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [
|
||||
(2, 1, 1, False, True, True),
|
||||
(2, 1, 2, False, True, True),
|
||||
(2, 1, 2, True, True, True),
|
||||
],
|
||||
ids=["tp2", "ep2", "dp2"])
|
||||
def test_w4_2gpus_nvfp4(self, tp_size, pp_size, ep_size, attention_dp,
|
||||
cuda_graph, overlap_scheduler, mocker):
|
||||
pytest.skip("Models not uploaded to CI")
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
|
||||
dtype="auto")
|
||||
|
||||
llm = LLM("./nvfp4ckpt",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_seq_len=8192,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
moe_config=MoeConfig(backend="TRTLLM"))
|
||||
|
||||
with llm:
|
||||
model_name = "GPT-OSS/20B-NVFP4"
|
||||
task = GSM8K(model_name)
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
|
||||
{"scores_filter": "exact_match,flexible-extract"})
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype",
|
||||
|
||||
@ -11,6 +11,7 @@ import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from _torch.helpers import (calc_woq_tolerence, per_block_cast_to_fp8,
|
||||
per_block_cast_to_fp8_e8m0,
|
||||
per_token_cast_to_fp8_e8m0)
|
||||
@ -1398,6 +1399,53 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
and moe_backend in ["TRTLLM", "CUTLASS"] else "0"
|
||||
})
|
||||
|
||||
run_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize("hidden_size, intermediate_size", [(2880, 2880)])
|
||||
@pytest.mark.parametrize("swiglu_alpha", [1, 0.1], ids=lambda v: f"alpha{v}")
|
||||
@pytest.mark.parametrize("swiglu_beta", [0, 1], ids=lambda v: f"beta{v}")
|
||||
@pytest.mark.parametrize("swiglu_limit", [float("inf"), 1],
|
||||
ids=lambda v: f"limit{v}")
|
||||
@pytest.mark.parametrize("enable_configurable_moe", [0, 1],
|
||||
ids=lambda x: ""
|
||||
if x == 0 else "enable_configurable_moe")
|
||||
def test_fused_moe_nvfp4_gptoss_style(hidden_size, intermediate_size,
|
||||
swiglu_alpha, swiglu_beta, swiglu_limit,
|
||||
enable_configurable_moe, mocker):
|
||||
mocker.patch.dict(os.environ, {
|
||||
"ENABLE_CONFIGURABLE_MOE":
|
||||
"1" if enable_configurable_moe == 1 else "0"
|
||||
})
|
||||
|
||||
run_fused_moe_nvfp4(dtype=torch.bfloat16,
|
||||
moe_backend="TRTLLM",
|
||||
finalize_fusion=False,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_experts=32,
|
||||
top_k=4,
|
||||
seq_len=256,
|
||||
gptoss_style=True,
|
||||
swiglu_alpha=swiglu_alpha,
|
||||
swiglu_beta=swiglu_beta,
|
||||
swiglu_limit=swiglu_limit)
|
||||
|
||||
|
||||
def run_fused_moe_nvfp4(dtype,
|
||||
moe_backend,
|
||||
finalize_fusion,
|
||||
hidden_size=512,
|
||||
intermediate_size=512,
|
||||
num_experts=8,
|
||||
top_k=2,
|
||||
seq_len=4,
|
||||
gptoss_style=False,
|
||||
swiglu_alpha=None,
|
||||
swiglu_beta=None,
|
||||
swiglu_limit=None):
|
||||
|
||||
if moe_backend == "TRTLLM":
|
||||
if dtype == torch.float16:
|
||||
pytest.skip("TRTLLM NVFP4 MoE backend does not support float16 yet")
|
||||
@ -1424,11 +1472,11 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
with torch.device(f"cuda:{mapping.rank}"):
|
||||
SCALING_VECTOR_SIZE = 16
|
||||
|
||||
SEQ_LEN = 4
|
||||
HIDDEN_SIZE = 512
|
||||
INTERMEDIATE_SIZE = 512
|
||||
NUM_EXPERTS = 8
|
||||
TOP_K = 2
|
||||
SEQ_LEN = seq_len
|
||||
HIDDEN_SIZE = hidden_size
|
||||
INTERMEDIATE_SIZE = intermediate_size
|
||||
NUM_EXPERTS = num_experts
|
||||
TOP_K = top_k
|
||||
routing_method = RenormalizeMoeRoutingMethod(top_k=TOP_K)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
@ -1455,24 +1503,38 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
device="cuda") * 0.05
|
||||
w3_sf_global = (448 * 6) / w3_weight.abs().max().float()
|
||||
|
||||
if gptoss_style:
|
||||
w1_bias = torch.randn(INTERMEDIATE_SIZE,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
w2_bias = torch.randn(HIDDEN_SIZE,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
w3_bias = torch.randn(INTERMEDIATE_SIZE,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
weights[f"{expert_id}.w1.bias"] = w1_bias
|
||||
weights[f"{expert_id}.w2.bias"] = w2_bias
|
||||
weights[f"{expert_id}.w3.bias"] = w3_bias
|
||||
|
||||
w3_w1_global = min(
|
||||
w1_sf_global,
|
||||
w3_sf_global) # w3 global and w1 global must be the same
|
||||
|
||||
w1_weight_nvfp4, w1_sf_block = torch.ops.trtllm.fp4_quantize(
|
||||
w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
|
||||
w1_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
|
||||
w1_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
|
||||
w1_weight_nvfp4, w1_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize(
|
||||
w1_weight, w3_w1_global, SCALING_VECTOR_SIZE, False, False)
|
||||
w1_sf_block_unswizzled = w1_sf_block_unswizzled.view(
|
||||
INTERMEDIATE_SIZE, -1)
|
||||
|
||||
w2_weight_nvfp4, w2_sf_block = torch.ops.trtllm.fp4_quantize(
|
||||
w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False)
|
||||
w2_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
|
||||
w2_sf_block.cpu().view(HIDDEN_SIZE, -1))
|
||||
w2_weight_nvfp4, w2_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize(
|
||||
w2_weight, w2_sf_global, SCALING_VECTOR_SIZE, False, False)
|
||||
w2_sf_block_unswizzled = w2_sf_block_unswizzled.view(
|
||||
HIDDEN_SIZE, -1)
|
||||
|
||||
w3_weight_nvfp4, w3_sf_block = torch.ops.trtllm.fp4_quantize(
|
||||
w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False)
|
||||
w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
|
||||
w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
|
||||
w3_weight_nvfp4, w3_sf_block_unswizzled = torch.ops.trtllm.fp4_quantize(
|
||||
w3_weight, w3_w1_global, SCALING_VECTOR_SIZE, False, False)
|
||||
w3_sf_block_unswizzled = w3_sf_block_unswizzled.view(
|
||||
INTERMEDIATE_SIZE, -1)
|
||||
|
||||
w1_input_scale = x_sf_global.cuda()
|
||||
w2_input_scale = x_sf_global.cuda()
|
||||
@ -1497,6 +1559,23 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global
|
||||
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global
|
||||
|
||||
swiglu_alpha_tensor = None
|
||||
swiglu_beta_tensor = None
|
||||
swiglu_limit_tensor = None
|
||||
if gptoss_style:
|
||||
swiglu_alpha_tensor = torch.full((NUM_EXPERTS, ),
|
||||
swiglu_alpha,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
swiglu_beta_tensor = torch.full((NUM_EXPERTS, ),
|
||||
swiglu_beta,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
swiglu_limit_tensor = torch.full((NUM_EXPERTS, ),
|
||||
swiglu_limit,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
|
||||
quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4)
|
||||
|
||||
# Create pretrained_config with necessary parameters
|
||||
@ -1514,6 +1593,10 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
quant_config=quant_config,
|
||||
moe_backend=moe_backend,
|
||||
moe_disable_finalize_fusion=not finalize_fusion),
|
||||
bias=gptoss_style,
|
||||
swiglu_alpha=swiglu_alpha_tensor,
|
||||
swiglu_beta=swiglu_beta_tensor,
|
||||
swiglu_limit=swiglu_limit_tensor,
|
||||
)
|
||||
fused_moe.load_weights([weights])
|
||||
fused_moe.post_load_weights()
|
||||
@ -1526,7 +1609,11 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
dtype=dtype,
|
||||
model_config=ModelConfig(quant_config=quant_config))
|
||||
model_config=ModelConfig(quant_config=quant_config),
|
||||
bias=gptoss_style,
|
||||
swiglu_alpha=swiglu_alpha,
|
||||
swiglu_beta=swiglu_beta,
|
||||
swiglu_limit=swiglu_limit)
|
||||
ref_fused_moe.load_weights([weights])
|
||||
ref_fused_moe.cuda()
|
||||
|
||||
@ -1534,11 +1621,33 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
with torch.inference_mode():
|
||||
ref_output = ref_fused_moe.forward(x, router_logits)
|
||||
|
||||
with torch.inference_mode(), autotune():
|
||||
fused_moe.forward(x, router_logits)
|
||||
if not gptoss_style:
|
||||
with torch.inference_mode(), autotune():
|
||||
fused_moe.forward(x, router_logits)
|
||||
else:
|
||||
# We skip autotune for gptoss style to reduce memory usage since the input shape is already quite large.
|
||||
with torch.inference_mode():
|
||||
fused_moe.forward(x, router_logits)
|
||||
|
||||
output = fused_moe.forward(x, router_logits)
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.15)
|
||||
|
||||
if gptoss_style:
|
||||
rtol = 0.1
|
||||
atol = 0.1
|
||||
percent = 0.95
|
||||
else:
|
||||
rtol = 1e-2
|
||||
atol = 0.15
|
||||
percent = None
|
||||
|
||||
if gptoss_style:
|
||||
check_accuracy(output,
|
||||
ref_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
percent=percent)
|
||||
else:
|
||||
torch.testing.assert_close(output, ref_output, rtol=rtol, atol=atol)
|
||||
|
||||
if not test_all_kernels:
|
||||
return
|
||||
@ -1551,10 +1660,17 @@ def test_fused_moe_nvfp4(dtype, moe_backend, finalize_fusion,
|
||||
for tactic in all_tactics:
|
||||
with AutoTuner.get().replay(tactic), torch.inference_mode():
|
||||
output = fused_moe.forward(x, router_logits)
|
||||
torch.testing.assert_close(output,
|
||||
ref_output,
|
||||
rtol=1e-2,
|
||||
atol=0.15)
|
||||
if gptoss_style:
|
||||
check_accuracy(output,
|
||||
ref_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
percent=percent)
|
||||
else:
|
||||
torch.testing.assert_close(output,
|
||||
ref_output,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@ -2690,7 +2806,10 @@ class RefGatedMLPFusedMoE(nn.Module):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
use_cute_dsl_blockscaling_mm: bool = False,
|
||||
bias=False):
|
||||
bias=False,
|
||||
swiglu_alpha: Optional[float] = None,
|
||||
swiglu_beta: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.routing_method = routing_method
|
||||
@ -2701,6 +2820,19 @@ class RefGatedMLPFusedMoE(nn.Module):
|
||||
self.dtype = dtype
|
||||
self.quant_config = model_config.quant_config
|
||||
|
||||
def custom_swiglu(x):
|
||||
gate, value = x.chunk(2, dim=-1)
|
||||
if swiglu_limit is not None and swiglu_limit != float("inf"):
|
||||
gate = gate.clamp(max=swiglu_limit)
|
||||
value = value.clamp(min=-swiglu_limit, max=swiglu_limit)
|
||||
|
||||
alpha = swiglu_alpha if swiglu_alpha is not None else 1.0
|
||||
gate_act = gate * torch.sigmoid(gate * alpha)
|
||||
|
||||
beta = swiglu_beta if swiglu_beta is not None else 0.0
|
||||
|
||||
return gate_act * (value + beta)
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
GatedMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -2709,6 +2841,8 @@ class RefGatedMLPFusedMoE(nn.Module):
|
||||
dtype=self.dtype,
|
||||
config=model_config,
|
||||
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
|
||||
activation=custom_swiglu
|
||||
if swiglu_alpha is not None else F.silu,
|
||||
) for _ in range(self.num_experts)
|
||||
])
|
||||
|
||||
|
||||
@ -547,11 +547,23 @@ def run_moe_reference_fp4(args):
|
||||
args.gemm2_weights, args.gemm2_scales, 1 / args.gemm2_scales_global,
|
||||
sf_vec_size).cuda()
|
||||
|
||||
args_dequant = moe_args_dequant(
|
||||
args.num_tokens, args.num_experts, args.hidden_size,
|
||||
args.intermediate_size, args.top_k, args.padding, hidden_states_dequant,
|
||||
args.expert_logits, gemm1_weights_dequant, gemm2_weights_dequant,
|
||||
args.permute_info, args.use_routing_scales_on_input)
|
||||
args_dequant = moe_args_dequant(args.num_tokens,
|
||||
args.num_experts,
|
||||
args.hidden_size,
|
||||
args.intermediate_size,
|
||||
args.top_k,
|
||||
args.padding,
|
||||
hidden_states_dequant,
|
||||
args.expert_logits,
|
||||
gemm1_weights_dequant,
|
||||
gemm2_weights_dequant,
|
||||
args.permute_info,
|
||||
args.use_routing_scales_on_input,
|
||||
gemm1_bias=args.gemm1_bias,
|
||||
gemm1_alpha=args.gemm1_alpha,
|
||||
gemm1_beta=args.gemm1_beta,
|
||||
gemm1_clamp_limit=args.gemm1_clamp_limit,
|
||||
gemm2_bias=args.gemm2_bias)
|
||||
|
||||
return run_moe_dequant(args_dequant, "fp4"), args_dequant
|
||||
|
||||
@ -1157,6 +1169,44 @@ class TestMoeFp4:
|
||||
use_autotune=False,
|
||||
use_topk_as_input=use_topk_as_input)
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1])
|
||||
@pytest.mark.parametrize("hidden_size", [512])
|
||||
@pytest.mark.parametrize("intermediate_size", [512])
|
||||
@pytest.mark.parametrize(
|
||||
"routing_info",
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
"num_experts": 128,
|
||||
"top_k": 4,
|
||||
"n_groups": None,
|
||||
"top_k_groups": None,
|
||||
"routed_scaling": None,
|
||||
"has_routing_bias": False,
|
||||
"routing_method_type": RoutingMethodType.Renormalize
|
||||
},
|
||||
id="RoutingGPTOSS")
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("swiglu_alpha", [1, 0.1],
|
||||
ids=lambda v: f"alpha{v}")
|
||||
@pytest.mark.parametrize("swiglu_beta", [0, 1], ids=lambda v: f"beta{v}")
|
||||
@pytest.mark.parametrize("swiglu_limit", [float("inf"), 1],
|
||||
ids=lambda v: f"limit{v}")
|
||||
def test_gptoss_style_nvfp4(self, num_tokens, hidden_size,
|
||||
intermediate_size, routing_info, swiglu_alpha,
|
||||
swiglu_beta, swiglu_limit):
|
||||
|
||||
self.run_moe_fp4_test(num_tokens,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
routing_info,
|
||||
use_autotune=False,
|
||||
gptoss_style=True,
|
||||
swiglu_alpha=swiglu_alpha,
|
||||
swiglu_beta=swiglu_beta,
|
||||
swiglu_limit=swiglu_limit)
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1])
|
||||
@pytest.mark.parametrize("hidden_size", [1024])
|
||||
@pytest.mark.parametrize("intermediate_size", [1024])
|
||||
@ -1219,9 +1269,17 @@ class TestMoeFp4:
|
||||
use_autotune=True,
|
||||
use_topk_as_input=True)
|
||||
|
||||
def run_moe_fp4_test(self, num_tokens: int, hidden_size: int,
|
||||
intermediate_size: int, routing_info: dict,
|
||||
use_autotune: bool, use_topk_as_input: bool) -> None:
|
||||
def run_moe_fp4_test(self,
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
routing_info: dict,
|
||||
use_autotune: bool,
|
||||
use_topk_as_input: bool = False,
|
||||
gptoss_style: bool = False,
|
||||
swiglu_alpha: float = None,
|
||||
swiglu_beta: float = None,
|
||||
swiglu_limit: float = None) -> None:
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
@ -1289,6 +1347,39 @@ class TestMoeFp4:
|
||||
device='cuda',
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
gemm1_bias = None
|
||||
gemm2_bias = None
|
||||
swiglu_alpha_tensor = None
|
||||
swiglu_beta_tensor = None
|
||||
swiglu_limit_tensor = None
|
||||
|
||||
if gptoss_style:
|
||||
gemm1_bias = 50 * torch.randn(num_experts,
|
||||
2 * intermediate_size,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
gemm2_bias = 50 * torch.randn(
|
||||
num_experts, hidden_size, device='cuda', dtype=torch.float)
|
||||
|
||||
# waived due to missing kernel support for bias in nvfp4
|
||||
#gemm1_bias[:] = 0
|
||||
#gemm2_bias[:] = 0
|
||||
|
||||
swiglu_alpha_tensor = torch.full((num_experts, ),
|
||||
swiglu_alpha,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
|
||||
swiglu_beta_tensor = torch.full((num_experts, ),
|
||||
swiglu_beta,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
|
||||
swiglu_limit_tensor = torch.full((num_experts, ),
|
||||
swiglu_limit,
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
|
||||
use_ue8m0 = False
|
||||
# Quantize hidden states. Produces scales for activations in 128x4 layout for ref impl.
|
||||
hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, hidden_states_scale_global = quant_fp4(
|
||||
@ -1343,14 +1434,29 @@ class TestMoeFp4:
|
||||
permute_info, scores = routing_reference_renormalize_naive(
|
||||
expert_logits, top_k, padding)
|
||||
|
||||
args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size,
|
||||
top_k, padding, hidden_states_fp4_bytes,
|
||||
args = moe_args(num_tokens,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
top_k,
|
||||
padding,
|
||||
hidden_states_fp4_bytes,
|
||||
hidden_states_scale_fp4_bytes,
|
||||
hidden_states_scale_global, scores,
|
||||
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes,
|
||||
gemm1_scales_global, gemm2_weights_fp4_bytes,
|
||||
gemm2_scales_fp4_bytes, gemm2_scales_global,
|
||||
permute_info, False)
|
||||
hidden_states_scale_global,
|
||||
scores,
|
||||
gemm1_weights_fp4_bytes,
|
||||
gemm1_scales_fp4_bytes,
|
||||
gemm1_scales_global,
|
||||
gemm2_weights_fp4_bytes,
|
||||
gemm2_scales_fp4_bytes,
|
||||
gemm2_scales_global,
|
||||
permute_info,
|
||||
False,
|
||||
gemm1_bias=gemm1_bias,
|
||||
gemm1_alpha=swiglu_alpha_tensor,
|
||||
gemm1_beta=swiglu_beta_tensor,
|
||||
gemm1_clamp_limit=swiglu_limit_tensor,
|
||||
gemm2_bias=gemm2_bias)
|
||||
#
|
||||
# Run the reference implementations
|
||||
#
|
||||
@ -1364,12 +1470,17 @@ class TestMoeFp4:
|
||||
# Reorder rows of W1 and scales for fused gated activation
|
||||
gemm1_weights_fp4_interleaved = []
|
||||
gemm1_scales_fp4_interleaved = []
|
||||
gemm1_bias_interleaved = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone()))
|
||||
gemm1_scales_fp4_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(
|
||||
gemm1_scales_linear_fp4[i].clone()))
|
||||
if gemm1_bias is not None:
|
||||
gemm1_bias_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(
|
||||
gemm1_bias[i].clone().reshape(-1, 1)))
|
||||
|
||||
# Stack weights and scales for all experts
|
||||
gemm1_weights_fp4_interleaved = torch.stack(
|
||||
@ -1384,8 +1495,10 @@ class TestMoeFp4:
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_fp4_shuffled = []
|
||||
gemm1_scales_fp4_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_weights_fp4_shuffled = []
|
||||
gemm2_scales_fp4_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(
|
||||
@ -1395,6 +1508,10 @@ class TestMoeFp4:
|
||||
shuffle_matrix_sf_a(
|
||||
gemm1_scales_fp4_interleaved[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
if gemm1_bias is not None:
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(gemm1_bias_interleaved[i],
|
||||
epilogue_tile_m))
|
||||
|
||||
gemm2_weights_fp4_shuffled.append(
|
||||
shuffle_matrix_a(gemm2_weights_fp4[i].view(torch.uint8),
|
||||
@ -1403,6 +1520,10 @@ class TestMoeFp4:
|
||||
shuffle_matrix_sf_a(
|
||||
gemm2_scales_linear_fp4[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
if gemm2_bias is not None:
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(gemm2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m))
|
||||
|
||||
# Stack weights for all experts
|
||||
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
||||
@ -1415,10 +1536,35 @@ class TestMoeFp4:
|
||||
torch.float8_e4m3fn).reshape(num_experts, hidden_size,
|
||||
intermediate_size // 16)
|
||||
|
||||
if gemm1_bias is not None:
|
||||
gemm1_bias_shuffled = torch.stack(gemm1_bias_shuffled).reshape(
|
||||
num_experts, -1)
|
||||
else:
|
||||
gemm1_bias_shuffled = None
|
||||
|
||||
if gemm2_bias is not None:
|
||||
gemm2_bias_shuffled = torch.stack(gemm2_bias_shuffled).reshape(
|
||||
num_experts, -1)
|
||||
else:
|
||||
gemm2_bias_shuffled = None
|
||||
|
||||
#
|
||||
# Run the TRT-LLM kernel
|
||||
#
|
||||
|
||||
if gptoss_style:
|
||||
# NOTE: correct the beta and clamp to account for the global scale factor
|
||||
# Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h
|
||||
# for more details
|
||||
swiglu_beta_tensor = swiglu_beta_tensor * args.gemm1_scales_global * args.hidden_states_scale_global
|
||||
swiglu_limit_tensor = swiglu_limit_tensor * args.gemm1_scales_global * args.hidden_states_scale_global
|
||||
# Check cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h
|
||||
# for more details
|
||||
gemm1_bias_shuffled = gemm1_bias_shuffled * args.gemm1_scales_global[:,
|
||||
None] * args.hidden_states_scale_global
|
||||
gemm2_bias_shuffled = gemm2_bias_shuffled * args_dequant.c_global_sf * args.gemm2_scales_global[:,
|
||||
None]
|
||||
|
||||
# c_global_sf: fc2_input_scale
|
||||
scale_c_fc1 = args_dequant.c_global_sf * (
|
||||
1.0 / args.gemm1_scales_global) * (1.0 /
|
||||
@ -1449,8 +1595,13 @@ class TestMoeFp4:
|
||||
hidden_states_scale_linear_fp4,
|
||||
gemm1_weights_fp4_shuffled,
|
||||
gemm1_scales_fp4_shuffled,
|
||||
gemm1_bias_shuffled, # Bias
|
||||
swiglu_alpha_tensor, # Alpha
|
||||
swiglu_beta_tensor, # Beta
|
||||
swiglu_limit_tensor, # Limit
|
||||
gemm2_weights_fp4_shuffled,
|
||||
gemm2_scales_fp4_shuffled,
|
||||
gemm2_bias_shuffled, # Bias
|
||||
scale_c_fc1,
|
||||
scale_gate_fc1,
|
||||
scale_c_fc2,
|
||||
@ -1469,11 +1620,20 @@ class TestMoeFp4:
|
||||
torch.cuda.synchronize()
|
||||
output_dequant_actual = output[0].to(torch.float)
|
||||
|
||||
if gptoss_style:
|
||||
atol = 0.2
|
||||
rtol = 0.2
|
||||
percent = 0.85
|
||||
else:
|
||||
atol = 0.1
|
||||
rtol = 0.85
|
||||
percent = 0.925
|
||||
|
||||
check_accuracy(output_dequant_reference,
|
||||
output_dequant_actual,
|
||||
atol=0.1,
|
||||
rtol=0.85,
|
||||
percent=0.925)
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
percent=percent)
|
||||
|
||||
def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int,
|
||||
intermediate_size: int, routing_info: dict,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user