diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 40c3d85d80..0cb4f8889e 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -116,6 +116,9 @@ transforms: fuse_moe: stage: post_load_fusion enabled: true + fuse_fp8_moe: + stage: post_load_fusion + enabled: true fuse_allreduce_residual_rmsnorm: stage: post_load_fusion # TODO (lucaslie): add backend selection as part of configurable inference optimizers diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index 006db22103..63519cf5a5 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -228,6 +228,8 @@ def torch_quant_fp8_moe( w1_weight_scale: List[torch.Tensor], w2_weight_scale: List[torch.Tensor], w3_weight_scale: List[torch.Tensor], + mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp" + act_fn: str = "silu", # silu or relu2 ) -> torch.Tensor: """ FP8 MoE op using quantized linear operations. @@ -239,40 +241,91 @@ def torch_quant_fp8_moe( x: Input tensor of shape (B, H) or (B, S, H). selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices. routing_weights: Tensor of normalized routing weights. - w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops. + w1_weight: + List of per-expert weight tensors: + • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection. + • mlp_style=="mlp": W_up with shape (I, H) — up projection. + w2_weight: + List of per-expert weight tensors: + • gated_mlp: W2 with shape (H, I) — down projection. + • mlp: W_down with shape (H, I) — down projection. + w3_weight: + List of per-expert weight tensors: + • gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP. + • mlp: pass an empty list []; ignored. w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops. w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops. - + mlp_style: + Selects the per-expert MLP computation: + • "gated_mlp" (default, Mixtral/DeepSeek-style): + y = W2( act(W1 x) * (W3 x) ) + • "mlp" (NemotronH-style 2-layer MLP): + y = W_down( act(W_up x) ) + act_fn: + Elementwise activation applied inside the expert MLP. + Supported: "silu" (default), "relu2" (ReLU then square). """ - def make_fp8_mlp(i): - def mlp(inp): - gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear( - inp, - w1_weight[i], - bias=None, - input_scale=w1_input_scale[i], - weight_scale=w1_weight_scale[i], - ) - up_out = torch.ops.auto_deploy.torch_quant_fp8_linear( - inp, - w3_weight[i], - bias=None, - input_scale=w3_input_scale[i], - weight_scale=w3_weight_scale[i], - ) - prod = F.silu(gate_out) * up_out - return torch.ops.auto_deploy.torch_quant_fp8_linear( - prod, - w2_weight[i], - bias=None, - input_scale=w2_input_scale[i], - weight_scale=w2_weight_scale[i], - ) + act_fn = _resolve_activation(act_fn) + style = mlp_style.lower() - return mlp + if style == "gated_mlp": + + def make_fp8_mlp(i): + def mlp(inp): + gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear( + inp, + w1_weight[i], + bias=None, + input_scale=w1_input_scale[i], + weight_scale=w1_weight_scale[i], + ) + up_out = torch.ops.auto_deploy.torch_quant_fp8_linear( + inp, + w3_weight[i], + bias=None, + input_scale=w3_input_scale[i], + weight_scale=w3_weight_scale[i], + ) + prod = act_fn(gate_out) * up_out + return torch.ops.auto_deploy.torch_quant_fp8_linear( + prod, + w2_weight[i], + bias=None, + input_scale=w2_input_scale[i], + weight_scale=w2_weight_scale[i], + ) + + return mlp + + mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))] + + elif style == "mlp": + + def make_fp8_mlp(i): + def mlp(inp): + up_out = torch.ops.auto_deploy.torch_quant_fp8_linear( + inp, + w1_weight[i], + bias=None, + input_scale=w1_input_scale[i], + weight_scale=w1_weight_scale[i], + ) + return torch.ops.auto_deploy.torch_quant_fp8_linear( + act_fn(up_out), + w2_weight[i], + bias=None, + input_scale=w2_input_scale[i], + weight_scale=w2_weight_scale[i], + ) + + return mlp + + mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))] + + else: + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))] return _template_moe(x, selected_experts, routing_weights, mlps) @@ -290,6 +343,8 @@ def torch_quant_fp8_moe_fake( w1_weight_scale: List[torch.Tensor], w2_weight_scale: List[torch.Tensor], w3_weight_scale: List[torch.Tensor], + mlp_style: str = "gated_mlp", + act_fn: str = "silu", ) -> torch.Tensor: return torch.empty_like(x) @@ -311,6 +366,8 @@ def torch_quant_nvfp4_moe( w1_alpha: List[torch.Tensor], w2_alpha: List[torch.Tensor], w3_alpha: List[torch.Tensor], + mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp" + act_fn: str = "silu", # silu or relu2 ) -> torch.Tensor: """ FP4 MoE op using quantized linear operations. @@ -322,45 +379,101 @@ def torch_quant_nvfp4_moe( x: Input tensor of shape (B, H) or (B, S, H). selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices. routing_weights: Tensor of normalized routing weights. - w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops. + w1_weight: + List of per-expert weight tensors: + • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection. + • mlp_style=="mlp": W_up with shape (I, H) — up projection. + w2_weight: + List of per-expert weight tensors: + • gated_mlp: W2 with shape (H, I) — down projection. + • mlp: W_down with shape (H, I) — down projection. + w3_weight: + List of per-expert weight tensors: + • gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP. + • mlp: pass an empty list []; ignored. w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors. w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors. w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization. + mlp_style: + Selects the per-expert MLP computation: + • "gated_mlp" (default, Mixtral/DeepSeek-style): + y = W2( act(W1 x) * (W3 x) ) + • "mlp" (NemotronH-style 2-layer MLP): + y = W_down( act(W_up x) ) + act_fn: + Elementwise activation applied inside the expert MLP. + Supported: "silu" (default), "relu2" (ReLU then square). """ - def make_fp4_mlp(i): - def mlp(inp): - if inp.shape[0] == 0: - return torch.zeros_like(inp) - gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( - inp, - w1_weight[i], - bias=None, - input_scale=w1_input_scale[i], - weight_scale=w1_weight_scale[i], - alpha=w1_alpha[i], - ) - up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( - inp, - w3_weight[i], - bias=None, - input_scale=w3_input_scale[i], - weight_scale=w3_weight_scale[i], - alpha=w3_alpha[i], - ) - prod = F.silu(gate_out) * up_out - return torch.ops.auto_deploy.torch_quant_nvfp4_linear( - prod, - w2_weight[i], - bias=None, - input_scale=w2_input_scale[i], - weight_scale=w2_weight_scale[i], - alpha=w2_alpha[i], - ) + act_fn = _resolve_activation(act_fn) + style = mlp_style.lower() - return mlp + if style == "gated_mlp": + + def make_fp4_mlp(i): + def mlp(inp): + if inp.shape[0] == 0: + return torch.zeros_like(inp) + gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( + inp, + w1_weight[i], + bias=None, + input_scale=w1_input_scale[i], + weight_scale=w1_weight_scale[i], + alpha=w1_alpha[i], + ) + up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( + inp, + w3_weight[i], + bias=None, + input_scale=w3_input_scale[i], + weight_scale=w3_weight_scale[i], + alpha=w3_alpha[i], + ) + prod = act_fn(gate_out) * up_out + return torch.ops.auto_deploy.torch_quant_nvfp4_linear( + prod, + w2_weight[i], + bias=None, + input_scale=w2_input_scale[i], + weight_scale=w2_weight_scale[i], + alpha=w2_alpha[i], + ) + + return mlp + + mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))] + + elif style == "mlp": + + def make_fp4_mlp(i): + def mlp(inp): + if inp.shape[0] == 0: + return torch.zeros_like(inp) + up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( + inp, + w1_weight[i], + bias=None, + input_scale=w1_input_scale[i], + weight_scale=w1_weight_scale[i], + alpha=w1_alpha[i], + ) + return torch.ops.auto_deploy.torch_quant_nvfp4_linear( + act_fn(up_out), + w2_weight[i], + bias=None, + input_scale=w2_input_scale[i], + weight_scale=w2_weight_scale[i], + alpha=w2_alpha[i], + ) + + return mlp + + mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))] + + else: + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))] return _template_moe(x, selected_experts, routing_weights, mlps) @@ -381,6 +494,8 @@ def torch_quant_nvfp4_moe_fake( w1_alpha: List[torch.Tensor], w2_alpha: List[torch.Tensor], w3_alpha: List[torch.Tensor], + mlp_style: str = "gated_mlp", + act_fn: str = "silu", ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py index 625e588a13..ff18cc5736 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py @@ -159,6 +159,130 @@ def fused_mlp_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +@triton.jit +def fused_mlp_moe_kernel_w8a8( + # Pointers to matrices (A in FP8, B in FP8) + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # Strides + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Scale pointers + a_scale_ptr, + b_scale_ptr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + token_id_mask = offs_token_id < EM + offs_token = tl.load( + sorted_token_ids_ptr + offs_token_id, mask=token_id_mask, other=num_valid_tokens + ) + token_mask = offs_token < num_valid_tokens + + # Clamp offs_token to valid range to avoid out-of-bounds pointer arithmetic + # Padding tokens have value >= num_valid_tokens and will be masked out + # Clamp to last valid token instead of 0 to avoid cache/memory issues + max_valid_token = num_valid_tokens - 1 + offs_token_clamped = tl.where(token_mask, offs_token, max_valid_token) + + # Expert id for this block (one expert per M-tile) + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + _write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token_clamped, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token_clamped[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) + + # Load tensor-wise scales before loop + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + ) + # Use acc= for FP8 fast accumulation (matches vLLM) + accumulator = tl.dot(a, b, acc=accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Apply scales after K-loop + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token_clamped, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token_clamped[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + def _default_kernel_config(M: int, E: int, N: int, K: int, top_k: int) -> dict: if M <= E: return { @@ -245,12 +369,15 @@ def _invoke_kernel( topk_weights: torch.Tensor | None, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, # Changed to tensor for CUDA graph compatibility + num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, config: dict, compute_type, + a_scale: torch.Tensor | None = None, + b_scale: torch.Tensor | None = None, ): + """Unified kernel launcher for both unquantized and FP8 W8A8 MoE kernels.""" assert B.ndim == 3 and C.ndim == 3 EM = sorted_token_ids.numel() if EM == 0: @@ -268,7 +395,7 @@ def _invoke_kernel( ) num_tokens = A.size(0) * top_k - fused_mlp_moe_kernel[_grid]( + common_args = [ A, B, C, @@ -287,112 +414,96 @@ def _invoke_kernel( B.stride(1), C.stride(1), C.stride(2), - BLOCK_SIZE_M=config["BLOCK_SIZE_M"], - BLOCK_SIZE_N=config["BLOCK_SIZE_N"], - BLOCK_SIZE_K=config["BLOCK_SIZE_K"], - GROUP_SIZE_M=config["GROUP_SIZE_M"], - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - num_warps=config["num_warps"], - num_stages=config["num_stages"], - ) + ] + + common_kwargs = { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "MUL_ROUTED_WEIGHT": mul_routed_weight, + "top_k": top_k, + "compute_type": compute_type, + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + } + + if a_scale is not None and b_scale is not None: + # FP8 W8A8 path + fused_mlp_moe_kernel_w8a8[_grid](*common_args, a_scale, b_scale, **common_kwargs) + else: + # Unquantized path + fused_mlp_moe_kernel[_grid](*common_args, **common_kwargs) -def fused_mlp_relu2_unquantized( - hidden_states: torch.Tensor, # [M, H] - w_up: torch.Tensor, # [E, I, H] - w_down: torch.Tensor, # [E, H, I] - topk_ids: torch.Tensor, # [M, top_k] - topk_weights: torch.Tensor, # [M, top_k] - *, +def _get_compute_type(dtype: torch.dtype): + """Get Triton compute type from torch dtype.""" + if dtype == torch.bfloat16: + return tl.bfloat16 + elif dtype == torch.float16: + return tl.float16 + elif dtype == torch.float32: + return tl.float32 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def _fused_moe_mlp_relu2( + hidden_states: torch.Tensor, + w_up: torch.Tensor, + w_down: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: - """ - Fast unquantized MoE MLP with ReLU^2 activation between two per-expert GEMMs. - - Requirements: - - w_up: (E, I, H) with last dim contiguous - - w_down: (E, H, I) with last dim contiguous - - hidden_states: (M, H), topk_ids/topk_weights: (M, top_k) - """ - assert hidden_states.ndim == 2 - assert w_up.ndim == 3 and w_down.ndim == 3 - assert topk_ids.ndim == 2 and topk_weights.ndim == 2 + """Fused MoE 2-layer MLP with ReLU^2 activation using Triton.""" M, H = hidden_states.shape - E, inter_size, H_up = w_up.shape - E2, H_down, inter_size2 = w_down.shape - assert E == E2 and H == H_up and H == H_down and inter_size == inter_size2 + E, inter_size, _ = w_up.shape top_k = topk_ids.shape[1] - # Ensure memory layout compatible with kernel expectations - A = hidden_states.contiguous() - B1 = w_up.contiguous() # (E, I, H) - B2 = w_down.contiguous() # (E, H, I) - - # Kernel config (use a single BLOCK_SIZE_M for both GEMMs) config = _default_kernel_config(M, E, inter_size, H, top_k) - - # Token routing packing (group-by-expert, pad to BLOCK_SIZE_M) sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens( - topk_ids, - M, - E, - top_k, - config["BLOCK_SIZE_M"], + topk_ids, M, E, top_k, config["BLOCK_SIZE_M"] ) - # Workspaces - cache1 = A.new_empty((M, top_k, inter_size)) - cache2 = A.new_empty((M * top_k, inter_size)) - cache3 = A.new_empty((M, top_k, H)) + cache1 = hidden_states.new_empty((M, top_k, inter_size)) + cache3 = hidden_states.new_empty((M, top_k, H)) + compute_type = _get_compute_type(hidden_states.dtype) - # Compute type - if A.dtype == torch.bfloat16: - compute_type = tl.bfloat16 - elif A.dtype == torch.float16: - compute_type = tl.float16 - elif A.dtype == torch.float32: - compute_type = tl.float32 - else: - raise ValueError(f"Unsupported dtype for hidden_states: {A.dtype}") - - # GEMM 1: X @ W_up^T → cache1 (no routing weights here) + # GEMM 1: hidden @ w_up^T _invoke_kernel( - A, - B1, + hidden_states.contiguous(), + w_up.contiguous(), cache1, None, sorted_token_ids, expert_ids, num_tokens_post_padded, - mul_routed_weight=False, - top_k=top_k, - config=config, - compute_type=compute_type, + False, + top_k, + config, + compute_type, ) - # Activation (ReLU^2) without gating/multiplication - cache2 = torch.square(F.relu(cache1.view(-1, inter_size))) + # Activation: ReLU^2 + act = torch.square(F.relu(cache1.view(-1, inter_size))) - # GEMM 2: Act(cache1) @ W_down^T → cache3 (apply routing weights) + # GEMM 2: act @ w_down^T _invoke_kernel( - cache2, - B2, + act, + w_down.contiguous(), cache3, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, - mul_routed_weight=not apply_router_weight_on_input, - top_k=1, # ensure offs_token maps to flattened rows (m*top_k + n) - config=config, - compute_type=compute_type, + not apply_router_weight_on_input, + 1, + config, + compute_type, ) - # Sum across top-k per token - out = cache3.sum(dim=1) - return out + return cache3.sum(dim=1) @torch.library.custom_op("auto_deploy::triton_moe_fused", mutates_args=()) @@ -403,36 +514,13 @@ def triton_fused_moe( w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, ) -> torch.Tensor: - """ - Triton implementation of the Fused MOE ops for Nemotron-6 models - - Each expert has two weight matrices and squared ReLU activation between them. - """ + """Triton unquantized MoE with 2-layer MLP and ReLU^2 activation.""" x_shape = x.shape - hidden_size = x_shape[-1] - x2d = x.view(-1, hidden_size) + x2d = x.view(-1, x_shape[-1]) + topk_ids = selected_experts.to(torch.int32).contiguous() + topk_weights = routing_weights.to(torch.float32).contiguous() - routing_weights = routing_weights.to(torch.float32) - selected_experts = selected_experts.to(torch.int32) - - # Expect selected_experts/routing_weights to be [M, top_k] - topk_ids = selected_experts.contiguous() - topk_weights = routing_weights.contiguous() - assert topk_ids.dim() == 2 and topk_weights.dim() == 2, ( - f"Expected 2D routing tensors, got {topk_ids.shape} and {topk_weights.shape}" - ) - assert topk_ids.shape[0] == x2d.shape[0], ( - f"Token count mismatch: tokens={x2d.shape[0]} ids={topk_ids.shape[0]}" - ) - - out2d = fused_mlp_relu2_unquantized( - x2d, - w1_stacked_weight, - w2_stacked_weight, - topk_ids, - topk_weights, - apply_router_weight_on_input=False, - ) + out2d = _fused_moe_mlp_relu2(x2d, w1_stacked_weight, w2_stacked_weight, topk_ids, topk_weights) return out2d.view(x_shape) @@ -445,3 +533,120 @@ def triton_fused_moe( w2_stacked_weight: torch.Tensor, ) -> torch.Tensor: return torch.empty_like(x) + + +def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear).""" + FP8_MIN = torch.finfo(torch.float8_e4m3fn).min + FP8_MAX = torch.finfo(torch.float8_e4m3fn).max + return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) + + +@torch.library.custom_op("auto_deploy::triton_quant_fp8_moe", mutates_args=()) +def triton_quant_fp8_moe( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights + w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights + w3_weight: torch.Tensor, # unused for mlp style + w1_input_scale: torch.Tensor, # [E] stacked input scales + w2_input_scale: torch.Tensor, # [E] stacked input scales + w3_input_scale: torch.Tensor, # unused + w1_weight_scale: torch.Tensor, # [E] stacked weight scales + w2_weight_scale: torch.Tensor, # [E] stacked weight scales + w3_weight_scale: torch.Tensor, # unused + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + """Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation.""" + if mlp_style != "mlp": + raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only") + + x_shape = x.shape + x2d = x.view(-1, x_shape[-1]) + topk_ids = selected_experts.to(torch.int32).contiguous() + topk_weights = routing_weights.to(torch.float32).contiguous() + + # Weights are already stacked [E, ...] - just ensure contiguous and extract scales + w1_q = w1_weight.contiguous() + w2_q = w2_weight.contiguous() + a1_scale = w1_input_scale[0].to(torch.float32).reshape(1).contiguous() + a2_scale = w2_input_scale[0].to(torch.float32).reshape(1).contiguous() + b1_scale = w1_weight_scale.to(torch.float32).contiguous() + b2_scale = w2_weight_scale.to(torch.float32).contiguous() + + # Setup + M, H = x2d.shape + E, inter_size, _ = w1_q.shape + top_k = topk_ids.shape[1] + config = _default_kernel_config(M, E, inter_size, H, top_k) + sorted_token_ids, expert_ids, num_tokens_post_padded = _pack_routed_tokens( + topk_ids, M, E, top_k, config["BLOCK_SIZE_M"] + ) + compute_type = _get_compute_type(x2d.dtype) + + # Quantize input and allocate caches + x_a8 = _quantize_fp8(x2d, a1_scale) + cache1 = x2d.new_empty((M, top_k, inter_size)) + cache3 = x2d.new_empty((M, top_k, H)) + + # GEMM 1: FP8 input @ FP8 w_up^T → BF16 + _invoke_kernel( + x_a8, + w1_q, + cache1, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + top_k, + config, + compute_type, + a_scale=a1_scale, + b_scale=b1_scale, + ) + + # Activation: ReLU^2, then quantize + act = torch.square(F.relu(cache1.view(-1, inter_size))) + act_a8 = _quantize_fp8(act, a2_scale) + + # GEMM 2: FP8 activation @ FP8 w_down^T → BF16 + _invoke_kernel( + act_a8, + w2_q, + cache3, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type, + a_scale=a2_scale, + b_scale=b2_scale, + ) + + return cache3.sum(dim=1).view(x_shape) + + +@triton_quant_fp8_moe.register_fake +def triton_quant_fp8_moe( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, + w2_weight: torch.Tensor, + w3_weight: torch.Tensor, + w1_input_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + w3_input_scale: torch.Tensor, + w1_weight_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 97fbc57108..a972d4d77b 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -50,7 +50,6 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int: fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}" # Triton fused MoE op supports mlp only. - replacement_op = torch.ops.auto_deploy.triton_moe_fused else: @@ -567,6 +566,176 @@ class MatchNVFP4MoePattern(MatchMoePattern): return ["input_scale", "weight_scale", "alpha"] +def _stack_fp8_moe_weights(gm: GraphModule) -> int: + """ + Stack per-expert FP8 weights and scales by materializing stacked tensors as parameters. + This is fast because we directly stack the tensor values (not graph nodes). + Similar to _insert_fused_moe_ops but for quantized MoE. + """ + fused_key_counter = 0 + graph = gm.graph + + for node in graph.nodes: + if not is_op(node, torch.ops.auto_deploy.torch_quant_fp8_moe): + continue + + # Extract weight and scale lists from args + try: + ( + hidden_states, + selected_experts, + routing_weights, + w1_list, + w2_list, + w3_list, + w1_input_scale, + w2_input_scale, + w3_input_scale, + w1_weight_scale, + w2_weight_scale, + w3_weight_scale, + ) = extract_op_args( + node, + "x", + "selected_experts", + "routing_weights", + "w1_weight", + "w2_weight", + "w3_weight", + "w1_input_scale", + "w2_input_scale", + "w3_input_scale", + "w1_weight_scale", + "w2_weight_scale", + "w3_weight_scale", + ) + except Exception: + continue + + # Helper to get parameter or buffer + def get_param_or_buffer(target): + """Get parameter or buffer by target name.""" + try: + return gm.get_parameter(target) + except AttributeError: + # It's a buffer, not a parameter + parts = target.rsplit(".", 1) + if len(parts) == 2: + mod = gm.get_submodule(parts[0]) + return getattr(mod, parts[1]) + else: + return getattr(gm, target) + + # Stack the actual tensor values (fast, like in quantize_moe.py) + w1_stacked = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) + w2_stacked = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) + w3_stacked = ( + torch.stack([gm.get_parameter(n.target) for n in w3_list], dim=0) + if w3_list + else torch.empty(0, device=w1_stacked.device, dtype=w1_stacked.dtype) + ) + + # Scales are buffers, not parameters + w1_input_scale_stacked = torch.stack( + [get_param_or_buffer(n.target) for n in w1_input_scale], dim=0 + ) + w2_input_scale_stacked = torch.stack( + [get_param_or_buffer(n.target) for n in w2_input_scale], dim=0 + ) + w3_input_scale_stacked = ( + torch.stack([get_param_or_buffer(n.target) for n in w3_input_scale], dim=0) + if w3_input_scale + else torch.empty( + 0, device=w1_input_scale_stacked.device, dtype=w1_input_scale_stacked.dtype + ) + ) + + w1_weight_scale_stacked = torch.stack( + [get_param_or_buffer(n.target) for n in w1_weight_scale], dim=0 + ) + w2_weight_scale_stacked = torch.stack( + [get_param_or_buffer(n.target) for n in w2_weight_scale], dim=0 + ) + w3_weight_scale_stacked = ( + torch.stack([get_param_or_buffer(n.target) for n in w3_weight_scale], dim=0) + if w3_weight_scale + else torch.empty( + 0, device=w1_weight_scale_stacked.device, dtype=w1_weight_scale_stacked.dtype + ) + ) + + # Register stacked tensors as new parameters + new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}" + new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}" + new_key_w3 = f"quant_moe_w3_stacked_{fused_key_counter}" + new_key_w1_input_scale = f"quant_moe_w1_input_scale_stacked_{fused_key_counter}" + new_key_w2_input_scale = f"quant_moe_w2_input_scale_stacked_{fused_key_counter}" + new_key_w3_input_scale = f"quant_moe_w3_input_scale_stacked_{fused_key_counter}" + new_key_w1_weight_scale = f"quant_moe_w1_weight_scale_stacked_{fused_key_counter}" + new_key_w2_weight_scale = f"quant_moe_w2_weight_scale_stacked_{fused_key_counter}" + new_key_w3_weight_scale = f"quant_moe_w3_weight_scale_stacked_{fused_key_counter}" + + fused_key_counter += 1 + + # Register as parameters (not buffers, to match the original per-expert params) + gm.register_parameter(new_key_w1, torch.nn.Parameter(w1_stacked, requires_grad=False)) + gm.register_parameter(new_key_w2, torch.nn.Parameter(w2_stacked, requires_grad=False)) + gm.register_parameter(new_key_w3, torch.nn.Parameter(w3_stacked, requires_grad=False)) + gm.register_parameter( + new_key_w1_input_scale, torch.nn.Parameter(w1_input_scale_stacked, requires_grad=False) + ) + gm.register_parameter( + new_key_w2_input_scale, torch.nn.Parameter(w2_input_scale_stacked, requires_grad=False) + ) + gm.register_parameter( + new_key_w3_input_scale, torch.nn.Parameter(w3_input_scale_stacked, requires_grad=False) + ) + gm.register_parameter( + new_key_w1_weight_scale, + torch.nn.Parameter(w1_weight_scale_stacked, requires_grad=False), + ) + gm.register_parameter( + new_key_w2_weight_scale, + torch.nn.Parameter(w2_weight_scale_stacked, requires_grad=False), + ) + gm.register_parameter( + new_key_w3_weight_scale, + torch.nn.Parameter(w3_weight_scale_stacked, requires_grad=False), + ) + + # Create new node with get_attr for stacked parameters + with graph.inserting_before(node): + new_node = graph.call_function( + torch.ops.auto_deploy.triton_quant_fp8_moe, + args=( + hidden_states, + selected_experts, + routing_weights, + graph.get_attr(new_key_w1), + graph.get_attr(new_key_w2), + graph.get_attr(new_key_w3), + graph.get_attr(new_key_w1_input_scale), + graph.get_attr(new_key_w2_input_scale), + graph.get_attr(new_key_w3_input_scale), + graph.get_attr(new_key_w1_weight_scale), + graph.get_attr(new_key_w2_weight_scale), + graph.get_attr(new_key_w3_weight_scale), + ), + kwargs=node.kwargs, + ) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + # Clean up after processing all nodes + # eliminate_dead_code will remove unused get_attr nodes, then delete_all_unused_submodules + # will remove the parameters/buffers that are no longer referenced + gm.graph.eliminate_dead_code() + gm.delete_all_unused_submodules() + + return fused_key_counter + + @TransformRegistry.register("fuse_moe") class FuseMoe(BaseTransform): """ @@ -588,3 +757,29 @@ class FuseMoe(BaseTransform): skipped=False, num_matches=fused_key_counter, is_clean=False, has_valid_shapes=False ) return gm, info + + +@TransformRegistry.register("fuse_fp8_moe") +class FuseFP8Moe(BaseTransform): + """ + Stack per-expert FP8 MoE weights and scales to avoid runtime stacking overhead. + This runs after weights are loaded, similar to FuseMoe for unquantized MoE. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + with cuda_memory_tracker(): + fused_key_counter = _stack_fp8_moe_weights(gm) + + info = TransformInfo( + skipped=(fused_key_counter == 0), + num_matches=fused_key_counter, + is_clean=False, + has_valid_shapes=False, + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py index d25ad7c270..a881c72fd7 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py @@ -16,11 +16,6 @@ from .quantization import ( Quantization, ) -quantized_moe_op_map = { - "FP8": torch.ops.auto_deploy.torch_quant_fp8_moe, - "NVFP4": torch.ops.auto_deploy.torch_quant_nvfp4_moe, -} - def _quantize_moe_node( gm: GraphModule, @@ -92,11 +87,33 @@ def _quantize_moe_node( s1, s2, s3 = collect_scales(idx) args.extend([s1, s2, s3]) + # Extract mlp_style and act_fn from the original node + # These can be in args[6:] or in kwargs + mlp_style = "gated_mlp" # default + act_fn = "silu" # default + + if len(node.args) > 6: + mlp_style = node.args[6] + elif "mlp_style" in node.kwargs: + mlp_style = node.kwargs["mlp_style"] + + if len(node.args) > 7: + act_fn = node.args[7] + elif "act_fn" in node.kwargs: + act_fn = node.kwargs["act_fn"] + + # Prepare kwargs for the quantized op + kwargs = { + "mlp_style": mlp_style, + "act_fn": act_fn, + } + # Replace the current node with the quantized version with gm.graph.inserting_after(node): new_node = gm.graph.call_function( quantized_op, args=tuple(args), + kwargs=kwargs, ) node.replace_all_uses_with(new_node) gm.graph.erase_node(node) diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index c2d53a64e6..0d8649ac92 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -140,9 +140,6 @@ class TestNemotronH(LlmapiAccuracyTestHarness): @pytest.mark.skip_less_device_memory(32000) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) def test_auto_dtype(self, enable_chunked_prefill): - if enable_chunked_prefill: - pytest.skip( - "see https://github.com/NVIDIA/TensorRT-LLM/issues/8272") kwargs = self.get_default_kwargs(enable_chunked_prefill) sampling_params = self.get_default_sampling_params() with AutoDeployLLM(model=self.MODEL_PATH, @@ -152,3 +149,49 @@ class TestNemotronH(LlmapiAccuracyTestHarness): task.evaluate(llm, sampling_params=sampling_params) task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + + +class TestNemotronMOE(LlmapiAccuracyTestHarness): + MODEL_NAME = "nvidia/Nemotron-MOE" + MODEL_PATH = f"{llm_models_root()}/Nemotron-MOE/" + + def get_default_kwargs(self): + return { + "skip_tokenizer_init": False, + "trust_remote_code": True, + # SSMs do not support cache reuse. + "kv_cache_config": { + "enable_block_reuse": False + }, + # Keep max_batch_size as in the PyTorch test to avoid OOM + "max_batch_size": 128, + # Model context length is 8K + "max_seq_len": 8192, + # Set explicitly to match default build_config behavior + "max_num_tokens": 8192, + "skip_loading_weights": False, + "compile_backend": "torch-cudagraph", + "free_mem_ratio": 0.7, + "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128], + } + + def get_default_sampling_params(self): + eos_id = -1 + beam_width = 1 + return SamplingParams(end_id=eos_id, + pad_id=eos_id, + n=beam_width, + use_beam_search=beam_width > 1) + + @pytest.mark.skip_less_device_memory(32000) + def test_auto_dtype(self): + pytest.skip("Nemotron-MOE is not in CI yet") + kwargs = self.get_default_kwargs() + sampling_params = self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_PATH, + tokenizer=self.MODEL_PATH, + **kwargs) as llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py index cfbba5bae2..64207513c0 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py @@ -1,5 +1,6 @@ import pytest import torch +from utils.util import skip_pre_hopper import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 from tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.load_moe_align import moe_align_block_size @@ -215,3 +216,147 @@ def test_moe_align_kernel_groups_tokens_by_expert_and_block_padding(): ref_counts_all = torch.bincount(ref_sorted_used.cpu().to(torch.int64), minlength=T + 1) assert torch.all(ref_counts_all == counts_all) + + +@skip_pre_hopper +def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(): + """Test triton_quant_fp8_moe against torch_quant_fp8_moe reference.""" + torch.manual_seed(0) + + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for triton_quant_fp8_moe test") + device = "cuda" + dtype = torch.bfloat16 + + M = 32 # tokens + HIDDEN_SIZE = 16 # Must be multiple of 16 for FP8 linear + INTERMEDIATE_SIZE = 32 # Must be multiple of 16 for FP8 linear + E = 4 # experts + top_k = 2 + + # Use small normalized values to avoid FP8 range issues + x = torch.randn(M, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 + + # Create BF16 weights for each expert (normalized to small values) + w_up_list = [ + torch.randn(INTERMEDIATE_SIZE, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 + for _ in range(E) + ] + w_down_list = [ + torch.randn(HIDDEN_SIZE, INTERMEDIATE_SIZE, device=device, dtype=dtype) * 0.1 + for _ in range(E) + ] + + # Stack weights [E, ...] + w_up_stacked = torch.stack(w_up_list, dim=0).contiguous() # [E, I, H] + w_down_stacked = torch.stack(w_down_list, dim=0).contiguous() # [E, H, I] + + # Quantize weights to FP8 with per-expert scales + FP8_MIN = torch.finfo(torch.float8_e4m3fn).min + FP8_MAX = torch.finfo(torch.float8_e4m3fn).max + + # Per-expert weight scales (use max absolute value per expert) + w1_weight_scale = torch.tensor( + [w_up_stacked[e].abs().max().item() / FP8_MAX for e in range(E)], + device=device, + dtype=torch.float32, + ) + w2_weight_scale = torch.tensor( + [w_down_stacked[e].abs().max().item() / FP8_MAX for e in range(E)], + device=device, + dtype=torch.float32, + ) + + # Quantize weights and stack + w1_fp8_list = [ + (w_up_stacked[e] / w1_weight_scale[e]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) + for e in range(E) + ] + w2_fp8_list = [ + (w_down_stacked[e] / w2_weight_scale[e]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) + for e in range(E) + ] + w1_fp8_stacked = torch.stack(w1_fp8_list).contiguous() + w2_fp8_stacked = torch.stack(w2_fp8_list).contiguous() + + # Input scales (tensor-wise, replicated per expert for interface compatibility) + x_scale = x.abs().max().item() / FP8_MAX + w1_input_scale_tensor = torch.full((E,), x_scale, device=device, dtype=torch.float32) + + # Compute intermediate activation scale by simulating first GEMM + ReLU^2 + # This ensures w2_input_scale matches the actual activation magnitude + with torch.no_grad(): + # Simulate the first GEMM: quantize input, do FP8 matmul, apply ReLU^2 + x_q = (x / w1_input_scale_tensor[0]).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) + # Dequantize and compute output for a sample + x_dq = x_q[:8].to(torch.float32) * w1_input_scale_tensor[0].item() + w1_dq = w1_fp8_stacked[0].to(torch.float32) * w1_weight_scale[0].item() + sample_out = torch.nn.functional.linear(x_dq.to(dtype), w1_dq.to(dtype)) + sample_act = torch.square(torch.nn.functional.relu(sample_out)) + intermediate_scale = sample_act.abs().max().item() / FP8_MAX + # Ensure scale is not too small + intermediate_scale = max(intermediate_scale, 1e-6) + + w2_input_scale_tensor = torch.full((E,), intermediate_scale, device=device, dtype=torch.float32) + + # Convert scales to lists for torch_quant_fp8_moe reference + w1_input_scale_list = [w1_input_scale_tensor[0].clone() for _ in range(E)] + w2_input_scale_list = [w2_input_scale_tensor[0].clone() for _ in range(E)] + w1_weight_scale_list = [w1_weight_scale[e].clone() for e in range(E)] + w2_weight_scale_list = [w2_weight_scale[e].clone() for e in range(E)] + + # Dummy w3 tensors (unused for mlp style) + w3_fp8_list = [torch.empty((1, 1), device=device, dtype=torch.float8_e4m3fn) for _ in range(E)] + w3_fp8_stacked = torch.stack(w3_fp8_list).contiguous() + w3_input_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)] + w3_input_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32) + w3_weight_scale_list = [torch.ones((), device=device, dtype=torch.float32) for _ in range(E)] + w3_weight_scale_tensor = torch.ones((E,), device=device, dtype=torch.float32) + + # Create controlled routing to ensure even token distribution across experts + selected_experts = torch.zeros((M, top_k), dtype=torch.int64, device=device) + for i in range(M): + # Distribute tokens evenly: token i goes to experts (i % E) and ((i+1) % E) + selected_experts[i, 0] = i % E + selected_experts[i, 1] = (i + 1) % E + + # Create equal routing weights + routing_weights = torch.ones((M, top_k), device=device, dtype=torch.float32) / top_k + + # Triton FP8 quantized MoE (uses stacked tensors) + out_triton = torch.ops.auto_deploy.triton_quant_fp8_moe( + x, + selected_experts.to(torch.int32), + routing_weights, + w1_fp8_stacked, + w2_fp8_stacked, + w3_fp8_stacked, + w1_input_scale_tensor, + w2_input_scale_tensor, + w3_input_scale_tensor, + w1_weight_scale, + w2_weight_scale, + w3_weight_scale_tensor, + mlp_style="mlp", + act_fn="relu2", + ) + + # Reference: Torch quantized FP8 MoE (uses lists of tensors and scales) + out_torch = torch.ops.auto_deploy.torch_quant_fp8_moe( + x, + selected_experts, + routing_weights, + w1_weight=w1_fp8_list, + w2_weight=w2_fp8_list, + w3_weight=w3_fp8_list, + w1_input_scale=w1_input_scale_list, + w2_input_scale=w2_input_scale_list, + w3_input_scale=w3_input_scale_list, + w1_weight_scale=w1_weight_scale_list, + w2_weight_scale=w2_weight_scale_list, + w3_weight_scale=w3_weight_scale_list, + mlp_style="mlp", + act_fn="relu2", + ) + + torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2)