/* * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/common/customAllReduceUtils.h" #include "tensorrt_llm/common/dataType.h" #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/kernels/allReduceFusionKernels.h" #include "tensorrt_llm/kernels/customAllReduceKernels.h" #include "tensorrt_llm/kernels/internal_cutlass_kernels/include/fp4_gemm.h" #include "tensorrt_llm/kernels/quantization.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "tensorrt_llm/thop/thUtils.h" #if ENABLE_MULTI_DEVICE #include #include #endif // ENABLE_MULTI_DEVICE #include #include #include #include #include namespace torch_ext { #if ENABLE_MULTI_DEVICE using tensorrt_llm::kernels::AllReduceFusionOp; namespace { class DeepseekAllreduceOp { public: DeepseekAllreduceOp() {} ~DeepseekAllreduceOp() = default; std::vector run(torch::Tensor input, torch::optional workspace, torch::TensorList reduce_fusion_inputs, int64_t rank, int64_t nranks, double eps, int64_t fusion_op) noexcept { auto const fusion_op_type = static_cast(int8_t(fusion_op)); torch::Tensor residual_out; torch::Tensor norm_out; torch::Tensor quant_out; torch::Tensor scale_out; tensorrt_llm::kernels::ar_fusion::AllReduceFusionParams allreduce_fusion_params; allreduce_fusion_params.quant_out = nullptr; allreduce_fusion_params.scale_out = nullptr; allreduce_fusion_params.residual_out = nullptr; allreduce_fusion_params.norm_out = nullptr; if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4 || fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_AND_QUANT_NVFP4) { TORCH_CHECK(reduce_fusion_inputs.size() == 3, "Pre-MLP fusion should have 3 inputs."); int64_t sfVecSize = 16; int64_t m = 1; auto const& inputShape = input.sizes(); auto const& r = inputShape.size(); TORCH_CHECK(r >= 2, "Input should be >=2D tensor."); for (size_t i = 0; i < r - 1; i++) { m *= inputShape[i]; } auto const k = inputShape[r - 1]; TORCH_CHECK(k % sfVecSize == 0, "Input should be divisible by sfVecSize."); std::vector outputShape(inputShape.begin(), inputShape.end()); outputShape[r - 1] = k / 2; quant_out = at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, input.device(), std::nullopt); scale_out = at::detail::empty_cuda( {tensorrt_llm::computeSFSize(m, k / sfVecSize)}, SF_DTYPE, input.device(), std::nullopt); residual_out = torch::empty_like(reduce_fusion_inputs[0]); allreduce_fusion_params.quant_out = quant_out.mutable_data_ptr(); allreduce_fusion_params.scale_out = scale_out.mutable_data_ptr(); allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr(); if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_AND_QUANT_NVFP4) { norm_out = torch::empty_like(input); allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr(); } } else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM) { norm_out = torch::empty_like(input); residual_out = torch::empty_like(reduce_fusion_inputs[0]); allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr(); allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr(); } else { return std::vector(); } allreduce_fusion_params.nranks = static_cast(nranks); allreduce_fusion_params.rank = static_cast(rank); allreduce_fusion_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type()); allreduce_fusion_params.size = static_cast(input.numel()); allreduce_fusion_params.hidden_dim = static_cast(input.size(-1)); allreduce_fusion_params.workspace = reinterpret_cast(workspace.value().mutable_data_ptr()); allreduce_fusion_params.allreduce_in = input.data_ptr(); allreduce_fusion_params.residual_in = reduce_fusion_inputs[0].data_ptr(); allreduce_fusion_params.rms_gamma = reduce_fusion_inputs[1].data_ptr(); allreduce_fusion_params.rms_eps = static_cast(eps); if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4 || fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_AND_QUANT_NVFP4) { allreduce_fusion_params.scale_factor = static_cast(reduce_fusion_inputs[2].data_ptr()); } else { allreduce_fusion_params.scale_factor = nullptr; } allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(input.get_device()); tensorrt_llm::kernels::ar_fusion::allreduce_fusion_op(allreduce_fusion_params); if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4) { return std::vector({quant_out, scale_out, residual_out}); } else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM) { return std::vector({norm_out, residual_out}); } else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_AND_QUANT_NVFP4) { return std::vector({norm_out, quant_out, scale_out, residual_out}); } else { return std::vector(); } } std::vector run_moe_allreduce(torch::optional workspace, torch::TensorList reduce_fusion_inputs, int64_t rank, int64_t nranks, double eps, int64_t fusion_op) noexcept { auto const fusion_op_type = static_cast(int8_t(fusion_op)); TORCH_CHECK(fusion_op_type == AllReduceFusionOp::MOE_ALLREDUCE_RESIDUAL_RMS_NORM, "Only support MOE_ALLREDUCE_RESIDUAL_RMS_NORM"); auto allreduce_fusion_params = tensorrt_llm::kernels::ar_fusion::MoeReductionAllReduceFusionParams(); allreduce_fusion_params.quant_out = nullptr; allreduce_fusion_params.scale_out = nullptr; allreduce_fusion_params.residual_out = nullptr; allreduce_fusion_params.norm_out = nullptr; allreduce_fusion_params.nranks = static_cast(nranks); allreduce_fusion_params.rank = static_cast(rank); allreduce_fusion_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(reduce_fusion_inputs[5].scalar_type()); // size: num_token * hidden_dim allreduce_fusion_params.size = static_cast(reduce_fusion_inputs[5].numel()); allreduce_fusion_params.hidden_dim = static_cast(reduce_fusion_inputs[4].size(-1)); // workspace: AR scratch space allreduce_fusion_params.workspace = reinterpret_cast(workspace.value().mutable_data_ptr()); allreduce_fusion_params.rms_gamma = reduce_fusion_inputs[1].data_ptr(); allreduce_fusion_params.rms_eps = static_cast(eps); allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(reduce_fusion_inputs[1].get_device()); allreduce_fusion_params.residual_in = reduce_fusion_inputs[0].data_ptr(); // MOE Reduction specific params // reduce_fusion_inputs[0]: residual // reduce_fusion_inputs[1]: gamma // reduce_fusion_inputs[2]: moe_reduction_device_num_experts // reduce_fusion_inputs[3]: moe_reduction_scale_input [device_num_experts, m] // reduce_fusion_inputs[4]: moe_reduction_active_experts_token_input [device_num_experts, m, 7168] // reduce_fusion_inputs[5]: moe_reduction_token_input [m, 7168] allreduce_fusion_params.allreduce_in = nullptr; // for safety, set nullptr allreduce_fusion_params.moe_reduction_device_num_experts = static_cast(reduce_fusion_inputs[2].data_ptr()); allreduce_fusion_params.moe_reduction_scale_input = static_cast(reduce_fusion_inputs[3].data_ptr()); allreduce_fusion_params.moe_reduction_active_experts_token_input = reduce_fusion_inputs[4].data_ptr(); allreduce_fusion_params.moe_reduction_token_input = reduce_fusion_inputs[5].data_ptr(); // output tensors torch::Tensor norm_out = torch::empty_like(reduce_fusion_inputs[5]); torch::Tensor residual_out = torch::empty_like(reduce_fusion_inputs[0]); allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr(); allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr(); tensorrt_llm::kernels::ar_fusion::moereduction_allreduce_fusion_op(allreduce_fusion_params); return std::vector({norm_out, residual_out}); } }; } // namespace #endif // ENABLE_MULTI_DEVICE std::vector deepseekAllreduceFusion(torch::Tensor input, torch::optional workspace, torch::TensorList reduce_fusion_inputs, int64_t const rank, int64_t const nranks, double const eps, int64_t const fusion_op) { #if ENABLE_MULTI_DEVICE DeepseekAllreduceOp op; auto fusion_op_type = static_cast(int8_t(fusion_op)); if (fusion_op_type == AllReduceFusionOp::MOE_ALLREDUCE_RESIDUAL_RMS_NORM) { return op.run_moe_allreduce(workspace, reduce_fusion_inputs, rank, nranks, eps, fusion_op); } else { return op.run(input, workspace, reduce_fusion_inputs, rank, nranks, eps, fusion_op); } #else return std::vector(); #endif // ENABLE_MULTI_DEVICE } } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) { // reduce_fusion_inputs includes // 0: residual // 1. gamma // 2. scale_factor: only when fusion_op == RESIDUAL_RMS_NORM_QUANT_NVFP4 // for moe allreduce // 0: residual // 1: gamma // 2: moe_reduction_device_num_experts [1] // 3: moe_reduction_scale_input [global_num_experts, m] // 4: moe_reduction_active_experts_token_input [device_num_experts, m, 7168] // 5: moe_reduction_token_input [m, 7168] m.def( "deepseek_allreduce_fusion(Tensor input, Tensor? workspace, Tensor[] reduce_fusion_inputs, " "int rank, int nranks, float eps, int fusion_op) -> Tensor[]"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("deepseek_allreduce_fusion", &torch_ext::deepseekAllreduceFusion); }