diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 76368412c4..efbbac39e3 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -226,6 +226,81 @@ def _register_fake(): return (input.new_empty(output_shape, dtype=torch.uint8), global_scale.new_empty(scale_shape, dtype=torch.uint8)) + @torch.library.register_fake("trtllm::mxfp8_quantize") + def _( + input: torch.Tensor, + swizzled_layout: bool = True, + alignment: int = 32, + ): + SF_VEC_SIZE = 32 + + def pad_up(x, m: int): + return (x + m - 1) // m * m + + m_val = 1 + for d in input.shape[:-1]: + m_val = m_val * d + + k = input.shape[-1] + padded_k = pad_up(k, alignment) + + out_shape = list(input.shape) + out_shape[-1] = padded_k + + # Output tensor: float8_e4m3fn, last dim padded to alignment + val_mxfp8 = input.new_empty(out_shape, dtype=torch.float8_e4m3fn) + + # Scale tensor: 1D uint8, size depends on swizzled vs linear layout + cols = padded_k // SF_VEC_SIZE + if swizzled_layout: + sf_size = pad_up(m_val, 128) * pad_up(cols, 4) + else: + sf_size = m_val * cols + + scale_fp8_sf = input.new_empty((sf_size, ), dtype=torch.uint8) + return val_mxfp8, scale_fp8_sf + + @torch.library.register_fake("trtllm::mxe4m3_mxe2m1_block_scale_moe_runner") + def _( + routing_logits: Optional[torch.Tensor], + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + hidden_states_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm1_bias: Optional[torch.Tensor], + gemm1_alpha: Optional[torch.Tensor], + gemm1_beta: Optional[torch.Tensor], + gemm1_clamp_limit: Optional[torch.Tensor], + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + gemm2_bias: Optional[torch.Tensor], + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + valid_hidden_size: Optional[int], + valid_intermediate_size: Optional[int], + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int, + act_type: int, + topk_weights: Optional[torch.Tensor] = None, + topk_ids: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] + out_hidden_size = valid_hidden_size if valid_hidden_size is not None else hidden_size + + if output is not None: + return output + + return hidden_states.new_empty((num_tokens, out_hidden_size), + dtype=torch.bfloat16) + @torch.library.register_fake("trtllm::calculate_nvfp4_global_scale") def _(input: torch.Tensor, tokens_per_batch: Optional[torch.Tensor]): return input.new_empty((input.shape[:-1], 1), dtype=torch.float32)