mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] AutoDeploy: Add FP8 MOE for Nemotron (#8599)
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: nvchenghaoz <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com> Co-authored-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
95be56e56b
commit
a6d20f6f9b
@ -116,6 +116,9 @@ transforms:
|
||||
fuse_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
fuse_fp8_moe:
|
||||
stage: post_load_fusion
|
||||
enabled: true
|
||||
fuse_allreduce_residual_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
|
||||
@ -228,6 +228,8 @@ def torch_quant_fp8_moe(
|
||||
w1_weight_scale: List[torch.Tensor],
|
||||
w2_weight_scale: List[torch.Tensor],
|
||||
w3_weight_scale: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
|
||||
act_fn: str = "silu", # silu or relu2
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
FP8 MoE op using quantized linear operations.
|
||||
@ -239,40 +241,91 @@ def torch_quant_fp8_moe(
|
||||
x: Input tensor of shape (B, H) or (B, S, H).
|
||||
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
|
||||
routing_weights: Tensor of normalized routing weights.
|
||||
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
|
||||
w1_weight:
|
||||
List of per-expert weight tensors:
|
||||
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
|
||||
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
|
||||
w2_weight:
|
||||
List of per-expert weight tensors:
|
||||
• gated_mlp: W2 with shape (H, I) — down projection.
|
||||
• mlp: W_down with shape (H, I) — down projection.
|
||||
w3_weight:
|
||||
List of per-expert weight tensors:
|
||||
• gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
|
||||
• mlp: pass an empty list []; ignored.
|
||||
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops.
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops.
|
||||
|
||||
mlp_style:
|
||||
Selects the per-expert MLP computation:
|
||||
• "gated_mlp" (default, Mixtral/DeepSeek-style):
|
||||
y = W2( act(W1 x) * (W3 x) )
|
||||
• "mlp" (NemotronH-style 2-layer MLP):
|
||||
y = W_down( act(W_up x) )
|
||||
act_fn:
|
||||
Elementwise activation applied inside the expert MLP.
|
||||
Supported: "silu" (default), "relu2" (ReLU then square).
|
||||
"""
|
||||
|
||||
def make_fp8_mlp(i):
|
||||
def mlp(inp):
|
||||
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
inp,
|
||||
w1_weight[i],
|
||||
bias=None,
|
||||
input_scale=w1_input_scale[i],
|
||||
weight_scale=w1_weight_scale[i],
|
||||
)
|
||||
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
inp,
|
||||
w3_weight[i],
|
||||
bias=None,
|
||||
input_scale=w3_input_scale[i],
|
||||
weight_scale=w3_weight_scale[i],
|
||||
)
|
||||
prod = F.silu(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
prod,
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
weight_scale=w2_weight_scale[i],
|
||||
)
|
||||
act_fn = _resolve_activation(act_fn)
|
||||
style = mlp_style.lower()
|
||||
|
||||
return mlp
|
||||
if style == "gated_mlp":
|
||||
|
||||
def make_fp8_mlp(i):
|
||||
def mlp(inp):
|
||||
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
inp,
|
||||
w1_weight[i],
|
||||
bias=None,
|
||||
input_scale=w1_input_scale[i],
|
||||
weight_scale=w1_weight_scale[i],
|
||||
)
|
||||
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
inp,
|
||||
w3_weight[i],
|
||||
bias=None,
|
||||
input_scale=w3_input_scale[i],
|
||||
weight_scale=w3_weight_scale[i],
|
||||
)
|
||||
prod = act_fn(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
prod,
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
weight_scale=w2_weight_scale[i],
|
||||
)
|
||||
|
||||
return mlp
|
||||
|
||||
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
elif style == "mlp":
|
||||
|
||||
def make_fp8_mlp(i):
|
||||
def mlp(inp):
|
||||
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
inp,
|
||||
w1_weight[i],
|
||||
bias=None,
|
||||
input_scale=w1_input_scale[i],
|
||||
weight_scale=w1_weight_scale[i],
|
||||
)
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
act_fn(up_out),
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
weight_scale=w2_weight_scale[i],
|
||||
)
|
||||
|
||||
return mlp
|
||||
|
||||
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
|
||||
|
||||
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
|
||||
return _template_moe(x, selected_experts, routing_weights, mlps)
|
||||
|
||||
|
||||
@ -290,6 +343,8 @@ def torch_quant_fp8_moe_fake(
|
||||
w1_weight_scale: List[torch.Tensor],
|
||||
w2_weight_scale: List[torch.Tensor],
|
||||
w3_weight_scale: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -311,6 +366,8 @@ def torch_quant_nvfp4_moe(
|
||||
w1_alpha: List[torch.Tensor],
|
||||
w2_alpha: List[torch.Tensor],
|
||||
w3_alpha: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
|
||||
act_fn: str = "silu", # silu or relu2
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
FP4 MoE op using quantized linear operations.
|
||||
@ -322,45 +379,101 @@ def torch_quant_nvfp4_moe(
|
||||
x: Input tensor of shape (B, H) or (B, S, H).
|
||||
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
|
||||
routing_weights: Tensor of normalized routing weights.
|
||||
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
|
||||
w1_weight:
|
||||
List of per-expert weight tensors:
|
||||
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
|
||||
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
|
||||
w2_weight:
|
||||
List of per-expert weight tensors:
|
||||
• gated_mlp: W2 with shape (H, I) — down projection.
|
||||
• mlp: W_down with shape (H, I) — down projection.
|
||||
w3_weight:
|
||||
List of per-expert weight tensors:
|
||||
• gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
|
||||
• mlp: pass an empty list []; ignored.
|
||||
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors.
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors.
|
||||
w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization.
|
||||
mlp_style:
|
||||
Selects the per-expert MLP computation:
|
||||
• "gated_mlp" (default, Mixtral/DeepSeek-style):
|
||||
y = W2( act(W1 x) * (W3 x) )
|
||||
• "mlp" (NemotronH-style 2-layer MLP):
|
||||
y = W_down( act(W_up x) )
|
||||
act_fn:
|
||||
Elementwise activation applied inside the expert MLP.
|
||||
Supported: "silu" (default), "relu2" (ReLU then square).
|
||||
"""
|
||||
|
||||
def make_fp4_mlp(i):
|
||||
def mlp(inp):
|
||||
if inp.shape[0] == 0:
|
||||
return torch.zeros_like(inp)
|
||||
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
inp,
|
||||
w1_weight[i],
|
||||
bias=None,
|
||||
input_scale=w1_input_scale[i],
|
||||
weight_scale=w1_weight_scale[i],
|
||||
alpha=w1_alpha[i],
|
||||
)
|
||||
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
inp,
|
||||
w3_weight[i],
|
||||
bias=None,
|
||||
input_scale=w3_input_scale[i],
|
||||
weight_scale=w3_weight_scale[i],
|
||||
alpha=w3_alpha[i],
|
||||
)
|
||||
prod = F.silu(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
prod,
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
weight_scale=w2_weight_scale[i],
|
||||
alpha=w2_alpha[i],
|
||||
)
|
||||
act_fn = _resolve_activation(act_fn)
|
||||
style = mlp_style.lower()
|
||||
|
||||
return mlp
|
||||
if style == "gated_mlp":
|
||||
|
||||
def make_fp4_mlp(i):
|
||||
def mlp(inp):
|
||||
if inp.shape[0] == 0:
|
||||
return torch.zeros_like(inp)
|
||||
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
inp,
|
||||
w1_weight[i],
|
||||
bias=None,
|
||||
input_scale=w1_input_scale[i],
|
||||
weight_scale=w1_weight_scale[i],
|
||||
alpha=w1_alpha[i],
|
||||
)
|
||||
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
inp,
|
||||
w3_weight[i],
|
||||
bias=None,
|
||||
input_scale=w3_input_scale[i],
|
||||
weight_scale=w3_weight_scale[i],
|
||||
alpha=w3_alpha[i],
|
||||
)
|
||||
prod = act_fn(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
prod,
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
weight_scale=w2_weight_scale[i],
|
||||
alpha=w2_alpha[i],
|
||||
)
|
||||
|
||||
return mlp
|
||||
|
||||
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
elif style == "mlp":
|
||||
|
||||
def make_fp4_mlp(i):
|
||||
def mlp(inp):
|
||||
if inp.shape[0] == 0:
|
||||
return torch.zeros_like(inp)
|
||||
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
inp,
|
||||
w1_weight[i],
|
||||
bias=None,
|
||||
input_scale=w1_input_scale[i],
|
||||
weight_scale=w1_weight_scale[i],
|
||||
alpha=w1_alpha[i],
|
||||
)
|
||||
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
act_fn(up_out),
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
weight_scale=w2_weight_scale[i],
|
||||
alpha=w2_alpha[i],
|
||||
)
|
||||
|
||||
return mlp
|
||||
|
||||
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
|
||||
|
||||
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
|
||||
return _template_moe(x, selected_experts, routing_weights, mlps)
|
||||
|
||||
|
||||
@ -381,6 +494,8 @@ def torch_quant_nvfp4_moe_fake(
|
||||
w1_alpha: List[torch.Tensor],
|
||||
w2_alpha: List[torch.Tensor],
|
||||
w3_alpha: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@ -159,6 +159,130 @@ def fused_mlp_moe_kernel(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_mlp_moe_kernel_w8a8(
|
||||
# Pointers to matrices (A in FP8, B in FP8)
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# Strides
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
# Scale pointers
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
token_id_mask = offs_token_id < EM
|
||||
offs_token = tl.load(
|
||||
sorted_token_ids_ptr + offs_token_id, mask=token_id_mask, other=num_valid_tokens
|
||||
)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
# Clamp offs_token to valid range to avoid out-of-bounds pointer arithmetic
|
||||
# Padding tokens have value >= num_valid_tokens and will be masked out
|
||||
# Clamp to last valid token instead of 0 to avoid cache/memory issues
|
||||
max_valid_token = num_valid_tokens - 1
|
||||
offs_token_clamped = tl.where(token_mask, offs_token, max_valid_token)
|
||||
|
||||
# Expert id for this block (one expert per M-tile)
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
if off_experts == -1:
|
||||
_write_zeros_to_output(
|
||||
c_ptr,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
pid_n,
|
||||
N,
|
||||
offs_token_clamped,
|
||||
token_mask,
|
||||
BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N,
|
||||
compute_type,
|
||||
)
|
||||
return
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_token_clamped[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
+ off_experts * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
|
||||
# Load tensor-wise scales before loop
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0,
|
||||
)
|
||||
# Use acc= for FP8 fast accumulation (matches vLLM)
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# Apply scales after K-loop
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token_clamped, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token_clamped[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def _default_kernel_config(M: int, E: int, N: int, K: int, top_k: int) -> dict:
|
||||
if M <= E:
|
||||
return {
|
||||
@ -245,12 +369,15 @@ def _invoke_kernel(
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor, # Changed to tensor for CUDA graph compatibility
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: dict,
|
||||
compute_type,
|
||||
a_scale: torch.Tensor | None = None,
|
||||
b_scale: torch.Tensor | None = None,
|
||||
):
|
||||
"""Unified kernel launcher for both unquantized and FP8 W8A8 MoE kernels."""
|
||||
assert B.ndim == 3 and C.ndim == 3
|
||||
EM = sorted_token_ids.numel()
|
||||
if EM == 0:
|
||||
@ -268,7 +395,7 @@ def _invoke_kernel(
|
||||
)
|
||||
|
||||
num_tokens = A.size(0) * top_k
|
||||
fused_mlp_moe_kernel[_grid](
|
||||
common_args = [
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
@ -287,112 +414,96 @@ def _invoke_kernel(
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
num_warps=config["num_warps"],
|
||||
num_stages=config["num_stages"],
|
||||
)
|
||||
]
|
||||
|
||||
common_kwargs = {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"MUL_ROUTED_WEIGHT": mul_routed_weight,
|
||||
"top_k": top_k,
|
||||
"compute_type": compute_type,
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
}
|
||||
|
||||
if a_scale is not None and b_scale is not None:
|
||||
# FP8 W8A8 path
|
||||
fused_mlp_moe_kernel_w8a8[_grid](*common_args, a_scale, b_scale, **common_kwargs)
|
||||
else:
|
||||
# Unquantized path
|
||||
fused_mlp_moe_kernel[_grid](*common_args, **common_kwargs)
|
||||
|
||||
|
||||
def fused_mlp_relu2_unquantized(
|
||||
hidden_states: torch.Tensor, # [M, H]
|
||||
w_up: torch.Tensor, # [E, I, H]
|
||||
w_down: torch.Tensor, # [E, H, I]
|
||||
topk_ids: torch.Tensor, # [M, top_k]
|
||||
topk_weights: torch.Tensor, # [M, top_k]
|
||||
*,
|
||||
def _get_compute_type(dtype: torch.dtype):
|
||||
"""Get Triton compute type from torch dtype."""
|
||||
if dtype == torch.bfloat16:
|
||||
return tl.bfloat16
|
||||
elif dtype == torch.float16:
|
||||
return tl.float16
|
||||
elif dtype == torch.float32:
|
||||
return tl.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
def _fused_moe_mlp_relu2(
|
||||
hidden_states: torch.Tensor,
|
||||
w_up: torch.Tensor,
|
||||
w_down: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fast unquantized MoE MLP with ReLU^2 activation between two per-expert GEMMs.
|
||||
|
||||
Requirements:
|
||||
- w_up: (E, I, H) with last dim contiguous
|
||||
- w_down: (E, H, I) with last dim contiguous
|
||||
- hidden_states: (M, H), topk_ids/topk_weights: (M, top_k)
|
||||
"""
|
||||
assert hidden_states.ndim == 2
|
||||
assert w_up.ndim == 3 and w_down.ndim == 3
|
||||
assert topk_ids.ndim == 2 and topk_weights.ndim == 2
|
||||
"""Fused MoE 2-layer MLP with ReLU^2 activation using Triton."""
|
||||
M, H = hidden_states.shape
|
||||
E, inter_size, H_up = w_up.shape
|
||||
E2, H_down, inter_size2 = w_down.shape
|
||||
assert E == E2 and H == H_up and H == H_down and inter_size == inter_size2
|
||||
E, inter_size, _ = w_up.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
# Ensure memory layout compatible with kernel expectations
|
||||
A = hidden_states.contiguous()
|
||||
B1 = w_up.contiguous() # (E, I, H)
|
||||
B2 = w_down.contiguous() # (E, H, I)
|
||||
|
||||
# Kernel config (use a single BLOCK_SIZE_M for both GEMMs)
|
||||
config = _default_kernel_config(M, E, inter_size, H, top_k)
|
||||
|
||||
# Token routing packing (group-by-expert, pad to BLOCK_SIZE_M)
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens(
|
||||
topk_ids,
|
||||
M,
|
||||
E,
|
||||
top_k,
|
||||
config["BLOCK_SIZE_M"],
|
||||
topk_ids, M, E, top_k, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
|
||||
# Workspaces
|
||||
cache1 = A.new_empty((M, top_k, inter_size))
|
||||
cache2 = A.new_empty((M * top_k, inter_size))
|
||||
cache3 = A.new_empty((M, top_k, H))
|
||||
cache1 = hidden_states.new_empty((M, top_k, inter_size))
|
||||
cache3 = hidden_states.new_empty((M, top_k, H))
|
||||
compute_type = _get_compute_type(hidden_states.dtype)
|
||||
|
||||
# Compute type
|
||||
if A.dtype == torch.bfloat16:
|
||||
compute_type = tl.bfloat16
|
||||
elif A.dtype == torch.float16:
|
||||
compute_type = tl.float16
|
||||
elif A.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype for hidden_states: {A.dtype}")
|
||||
|
||||
# GEMM 1: X @ W_up^T → cache1 (no routing weights here)
|
||||
# GEMM 1: hidden @ w_up^T
|
||||
_invoke_kernel(
|
||||
A,
|
||||
B1,
|
||||
hidden_states.contiguous(),
|
||||
w_up.contiguous(),
|
||||
cache1,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight=False,
|
||||
top_k=top_k,
|
||||
config=config,
|
||||
compute_type=compute_type,
|
||||
False,
|
||||
top_k,
|
||||
config,
|
||||
compute_type,
|
||||
)
|
||||
|
||||
# Activation (ReLU^2) without gating/multiplication
|
||||
cache2 = torch.square(F.relu(cache1.view(-1, inter_size)))
|
||||
# Activation: ReLU^2
|
||||
act = torch.square(F.relu(cache1.view(-1, inter_size)))
|
||||
|
||||
# GEMM 2: Act(cache1) @ W_down^T → cache3 (apply routing weights)
|
||||
# GEMM 2: act @ w_down^T
|
||||
_invoke_kernel(
|
||||
cache2,
|
||||
B2,
|
||||
act,
|
||||
w_down.contiguous(),
|
||||
cache3,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight=not apply_router_weight_on_input,
|
||||
top_k=1, # ensure offs_token maps to flattened rows (m*top_k + n)
|
||||
config=config,
|
||||
compute_type=compute_type,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type,
|
||||
)
|
||||
|
||||
# Sum across top-k per token
|
||||
out = cache3.sum(dim=1)
|
||||
return out
|
||||
return cache3.sum(dim=1)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_moe_fused", mutates_args=())
|
||||
@ -403,36 +514,13 @@ def triton_fused_moe(
|
||||
w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Triton implementation of the Fused MOE ops for Nemotron-6 models
|
||||
|
||||
Each expert has two weight matrices and squared ReLU activation between them.
|
||||
"""
|
||||
"""Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
|
||||
x_shape = x.shape
|
||||
hidden_size = x_shape[-1]
|
||||
x2d = x.view(-1, hidden_size)
|
||||
x2d = x.view(-1, x_shape[-1])
|
||||
topk_ids = selected_experts.to(torch.int32).contiguous()
|
||||
topk_weights = routing_weights.to(torch.float32).contiguous()
|
||||
|
||||
routing_weights = routing_weights.to(torch.float32)
|
||||
selected_experts = selected_experts.to(torch.int32)
|
||||
|
||||
# Expect selected_experts/routing_weights to be [M, top_k]
|
||||
topk_ids = selected_experts.contiguous()
|
||||
topk_weights = routing_weights.contiguous()
|
||||
assert topk_ids.dim() == 2 and topk_weights.dim() == 2, (
|
||||
f"Expected 2D routing tensors, got {topk_ids.shape} and {topk_weights.shape}"
|
||||
)
|
||||
assert topk_ids.shape[0] == x2d.shape[0], (
|
||||
f"Token count mismatch: tokens={x2d.shape[0]} ids={topk_ids.shape[0]}"
|
||||
)
|
||||
|
||||
out2d = fused_mlp_relu2_unquantized(
|
||||
x2d,
|
||||
w1_stacked_weight,
|
||||
w2_stacked_weight,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
out2d = _fused_moe_mlp_relu2(x2d, w1_stacked_weight, w2_stacked_weight, topk_ids, topk_weights)
|
||||
return out2d.view(x_shape)
|
||||
|
||||
|
||||
@ -445,3 +533,120 @@ def triton_fused_moe(
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear)."""
|
||||
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
|
||||
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::triton_quant_fp8_moe", mutates_args=())
|
||||
def triton_quant_fp8_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights
|
||||
w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights
|
||||
w3_weight: torch.Tensor, # unused for mlp style
|
||||
w1_input_scale: torch.Tensor, # [E] stacked input scales
|
||||
w2_input_scale: torch.Tensor, # [E] stacked input scales
|
||||
w3_input_scale: torch.Tensor, # unused
|
||||
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
|
||||
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
|
||||
w3_weight_scale: torch.Tensor, # unused
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
"""Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation."""
|
||||
if mlp_style != "mlp":
|
||||
raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only")
|
||||
|
||||
x_shape = x.shape
|
||||
x2d = x.view(-1, x_shape[-1])
|
||||
topk_ids = selected_experts.to(torch.int32).contiguous()
|
||||
topk_weights = routing_weights.to(torch.float32).contiguous()
|
||||
|
||||
# Weights are already stacked [E, ...] - just ensure contiguous and extract scales
|
||||
w1_q = w1_weight.contiguous()
|
||||
w2_q = w2_weight.contiguous()
|
||||
a1_scale = w1_input_scale[0].to(torch.float32).reshape(1).contiguous()
|
||||
a2_scale = w2_input_scale[0].to(torch.float32).reshape(1).contiguous()
|
||||
b1_scale = w1_weight_scale.to(torch.float32).contiguous()
|
||||
b2_scale = w2_weight_scale.to(torch.float32).contiguous()
|
||||
|
||||
# Setup
|
||||
M, H = x2d.shape
|
||||
E, inter_size, _ = w1_q.shape
|
||||
top_k = topk_ids.shape[1]
|
||||
config = _default_kernel_config(M, E, inter_size, H, top_k)
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens(
|
||||
topk_ids, M, E, top_k, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
compute_type = _get_compute_type(x2d.dtype)
|
||||
|
||||
# Quantize input and allocate caches
|
||||
x_a8 = _quantize_fp8(x2d, a1_scale)
|
||||
cache1 = x2d.new_empty((M, top_k, inter_size))
|
||||
cache3 = x2d.new_empty((M, top_k, H))
|
||||
|
||||
# GEMM 1: FP8 input @ FP8 w_up^T → BF16
|
||||
_invoke_kernel(
|
||||
x_a8,
|
||||
w1_q,
|
||||
cache1,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
top_k,
|
||||
config,
|
||||
compute_type,
|
||||
a_scale=a1_scale,
|
||||
b_scale=b1_scale,
|
||||
)
|
||||
|
||||
# Activation: ReLU^2, then quantize
|
||||
act = torch.square(F.relu(cache1.view(-1, inter_size)))
|
||||
act_a8 = _quantize_fp8(act, a2_scale)
|
||||
|
||||
# GEMM 2: FP8 activation @ FP8 w_down^T → BF16
|
||||
_invoke_kernel(
|
||||
act_a8,
|
||||
w2_q,
|
||||
cache3,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type,
|
||||
a_scale=a2_scale,
|
||||
b_scale=b2_scale,
|
||||
)
|
||||
|
||||
return cache3.sum(dim=1).view(x_shape)
|
||||
|
||||
|
||||
@triton_quant_fp8_moe.register_fake
|
||||
def triton_quant_fp8_moe(
|
||||
x: torch.Tensor,
|
||||
selected_experts: torch.Tensor,
|
||||
routing_weights: torch.Tensor,
|
||||
w1_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w3_weight: torch.Tensor,
|
||||
w1_input_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
w3_input_scale: torch.Tensor,
|
||||
w1_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w3_weight_scale: torch.Tensor,
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -50,7 +50,6 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
|
||||
fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0)
|
||||
new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}"
|
||||
# Triton fused MoE op supports mlp only.
|
||||
|
||||
replacement_op = torch.ops.auto_deploy.triton_moe_fused
|
||||
|
||||
else:
|
||||
@ -567,6 +566,176 @@ class MatchNVFP4MoePattern(MatchMoePattern):
|
||||
return ["input_scale", "weight_scale", "alpha"]
|
||||
|
||||
|
||||
def _stack_fp8_moe_weights(gm: GraphModule) -> int:
|
||||
"""
|
||||
Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters.
|
||||
This is fast because we directly stack the tensor values (not graph nodes).
|
||||
Similar to _insert_fused_moe_ops but for quantized MoE.
|
||||
"""
|
||||
fused_key_counter = 0
|
||||
graph = gm.graph
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_quant_fp8_moe):
|
||||
continue
|
||||
|
||||
# Extract weight and scale lists from args
|
||||
try:
|
||||
(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
w1_input_scale,
|
||||
w2_input_scale,
|
||||
w3_input_scale,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale,
|
||||
) = extract_op_args(
|
||||
node,
|
||||
"x",
|
||||
"selected_experts",
|
||||
"routing_weights",
|
||||
"w1_weight",
|
||||
"w2_weight",
|
||||
"w3_weight",
|
||||
"w1_input_scale",
|
||||
"w2_input_scale",
|
||||
"w3_input_scale",
|
||||
"w1_weight_scale",
|
||||
"w2_weight_scale",
|
||||
"w3_weight_scale",
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Helper to get parameter or buffer
|
||||
def get_param_or_buffer(target):
|
||||
"""Get parameter or buffer by target name."""
|
||||
try:
|
||||
return gm.get_parameter(target)
|
||||
except AttributeError:
|
||||
# It's a buffer, not a parameter
|
||||
parts = target.rsplit(".", 1)
|
||||
if len(parts) == 2:
|
||||
mod = gm.get_submodule(parts[0])
|
||||
return getattr(mod, parts[1])
|
||||
else:
|
||||
return getattr(gm, target)
|
||||
|
||||
# Stack the actual tensor values (fast, like in quantize_moe.py)
|
||||
w1_stacked = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0)
|
||||
w2_stacked = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0)
|
||||
w3_stacked = (
|
||||
torch.stack([gm.get_parameter(n.target) for n in w3_list], dim=0)
|
||||
if w3_list
|
||||
else torch.empty(0, device=w1_stacked.device, dtype=w1_stacked.dtype)
|
||||
)
|
||||
|
||||
# Scales are buffers, not parameters
|
||||
w1_input_scale_stacked = torch.stack(
|
||||
[get_param_or_buffer(n.target) for n in w1_input_scale], dim=0
|
||||
)
|
||||
w2_input_scale_stacked = torch.stack(
|
||||
[get_param_or_buffer(n.target) for n in w2_input_scale], dim=0
|
||||
)
|
||||
w3_input_scale_stacked = (
|
||||
torch.stack([get_param_or_buffer(n.target) for n in w3_input_scale], dim=0)
|
||||
if w3_input_scale
|
||||
else torch.empty(
|
||||
0, device=w1_input_scale_stacked.device, dtype=w1_input_scale_stacked.dtype
|
||||
)
|
||||
)
|
||||
|
||||
w1_weight_scale_stacked = torch.stack(
|
||||
[get_param_or_buffer(n.target) for n in w1_weight_scale], dim=0
|
||||
)
|
||||
w2_weight_scale_stacked = torch.stack(
|
||||
[get_param_or_buffer(n.target) for n in w2_weight_scale], dim=0
|
||||
)
|
||||
w3_weight_scale_stacked = (
|
||||
torch.stack([get_param_or_buffer(n.target) for n in w3_weight_scale], dim=0)
|
||||
if w3_weight_scale
|
||||
else torch.empty(
|
||||
0, device=w1_weight_scale_stacked.device, dtype=w1_weight_scale_stacked.dtype
|
||||
)
|
||||
)
|
||||
|
||||
# Register stacked tensors as new parameters
|
||||
new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}"
|
||||
new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}"
|
||||
new_key_w3 = f"quant_moe_w3_stacked_{fused_key_counter}"
|
||||
new_key_w1_input_scale = f"quant_moe_w1_input_scale_stacked_{fused_key_counter}"
|
||||
new_key_w2_input_scale = f"quant_moe_w2_input_scale_stacked_{fused_key_counter}"
|
||||
new_key_w3_input_scale = f"quant_moe_w3_input_scale_stacked_{fused_key_counter}"
|
||||
new_key_w1_weight_scale = f"quant_moe_w1_weight_scale_stacked_{fused_key_counter}"
|
||||
new_key_w2_weight_scale = f"quant_moe_w2_weight_scale_stacked_{fused_key_counter}"
|
||||
new_key_w3_weight_scale = f"quant_moe_w3_weight_scale_stacked_{fused_key_counter}"
|
||||
|
||||
fused_key_counter += 1
|
||||
|
||||
# Register as parameters (not buffers, to match the original per-expert params)
|
||||
gm.register_parameter(new_key_w1, torch.nn.Parameter(w1_stacked, requires_grad=False))
|
||||
gm.register_parameter(new_key_w2, torch.nn.Parameter(w2_stacked, requires_grad=False))
|
||||
gm.register_parameter(new_key_w3, torch.nn.Parameter(w3_stacked, requires_grad=False))
|
||||
gm.register_parameter(
|
||||
new_key_w1_input_scale, torch.nn.Parameter(w1_input_scale_stacked, requires_grad=False)
|
||||
)
|
||||
gm.register_parameter(
|
||||
new_key_w2_input_scale, torch.nn.Parameter(w2_input_scale_stacked, requires_grad=False)
|
||||
)
|
||||
gm.register_parameter(
|
||||
new_key_w3_input_scale, torch.nn.Parameter(w3_input_scale_stacked, requires_grad=False)
|
||||
)
|
||||
gm.register_parameter(
|
||||
new_key_w1_weight_scale,
|
||||
torch.nn.Parameter(w1_weight_scale_stacked, requires_grad=False),
|
||||
)
|
||||
gm.register_parameter(
|
||||
new_key_w2_weight_scale,
|
||||
torch.nn.Parameter(w2_weight_scale_stacked, requires_grad=False),
|
||||
)
|
||||
gm.register_parameter(
|
||||
new_key_w3_weight_scale,
|
||||
torch.nn.Parameter(w3_weight_scale_stacked, requires_grad=False),
|
||||
)
|
||||
|
||||
# Create new node with get_attr for stacked parameters
|
||||
with graph.inserting_before(node):
|
||||
new_node = graph.call_function(
|
||||
torch.ops.auto_deploy.triton_quant_fp8_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
graph.get_attr(new_key_w1),
|
||||
graph.get_attr(new_key_w2),
|
||||
graph.get_attr(new_key_w3),
|
||||
graph.get_attr(new_key_w1_input_scale),
|
||||
graph.get_attr(new_key_w2_input_scale),
|
||||
graph.get_attr(new_key_w3_input_scale),
|
||||
graph.get_attr(new_key_w1_weight_scale),
|
||||
graph.get_attr(new_key_w2_weight_scale),
|
||||
graph.get_attr(new_key_w3_weight_scale),
|
||||
),
|
||||
kwargs=node.kwargs,
|
||||
)
|
||||
|
||||
node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
# Clean up after processing all nodes
|
||||
# eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules
|
||||
# will remove the parameters/buffers that are no longer referenced
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.delete_all_unused_submodules()
|
||||
|
||||
return fused_key_counter
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_moe")
|
||||
class FuseMoe(BaseTransform):
|
||||
"""
|
||||
@ -588,3 +757,29 @@ class FuseMoe(BaseTransform):
|
||||
skipped=False, num_matches=fused_key_counter, is_clean=False, has_valid_shapes=False
|
||||
)
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_fp8_moe")
|
||||
class FuseFP8Moe(BaseTransform):
|
||||
"""
|
||||
Stack per-expert FP8 MoE weights and scales to avoid runtime stacking overhead.
|
||||
This runs after weights are loaded, similar to FuseMoe for unquantized MoE.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
with cuda_memory_tracker():
|
||||
fused_key_counter = _stack_fp8_moe_weights(gm)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=(fused_key_counter == 0),
|
||||
num_matches=fused_key_counter,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
return gm, info
|
||||
|
||||
@ -16,11 +16,6 @@ from .quantization import (
|
||||
Quantization,
|
||||
)
|
||||
|
||||
quantized_moe_op_map = {
|
||||
"FP8": torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
"NVFP4": torch.ops.auto_deploy.torch_quant_nvfp4_moe,
|
||||
}
|
||||
|
||||
|
||||
def _quantize_moe_node(
|
||||
gm: GraphModule,
|
||||
@ -92,11 +87,33 @@ def _quantize_moe_node(
|
||||
s1, s2, s3 = collect_scales(idx)
|
||||
args.extend([s1, s2, s3])
|
||||
|
||||
# Extract mlp_style and act_fn from the original node
|
||||
# These can be in args[6:] or in kwargs
|
||||
mlp_style = "gated_mlp" # default
|
||||
act_fn = "silu" # default
|
||||
|
||||
if len(node.args) > 6:
|
||||
mlp_style = node.args[6]
|
||||
elif "mlp_style" in node.kwargs:
|
||||
mlp_style = node.kwargs["mlp_style"]
|
||||
|
||||
if len(node.args) > 7:
|
||||
act_fn = node.args[7]
|
||||
elif "act_fn" in node.kwargs:
|
||||
act_fn = node.kwargs["act_fn"]
|
||||
|
||||
# Prepare kwargs for the quantized op
|
||||
kwargs = {
|
||||
"mlp_style": mlp_style,
|
||||
"act_fn": act_fn,
|
||||
}
|
||||
|
||||
# Replace the current node with the quantized version
|
||||
with gm.graph.inserting_after(node):
|
||||
new_node = gm.graph.call_function(
|
||||
quantized_op,
|
||||
args=tuple(args),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
node.replace_all_uses_with(new_node)
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
@ -140,9 +140,6 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
|
||||
def test_auto_dtype(self, enable_chunked_prefill):
|
||||
if enable_chunked_prefill:
|
||||
pytest.skip(
|
||||
"see https://github.com/NVIDIA/TensorRT-LLM/issues/8272")
|
||||
kwargs = self.get_default_kwargs(enable_chunked_prefill)
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH,
|
||||
@ -152,3 +149,49 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestNemotronMOE(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-MOE"
|
||||
MODEL_PATH = f"{llm_models_root()}/Nemotron-MOE/"
|
||||
|
||||
def get_default_kwargs(self):
|
||||
return {
|
||||
"skip_tokenizer_init": False,
|
||||
"trust_remote_code": True,
|
||||
# SSMs do not support cache reuse.
|
||||
"kv_cache_config": {
|
||||
"enable_block_reuse": False
|
||||
},
|
||||
# Keep max_batch_size as in the PyTorch test to avoid OOM
|
||||
"max_batch_size": 128,
|
||||
# Model context length is 8K
|
||||
"max_seq_len": 8192,
|
||||
# Set explicitly to match default build_config behavior
|
||||
"max_num_tokens": 8192,
|
||||
"skip_loading_weights": False,
|
||||
"compile_backend": "torch-cudagraph",
|
||||
"free_mem_ratio": 0.7,
|
||||
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
|
||||
}
|
||||
|
||||
def get_default_sampling_params(self):
|
||||
eos_id = -1
|
||||
beam_width = 1
|
||||
return SamplingParams(end_id=eos_id,
|
||||
pad_id=eos_id,
|
||||
n=beam_width,
|
||||
use_beam_search=beam_width > 1)
|
||||
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
def test_auto_dtype(self):
|
||||
pytest.skip("Nemotron-MOE is not in CI yet")
|
||||
kwargs = self.get_default_kwargs()
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
with AutoDeployLLM(model=self.MODEL_PATH,
|
||||
tokenizer=self.MODEL_PATH,
|
||||
**kwargs) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=sampling_params)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
from utils.util import skip_pre_hopper
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.load_moe_align import moe_align_block_size
|
||||
@ -215,3 +216,147 @@ def test_moe_align_kernel_groups_tokens_by_expert_and_block_padding():
|
||||
|
||||
ref_counts_all = torch.bincount(ref_sorted_used.cpu().to(torch.int64), minlength=T + 1)
|
||||
assert torch.all(ref_counts_all == counts_all)
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe():
|
||||
"""Test triton_quant_fp8_moe against torch_quant_fp8_moe reference."""
|
||||
torch.manual_seed(0)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for triton_quant_fp8_moe test")
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
M = 32 # tokens
|
||||
HIDDEN_SIZE = 16 # Must be multiple of 16 for FP8 linear
|
||||
INTERMEDIATE_SIZE = 32 # Must be multiple of 16 for FP8 linear
|
||||
E = 4 # experts
|
||||
top_k = 2
|
||||
|
||||
# Use small normalized values to avoid FP8 range issues
|
||||
x = torch.randn(M, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
||||
|
||||
# Create BF16 weights for each expert (normalized to small values)
|
||||
w_up_list = [
|
||||
torch.randn(INTERMEDIATE_SIZE, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
||||
for _ in range(E)
|
||||
]
|
||||
w_down_list = [
|
||||
torch.randn(HIDDEN_SIZE, INTERMEDIATE_SIZE, device=device, dtype=dtype) * 0.1
|
||||
for _ in range(E)
|
||||
]
|
||||
|
||||
# Stack weights [E, ...]
|
||||
w_up_stacked = torch.stack(w_up_list, dim=0).contiguous() # [E, I, H]
|
||||
w_down_stacked = torch.stack(w_down_list, dim=0).contiguous() # [E, H, I]
|
||||
|
||||
# Quantize weights to FP8 with per-expert scales
|
||||
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
|
||||
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
# Per-expert weight scales (use max absolute value per expert)
|
||||
w1_weight_scale = torch.tensor(
|
||||
[w_up_stacked[e].abs().max().item() / FP8_MAX for e in range(E)],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
w2_weight_scale = torch.tensor(
|
||||
[w_down_stacked[e].abs().max().item() / FP8_MAX for e in range(E)],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# Quantize weights and stack
|
||||
w1_fp8_list = [
|
||||
(w_up_stacked[e] / w1_weight_scale[e]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
|
||||
for e in range(E)
|
||||
]
|
||||
w2_fp8_list = [
|
||||
(w_down_stacked[e] / w2_weight_scale[e]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
|
||||
for e in range(E)
|
||||
]
|
||||
w1_fp8_stacked = torch.stack(w1_fp8_list).contiguous()
|
||||
w2_fp8_stacked = torch.stack(w2_fp8_list).contiguous()
|
||||
|
||||
# Input scales (tensor-wise, replicated per expert for interface compatibility)
|
||||
x_scale = x.abs().max().item() / FP8_MAX
|
||||
w1_input_scale_tensor = torch.full((E,), x_scale, device=device, dtype=torch.float32)
|
||||
|
||||
# Compute intermediate activation scale by simulating first GEMM + ReLU^2
|
||||
# This ensures w2_input_scale matches the actual activation magnitude
|
||||
with torch.no_grad():
|
||||
# Simulate the first GEMM: quantize input, do FP8 matmul, apply ReLU^2
|
||||
x_q = (x / w1_input_scale_tensor[0]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
|
||||
# Dequantize and compute output for a sample
|
||||
x_dq = x_q[:8].to(torch.float32) * w1_input_scale_tensor[0].item()
|
||||
w1_dq = w1_fp8_stacked[0].to(torch.float32) * w1_weight_scale[0].item()
|
||||
sample_out = torch.nn.functional.linear(x_dq.to(dtype), w1_dq.to(dtype))
|
||||
sample_act = torch.square(torch.nn.functional.relu(sample_out))
|
||||
intermediate_scale = sample_act.abs().max().item() / FP8_MAX
|
||||
# Ensure scale is not too small
|
||||
intermediate_scale = max(intermediate_scale, 1e-6)
|
||||
|
||||
w2_input_scale_tensor = torch.full((E,), intermediate_scale, device=device, dtype=torch.float32)
|
||||
|
||||
# Convert scales to lists for torch_quant_fp8_moe reference
|
||||
w1_input_scale_list = [w1_input_scale_tensor[0].clone() for _ in range(E)]
|
||||
w2_input_scale_list = [w2_input_scale_tensor[0].clone() for _ in range(E)]
|
||||
w1_weight_scale_list = [w1_weight_scale[e].clone() for e in range(E)]
|
||||
w2_weight_scale_list = [w2_weight_scale[e].clone() for e in range(E)]
|
||||
|
||||
# Dummy w3 tensors (unused for mlp style)
|
||||
w3_fp8_list = [torch.empty((1, 1), device=device, dtype=torch.float8_e4m3fn) for _ in range(E)]
|
||||
w3_fp8_stacked = torch.stack(w3_fp8_list).contiguous()
|
||||
w3_input_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
|
||||
w3_input_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)
|
||||
w3_weight_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)]
|
||||
w3_weight_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32)
|
||||
|
||||
# Create controlled routing to ensure even token distribution across experts
|
||||
selected_experts = torch.zeros((M, top_k), dtype=torch.int64, device=device)
|
||||
for i in range(M):
|
||||
# Distribute tokens evenly: token i goes to experts (i % E) and ((i+1) % E)
|
||||
selected_experts[i, 0] = i % E
|
||||
selected_experts[i, 1] = (i + 1) % E
|
||||
|
||||
# Create equal routing weights
|
||||
routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k
|
||||
|
||||
# Triton FP8 quantized MoE (uses stacked tensors)
|
||||
out_triton = torch.ops.auto_deploy.triton_quant_fp8_moe(
|
||||
x,
|
||||
selected_experts.to(torch.int32),
|
||||
routing_weights,
|
||||
w1_fp8_stacked,
|
||||
w2_fp8_stacked,
|
||||
w3_fp8_stacked,
|
||||
w1_input_scale_tensor,
|
||||
w2_input_scale_tensor,
|
||||
w3_input_scale_tensor,
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale_tensor,
|
||||
mlp_style="mlp",
|
||||
act_fn="relu2",
|
||||
)
|
||||
|
||||
# Reference: Torch quantized FP8 MoE (uses lists of tensors and scales)
|
||||
out_torch = torch.ops.auto_deploy.torch_quant_fp8_moe(
|
||||
x,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w1_weight=w1_fp8_list,
|
||||
w2_weight=w2_fp8_list,
|
||||
w3_weight=w3_fp8_list,
|
||||
w1_input_scale=w1_input_scale_list,
|
||||
w2_input_scale=w2_input_scale_list,
|
||||
w3_input_scale=w3_input_scale_list,
|
||||
w1_weight_scale=w1_weight_scale_list,
|
||||
w2_weight_scale=w2_weight_scale_list,
|
||||
w3_weight_scale=w3_weight_scale_list,
|
||||
mlp_style="mlp",
|
||||
act_fn="relu2",
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user