[None][feat] AutoDeploy: Add FP8 MOE for Nemotron (#8599)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Co-authored-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Chenghao Zhang 2025-10-25 12:26:45 -07:00 committed by GitHub
parent 95be56e56b
commit a6d20f6f9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 899 additions and 176 deletions

View File

@ -116,6 +116,9 @@ transforms:
fuse_moe:
stage: post_load_fusion
enabled: true
fuse_fp8_moe:
stage: post_load_fusion
enabled: true
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
# TODO (lucaslie): add backend selection as part of configurable inference optimizers

View File

@ -228,6 +228,8 @@ def torch_quant_fp8_moe(
w1_weight_scale: List[torch.Tensor],
w2_weight_scale: List[torch.Tensor],
w3_weight_scale: List[torch.Tensor],
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
act_fn: str = "silu", # silu or relu2
) -> torch.Tensor:
"""
FP8 MoE op using quantized linear operations.
@ -239,40 +241,91 @@ def torch_quant_fp8_moe(
x: Input tensor of shape (B, H) or (B, S, H).
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
routing_weights: Tensor of normalized routing weights.
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
w1_weight:
List of per-expert weight tensors:
mlp_style=="gated_mlp": W1 with shape (I, H) "gate" projection.
mlp_style=="mlp": W_up with shape (I, H) up projection.
w2_weight:
List of per-expert weight tensors:
gated_mlp: W2 with shape (H, I) down projection.
mlp: W_down with shape (H, I) down projection.
w3_weight:
List of per-expert weight tensors:
gated_mlp: W3 with shape (I, H) "up" (second) projection in gated MLP.
mlp: pass an empty list []; ignored.
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops.
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops.
mlp_style:
Selects the per-expert MLP computation:
"gated_mlp" (default, Mixtral/DeepSeek-style):
y = W2( act(W1 x) * (W3 x) )
"mlp" (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
"""
def make_fp8_mlp(i):
def mlp(inp):
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
)
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
)
prod = F.silu(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_fp8_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
)
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()
return mlp
if style == "gated_mlp":
def make_fp8_mlp(i):
def mlp(inp):
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
)
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
)
prod = act_fn(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_fp8_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
)
return mlp
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
elif style == "mlp":
def make_fp8_mlp(i):
def mlp(inp):
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
)
return torch.ops.auto_deploy.torch_quant_fp8_linear(
act_fn(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
)
return mlp
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
return _template_moe(x, selected_experts, routing_weights, mlps)
@ -290,6 +343,8 @@ def torch_quant_fp8_moe_fake(
w1_weight_scale: List[torch.Tensor],
w2_weight_scale: List[torch.Tensor],
w3_weight_scale: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)
@ -311,6 +366,8 @@ def torch_quant_nvfp4_moe(
w1_alpha: List[torch.Tensor],
w2_alpha: List[torch.Tensor],
w3_alpha: List[torch.Tensor],
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
act_fn: str = "silu", # silu or relu2
) -> torch.Tensor:
"""
FP4 MoE op using quantized linear operations.
@ -322,45 +379,101 @@ def torch_quant_nvfp4_moe(
x: Input tensor of shape (B, H) or (B, S, H).
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
routing_weights: Tensor of normalized routing weights.
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
w1_weight:
List of per-expert weight tensors:
mlp_style=="gated_mlp": W1 with shape (I, H) "gate" projection.
mlp_style=="mlp": W_up with shape (I, H) up projection.
w2_weight:
List of per-expert weight tensors:
gated_mlp: W2 with shape (H, I) down projection.
mlp: W_down with shape (H, I) down projection.
w3_weight:
List of per-expert weight tensors:
gated_mlp: W3 with shape (I, H) "up" (second) projection in gated MLP.
mlp: pass an empty list []; ignored.
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors.
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors.
w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization.
mlp_style:
Selects the per-expert MLP computation:
"gated_mlp" (default, Mixtral/DeepSeek-style):
y = W2( act(W1 x) * (W3 x) )
"mlp" (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
"""
def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
alpha=w3_alpha[i],
)
prod = F.silu(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
alpha=w2_alpha[i],
)
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()
return mlp
if style == "gated_mlp":
def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w3_weight[i],
bias=None,
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
alpha=w3_alpha[i],
)
prod = act_fn(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
prod,
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
alpha=w2_alpha[i],
)
return mlp
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
elif style == "mlp":
def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
act_fn(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
weight_scale=w2_weight_scale[i],
alpha=w2_alpha[i],
)
return mlp
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
return _template_moe(x, selected_experts, routing_weights, mlps)
@ -381,6 +494,8 @@ def torch_quant_nvfp4_moe_fake(
w1_alpha: List[torch.Tensor],
w2_alpha: List[torch.Tensor],
w3_alpha: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -159,6 +159,130 @@ def fused_mlp_moe_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit
def fused_mlp_moe_kernel_w8a8(
# Pointers to matrices (A in FP8, B in FP8)
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# Strides
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Scale pointers
a_scale_ptr,
b_scale_ptr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_id_mask = offs_token_id < EM
offs_token = tl.load(
sorted_token_ids_ptr + offs_token_id, mask=token_id_mask, other=num_valid_tokens
)
token_mask = offs_token < num_valid_tokens
# Clamp offs_token to valid range to avoid out-of-bounds pointer arithmetic
# Padding tokens have value >= num_valid_tokens and will be masked out
# Clamp to last valid token instead of 0 to avoid cache/memory issues
max_valid_token = num_valid_tokens - 1
offs_token_clamped = tl.where(token_mask, offs_token, max_valid_token)
# Expert id for this block (one expert per M-tile)
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
_write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token_clamped,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token_clamped[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
# Load tensor-wise scales before loop
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0,
)
# Use acc= for FP8 fast accumulation (matches vLLM)
accumulator = tl.dot(a, b, acc=accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# Apply scales after K-loop
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token_clamped, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token_clamped[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def _default_kernel_config(M: int, E: int, N: int, K: int, top_k: int) -> dict:
if M <= E:
return {
@ -245,12 +369,15 @@ def _invoke_kernel(
topk_weights: torch.Tensor | None,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, # Changed to tensor for CUDA graph compatibility
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict,
compute_type,
a_scale: torch.Tensor | None = None,
b_scale: torch.Tensor | None = None,
):
"""Unified kernel launcher for both unquantized and FP8 W8A8 MoE kernels."""
assert B.ndim == 3 and C.ndim == 3
EM = sorted_token_ids.numel()
if EM == 0:
@ -268,7 +395,7 @@ def _invoke_kernel(
)
num_tokens = A.size(0) * top_k
fused_mlp_moe_kernel[_grid](
common_args = [
A,
B,
C,
@ -287,112 +414,96 @@ def _invoke_kernel(
B.stride(1),
C.stride(1),
C.stride(2),
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
num_warps=config["num_warps"],
num_stages=config["num_stages"],
)
]
common_kwargs = {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"MUL_ROUTED_WEIGHT": mul_routed_weight,
"top_k": top_k,
"compute_type": compute_type,
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
}
if a_scale is not None and b_scale is not None:
# FP8 W8A8 path
fused_mlp_moe_kernel_w8a8[_grid](*common_args, a_scale, b_scale, **common_kwargs)
else:
# Unquantized path
fused_mlp_moe_kernel[_grid](*common_args, **common_kwargs)
def fused_mlp_relu2_unquantized(
hidden_states: torch.Tensor, # [M, H]
w_up: torch.Tensor, # [E, I, H]
w_down: torch.Tensor, # [E, H, I]
topk_ids: torch.Tensor, # [M, top_k]
topk_weights: torch.Tensor, # [M, top_k]
*,
def _get_compute_type(dtype: torch.dtype):
"""Get Triton compute type from torch dtype."""
if dtype == torch.bfloat16:
return tl.bfloat16
elif dtype == torch.float16:
return tl.float16
elif dtype == torch.float32:
return tl.float32
else:
raise ValueError(f"Unsupported dtype: {dtype}")
def _fused_moe_mlp_relu2(
hidden_states: torch.Tensor,
w_up: torch.Tensor,
w_down: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
Fast unquantized MoE MLP with ReLU^2 activation between two per-expert GEMMs.
Requirements:
- w_up: (E, I, H) with last dim contiguous
- w_down: (E, H, I) with last dim contiguous
- hidden_states: (M, H), topk_ids/topk_weights: (M, top_k)
"""
assert hidden_states.ndim == 2
assert w_up.ndim == 3 and w_down.ndim == 3
assert topk_ids.ndim == 2 and topk_weights.ndim == 2
"""Fused MoE 2-layer MLP with ReLU^2 activation using Triton."""
M, H = hidden_states.shape
E, inter_size, H_up = w_up.shape
E2, H_down, inter_size2 = w_down.shape
assert E == E2 and H == H_up and H == H_down and inter_size == inter_size2
E, inter_size, _ = w_up.shape
top_k = topk_ids.shape[1]
# Ensure memory layout compatible with kernel expectations
A = hidden_states.contiguous()
B1 = w_up.contiguous() # (E, I, H)
B2 = w_down.contiguous() # (E, H, I)
# Kernel config (use a single BLOCK_SIZE_M for both GEMMs)
config = _default_kernel_config(M, E, inter_size, H, top_k)
# Token routing packing (group-by-expert, pad to BLOCK_SIZE_M)
sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens(
topk_ids,
M,
E,
top_k,
config["BLOCK_SIZE_M"],
topk_ids, M, E, top_k, config["BLOCK_SIZE_M"]
)
# Workspaces
cache1 = A.new_empty((M, top_k, inter_size))
cache2 = A.new_empty((M * top_k, inter_size))
cache3 = A.new_empty((M, top_k, H))
cache1 = hidden_states.new_empty((M, top_k, inter_size))
cache3 = hidden_states.new_empty((M, top_k, H))
compute_type = _get_compute_type(hidden_states.dtype)
# Compute type
if A.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif A.dtype == torch.float16:
compute_type = tl.float16
elif A.dtype == torch.float32:
compute_type = tl.float32
else:
raise ValueError(f"Unsupported dtype for hidden_states: {A.dtype}")
# GEMM 1: X @ W_up^T → cache1 (no routing weights here)
# GEMM 1: hidden @ w_up^T
_invoke_kernel(
A,
B1,
hidden_states.contiguous(),
w_up.contiguous(),
cache1,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight=False,
top_k=top_k,
config=config,
compute_type=compute_type,
False,
top_k,
config,
compute_type,
)
# Activation (ReLU^2) without gating/multiplication
cache2 = torch.square(F.relu(cache1.view(-1, inter_size)))
# Activation: ReLU^2
act = torch.square(F.relu(cache1.view(-1, inter_size)))
# GEMM 2: Act(cache1) @ W_down^T → cache3 (apply routing weights)
# GEMM 2: act @ w_down^T
_invoke_kernel(
cache2,
B2,
act,
w_down.contiguous(),
cache3,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight=not apply_router_weight_on_input,
top_k=1, # ensure offs_token maps to flattened rows (m*top_k + n)
config=config,
compute_type=compute_type,
not apply_router_weight_on_input,
1,
config,
compute_type,
)
# Sum across top-k per token
out = cache3.sum(dim=1)
return out
return cache3.sum(dim=1)
@torch.library.custom_op("auto_deploy::triton_moe_fused", mutates_args=())
@ -403,36 +514,13 @@ def triton_fused_moe(
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
"""
Triton implementation of the Fused MOE ops for Nemotron-6 models
Each expert has two weight matrices and squared ReLU activation between them.
"""
"""Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
x_shape = x.shape
hidden_size = x_shape[-1]
x2d = x.view(-1, hidden_size)
x2d = x.view(-1, x_shape[-1])
topk_ids = selected_experts.to(torch.int32).contiguous()
topk_weights = routing_weights.to(torch.float32).contiguous()
routing_weights = routing_weights.to(torch.float32)
selected_experts = selected_experts.to(torch.int32)
# Expect selected_experts/routing_weights to be [M, top_k]
topk_ids = selected_experts.contiguous()
topk_weights = routing_weights.contiguous()
assert topk_ids.dim() == 2 and topk_weights.dim() == 2, (
f"Expected 2D routing tensors, got {topk_ids.shape} and {topk_weights.shape}"
)
assert topk_ids.shape[0] == x2d.shape[0], (
f"Token count mismatch: tokens={x2d.shape[0]} ids={topk_ids.shape[0]}"
)
out2d = fused_mlp_relu2_unquantized(
x2d,
w1_stacked_weight,
w2_stacked_weight,
topk_ids,
topk_weights,
apply_router_weight_on_input=False,
)
out2d = _fused_moe_mlp_relu2(x2d, w1_stacked_weight, w2_stacked_weight, topk_ids, topk_weights)
return out2d.view(x_shape)
@ -445,3 +533,120 @@ def triton_fused_moe(
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)
def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear)."""
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
@torch.library.custom_op("auto_deploy::triton_quant_fp8_moe", mutates_args=())
def triton_quant_fp8_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights
w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights
w3_weight: torch.Tensor, # unused for mlp style
w1_input_scale: torch.Tensor, # [E] stacked input scales
w2_input_scale: torch.Tensor, # [E] stacked input scales
w3_input_scale: torch.Tensor, # unused
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
w3_weight_scale: torch.Tensor, # unused
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
"""Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation."""
if mlp_style != "mlp":
raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only")
x_shape = x.shape
x2d = x.view(-1, x_shape[-1])
topk_ids = selected_experts.to(torch.int32).contiguous()
topk_weights = routing_weights.to(torch.float32).contiguous()
# Weights are already stacked [E, ...] - just ensure contiguous and extract scales
w1_q = w1_weight.contiguous()
w2_q = w2_weight.contiguous()
a1_scale = w1_input_scale[0].to(torch.float32).reshape(1).contiguous()
a2_scale = w2_input_scale[0].to(torch.float32).reshape(1).contiguous()
b1_scale = w1_weight_scale.to(torch.float32).contiguous()
b2_scale = w2_weight_scale.to(torch.float32).contiguous()
# Setup
M, H = x2d.shape
E, inter_size, _ = w1_q.shape
top_k = topk_ids.shape[1]
config = _default_kernel_config(M, E, inter_size, H, top_k)
sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens(
topk_ids, M, E, top_k, config["BLOCK_SIZE_M"]
)
compute_type = _get_compute_type(x2d.dtype)
# Quantize input and allocate caches
x_a8 = _quantize_fp8(x2d, a1_scale)
cache1 = x2d.new_empty((M, top_k, inter_size))
cache3 = x2d.new_empty((M, top_k, H))
# GEMM 1: FP8 input @ FP8 w_up^T → BF16
_invoke_kernel(
x_a8,
w1_q,
cache1,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
top_k,
config,
compute_type,
a_scale=a1_scale,
b_scale=b1_scale,
)
# Activation: ReLU^2, then quantize
act = torch.square(F.relu(cache1.view(-1, inter_size)))
act_a8 = _quantize_fp8(act, a2_scale)
# GEMM 2: FP8 activation @ FP8 w_down^T → BF16
_invoke_kernel(
act_a8,
w2_q,
cache3,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type,
a_scale=a2_scale,
b_scale=b2_scale,
)
return cache3.sum(dim=1).view(x_shape)
@triton_quant_fp8_moe.register_fake
def triton_quant_fp8_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_weight: torch.Tensor,
w2_weight: torch.Tensor,
w3_weight: torch.Tensor,
w1_input_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
w3_input_scale: torch.Tensor,
w1_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -50,7 +50,6 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0)
new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}"
# Triton fused MoE op supports mlp only.
replacement_op = torch.ops.auto_deploy.triton_moe_fused
else:
@ -567,6 +566,176 @@ class MatchNVFP4MoePattern(MatchMoePattern):
return ["input_scale", "weight_scale", "alpha"]
def _stack_fp8_moe_weights(gm: GraphModule) -> int:
"""
Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters.
This is fast because we directly stack the tensor values (not graph nodes).
Similar to _insert_fused_moe_ops but for quantized MoE.
"""
fused_key_counter = 0
graph = gm.graph
for node in graph.nodes:
if not is_op(node, torch.ops.auto_deploy.torch_quant_fp8_moe):
continue
# Extract weight and scale lists from args
try:
(
hidden_states,
selected_experts,
routing_weights,
w1_list,
w2_list,
w3_list,
w1_input_scale,
w2_input_scale,
w3_input_scale,
w1_weight_scale,
w2_weight_scale,
w3_weight_scale,
) = extract_op_args(
node,
"x",
"selected_experts",
"routing_weights",
"w1_weight",
"w2_weight",
"w3_weight",
"w1_input_scale",
"w2_input_scale",
"w3_input_scale",
"w1_weight_scale",
"w2_weight_scale",
"w3_weight_scale",
)
except Exception:
continue
# Helper to get parameter or buffer
def get_param_or_buffer(target):
"""Get parameter or buffer by target name."""
try:
return gm.get_parameter(target)
except AttributeError:
# It's a buffer, not a parameter
parts = target.rsplit(".", 1)
if len(parts) == 2:
mod = gm.get_submodule(parts[0])
return getattr(mod, parts[1])
else:
return getattr(gm, target)
# Stack the actual tensor values (fast, like in quantize_moe.py)
w1_stacked = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0)
w2_stacked = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0)
w3_stacked = (
torch.stack([gm.get_parameter(n.target) for n in w3_list], dim=0)
if w3_list
else torch.empty(0, device=w1_stacked.device, dtype=w1_stacked.dtype)
)
# Scales are buffers, not parameters
w1_input_scale_stacked = torch.stack(
[get_param_or_buffer(n.target) for n in w1_input_scale], dim=0
)
w2_input_scale_stacked = torch.stack(
[get_param_or_buffer(n.target) for n in w2_input_scale], dim=0
)
w3_input_scale_stacked = (
torch.stack([get_param_or_buffer(n.target) for n in w3_input_scale], dim=0)
if w3_input_scale
else torch.empty(
0, device=w1_input_scale_stacked.device, dtype=w1_input_scale_stacked.dtype
)
)
w1_weight_scale_stacked = torch.stack(
[get_param_or_buffer(n.target) for n in w1_weight_scale], dim=0
)
w2_weight_scale_stacked = torch.stack(
[get_param_or_buffer(n.target) for n in w2_weight_scale], dim=0
)
w3_weight_scale_stacked = (
torch.stack([get_param_or_buffer(n.target) for n in w3_weight_scale], dim=0)
if w3_weight_scale
else torch.empty(
0, device=w1_weight_scale_stacked.device, dtype=w1_weight_scale_stacked.dtype
)
)
# Register stacked tensors as new parameters
new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}"
new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}"
new_key_w3 = f"quant_moe_w3_stacked_{fused_key_counter}"
new_key_w1_input_scale = f"quant_moe_w1_input_scale_stacked_{fused_key_counter}"
new_key_w2_input_scale = f"quant_moe_w2_input_scale_stacked_{fused_key_counter}"
new_key_w3_input_scale = f"quant_moe_w3_input_scale_stacked_{fused_key_counter}"
new_key_w1_weight_scale = f"quant_moe_w1_weight_scale_stacked_{fused_key_counter}"
new_key_w2_weight_scale = f"quant_moe_w2_weight_scale_stacked_{fused_key_counter}"
new_key_w3_weight_scale = f"quant_moe_w3_weight_scale_stacked_{fused_key_counter}"
fused_key_counter += 1
# Register as parameters (not buffers, to match the original per-expert params)
gm.register_parameter(new_key_w1, torch.nn.Parameter(w1_stacked, requires_grad=False))
gm.register_parameter(new_key_w2, torch.nn.Parameter(w2_stacked, requires_grad=False))
gm.register_parameter(new_key_w3, torch.nn.Parameter(w3_stacked, requires_grad=False))
gm.register_parameter(
new_key_w1_input_scale, torch.nn.Parameter(w1_input_scale_stacked, requires_grad=False)
)
gm.register_parameter(
new_key_w2_input_scale, torch.nn.Parameter(w2_input_scale_stacked, requires_grad=False)
)
gm.register_parameter(
new_key_w3_input_scale, torch.nn.Parameter(w3_input_scale_stacked, requires_grad=False)
)
gm.register_parameter(
new_key_w1_weight_scale,
torch.nn.Parameter(w1_weight_scale_stacked, requires_grad=False),
)
gm.register_parameter(
new_key_w2_weight_scale,
torch.nn.Parameter(w2_weight_scale_stacked, requires_grad=False),
)
gm.register_parameter(
new_key_w3_weight_scale,
torch.nn.Parameter(w3_weight_scale_stacked, requires_grad=False),
)
# Create new node with get_attr for stacked parameters
with graph.inserting_before(node):
new_node = graph.call_function(
torch.ops.auto_deploy.triton_quant_fp8_moe,
args=(
hidden_states,
selected_experts,
routing_weights,
graph.get_attr(new_key_w1),
graph.get_attr(new_key_w2),
graph.get_attr(new_key_w3),
graph.get_attr(new_key_w1_input_scale),
graph.get_attr(new_key_w2_input_scale),
graph.get_attr(new_key_w3_input_scale),
graph.get_attr(new_key_w1_weight_scale),
graph.get_attr(new_key_w2_weight_scale),
graph.get_attr(new_key_w3_weight_scale),
),
kwargs=node.kwargs,
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
# Clean up after processing all nodes
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
# will remove the parameters/buffers that are no longer referenced
gm.graph.eliminate_dead_code()
gm.delete_all_unused_submodules()
return fused_key_counter
@TransformRegistry.register("fuse_moe")
class FuseMoe(BaseTransform):
"""
@ -588,3 +757,29 @@ class FuseMoe(BaseTransform):
skipped=False, num_matches=fused_key_counter, is_clean=False, has_valid_shapes=False
)
return gm, info
@TransformRegistry.register("fuse_fp8_moe")
class FuseFP8Moe(BaseTransform):
"""
Stack per-expert FP8 MoE weights and scales to avoid runtime stacking overhead.
This runs after weights are loaded, similar to FuseMoe for unquantized MoE.
"""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
with cuda_memory_tracker():
fused_key_counter = _stack_fp8_moe_weights(gm)
info = TransformInfo(
skipped=(fused_key_counter == 0),
num_matches=fused_key_counter,
is_clean=False,
has_valid_shapes=False,
)
return gm, info

View File

@ -16,11 +16,6 @@ from .quantization import (
Quantization,
)
quantized_moe_op_map = {
"FP8": torch.ops.auto_deploy.torch_quant_fp8_moe,
"NVFP4": torch.ops.auto_deploy.torch_quant_nvfp4_moe,
}
def _quantize_moe_node(
gm: GraphModule,
@ -92,11 +87,33 @@ def _quantize_moe_node(
s1, s2, s3 = collect_scales(idx)
args.extend([s1, s2, s3])
# Extract mlp_style and act_fn from the original node
# These can be in args[6:] or in kwargs
mlp_style = "gated_mlp" # default
act_fn = "silu" # default
if len(node.args) > 6:
mlp_style = node.args[6]
elif "mlp_style" in node.kwargs:
mlp_style = node.kwargs["mlp_style"]
if len(node.args) > 7:
act_fn = node.args[7]
elif "act_fn" in node.kwargs:
act_fn = node.kwargs["act_fn"]
# Prepare kwargs for the quantized op
kwargs = {
"mlp_style": mlp_style,
"act_fn": act_fn,
}
# Replace the current node with the quantized version
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
quantized_op,
args=tuple(args),
kwargs=kwargs,
)
node.replace_all_uses_with(new_node)
gm.graph.erase_node(node)

View File

@ -140,9 +140,6 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_auto_dtype(self, enable_chunked_prefill):
if enable_chunked_prefill:
pytest.skip(
"see https://github.com/NVIDIA/TensorRT-LLM/issues/8272")
kwargs = self.get_default_kwargs(enable_chunked_prefill)
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH,
@ -152,3 +149,49 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
class TestNemotronMOE(LlmapiAccuracyTestHarness):
MODEL_NAME = "nvidia/Nemotron-MOE"
MODEL_PATH = f"{llm_models_root()}/Nemotron-MOE/"
def get_default_kwargs(self):
return {
"skip_tokenizer_init": False,
"trust_remote_code": True,
# SSMs do not support cache reuse.
"kv_cache_config": {
"enable_block_reuse": False
},
# Keep max_batch_size as in the PyTorch test to avoid OOM
"max_batch_size": 128,
# Model context length is 8K
"max_seq_len": 8192,
# Set explicitly to match default build_config behavior
"max_num_tokens": 8192,
"skip_loading_weights": False,
"compile_backend": "torch-cudagraph",
"free_mem_ratio": 0.7,
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
}
def get_default_sampling_params(self):
eos_id = -1
beam_width = 1
return SamplingParams(end_id=eos_id,
pad_id=eos_id,
n=beam_width,
use_beam_search=beam_width > 1)
@pytest.mark.skip_less_device_memory(32000)
def test_auto_dtype(self):
pytest.skip("Nemotron-MOE is not in CI yet")
kwargs = self.get_default_kwargs()
sampling_params = self.get_default_sampling_params()
with AutoDeployLLM(model=self.MODEL_PATH,
tokenizer=self.MODEL_PATH,
**kwargs) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=sampling_params)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

View File

@ -1,5 +1,6 @@
import pytest
import torch
from utils.util import skip_pre_hopper
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.load_moe_align import moe_align_block_size
@ -215,3 +216,147 @@ def test_moe_align_kernel_groups_tokens_by_expert_and_block_padding():
ref_counts_all = torch.bincount(ref_sorted_used.cpu().to(torch.int64), minlength=T + 1)
assert torch.all(ref_counts_all == counts_all)
@skip_pre_hopper
def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
"""Test triton_quant_fp8_moe against torch_quant_fp8_moe reference."""
torch.manual_seed(0)
if not torch.cuda.is_available():
pytest.skip("CUDA is required for triton_quant_fp8_moe test")
device = "cuda"
dtype = torch.bfloat16
M = 32 # tokens
HIDDEN_SIZE = 16 # Must be multiple of 16 for FP8 linear
INTERMEDIATE_SIZE = 32 # Must be multiple of 16 for FP8 linear
E = 4 # experts
top_k = 2
# Use small normalized values to avoid FP8 range issues
x = torch.randn(M, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
# Create BF16 weights for each expert (normalized to small values)
w_up_list = [
torch.randn(INTERMEDIATE_SIZE, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
for _ in range(E)
]
w_down_list = [
torch.randn(HIDDEN_SIZE, INTERMEDIATE_SIZE, device=device, dtype=dtype) * 0.1
for _ in range(E)
]
# Stack weights [E, ...]
w_up_stacked = torch.stack(w_up_list, dim=0).contiguous() # [E, I, H]
w_down_stacked = torch.stack(w_down_list, dim=0).contiguous() # [E, H, I]
# Quantize weights to FP8 with per-expert scales
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
# Per-expert weight scales (use max absolute value per expert)
w1_weight_scale = torch.tensor(
[w_up_stacked[e].abs().max().item() / FP8_MAX for e in range(E)],
device=device,
dtype=torch.float32,
)
w2_weight_scale = torch.tensor(
[w_down_stacked[e].abs().max().item() / FP8_MAX for e in range(E)],
device=device,
dtype=torch.float32,
)
# Quantize weights and stack
w1_fp8_list = [
(w_up_stacked[e] / w1_weight_scale[e]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
for e in range(E)
]
w2_fp8_list = [
(w_down_stacked[e] / w2_weight_scale[e]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
for e in range(E)
]
w1_fp8_stacked = torch.stack(w1_fp8_list).contiguous()
w2_fp8_stacked = torch.stack(w2_fp8_list).contiguous()
# Input scales (tensor-wise, replicated per expert for interface compatibility)
x_scale = x.abs().max().item() / FP8_MAX
w1_input_scale_tensor = torch.full((E,), x_scale, device=device, dtype=torch.float32)
# Compute intermediate activation scale by simulating first GEMM + ReLU^2
# This ensures w2_input_scale matches the actual activation magnitude
with torch.no_grad():
# Simulate the first GEMM: quantize input, do FP8 matmul, apply ReLU^2
x_q = (x / w1_input_scale_tensor[0]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
# Dequantize and compute output for a sample
x_dq = x_q[:8].to(torch.float32) * w1_input_scale_tensor[0].item()
w1_dq = w1_fp8_stacked[0].to(torch.float32) * w1_weight_scale[0].item()
sample_out = torch.nn.functional.linear(x_dq.to(dtype), w1_dq.to(dtype))
sample_act = torch.square(torch.nn.functional.relu(sample_out))
intermediate_scale = sample_act.abs().max().item() / FP8_MAX
# Ensure scale is not too small
intermediate_scale = max(intermediate_scale, 1e-6)
w2_input_scale_tensor = torch.full((E,), intermediate_scale, device=device, dtype=torch.float32)
# Convert scales to lists for torch_quant_fp8_moe reference
w1_input_scale_list = [w1_input_scale_tensor[0].clone() for _ in range(E)]
w2_input_scale_list = [w2_input_scale_tensor[0].clone() for _ in range(E)]
w1_weight_scale_list = [w1_weight_scale[e].clone() for e in range(E)]
w2_weight_scale_list = [w2_weight_scale[e].clone() for e in range(E)]
# Dummy w3 tensors (unused for mlp style)
w3_fp8_list = [torch.empty((1, 1), device=device, dtype=torch.float8_e4m3fn) for _ in range(E)]
w3_fp8_stacked = torch.stack(w3_fp8_list).contiguous()
w3_input_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
w3_input_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)
w3_weight_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
w3_weight_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)
# Create controlled routing to ensure even token distribution across experts
selected_experts = torch.zeros((M, top_k), dtype=torch.int64, device=device)
for i in range(M):
# Distribute tokens evenly: token i goes to experts (i % E) and ((i+1) % E)
selected_experts[i, 0] = i % E
selected_experts[i, 1] = (i + 1) % E
# Create equal routing weights
routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k
# Triton FP8 quantized MoE (uses stacked tensors)
out_triton = torch.ops.auto_deploy.triton_quant_fp8_moe(
x,
selected_experts.to(torch.int32),
routing_weights,
w1_fp8_stacked,
w2_fp8_stacked,
w3_fp8_stacked,
w1_input_scale_tensor,
w2_input_scale_tensor,
w3_input_scale_tensor,
w1_weight_scale,
w2_weight_scale,
w3_weight_scale_tensor,
mlp_style="mlp",
act_fn="relu2",
)
# Reference: Torch quantized FP8 MoE (uses lists of tensors and scales)
out_torch = torch.ops.auto_deploy.torch_quant_fp8_moe(
x,
selected_experts,
routing_weights,
w1_weight=w1_fp8_list,
w2_weight=w2_fp8_list,
w3_weight=w3_fp8_list,
w1_input_scale=w1_input_scale_list,
w2_input_scale=w2_input_scale_list,
w3_input_scale=w3_input_scale_list,
w1_weight_scale=w1_weight_scale_list,
w2_weight_scale=w2_weight_scale_list,
w3_weight_scale=w3_weight_scale_list,
mlp_style="mlp",
act_fn="relu2",
)
torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2)