TensorRT-LLMs/cpp/tensorrt_llm/thop/deepseekAllreduceFusionOp.cpp
Yukun He 5502a522d2
Fixing minor typo in allreduce kernel selection (#3912)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Co-authored-by: Kefeng-Duan <176893526+Kefeng-Duan@users.noreply.github.com>
2025-04-28 23:06:49 +08:00

272 lines
11 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/customAllReduceUtils.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
#include "tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h"
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
#include "tensorrt_llm/kernels/internal_cutlass_kernels/include/fp4_gemm.h"
#include "tensorrt_llm/kernels/quantization.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#if ENABLE_MULTI_DEVICE
#include <ATen/cuda/EmptyTensor.h>
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE
#include <nvml.h>
#include <torch/extension.h>
#include <cstddef>
#include <cstdint>
#include <vector>
namespace torch_ext
{
#if ENABLE_MULTI_DEVICE
using tensorrt_llm::kernels::AllReduceFusionOp;
namespace
{
class DeepseekAllreduceOp
{
public:
DeepseekAllreduceOp() {}
~DeepseekAllreduceOp() = default;
std::vector<torch::Tensor> run(torch::Tensor input, torch::optional<torch::Tensor> workspace,
torch::TensorList reduce_fusion_inputs, int64_t rank, int64_t nranks, double eps, int64_t fusion_op) noexcept
{
auto const fusion_op_type = static_cast<AllReduceFusionOp>(int8_t(fusion_op));
torch::Tensor residual_out;
torch::Tensor norm_out;
torch::Tensor quant_out;
torch::Tensor scale_out;
tensorrt_llm::kernels::ar_fusion::AllReduceFusionParams allreduce_fusion_params;
allreduce_fusion_params.quant_out = nullptr;
allreduce_fusion_params.scale_out = nullptr;
allreduce_fusion_params.residual_out = nullptr;
allreduce_fusion_params.norm_out = nullptr;
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4
|| fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
{
TORCH_CHECK(reduce_fusion_inputs.size() == 3, "Pre-MLP fusion should have 3 inputs.");
int64_t sfVecSize = 16;
int64_t m = 1;
auto const& inputShape = input.sizes();
auto const& r = inputShape.size();
TORCH_CHECK(r >= 2, "Input should be >=2D tensor.");
for (size_t i = 0; i < r - 1; i++)
{
m *= inputShape[i];
}
auto const k = inputShape[r - 1];
TORCH_CHECK(k % sfVecSize == 0, "Input should be divisible by sfVecSize.");
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
outputShape[r - 1] = k / 2;
quant_out = at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, input.device(), std::nullopt);
scale_out = at::detail::empty_cuda({tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)},
SF_DTYPE, input.device(), std::nullopt);
residual_out = torch::empty_like(reduce_fusion_inputs[0]);
allreduce_fusion_params.quant_out = quant_out.mutable_data_ptr();
allreduce_fusion_params.scale_out = scale_out.mutable_data_ptr();
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
{
norm_out = torch::empty_like(input);
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
allreduce_fusion_params.pattern
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant;
}
else
{
allreduce_fusion_params.pattern
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP4Quant;
}
}
else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM)
{
norm_out = torch::empty_like(input);
residual_out = torch::empty_like(reduce_fusion_inputs[0]);
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
allreduce_fusion_params.pattern
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNorm;
}
else
{
return std::vector<torch::Tensor>();
}
allreduce_fusion_params.nranks = static_cast<int>(nranks);
allreduce_fusion_params.rank = static_cast<int>(rank);
allreduce_fusion_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
allreduce_fusion_params.size = static_cast<int>(input.numel());
allreduce_fusion_params.hidden_dim = static_cast<int>(input.size(-1));
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr());
allreduce_fusion_params.allreduce_in = input.data_ptr();
allreduce_fusion_params.residual_in = reduce_fusion_inputs[0].data_ptr();
allreduce_fusion_params.rms_gamma = reduce_fusion_inputs[1].data_ptr();
allreduce_fusion_params.rms_eps = static_cast<float>(eps);
allreduce_fusion_params.use_oneshot = input.size(0) <= tensorrt_llm::kernels::ar_fusion::kOneShotMaxToken;
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4
|| fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
{
allreduce_fusion_params.scale_factor = static_cast<float*>(reduce_fusion_inputs[2].data_ptr());
}
else
{
allreduce_fusion_params.scale_factor = nullptr;
}
allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
tensorrt_llm::kernels::ar_fusion::allreduce_fusion_op(allreduce_fusion_params);
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4)
{
return std::vector<torch::Tensor>({quant_out, scale_out, residual_out});
}
else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM)
{
return std::vector<torch::Tensor>({norm_out, residual_out});
}
else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
{
return std::vector<torch::Tensor>({norm_out, quant_out, scale_out, residual_out});
}
else
{
return std::vector<torch::Tensor>();
}
}
std::vector<torch::Tensor> run_moe_allreduce(torch::optional<torch::Tensor> workspace,
torch::TensorList reduce_fusion_inputs, int64_t rank, int64_t nranks, double eps, int64_t fusion_op) noexcept
{
auto allreduce_fusion_params = tensorrt_llm::kernels::ar_fusion::moe::MoeReductionAllReduceFusionParams();
allreduce_fusion_params.quant_out = nullptr;
allreduce_fusion_params.scale_out = nullptr;
allreduce_fusion_params.residual_out = nullptr;
allreduce_fusion_params.norm_out = nullptr;
allreduce_fusion_params.nranks = static_cast<int>(nranks);
allreduce_fusion_params.rank = static_cast<int>(rank);
allreduce_fusion_params.dtype
= tensorrt_llm::runtime::TorchUtils::dataType(reduce_fusion_inputs[5].scalar_type());
// size: num_token * hidden_dim
allreduce_fusion_params.size = static_cast<int>(reduce_fusion_inputs[5].numel());
allreduce_fusion_params.hidden_dim = static_cast<int>(reduce_fusion_inputs[4].size(-1));
// workspace: AR scratch space
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr());
allreduce_fusion_params.rms_gamma = reduce_fusion_inputs[1].data_ptr();
allreduce_fusion_params.rms_eps = static_cast<float>(eps);
allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(reduce_fusion_inputs[1].get_device());
allreduce_fusion_params.residual_in = reduce_fusion_inputs[0].data_ptr();
// MOE Reduction specific params
// reduce_fusion_inputs[0]: residual
// reduce_fusion_inputs[1]: gamma
// reduce_fusion_inputs[2]: moe_reduction_device_num_experts
// reduce_fusion_inputs[3]: moe_reduction_scale_input [device_num_experts, m]
// reduce_fusion_inputs[4]: moe_reduction_active_experts_token_input [device_num_experts, m, 7168]
// reduce_fusion_inputs[5]: moe_reduction_token_input [m, 7168]
allreduce_fusion_params.allreduce_in = nullptr; // for safety, set nullptr
allreduce_fusion_params.moe_reduction_device_num_experts
= static_cast<int*>(reduce_fusion_inputs[2].data_ptr());
allreduce_fusion_params.moe_reduction_scale_input = static_cast<float*>(reduce_fusion_inputs[3].data_ptr());
allreduce_fusion_params.moe_reduction_active_experts_token_input = reduce_fusion_inputs[4].data_ptr();
allreduce_fusion_params.moe_reduction_token_input = reduce_fusion_inputs[5].data_ptr();
// output tensors
torch::Tensor norm_out = torch::empty_like(reduce_fusion_inputs[5]);
torch::Tensor residual_out = torch::empty_like(reduce_fusion_inputs[0]);
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
tensorrt_llm::kernels::ar_fusion::moe::moereduction_allreduce_fusion_op(allreduce_fusion_params);
return std::vector<torch::Tensor>({norm_out, residual_out});
}
};
} // namespace
#endif // ENABLE_MULTI_DEVICE
std::vector<torch::Tensor> deepseekAllreduceFusion(torch::Tensor input, torch::optional<torch::Tensor> workspace,
torch::TensorList reduce_fusion_inputs, int64_t const rank, int64_t const nranks, double const eps,
int64_t const fusion_op)
{
#if ENABLE_MULTI_DEVICE
DeepseekAllreduceOp op;
auto fusion_op_type = static_cast<AllReduceFusionOp>(int8_t(fusion_op));
if (fusion_op_type == AllReduceFusionOp::MOE_ALLREDUCE_RESIDUAL_RMS_NORM)
{
return op.run_moe_allreduce(workspace, reduce_fusion_inputs, rank, nranks, eps, fusion_op);
}
else
{
return op.run(input, workspace, reduce_fusion_inputs, rank, nranks, eps, fusion_op);
}
#else
return std::vector<torch::Tensor>();
#endif // ENABLE_MULTI_DEVICE
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
// reduce_fusion_inputs includes
// 0: residual
// 1. gamma
// 2. scale_factor: only when fusion_op == RESIDUAL_RMS_NORM_QUANT_NVFP4
m.def(
"deepseek_allreduce_fusion(Tensor input, Tensor? workspace, Tensor[] reduce_fusion_inputs, "
"int rank, int nranks, float eps, int fusion_op) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("deepseek_allreduce_fusion", &torch_ext::deepseekAllreduceFusion);
}