/* * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/kernels/quantization.h" #include "tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/thop/thUtils.h" #include #include #include #include #include #include namespace torch_ext { namespace btg = batchedGemm::trtllm::gen; using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType; using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner; using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN; torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional const& routing_logits, torch::optional const& routing_bias, torch::Tensor const& hidden_states, std::optional const& hidden_states_scale, torch::Tensor const& gemm1_weights, torch::Tensor const& gemm1_weights_scale, std::optional const& gemm1_bias, std::optional const& gemm1_alpha, std::optional const& gemm1_beta, std::optional const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights, torch::Tensor const& gemm2_weights_scale, std::optional const& gemm2_bias, std::optional const& output1_scale_scalar, std::optional const& output1_scale_gate_scalar, std::optional const& output2_scale_scalar, int64_t const num_experts, int64_t const top_k, std::optional const n_group, std::optional const topk_group, int64_t const intermediate_size, std::optional const valid_hidden_size, std::optional const valid_intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts, std::optional const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, btg::Dtype const dtype, MoeRunnerType& moe_runner, int64_t moeConfigIndex, torch::optional const& topk_weights, torch::optional const& topk_ids, torch::optional const& out_tensor) { TORCH_CHECK(tensorrt_llm::common::isSM100Family(), "Only SM100f is supported by MXFP4 block scale MOE"); TORCH_CHECK(tile_tokens_dim == 8 || tile_tokens_dim == 16 || tile_tokens_dim == 32 || tile_tokens_dim == 64 || tile_tokens_dim == 128 || tile_tokens_dim == 256, "tile_tokens_dim must be 8, 16, 32, 64, 128, 256"); if (topk_ids.has_value() && topk_weights.has_value()) { TORCH_CHECK(topk_ids.value().scalar_type() == at::ScalarType::Int, "topk_ids must be int."); TORCH_CHECK(topk_weights.value().scalar_type() == at::ScalarType::BFloat16, "topk_weights must be bfloat."); TORCH_CHECK(topk_ids.value().dim() == 2, "topk_ids must be 2D."); TORCH_CHECK(topk_weights.value().dim() == 2, "topk_weights must be 2D."); TORCH_CHECK( topk_ids.value().sizes()[0] == hidden_states.sizes()[0], "topk_ids dim0 must match hidden_states dim0."); TORCH_CHECK(topk_ids.value().sizes()[1] == top_k, "topk_ids dim1 must match top_k."); TORCH_CHECK(topk_weights.value().sizes()[0] == hidden_states.sizes()[0], "topk_weights dim0 must match hidden_states dim0."); TORCH_CHECK(topk_weights.value().sizes()[1] == top_k, "topk_weights dim1 must match top_k."); } else if (routing_logits.has_value()) { if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::Float, "routing_logits must be float"); } else { TORCH_CHECK( routing_logits.value().scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16"); } TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts, "routing_logits dim1 must match num_experts."); } else { TORCH_CHECK(false, "routing_logits or (topk_ids and topk_weights) must be provided."); } if (topk_ids.has_value() && topk_weights.has_value() && routing_logits.has_value()) { TLLM_LOG_WARNING( "When logits and (topk_ids and topk_weights) are both provided, we only use (topk_ids and topk_weights)."); } if (topk_ids.has_value()) { TORCH_CHECK(topk_ids.value().sizes()[0] == hidden_states.sizes()[0], "topk_ids and hidden_states must have the same number of tokens."); } else { TORCH_CHECK(routing_logits.value().sizes()[0] == hidden_states.sizes()[0], "routing_logits and hidden_states must have the same number of tokens."); } if (routing_bias.has_value()) { TORCH_CHECK(routing_bias.value().scalar_type() == at::ScalarType::BFloat16, "routing_bias must be bfloat16."); TORCH_CHECK(routing_bias.value().dim() == 1, "routing_bias must be 1D."); TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape."); } if (n_group.has_value() && n_group.value() != 0) { TORCH_CHECK(static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3, "Routing kernel with groups implies DeepSeekV3 routing method."); TORCH_CHECK(topk_group.has_value(), "if n_group is given, topk_group must be given"); TORCH_CHECK(num_experts % n_group.value() == 0, "num_experts must be divisible by n_group"); TORCH_CHECK(top_k <= 8 && top_k > 0, "Current routing kernel (with groups) only supports top_k<=8 && top_k>0."); TORCH_CHECK(topk_group.value() <= 4 && topk_group.value() > 0, "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0."); TORCH_CHECK(topk_group.value() <= n_group.value(), "n_group must not be smaller than topk_group."); // This check ensures we have enough experts in the selected groups to handle the top_k routing TORCH_CHECK(top_k < (topk_group.value() * num_experts / n_group.value()), "top_k must be less than total number of experts in selected groups"); } else if (static_cast(routing_method_type) == RoutingMethodType::Renormalize || static_cast(routing_method_type) == RoutingMethodType::RenormalizeNaive) { TORCH_CHECK(top_k <= 10 && top_k > 0, "Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0."); } TORCH_CHECK(num_experts % 4 == 0, "Routing kernel expects that num_experts must be divisible by 4"); TORCH_CHECK(num_experts > top_k, "num_experts must be greater than top_k"); // If both routing inputs are provided, they must be on the same device if (routing_logits.has_value() && topk_ids.has_value()) { TORCH_CHECK( routing_logits->device() == topk_ids->device(), "routing_logits and topk_ids must be on the same device"); } tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::MoERunnerArgs args; tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::MoEWorkspace workspace; // setup args args.mDtypeElt = dtype; args.routing_logits = routing_logits.has_value() ? routing_logits.value().data_ptr() : nullptr; args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; args.hidden_states = hidden_states.data_ptr(); args.hidden_states_scale = hidden_states_scale.has_value() ? hidden_states_scale.value().data_ptr() : nullptr; args.topk_weights = topk_weights.has_value() ? topk_weights.value().data_ptr() : nullptr; args.topk_ids = topk_ids.has_value() ? static_cast(topk_ids.value().data_ptr()) : nullptr; args.gemm1_weights = gemm1_weights.data_ptr(); args.gemm1_weights_scale = gemm1_weights_scale.data_ptr(); args.gemm2_weights = gemm2_weights.data_ptr(); args.gemm2_weights_scale = gemm2_weights_scale.data_ptr(); args.gemm1_bias = gemm1_bias.has_value() ? gemm1_bias.value().data_ptr() : nullptr; args.gemm1_alpha = gemm1_alpha.has_value() ? gemm1_alpha.value().data_ptr() : nullptr; args.gemm1_beta = gemm1_beta.has_value() ? gemm1_beta.value().data_ptr() : nullptr; args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr() : nullptr; args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr() : nullptr; args.output1_scales_scalar = output1_scale_scalar.has_value() ? output1_scale_scalar.value().data_ptr() : nullptr; args.output1_scales_gate_scalar = output1_scale_gate_scalar.has_value() ? output1_scale_gate_scalar.value().data_ptr() : nullptr; args.output2_scales_scalar = output2_scale_scalar.has_value() ? output2_scale_scalar.value().data_ptr() : nullptr; args.num_tokens = hidden_states.sizes()[0]; args.num_experts = num_experts; // Hidden dimension input of MoE block. It might be padded. args.hidden_size = hidden_states.sizes()[1]; args.output_hidden_size = valid_hidden_size.has_value() ? tensorrt_llm::common::roundUp(valid_hidden_size.value(), 128) : args.hidden_size; // Unpadded size. args.valid_hidden_size = valid_hidden_size; args.top_k = top_k; args.n_group = n_group.value_or(0); args.topk_group = topk_group.value_or(0); args.local_expert_offset = local_expert_offset; args.local_num_experts = local_num_experts; args.routed_scaling_factor = routed_scaling_factor.value_or(1.0); args.intermediate_size = intermediate_size; args.valid_intermediate_size = valid_intermediate_size; // allocate workspace for routing kernel if (routing_logits.has_value() && topk_ids.has_value()) { TORCH_CHECK(routing_logits.value().device() == topk_ids.value().device(), "routing_logits and topk_ids must be on the same device"); } auto routing_device = routing_logits.has_value() ? routing_logits.value().device() : topk_ids.value().device(); at::Tensor num_tokens_per_expert = at::detail::empty_cuda({num_experts}, at::ScalarType::Int, routing_device, std::nullopt); int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::getMaxPermutedPaddedCount( args.num_tokens, top_k, num_experts, tile_tokens_dim); int32_t max_num_padded_tokens_gemm1 = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::maybeGetMinTokenCount( max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt)); int32_t max_num_padded_tokens_gemm2 = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::maybeGetMinTokenCount( max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); at::Tensor total_num_padded_tokens = at::empty({}, at::TensorOptions().device(routing_device).dtype(at::ScalarType::Int)); at::Tensor expanded_idx_to_permuted_idx = at::detail::empty_cuda({args.num_tokens * args.top_k}, at::ScalarType::Int, routing_device, std::nullopt); at::Tensor permuted_idx_to_token_idx = at::detail::empty_cuda({max_num_padded_tokens}, at::ScalarType::Int, routing_device, std::nullopt); at::Tensor expert_weights = at::detail::empty_cuda({args.num_tokens, args.top_k}, at::ScalarType::BFloat16, routing_device, std::nullopt); at::Tensor expert_indexes = at::detail::empty_cuda({args.num_tokens, args.top_k}, at::ScalarType::Int, routing_device, std::nullopt); int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); at::Tensor expert_count_histogram = at::detail::empty_cuda({size_of_expert_count_histogram}, at::ScalarType::Int, routing_device, std::nullopt); // Set the optional pointer to the expert weights and expert ids void* expert_weights_ptr = args.topk_weights ? args.topk_weights : expert_weights.data_ptr(); int32_t const sf_block_size = 32; // allocate workspace for activation/gemm/finalize kernels auto const gemm1_output_type = dtype == btg::Dtype::Bfloat16 ? at::ScalarType::BFloat16 : at::ScalarType::Float8_e4m3fn; at::Tensor gemm1_output = at::detail::empty_cuda( {max_num_padded_tokens_gemm1, intermediate_size}, gemm1_output_type, routing_device, std::nullopt); std::optional gemm1_output_scale; if (dtype == btg::Dtype::MxE4m3) { int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens_gemm1, intermediate_size / sf_block_size); gemm1_output_scale = at::detail::empty_cuda({sf_size}, SF_DTYPE, routing_device, std::nullopt); } at::Tensor gemm2_output = at::detail::empty_cuda( {max_num_padded_tokens_gemm2, args.hidden_size}, at::ScalarType::BFloat16, routing_device, std::nullopt); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::getMaxNumCtasInBatchDim( args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); at::Tensor cta_idx_xy_to_batch_idx = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int, routing_device, std::nullopt); at::Tensor cta_idx_xy_to_mn_limit = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int, routing_device, std::nullopt); at::Tensor num_non_exiting_ctas = at::empty({}, at::TensorOptions().device(routing_device).dtype(at::ScalarType::Int)); // FIXME: check shape TORCH_CHECK(dtype == btg::Dtype::MxE4m3 || dtype == btg::Dtype::Bfloat16 || dtype == btg::Dtype::E4m3, "dtype must be MxE4m3 or Bfloat16 or E4m3."); if (dtype == btg::Dtype::MxE4m3) { TORCH_CHECK(hidden_states_scale.has_value(), "hidden_states_scale must be provided for MxE4m3."); } else { TORCH_CHECK( !hidden_states_scale.has_value(), "hidden_states_scale must not be provided for Bfloat16 and E4m3."); } // // TopK routing // tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream( routing_logits.has_value() ? routing_logits.value().get_device() : topk_ids.value().get_device()); routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), expanded_idx_to_permuted_idx.data_ptr(), nullptr, /*permuted_idx_to_expanded_idx.data_ptr(),*/ permuted_idx_to_token_idx.data_ptr(), expert_weights_ptr, args.topk_ids, num_tokens_per_expert.data_ptr(), cta_idx_xy_to_batch_idx.data_ptr(), cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), args.mDtypeElt, false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, static_cast(routing_method_type), stream); // // FC13 (gemm1) + FC2 (gemm2) // if (dtype == btg::Dtype::MxE4m3 || dtype == btg::Dtype::E4m3) { TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be Float8_e4m3fn, got %s.", c10::toString(hidden_states.scalar_type())); } else { TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::BFloat16, "hidden_states must be BFloat16, got %s.", c10::toString(hidden_states.scalar_type())); } if (dtype == btg::Dtype::MxE4m3) { TORCH_CHECK(hidden_states_scale->scalar_type() == SF_DTYPE, "hidden_states_scale must be UInt8, got %s.", c10::toString(hidden_states_scale->scalar_type())); TORCH_CHECK(hidden_states_scale->dim() == 1, "hidden_states_scale must be 1D."); TORCH_CHECK(hidden_states_scale->sizes()[0] == tensorrt_llm::computeLinearLayoutSFSize(args.num_tokens, args.hidden_size / sf_block_size), "hidden_states_scale has incorrect size"); } TORCH_CHECK(gemm1_weights.scalar_type() == FLOAT4_E2M1X2, "gemm1_weights must be byte, got %s.", c10::toString(gemm1_weights.scalar_type())); TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D."); TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even."); TORCH_CHECK(2 * intermediate_size == gemm1_weights.sizes()[1], "intermediate_size has incorrect dim 1 (lhs = ", 2 * intermediate_size, ", rhs = ", gemm1_weights.sizes()[1], ")"); // The actual shape of the weights[2] is 2 times larger than and hidden_states[1] // due to the fact that 2 e2m1 are packed into 1 byte for FP4 weights. TORCH_CHECK(gemm1_weights.sizes()[2] * 2 == hidden_states.sizes()[1], "the third dimension of weights must be equal to hidden_size (lhs = ", gemm1_weights.sizes()[2] * 2, ", rhs = ", hidden_states.sizes()[1], ")"); TORCH_CHECK(gemm1_weights_scale.scalar_type() == SF_DTYPE, "gemm1_weights_scale must be UInt8, got %s.", c10::toString(gemm1_weights_scale.scalar_type())); TORCH_CHECK(gemm1_weights_scale.dim() == 3, "gemm1_weights_scale must be 3D."); TORCH_CHECK(gemm1_weights_scale.sizes()[0] == local_num_experts, "gemm1_weights_scale has incorrect dim 0."); TORCH_CHECK(intermediate_size % sf_block_size == 0, "the second dimension of weights must be a multiple of 32."); TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size, "gemm1_weights_scale has incorrect dim 1."); TORCH_CHECK(gemm1_weights_scale.sizes()[2] == args.hidden_size / sf_block_size, "gemm1_weights_scale has incorrect dim 2 (lhs = ", gemm1_weights_scale.sizes()[2], ", rhs = ", args.hidden_size / sf_block_size, ")"); if (gemm1_bias.has_value()) { TORCH_CHECK(gemm1_bias.value().scalar_type() == at::ScalarType::Float, "gemm1_bias must be float, got %s.", c10::toString(gemm1_bias.value().scalar_type())); TORCH_CHECK(gemm1_bias.value().dim() == 2, "gemm1_bias must be 2D."); TORCH_CHECK(gemm1_bias.value().sizes()[0] == local_num_experts, "gemm1_bias has incorrect dim 0."); TORCH_CHECK(gemm1_bias.value().sizes()[1] == 2 * intermediate_size, "gemm1_bias has incorrect dim 1."); } if (gemm1_alpha.has_value()) { TORCH_CHECK(gemm1_alpha.value().scalar_type() == at::ScalarType::Float, "gemm1_alpha must be float, got %s.", c10::toString(gemm1_alpha.value().scalar_type())); TORCH_CHECK(gemm1_alpha.value().dim() == 1, "gemm1_alpha must be 1D."); TORCH_CHECK(gemm1_alpha.value().sizes()[0] == local_num_experts, "gemm1_alpha has incorrect dim 0."); } if (gemm1_beta.has_value()) { TORCH_CHECK(gemm1_beta.value().scalar_type() == at::ScalarType::Float, "gemm1_beta must be float, got %s.", c10::toString(gemm1_beta.value().scalar_type())); TORCH_CHECK(gemm1_beta.value().dim() == 1, "gemm1_beta must be 1D."); TORCH_CHECK(gemm1_beta.value().sizes()[0] == local_num_experts, "gemm1_beta has incorrect dim 0."); } if (gemm1_clamp_limit.has_value()) { TORCH_CHECK(gemm1_clamp_limit.value().scalar_type() == at::ScalarType::Float, "gemm1_clamp_limit must be float, got %s.", c10::toString(gemm1_clamp_limit.value().scalar_type())); TORCH_CHECK(gemm1_clamp_limit.value().dim() == 1, "gemm1_clamp_limit must be 1D."); TORCH_CHECK( gemm1_clamp_limit.value().sizes()[0] == local_num_experts, "gemm1_clamp_limit has incorrect dim 0."); } TORCH_CHECK(gemm2_weights.scalar_type() == FLOAT4_E2M1X2, "gemm2_weights must be byte, got %s.", c10::toString(gemm2_weights.scalar_type())); TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D."); // / 2 to compensate for the fact that we pack 2 e2m1 into 1 byte. TORCH_CHECK(gemm2_weights.sizes()[2] == intermediate_size / 2, "the third dimension of weights must be equal to intermediate_size."); TORCH_CHECK(gemm2_weights_scale.scalar_type() == SF_DTYPE, "gemm2_weights_scale must be UInt8, got %s.", c10::toString(gemm2_weights_scale.scalar_type())); TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D."); TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts, "gemm2_weights_scale has incorrect dim 0."); TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.output_hidden_size, "gemm2_weights_scale has incorrect dim 1. (lhs = ", gemm2_weights_scale.sizes()[1], ", rhs = ", args.output_hidden_size, ")"); TORCH_CHECK(gemm2_weights_scale.sizes()[2] == intermediate_size / sf_block_size, "gemm2_weights_scale has incorrect dim 2."); if (gemm2_bias.has_value()) { TORCH_CHECK(gemm2_bias.value().scalar_type() == at::ScalarType::Float, "gemm2_bias must be float, got %s.", c10::toString(gemm2_bias.value().scalar_type())); TORCH_CHECK(gemm2_bias.value().dim() == 2, "gemm2_bias must be 2D."); TORCH_CHECK(gemm2_bias.value().sizes()[0] == local_num_experts, "gemm2_bias has incorrect dim 0."); TORCH_CHECK(gemm2_bias.value().sizes()[1] == args.output_hidden_size, "gemm2_bias has incorrect dim 1 (lhs = ", gemm2_bias.value().sizes()[1], ", rhs = ", args.output_hidden_size, ")"); } if (dtype == btg::Dtype::E4m3) { TORCH_CHECK(output1_scale_scalar.has_value(), "output1_scale_scalar must be provided for MxE4m3."); TORCH_CHECK(output1_scale_gate_scalar.has_value(), "output1_scale_gate_scalar must be provided for MxE4m3."); TORCH_CHECK(output2_scale_scalar.has_value(), "output2_scale_scalar must be provided for MxE4m3."); TORCH_CHECK( output1_scale_scalar->scalar_type() == at::ScalarType::Float, "output1_scales_scalar must be float."); TORCH_CHECK(output1_scale_scalar->dim() == 1, "output1_scales_scalar must be 1D."); TORCH_CHECK( output1_scale_scalar->sizes()[0] == local_num_experts, "output1_scales_scalar has incorrect dim 0."); TORCH_CHECK(output1_scale_gate_scalar->scalar_type() == at::ScalarType::Float, "output1_scales_gate_scalar must be float."); TORCH_CHECK(output1_scale_gate_scalar->dim() == 1, "output1_scales_gate_scalar must be 1D."); TORCH_CHECK(output1_scale_gate_scalar->sizes()[0] == local_num_experts, "output1_scales_gate_scalar has incorrect dim 0."); TORCH_CHECK( output2_scale_scalar->scalar_type() == at::ScalarType::Float, "output2_scales_scalar must be float."); TORCH_CHECK(output2_scale_scalar->dim() == 1, "output2_scales_scalar must be 1D."); TORCH_CHECK( output2_scale_scalar->sizes()[0] == local_num_experts, "output2_scales_scalar has incorrect dim 0."); } // allocate or use provided output at::Tensor output; if (out_tensor.has_value()) { TORCH_CHECK(out_tensor->scalar_type() == at::ScalarType::BFloat16, "out_tensor must be bfloat16."); TORCH_CHECK(out_tensor->dim() == 2, "out_tensor must be 2D."); TORCH_CHECK(out_tensor->sizes()[0] == args.num_tokens && out_tensor->sizes()[1] == args.valid_hidden_size.value_or(args.hidden_size), "out_tensor has incorrect shape."); TORCH_CHECK(out_tensor->device() == hidden_states.device(), "out_tensor must be on the same device as inputs."); output = out_tensor.value(); } else { output = at::detail::empty_cuda({args.num_tokens, args.valid_hidden_size.value_or(args.hidden_size)}, at::ScalarType::BFloat16, hidden_states.device(), std::nullopt); } // setup workspace workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr(); workspace.total_max_padded_tokens = std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); workspace.ProjUpTileN = tile_tokens_dim; workspace.routing_expert_indexes = expert_indexes.data_ptr(); workspace.permuted_idx_size = total_num_padded_tokens.data_ptr(); workspace.expanded_idx_to_permuted_idx = expanded_idx_to_permuted_idx.data_ptr(); // Needed by permute/finalize kernels workspace.permuted_idx_to_token_idx = permuted_idx_to_token_idx.data_ptr(); // Needed by permuteGemm1 kernel workspace.expert_weights = expert_weights_ptr; // Consumed by finalize kernel workspace.cta_idx_xy_to_batch_idx = cta_idx_xy_to_batch_idx.data_ptr(); workspace.cta_idx_xy_to_mn_limit = cta_idx_xy_to_mn_limit.data_ptr(); workspace.num_non_exiting_ctas = num_non_exiting_ctas.data_ptr(); // gemm1 intermediate ws workspace.gemm1_output = gemm1_output.data_ptr(); workspace.gemm1_output_scale = gemm1_output_scale.has_value() ? reinterpret_cast(gemm1_output_scale->data_ptr()) : nullptr; // gemm2 intermediate ws workspace.gemm2_output = gemm2_output.data_ptr(); workspace.gemm2_output_scale = nullptr; args.output = output.data_ptr(); args.output_scale = nullptr; auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); at::Tensor workspace_fc1 = at::detail::empty_cuda( {std::get<0>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt); at::Tensor workspace_fc2 = at::detail::empty_cuda( {std::get<1>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt); workspace.bmm1_workspace = workspace_fc1.data_ptr(); workspace.bmm2_workspace = workspace_fc2.data_ptr(); auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device()); moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex); return output; } // Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow // use with the torch workflow autotuner class. class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder { public: explicit Bf16MxE2m1BlockScaleMoeRunner(int64_t actType) // Update this as new cubins come in : mSupportedTileN{8, 16, 32, 64} { for (int tileN : mSupportedTileN) { mRunners.emplace(tileN, std::make_unique(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, tileN, static_cast(actType))); } } [[nodiscard]] std::vector> getValidConfigs(int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens, int64_t validHiddenSize, int64_t validIntermediateSize) const { // returns (tileN, config) std::vector> tactics; for (auto& [tileN, runner] : mRunners) { auto chosen = computeSelectedTileN(mSupportedTileN, numTokens, topK, numLocalExperts); if (chosen.find(tileN) == chosen.end()) { continue; } auto config_indices_per_runner = runner->getValidConfigIndices( topK, hiddenSize, intermediateSize, numLocalExperts, numTokens, validHiddenSize, validIntermediateSize); for (auto cfg : config_indices_per_runner) { tactics.push_back({tileN, cfg}); } } return tactics; } // BF16 run does not use hidden_states_scale [[nodiscard]] torch::Tensor run(torch::optional const& routing_logits, std::optional const& routing_bias, torch::Tensor const& hidden_states, torch::Tensor const& gemm1_weights, torch::Tensor const& gemm1_weights_scale, std::optional const& gemm1_bias, std::optional const& gemm1_alpha, std::optional const& gemm1_beta, std::optional const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights, torch::Tensor const& gemm2_weights_scale, std::optional const& gemm2_bias, int64_t num_experts, int64_t top_k, std::optional const n_group, std::optional const topk_group, int64_t intermediate_size, std::optional const valid_hidden_size, std::optional const valid_intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, std::optional routed_scaling_factor, int64_t routing_method_type, std::vector moeConfigIndex, torch::optional const& topk_weights, torch::optional const& topk_ids) { // moeConfigIndex corresponds to pair (tileN, config) auto [tileN, config] = std::tie(moeConfigIndex[0], moeConfigIndex[1]); // Autotuner has requested a default or 'fallback' config index if (tileN == -1 || config == -1) { auto const num_tokens = hidden_states.sizes()[0]; auto const hidden_size = hidden_states.sizes()[1]; float const avg_tokens_per_expert = static_cast(num_tokens * top_k) / local_num_experts; tileN = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), mSupportedTileN.front(), mSupportedTileN.back()); config = mRunners[tileN]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, local_num_experts, num_tokens, valid_hidden_size.value_or(hidden_size), valid_intermediate_size.value_or(intermediate_size)); } return dtype_mxe2m1_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, std::nullopt, gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, gemm2_bias, std::nullopt, std::nullopt, std::nullopt, num_experts, top_k, n_group, topk_group, intermediate_size, valid_hidden_size, valid_intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN, routing_method_type, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids, /*out_tensor=*/torch::nullopt); // TODO: Support user-provided output } private: using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner; std::vector const mSupportedTileN; std::unordered_map> mRunners; btg::Dtype mDtypeAct{btg::Dtype::Bfloat16}; btg::Dtype mDtypeWeights{btg::Dtype::MxE2m1}; bool mUseDeepSeekFp8{false}; }; class MxE4m3MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder { public: explicit MxE4m3MxE2m1BlockScaleMoeRunner(int64_t actType, bool isMxFp8) // Update this as new cubins come in : mSupportedTileN{isMxFp8 ? std::vector{8, 16, 32, 64, 128, 256} : std::vector{8, 16, 32, 64}} , mDtypeAct(isMxFp8 ? btg::Dtype::MxE4m3 : btg::Dtype::E4m3) { for (int tileN : mSupportedTileN) { mRunners.emplace(tileN, std::make_unique(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, tileN, static_cast(actType))); } } [[nodiscard]] std::vector> getValidConfigs(int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens, int64_t validHiddenSize, int64_t validIntermediateSize) const { // returns (tileN, config) std::vector> tactics; for (auto& [tileN, runner] : mRunners) { auto chosen = computeSelectedTileN(mSupportedTileN, numTokens, topK, numLocalExperts); if (chosen.find(tileN) == chosen.end()) { continue; } auto config_indices_per_runner = runner->getValidConfigIndices( topK, hiddenSize, intermediateSize, numLocalExperts, numTokens, validHiddenSize, validIntermediateSize); for (auto cfg : config_indices_per_runner) { tactics.push_back({tileN, cfg}); } } return tactics; } [[nodiscard]] torch::Tensor run(torch::optional const& routing_logits, std::optional const& routing_bias, torch::Tensor const& hidden_states, std::optional const& hidden_states_scale, torch::Tensor const& gemm1_weights, torch::Tensor const& gemm1_weights_scale, std::optional const& gemm1_bias, std::optional const& gemm1_alpha, std::optional const& gemm1_beta, std::optional const& gemm1_clamp_limit, torch::Tensor const& gemm2_weights, torch::Tensor const& gemm2_weights_scale, std::optional const& gemm2_bias, std::optional const& output1_scale_scalar, std::optional const& output1_scale_gate_scalar, std::optional const& output2_scale_scalar, int64_t num_experts, int64_t top_k, std::optional const n_group, std::optional const topk_group, int64_t intermediate_size, std::optional const valid_hidden_size, std::optional const valid_intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, std::optional routed_scaling_factor, int64_t routing_method_type, std::vector tile_config_pair, torch::optional const& topk_weights, torch::optional const& topk_ids, torch::optional const& output) { // tile_config_pair corresponds to pair (tileN, config) auto [tileN, config] = std::tie(tile_config_pair[0], tile_config_pair[1]); // Autotuner has requested a default or 'fallback' config index if (tileN == -1 || config == -1) { auto const num_tokens = hidden_states.sizes()[0]; auto const hidden_size = hidden_states.sizes()[1]; float const avg_tokens_per_expert = static_cast(num_tokens * top_k) / local_num_experts; tileN = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), mSupportedTileN.front(), mSupportedTileN.back()); config = mRunners[tileN]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, local_num_experts, num_tokens, valid_hidden_size.value_or(hidden_size), valid_intermediate_size.value_or(intermediate_size)); } return dtype_mxe2m1_block_scale_moe_runner(routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, gemm2_bias, output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar, num_experts, top_k, n_group, topk_group, intermediate_size, valid_hidden_size, valid_intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN, routing_method_type, mDtypeAct, *mRunners[tileN], config, topk_weights, topk_ids, output); } private: using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner; std::vector const mSupportedTileN; std::unordered_map> mRunners; btg::Dtype mDtypeAct{btg::Dtype::MxE4m3}; btg::Dtype mDtypeWeights{btg::Dtype::MxE2m1}; bool mUseDeepSeekFp8{false}; }; } // namespace torch_ext // Accepts CUDA tensor only TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.class_("Bf16MxE2m1BlockScaleMoERunner") .def(torch::init()) .def("get_valid_configs", &torch_ext::Bf16MxE2m1BlockScaleMoeRunner::getValidConfigs) .def("run_moe", &torch_ext::Bf16MxE2m1BlockScaleMoeRunner::run); m.class_("MxE4m3MxE2m1BlockScaleMoERunner") .def(torch::init()) .def("get_valid_configs", &torch_ext::MxE4m3MxE2m1BlockScaleMoeRunner::getValidConfigs) .def("run_moe", &torch_ext::MxE4m3MxE2m1BlockScaleMoeRunner::run); }