mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[None][chore] NVFP4 MoE - Move weights transformation to fusion phase… (#10803)
Signed-off-by: Tal Cherckez <tcherckez@nvl72070-T11.cm.cluster> Signed-off-by: Tal Cherckez <tcherckez@nvl72039-T03.cm.cluster> Signed-off-by: Tal Cherckez <tcherckez@nvl72098-T11.cm.cluster> Signed-off-by: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com> Co-authored-by: Tal Cherckez <tcherckez@nvl72070-T11.cm.cluster> Co-authored-by: Tal Cherckez <tcherckez@nvl72039-T03.cm.cluster> Co-authored-by: Tal Cherckez <tcherckez@nvl72098-T11.cm.cluster>
This commit is contained in:
parent
0243abee22
commit
128d4ac5be
@ -13,16 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import (
|
||||
TRTLLM_NVFP4_COLUMN_SIZE,
|
||||
TRTLLM_NVFP4_ROW_SIZE,
|
||||
TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.quant import TRTLLM_NVFP4_SCALING_VECTOR_SIZE
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
@ -262,14 +255,6 @@ def trtllm_quant_nvfp4_moe_fused(
|
||||
mlp_style: "gated_mlp" or "mlp"
|
||||
act_fn: "silu" for gated_mlp, "relu2" for mlp
|
||||
"""
|
||||
NVFP4_BLOCK_SIZE = TRTLLM_NVFP4_SCALING_VECTOR_SIZE
|
||||
FP4_PER_UINT8 = 2
|
||||
|
||||
_, fc1_inter_size, _ = fc1_expert_weights_fp4.shape
|
||||
n_experts, hidden_size, inter_size = fc2_expert_weights_fp4.shape
|
||||
|
||||
# Convert the inter_size from number of uint8 elements to number of FP4 elements.
|
||||
inter_size *= FP4_PER_UINT8
|
||||
|
||||
# Validate block scale tensors are 3D (padding requirements handled below)
|
||||
assert fc1_weight_blockscale_fp8.ndim == 3, "fc1_weight_blockscale_fp8 must be 3D"
|
||||
@ -280,7 +265,7 @@ def trtllm_quant_nvfp4_moe_fused(
|
||||
|
||||
if x.dtype in (torch.float16, torch.bfloat16):
|
||||
x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize(
|
||||
x, fc1_act_global_scale, NVFP4_BLOCK_SIZE
|
||||
x, fc1_act_global_scale, TRTLLM_NVFP4_SCALING_VECTOR_SIZE
|
||||
)
|
||||
output_dtype = x.dtype
|
||||
else:
|
||||
@ -288,66 +273,6 @@ def trtllm_quant_nvfp4_moe_fused(
|
||||
input_blockscale = None
|
||||
output_dtype = x.dtype
|
||||
|
||||
# Pad inter_size to be divisible by TRTLLM_NVFP4_ROW_SIZE
|
||||
inter_size_padded = math.ceil(inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
|
||||
fc1_inter_size_padded = (
|
||||
math.ceil(fc1_inter_size / TRTLLM_NVFP4_ROW_SIZE) * TRTLLM_NVFP4_ROW_SIZE
|
||||
)
|
||||
hidden_size_padded = (
|
||||
math.ceil(hidden_size / TRTLLM_NVFP4_COLUMN_SIZE) * TRTLLM_NVFP4_COLUMN_SIZE
|
||||
)
|
||||
|
||||
inter_size_needs_padding = (is_gated_mlp and fc1_inter_size_padded != fc1_inter_size) or (
|
||||
not is_gated_mlp and inter_size_padded != inter_size
|
||||
)
|
||||
hidden_size_needs_padding = hidden_size % TRTLLM_NVFP4_COLUMN_SIZE != 0
|
||||
|
||||
hidden_blocks_padded = hidden_size_padded // NVFP4_BLOCK_SIZE
|
||||
inter_blocks_padded = inter_size_padded // NVFP4_BLOCK_SIZE
|
||||
|
||||
if inter_size_needs_padding or hidden_size_needs_padding:
|
||||
# Pad fc1_expert_weights_fp4: [E, I, H/2] or [E, 2*I, H/2]
|
||||
fc1_padded = fc1_expert_weights_fp4.new_zeros(
|
||||
fc1_expert_weights_fp4.size(0),
|
||||
fc1_inter_size_padded,
|
||||
hidden_size_padded // FP4_PER_UINT8,
|
||||
)
|
||||
fc1_padded[:, :fc1_inter_size, : hidden_size // FP4_PER_UINT8] = fc1_expert_weights_fp4
|
||||
fc1_expert_weights_fp4 = fc1_padded
|
||||
|
||||
# Block scales may already be padded, so check actual size
|
||||
fc1_bs_size1 = fc1_weight_blockscale_fp8.size(1)
|
||||
fc1_bs_size2 = fc1_weight_blockscale_fp8.size(2)
|
||||
if fc1_bs_size1 < fc1_inter_size_padded or fc1_bs_size2 < hidden_blocks_padded:
|
||||
fc1_blockscale_fp8_padded = fc1_weight_blockscale_fp8.new_zeros(
|
||||
n_experts, fc1_inter_size_padded, hidden_blocks_padded
|
||||
)
|
||||
fc1_blockscale_fp8_padded[:, :fc1_bs_size1, :fc1_bs_size2] = fc1_weight_blockscale_fp8
|
||||
fc1_weight_blockscale_fp8 = fc1_blockscale_fp8_padded
|
||||
|
||||
fc2_padded = fc2_expert_weights_fp4.new_zeros(
|
||||
n_experts, hidden_size_padded, inter_size_padded // FP4_PER_UINT8
|
||||
)
|
||||
|
||||
assert inter_size % NVFP4_BLOCK_SIZE == 0, (
|
||||
f"inter_size {inter_size} must be divisible by {NVFP4_BLOCK_SIZE}"
|
||||
)
|
||||
|
||||
fc2_padded[:, :hidden_size, : inter_size // FP4_PER_UINT8] = fc2_expert_weights_fp4
|
||||
fc2_expert_weights_fp4 = fc2_padded
|
||||
|
||||
# Pad fc2_weight_blockscale_fp8: [E, H, I/16]
|
||||
# Block scales may already be padded, so check actual size
|
||||
fc2_bs_size1 = fc2_weight_blockscale_fp8.size(1)
|
||||
fc2_bs_size2 = fc2_weight_blockscale_fp8.size(2)
|
||||
if fc2_bs_size1 < hidden_size_padded or fc2_bs_size2 < inter_blocks_padded:
|
||||
fc2_blockscale_fp8_padded = fc2_weight_blockscale_fp8.new_zeros(
|
||||
n_experts, hidden_size_padded, inter_blocks_padded
|
||||
)
|
||||
fc2_blockscale_fp8_padded[:, :fc2_bs_size1, :fc2_bs_size2] = fc2_weight_blockscale_fp8
|
||||
fc2_weight_blockscale_fp8 = fc2_blockscale_fp8_padded
|
||||
# else: block scales already have correct padded size, use as-is
|
||||
|
||||
# quant_scales is described by this code:
|
||||
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
|
||||
quant_scales = [
|
||||
|
||||
@ -14,6 +14,7 @@ TRTLLM_FP4_OP_AVAILABLE = True
|
||||
TRTLLM_NVFP4_SCALING_VECTOR_SIZE = 16
|
||||
TRTLLM_NVFP4_ROW_SIZE = 128
|
||||
TRTLLM_NVFP4_COLUMN_SIZE = 4
|
||||
TRTLLM_NVFP4_PACKING_FACTOR = 2
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_quant_fn", mutates_args=())
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Type
|
||||
@ -8,6 +9,7 @@ from torch.fx import GraphModule, Node
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
from ...custom_ops.quant import TRTLLM_NVFP4_PACKING_FACTOR, TRTLLM_NVFP4_SCALING_VECTOR_SIZE
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils._graph import delete_all_unused_submodules, eliminate_dead_code
|
||||
@ -1627,14 +1629,13 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
"w3_weight",
|
||||
"w1_input_scale",
|
||||
"w2_input_scale",
|
||||
"w3_input_scale",
|
||||
"w1_weight_scale",
|
||||
"w2_weight_scale",
|
||||
"w3_weight_scale",
|
||||
"w1_alpha",
|
||||
"w2_alpha",
|
||||
"w3_alpha",
|
||||
"is_gated_mlp",
|
||||
"act_fn",
|
||||
)
|
||||
|
||||
def _stack(param_list, dim=0, device=None, dtype=None):
|
||||
@ -1648,13 +1649,10 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
def _prepare_args_cutlass_format_nvfp4():
|
||||
if is_gated_mlp:
|
||||
# For gated MLP, concatenate w1 and w3 as [w3, w1]
|
||||
fc1_expert_weights = torch.cat(
|
||||
[w3_stacked, w1_stacked], dim=1
|
||||
).contiguous() # [E, 2*I, H]
|
||||
fc1_act_scale = torch.cat(
|
||||
[w3_input_scale_stacked, w1_input_scale_stacked], dim=1
|
||||
).contiguous()
|
||||
fc1_alpha_stacked = torch.cat([w3_alpha_stacked, w1_alpha_stacked], dim=1).contiguous()
|
||||
fc1_expert_weights = torch.cat([w3_stacked, w1_stacked], dim=1).contiguous()
|
||||
# Expect w3 input scale and alpha to be the same as w1
|
||||
fc1_act_scale = w1_input_scale_stacked
|
||||
fc1_alpha_stacked = w1_alpha_stacked
|
||||
fc1_weight_blockscale_fp8_stacked = torch.cat(
|
||||
[w3_weight_blockscale_fp8_stacked, w1_weight_blockscale_fp8_stacked], dim=1
|
||||
).contiguous()
|
||||
@ -1666,6 +1664,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
|
||||
fc2_expert_weights = w2_stacked
|
||||
fc2_act_scale = w2_input_scale_stacked
|
||||
fc2_weight_blockscale_fp8_stacked = w2_weight_blockscale_fp8_stacked
|
||||
|
||||
new_key_fc1_expert_weights = f"nvfp4_moe_w3_w1_stacked_{fused_key_counter}"
|
||||
new_key_fc2_expert_weights = f"nvfp4_moe_w2_stacked_{fused_key_counter}"
|
||||
@ -1681,13 +1680,57 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
new_key_fc1_alpha = f"nvfp4_moe_w1_alpha_stacked_{fused_key_counter}"
|
||||
new_key_fc2_alpha = f"nvfp4_moe_w2_alpha_stacked_{fused_key_counter}"
|
||||
|
||||
# Pad fc1_expert_weights to match the already padded scales
|
||||
fc1_pad_size = fc1_weight_blockscale_fp8_stacked.shape[1] - fc1_expert_weights.shape[1]
|
||||
if fc1_pad_size > 0:
|
||||
fc1_expert_weights = torch.nn.functional.pad(
|
||||
fc1_expert_weights, (0, 0, 0, fc1_pad_size), mode="constant", value=0
|
||||
)
|
||||
# Need to update fc2 scales and weights to match the padded size of fc1,
|
||||
# as they share the same intermediate dimension.
|
||||
target_intermediate = fc1_weight_blockscale_fp8_stacked.shape[1]
|
||||
TRTLLM_NVFP4_SCALING_VECTOR_NUM_ELEMENTS = TRTLLM_NVFP4_SCALING_VECTOR_SIZE
|
||||
TRTLLM_NVFP4_SCALING_BYTES_SIZE = (
|
||||
TRTLLM_NVFP4_SCALING_VECTOR_NUM_ELEMENTS // TRTLLM_NVFP4_PACKING_FACTOR
|
||||
)
|
||||
target_n_blocks = target_intermediate // TRTLLM_NVFP4_SCALING_VECTOR_NUM_ELEMENTS
|
||||
padded_target_n_blocks = (
|
||||
math.ceil(target_n_blocks / TRTLLM_NVFP4_SCALING_BYTES_SIZE)
|
||||
* TRTLLM_NVFP4_SCALING_BYTES_SIZE
|
||||
)
|
||||
fc2_blocks_pad = padded_target_n_blocks - fc2_weight_blockscale_fp8_stacked.shape[2]
|
||||
|
||||
if fc2_blocks_pad > 0:
|
||||
# unswizzle fc2 scales
|
||||
fc2_blockscale_shape = list(fc2_weight_blockscale_fp8_stacked.shape)
|
||||
fc2_blockscale_shape[2] = padded_target_n_blocks
|
||||
fc2_weight_blockscale_fp8_stacked = torch.ops.trtllm.block_scale_interleave_reverse(
|
||||
fc2_weight_blockscale_fp8_stacked.view(torch.uint8)
|
||||
)
|
||||
fc2_weight_blockscale_fp8_stacked = torch.nn.functional.pad(
|
||||
fc2_weight_blockscale_fp8_stacked, (0, fc2_blocks_pad), mode="constant", value=0
|
||||
)
|
||||
fc2_weight_blockscale_fp8_stacked = (
|
||||
torch.ops.trtllm.block_scale_interleave(fc2_weight_blockscale_fp8_stacked)
|
||||
.view(torch.float8_e4m3fn)
|
||||
.reshape(fc2_blockscale_shape)
|
||||
)
|
||||
fc2_expert_weights = torch.nn.functional.pad(
|
||||
fc2_expert_weights,
|
||||
(0, fc1_pad_size // TRTLLM_NVFP4_PACKING_FACTOR, 0, 0),
|
||||
mode="constant",
|
||||
value=0,
|
||||
).view(torch.uint8)
|
||||
|
||||
# FP4 weights are already packed as uint8, don't convert dtype
|
||||
_register_parameter(gm, new_key_fc1_expert_weights, fc1_expert_weights)
|
||||
_register_parameter(gm, new_key_fc2_expert_weights, fc2_expert_weights)
|
||||
_register_parameter(
|
||||
gm, new_key_fc1_weight_blockscale_fp8, fc1_weight_blockscale_fp8_stacked
|
||||
)
|
||||
_register_parameter(gm, new_key_fc2_weight_blockscale_fp8, w2_weight_blockscale_fp8_stacked)
|
||||
_register_parameter(
|
||||
gm, new_key_fc2_weight_blockscale_fp8, fc2_weight_blockscale_fp8_stacked
|
||||
)
|
||||
_register_parameter(gm, new_key_fc1_act_scale, fc1_act_scale)
|
||||
_register_parameter(gm, new_key_fc2_act_scale, fc2_act_scale)
|
||||
_register_parameter(gm, new_key_fc1_alpha, fc1_alpha_stacked)
|
||||
@ -1707,7 +1750,11 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
graph.get_attr(new_key_fc1_alpha),
|
||||
graph.get_attr(new_key_fc2_alpha),
|
||||
)
|
||||
return args
|
||||
kwargs = {
|
||||
"is_gated_mlp": is_gated_mlp,
|
||||
"act_fn": act_fn,
|
||||
}
|
||||
return args, kwargs
|
||||
|
||||
fused_key_counter = 0
|
||||
graph = gm.graph
|
||||
@ -1727,14 +1774,13 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
w3_list,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale,
|
||||
w1_alpha,
|
||||
w2_alpha,
|
||||
w3_alpha,
|
||||
is_gated_mlp,
|
||||
act_fn,
|
||||
) = _extract_op_args(node)
|
||||
|
||||
# Stack the actual tensor values (fast, like in quantize_moe.py)
|
||||
@ -1746,7 +1792,6 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
# Scales are buffers, not parameters
|
||||
w1_input_scale_stacked = _stack(w1_input_scale, dim=0)
|
||||
w2_input_scale_stacked = _stack(w2_input_scale, dim=0)
|
||||
w3_input_scale_stacked = _stack(w3_input_scale, dim=0, device=device, dtype=dtype)
|
||||
|
||||
# Use .view() not .to() to reinterpret bytes as float8, not value conversion
|
||||
w1_weight_blockscale_fp8_stacked = _stack(w1_weight_scale, dim=0).view(torch.float8_e4m3fn)
|
||||
@ -1757,9 +1802,8 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
|
||||
w1_alpha_stacked = _stack(w1_alpha, dim=0)
|
||||
w2_alpha_stacked = _stack(w2_alpha, dim=0)
|
||||
w3_alpha_stacked = _stack(w3_alpha, dim=0, device=device, dtype=dtype)
|
||||
|
||||
args = _prepare_args_cutlass_format_nvfp4()
|
||||
args, kwargs = _prepare_args_cutlass_format_nvfp4()
|
||||
|
||||
fused_key_counter += 1
|
||||
|
||||
@ -1768,7 +1812,7 @@ def _stack_nvfp4_moe_weights(gm: GraphModule) -> int:
|
||||
new_node = graph.call_function(
|
||||
replacement_op,
|
||||
args,
|
||||
kwargs=node.kwargs,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
node.replace_all_uses_with(new_node)
|
||||
|
||||
@ -180,6 +180,9 @@ nvidia/Nemotron-MOE:
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 86.884
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 63.268
|
||||
nvidia/Llama-3.1-Nemotron-Nano-8B-v1:
|
||||
- accuracy: 37.15
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -284,6 +284,9 @@ nvidia/Nemotron-MOE:
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 73.879
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 70.879
|
||||
microsoft/Phi-4-mini-instruct:
|
||||
- accuracy: 68.98
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
from defs.conftest import skip_pre_blackwell
|
||||
from test_common.llm_data import hf_id_to_local_model_dir, llm_models_root
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
|
||||
@ -142,7 +143,7 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-MOE"
|
||||
MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-dev-1024"
|
||||
MODEL_PATH_FP8 = f"{llm_models_root()}/Nemotron-Nano-3-30B-A3.5B-FP8-KVFP8-dev"
|
||||
MODEL_PATH_NVFP4 = f"{llm_models_root()}/Nemotron-3-Nano-30B-A3B-NVFP4"
|
||||
MODEL_PATH_NVFP4 = f"{llm_models_root()}/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4"
|
||||
|
||||
def get_default_kwargs(self):
|
||||
return {
|
||||
@ -220,11 +221,13 @@ class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip(reason="NVFP4 model is not in the CI yet")
|
||||
def test_nvfp4(self):
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
||||
def test_nvfp4(self, world_size):
|
||||
kwargs = self.get_default_kwargs()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH_NVFP4,
|
||||
tokenizer=self.MODEL_PATH_NVFP4,
|
||||
world_size=world_size,
|
||||
**kwargs) as llm:
|
||||
# Manually set quant_config for NVFP4 model to get the accuracy threshold
|
||||
llm.args.quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
|
||||
@ -218,3 +218,6 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[4]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_fp8[8]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_nvfp4[1]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_nvfp4[2]
|
||||
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_nvfp4[4]
|
||||
|
||||
@ -10,6 +10,7 @@ from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
class BlockSparseTop2MLP(nn.Module):
|
||||
@ -416,3 +417,236 @@ def test_fuse_moe_cleanup():
|
||||
assert mem_after <= mem_before, (
|
||||
f"CUDA memory increased after fusion: before={mem_before} after={mem_after}"
|
||||
)
|
||||
|
||||
|
||||
class MoEOpModelNVFP4(nn.Module):
|
||||
"""MoE model using torch_quant_nvfp4_moe op for testing fusion to trtllm_quant_nvfp4_moe_fused.
|
||||
|
||||
This model creates weights with 3D block scales that are compatible with
|
||||
the trtllm fused MoE kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=512, # Already aligned to all requirements (16, 128, etc.)
|
||||
intermediate_size=256, # Already aligned - no padding needed
|
||||
num_experts=3,
|
||||
top_k=2,
|
||||
dtype=torch.bfloat16,
|
||||
is_gated_mlp=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.dtype = dtype
|
||||
self.is_gated_mlp = is_gated_mlp
|
||||
|
||||
# Constants for NVFP4 layout
|
||||
NVFP4_BLOCK_SIZE = 16
|
||||
NVFP4_PACK_FACTOR = 2
|
||||
FLOAT8_E4M3_MAX = 448.0
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
|
||||
self.gate = nn.Linear(hidden_size, num_experts, dtype=dtype)
|
||||
|
||||
# Create sample input for scale computation
|
||||
sample_input = torch.randn(2, hidden_size, dtype=dtype, device="cuda") * 0.01
|
||||
inp_scale = fp4_global_scale(sample_input)
|
||||
|
||||
# Per-expert quantized weights and scales
|
||||
self.w1_weight = nn.ParameterList()
|
||||
self.w2_weight = nn.ParameterList()
|
||||
self.w3_weight = nn.ParameterList() if is_gated_mlp else None
|
||||
|
||||
for i in range(num_experts):
|
||||
w1_fp32 = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) * 0.01
|
||||
w2_fp32 = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) * 0.01
|
||||
|
||||
# Compute global scales
|
||||
w1_amax = torch.abs(w1_fp32).max().to(torch.float32)
|
||||
w2_amax = torch.abs(w2_fp32).max().to(torch.float32)
|
||||
|
||||
scale_1 = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||
scale_2 = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||
|
||||
# Quantize weights (non-swizzled layout)
|
||||
w1_fp4, w1_bs = torch.ops.trtllm.fp4_quantize(w1_fp32, scale_1, NVFP4_BLOCK_SIZE, False)
|
||||
w2_fp4, w2_bs = torch.ops.trtllm.fp4_quantize(w2_fp32, scale_2, NVFP4_BLOCK_SIZE, False)
|
||||
|
||||
# fp4_quantize pads block scales but not weights - infer padded dims from block scale size
|
||||
_, w1_k_packed = w1_fp4.shape
|
||||
_, w2_k_packed = w2_fp4.shape
|
||||
w1_k_padded = w1_k_packed * NVFP4_PACK_FACTOR # Convert from uint8 to FP4 element count
|
||||
w2_k_padded = w2_k_packed * NVFP4_PACK_FACTOR
|
||||
|
||||
# Calculate padded N dimension from block scale tensor size
|
||||
w1_n_padded = w1_bs.numel() // (w1_k_padded // NVFP4_BLOCK_SIZE)
|
||||
w2_n_padded = w2_bs.numel() // (w2_k_padded // NVFP4_BLOCK_SIZE)
|
||||
|
||||
# Reshape block scales to 3D format [N_padded, K/block]
|
||||
w1_bs_3d = w1_bs.view(w1_n_padded, w1_k_padded // NVFP4_BLOCK_SIZE)
|
||||
w2_bs_3d = w2_bs.view(w2_n_padded, w2_k_padded // NVFP4_BLOCK_SIZE)
|
||||
|
||||
self.w1_weight.append(nn.Parameter(w1_fp4, requires_grad=False))
|
||||
self.w2_weight.append(nn.Parameter(w2_fp4, requires_grad=False))
|
||||
|
||||
self.register_buffer(f"w1_input_scale_{i}", inp_scale)
|
||||
self.register_buffer(f"w2_input_scale_{i}", inp_scale)
|
||||
self.register_buffer(f"w1_weight_scale_{i}", w1_bs_3d.contiguous())
|
||||
self.register_buffer(f"w2_weight_scale_{i}", w2_bs_3d.contiguous())
|
||||
self.register_buffer(f"w1_alpha_{i}", 1.0 / (inp_scale * scale_1))
|
||||
self.register_buffer(f"w2_alpha_{i}", 1.0 / (inp_scale * scale_2))
|
||||
|
||||
if is_gated_mlp:
|
||||
w3_fp32 = (
|
||||
torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) * 0.01
|
||||
)
|
||||
w3_amax = torch.abs(w3_fp32).max().to(torch.float32)
|
||||
scale_3 = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w3_amax
|
||||
w3_fp4, w3_bs = torch.ops.trtllm.fp4_quantize(
|
||||
w3_fp32, scale_3, NVFP4_BLOCK_SIZE, False
|
||||
)
|
||||
|
||||
# Infer padded dimensions for w3
|
||||
_, w3_k_packed = w3_fp4.shape
|
||||
w3_k_padded = w3_k_packed * NVFP4_PACK_FACTOR
|
||||
w3_n_padded = w3_bs.numel() // (w3_k_padded // NVFP4_BLOCK_SIZE)
|
||||
w3_bs_3d = w3_bs.view(w3_n_padded, w3_k_padded // NVFP4_BLOCK_SIZE)
|
||||
|
||||
self.w3_weight.append(nn.Parameter(w3_fp4, requires_grad=False))
|
||||
self.register_buffer(f"w3_input_scale_{i}", inp_scale)
|
||||
self.register_buffer(f"w3_weight_scale_{i}", w3_bs_3d.contiguous())
|
||||
self.register_buffer(f"w3_alpha_{i}", 1.0 / (inp_scale * scale_3))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
router_logits = self.gate(x)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights = routing_weights.to(x.dtype)
|
||||
|
||||
w1_list = list(self.w1_weight)
|
||||
w2_list = list(self.w2_weight)
|
||||
w3_list = list(self.w3_weight) if self.is_gated_mlp else []
|
||||
|
||||
w1_input_scale = [getattr(self, f"w1_input_scale_{i}") for i in range(self.num_experts)]
|
||||
w2_input_scale = [getattr(self, f"w2_input_scale_{i}") for i in range(self.num_experts)]
|
||||
w3_input_scale = (
|
||||
[getattr(self, f"w3_input_scale_{i}") for i in range(self.num_experts)]
|
||||
if self.is_gated_mlp
|
||||
else []
|
||||
)
|
||||
w1_weight_scale = [getattr(self, f"w1_weight_scale_{i}") for i in range(self.num_experts)]
|
||||
w2_weight_scale = [getattr(self, f"w2_weight_scale_{i}") for i in range(self.num_experts)]
|
||||
w3_weight_scale = (
|
||||
[getattr(self, f"w3_weight_scale_{i}") for i in range(self.num_experts)]
|
||||
if self.is_gated_mlp
|
||||
else []
|
||||
)
|
||||
w1_alpha = [getattr(self, f"w1_alpha_{i}") for i in range(self.num_experts)]
|
||||
w2_alpha = [getattr(self, f"w2_alpha_{i}") for i in range(self.num_experts)]
|
||||
w3_alpha = (
|
||||
[getattr(self, f"w3_alpha_{i}") for i in range(self.num_experts)]
|
||||
if self.is_gated_mlp
|
||||
else []
|
||||
)
|
||||
|
||||
out = torch.ops.auto_deploy.torch_quant_nvfp4_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale,
|
||||
w1_alpha,
|
||||
w2_alpha,
|
||||
w3_alpha,
|
||||
is_gated_mlp=self.is_gated_mlp,
|
||||
act_fn=ActivationType.Silu if self.is_gated_mlp else ActivationType.Relu2,
|
||||
)
|
||||
return out
|
||||
|
||||
def get_input(self, device, dtype=torch.bfloat16):
|
||||
return torch.randn(2, self.hidden_size, device=device, dtype=dtype) * 0.01
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not fp4_compatible() or not trtllm_ops_available(),
|
||||
reason="Requires FP4 + TRTLLM support",
|
||||
)
|
||||
@pytest.mark.parametrize("is_gated_mlp", [True, False], ids=["gated_mlp", "mlp"])
|
||||
@pytest.mark.parametrize(
|
||||
"hidden_size,intermediate_size",
|
||||
[
|
||||
(512, 256), # Standard aligned dimensions
|
||||
(1024, 512), # Larger aligned dimensions
|
||||
(768, 384), # Common transformer dimensions (divisible by 16)
|
||||
(512, 128), # Smaller intermediate
|
||||
(256, 256), # Equal dimensions
|
||||
],
|
||||
ids=["512x256", "1024x512", "768x384", "512x128", "256x256"],
|
||||
)
|
||||
def test_nvfp4_moe_fusion(is_gated_mlp, hidden_size, intermediate_size):
|
||||
"""Test that torch_quant_nvfp4_moe fuses to trtllm_quant_nvfp4_moe_fused.
|
||||
|
||||
Note: This test uses swizzled block scales that are compatible with the fused trtllm kernel.
|
||||
The non-fused op (torch_quant_nvfp4_moe) uses a different internal path that expects
|
||||
non-swizzled scales, so we don't compare outputs between non-fused and fused.
|
||||
Instead, we verify the fusion transformation works correctly and produces valid output.
|
||||
|
||||
Tests both gated MLP (with w3) and non-gated MLP (without w3) variants with various
|
||||
hidden_size and intermediate_size configurations.
|
||||
"""
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
torch.manual_seed(1234)
|
||||
torch.cuda.manual_seed(1234)
|
||||
|
||||
model = MoEOpModelNVFP4(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
dtype=dtype,
|
||||
is_gated_mlp=is_gated_mlp,
|
||||
).to(device=device)
|
||||
x = model.get_input(device=device, dtype=dtype)
|
||||
|
||||
# Export to GraphModule
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
|
||||
# Verify non-fused op is present before fusion
|
||||
has_nonfused = any(
|
||||
is_op(n, torch.ops.auto_deploy.torch_quant_nvfp4_moe) for n in gm.graph.nodes
|
||||
)
|
||||
assert has_nonfused, "Expected torch_quant_nvfp4_moe op before fusion"
|
||||
|
||||
# Apply NVFP4 MoE fusion
|
||||
gm_transformed = InferenceOptimizer(
|
||||
None,
|
||||
{
|
||||
"fuse_nvfp4_moe": {
|
||||
"stage": "post_load_fusion",
|
||||
},
|
||||
},
|
||||
)(None, gm)
|
||||
|
||||
# Verify fused op is present after fusion
|
||||
has_fused = any(
|
||||
is_op(n, torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused)
|
||||
for n in gm_transformed.graph.nodes
|
||||
)
|
||||
assert has_fused, "Expected trtllm_quant_nvfp4_moe_fused op after fusion"
|
||||
|
||||
# Run fused graph to verify it produces valid output (not NaN/Inf)
|
||||
with torch.inference_mode():
|
||||
fused_output = gm_transformed(x)
|
||||
|
||||
assert not torch.isnan(fused_output).any(), "Fused output contains NaN"
|
||||
assert not torch.isinf(fused_output).any(), "Fused output contains Inf"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user