mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Co-authored-by: Kefeng-Duan <176893526+Kefeng-Duan@users.noreply.github.com>
272 lines
11 KiB
C++
272 lines
11 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
|
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/common/customAllReduceUtils.h"
|
|
#include "tensorrt_llm/common/dataType.h"
|
|
#include "tensorrt_llm/common/opUtils.h"
|
|
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
|
|
#include "tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h"
|
|
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
|
|
#include "tensorrt_llm/kernels/internal_cutlass_kernels/include/fp4_gemm.h"
|
|
#include "tensorrt_llm/kernels/quantization.h"
|
|
#include "tensorrt_llm/runtime/torchUtils.h"
|
|
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
|
#include "tensorrt_llm/thop/thUtils.h"
|
|
|
|
#if ENABLE_MULTI_DEVICE
|
|
#include <ATen/cuda/EmptyTensor.h>
|
|
|
|
#include <nccl.h>
|
|
#endif // ENABLE_MULTI_DEVICE
|
|
#include <nvml.h>
|
|
#include <torch/extension.h>
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <vector>
|
|
|
|
namespace torch_ext
|
|
{
|
|
|
|
#if ENABLE_MULTI_DEVICE
|
|
|
|
using tensorrt_llm::kernels::AllReduceFusionOp;
|
|
|
|
namespace
|
|
{
|
|
|
|
class DeepseekAllreduceOp
|
|
{
|
|
public:
|
|
DeepseekAllreduceOp() {}
|
|
|
|
~DeepseekAllreduceOp() = default;
|
|
|
|
std::vector<torch::Tensor> run(torch::Tensor input, torch::optional<torch::Tensor> workspace,
|
|
torch::TensorList reduce_fusion_inputs, int64_t rank, int64_t nranks, double eps, int64_t fusion_op) noexcept
|
|
{
|
|
auto const fusion_op_type = static_cast<AllReduceFusionOp>(int8_t(fusion_op));
|
|
|
|
torch::Tensor residual_out;
|
|
torch::Tensor norm_out;
|
|
torch::Tensor quant_out;
|
|
torch::Tensor scale_out;
|
|
|
|
tensorrt_llm::kernels::ar_fusion::AllReduceFusionParams allreduce_fusion_params;
|
|
|
|
allreduce_fusion_params.quant_out = nullptr;
|
|
allreduce_fusion_params.scale_out = nullptr;
|
|
allreduce_fusion_params.residual_out = nullptr;
|
|
allreduce_fusion_params.norm_out = nullptr;
|
|
|
|
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4
|
|
|| fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
|
|
{
|
|
TORCH_CHECK(reduce_fusion_inputs.size() == 3, "Pre-MLP fusion should have 3 inputs.");
|
|
|
|
int64_t sfVecSize = 16;
|
|
int64_t m = 1;
|
|
auto const& inputShape = input.sizes();
|
|
auto const& r = inputShape.size();
|
|
TORCH_CHECK(r >= 2, "Input should be >=2D tensor.");
|
|
for (size_t i = 0; i < r - 1; i++)
|
|
{
|
|
m *= inputShape[i];
|
|
}
|
|
auto const k = inputShape[r - 1];
|
|
TORCH_CHECK(k % sfVecSize == 0, "Input should be divisible by sfVecSize.");
|
|
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
|
|
outputShape[r - 1] = k / 2;
|
|
|
|
quant_out = at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, input.device(), std::nullopt);
|
|
scale_out = at::detail::empty_cuda({tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)},
|
|
SF_DTYPE, input.device(), std::nullopt);
|
|
residual_out = torch::empty_like(reduce_fusion_inputs[0]);
|
|
|
|
allreduce_fusion_params.quant_out = quant_out.mutable_data_ptr();
|
|
allreduce_fusion_params.scale_out = scale_out.mutable_data_ptr();
|
|
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
|
|
|
|
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
|
|
{
|
|
norm_out = torch::empty_like(input);
|
|
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
|
|
allreduce_fusion_params.pattern
|
|
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant;
|
|
}
|
|
else
|
|
{
|
|
allreduce_fusion_params.pattern
|
|
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNormFP4Quant;
|
|
}
|
|
}
|
|
else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM)
|
|
{
|
|
norm_out = torch::empty_like(input);
|
|
residual_out = torch::empty_like(reduce_fusion_inputs[0]);
|
|
|
|
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
|
|
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
|
|
allreduce_fusion_params.pattern
|
|
= tensorrt_llm::kernels::ar_fusion::AllReduceFusionPattern::kARResidualRMSNorm;
|
|
}
|
|
else
|
|
{
|
|
return std::vector<torch::Tensor>();
|
|
}
|
|
|
|
allreduce_fusion_params.nranks = static_cast<int>(nranks);
|
|
allreduce_fusion_params.rank = static_cast<int>(rank);
|
|
allreduce_fusion_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
|
|
allreduce_fusion_params.size = static_cast<int>(input.numel());
|
|
allreduce_fusion_params.hidden_dim = static_cast<int>(input.size(-1));
|
|
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr());
|
|
allreduce_fusion_params.allreduce_in = input.data_ptr();
|
|
allreduce_fusion_params.residual_in = reduce_fusion_inputs[0].data_ptr();
|
|
allreduce_fusion_params.rms_gamma = reduce_fusion_inputs[1].data_ptr();
|
|
allreduce_fusion_params.rms_eps = static_cast<float>(eps);
|
|
allreduce_fusion_params.use_oneshot = input.size(0) <= tensorrt_llm::kernels::ar_fusion::kOneShotMaxToken;
|
|
|
|
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4
|
|
|| fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
|
|
{
|
|
allreduce_fusion_params.scale_factor = static_cast<float*>(reduce_fusion_inputs[2].data_ptr());
|
|
}
|
|
else
|
|
{
|
|
allreduce_fusion_params.scale_factor = nullptr;
|
|
}
|
|
|
|
allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
|
|
|
tensorrt_llm::kernels::ar_fusion::allreduce_fusion_op(allreduce_fusion_params);
|
|
|
|
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4)
|
|
{
|
|
return std::vector<torch::Tensor>({quant_out, scale_out, residual_out});
|
|
}
|
|
else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM)
|
|
{
|
|
return std::vector<torch::Tensor>({norm_out, residual_out});
|
|
}
|
|
else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
|
|
{
|
|
return std::vector<torch::Tensor>({norm_out, quant_out, scale_out, residual_out});
|
|
}
|
|
else
|
|
{
|
|
return std::vector<torch::Tensor>();
|
|
}
|
|
}
|
|
|
|
std::vector<torch::Tensor> run_moe_allreduce(torch::optional<torch::Tensor> workspace,
|
|
torch::TensorList reduce_fusion_inputs, int64_t rank, int64_t nranks, double eps, int64_t fusion_op) noexcept
|
|
{
|
|
|
|
auto allreduce_fusion_params = tensorrt_llm::kernels::ar_fusion::moe::MoeReductionAllReduceFusionParams();
|
|
|
|
allreduce_fusion_params.quant_out = nullptr;
|
|
allreduce_fusion_params.scale_out = nullptr;
|
|
allreduce_fusion_params.residual_out = nullptr;
|
|
allreduce_fusion_params.norm_out = nullptr;
|
|
|
|
allreduce_fusion_params.nranks = static_cast<int>(nranks);
|
|
allreduce_fusion_params.rank = static_cast<int>(rank);
|
|
allreduce_fusion_params.dtype
|
|
= tensorrt_llm::runtime::TorchUtils::dataType(reduce_fusion_inputs[5].scalar_type());
|
|
// size: num_token * hidden_dim
|
|
allreduce_fusion_params.size = static_cast<int>(reduce_fusion_inputs[5].numel());
|
|
allreduce_fusion_params.hidden_dim = static_cast<int>(reduce_fusion_inputs[4].size(-1));
|
|
|
|
// workspace: AR scratch space
|
|
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr());
|
|
|
|
allreduce_fusion_params.rms_gamma = reduce_fusion_inputs[1].data_ptr();
|
|
allreduce_fusion_params.rms_eps = static_cast<float>(eps);
|
|
allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(reduce_fusion_inputs[1].get_device());
|
|
|
|
allreduce_fusion_params.residual_in = reduce_fusion_inputs[0].data_ptr();
|
|
|
|
// MOE Reduction specific params
|
|
// reduce_fusion_inputs[0]: residual
|
|
// reduce_fusion_inputs[1]: gamma
|
|
// reduce_fusion_inputs[2]: moe_reduction_device_num_experts
|
|
// reduce_fusion_inputs[3]: moe_reduction_scale_input [device_num_experts, m]
|
|
// reduce_fusion_inputs[4]: moe_reduction_active_experts_token_input [device_num_experts, m, 7168]
|
|
// reduce_fusion_inputs[5]: moe_reduction_token_input [m, 7168]
|
|
allreduce_fusion_params.allreduce_in = nullptr; // for safety, set nullptr
|
|
allreduce_fusion_params.moe_reduction_device_num_experts
|
|
= static_cast<int*>(reduce_fusion_inputs[2].data_ptr());
|
|
allreduce_fusion_params.moe_reduction_scale_input = static_cast<float*>(reduce_fusion_inputs[3].data_ptr());
|
|
allreduce_fusion_params.moe_reduction_active_experts_token_input = reduce_fusion_inputs[4].data_ptr();
|
|
allreduce_fusion_params.moe_reduction_token_input = reduce_fusion_inputs[5].data_ptr();
|
|
|
|
// output tensors
|
|
torch::Tensor norm_out = torch::empty_like(reduce_fusion_inputs[5]);
|
|
torch::Tensor residual_out = torch::empty_like(reduce_fusion_inputs[0]);
|
|
|
|
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
|
|
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
|
|
|
|
tensorrt_llm::kernels::ar_fusion::moe::moereduction_allreduce_fusion_op(allreduce_fusion_params);
|
|
|
|
return std::vector<torch::Tensor>({norm_out, residual_out});
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
#endif // ENABLE_MULTI_DEVICE
|
|
|
|
std::vector<torch::Tensor> deepseekAllreduceFusion(torch::Tensor input, torch::optional<torch::Tensor> workspace,
|
|
torch::TensorList reduce_fusion_inputs, int64_t const rank, int64_t const nranks, double const eps,
|
|
int64_t const fusion_op)
|
|
{
|
|
#if ENABLE_MULTI_DEVICE
|
|
DeepseekAllreduceOp op;
|
|
auto fusion_op_type = static_cast<AllReduceFusionOp>(int8_t(fusion_op));
|
|
if (fusion_op_type == AllReduceFusionOp::MOE_ALLREDUCE_RESIDUAL_RMS_NORM)
|
|
{
|
|
return op.run_moe_allreduce(workspace, reduce_fusion_inputs, rank, nranks, eps, fusion_op);
|
|
}
|
|
else
|
|
{
|
|
return op.run(input, workspace, reduce_fusion_inputs, rank, nranks, eps, fusion_op);
|
|
}
|
|
#else
|
|
return std::vector<torch::Tensor>();
|
|
#endif // ENABLE_MULTI_DEVICE
|
|
}
|
|
|
|
} // namespace torch_ext
|
|
|
|
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
|
{
|
|
// reduce_fusion_inputs includes
|
|
// 0: residual
|
|
// 1. gamma
|
|
// 2. scale_factor: only when fusion_op == RESIDUAL_RMS_NORM_QUANT_NVFP4
|
|
m.def(
|
|
"deepseek_allreduce_fusion(Tensor input, Tensor? workspace, Tensor[] reduce_fusion_inputs, "
|
|
"int rank, int nranks, float eps, int fusion_op) -> Tensor[]");
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
|
{
|
|
m.impl("deepseek_allreduce_fusion", &torch_ext::deepseekAllreduceFusion);
|
|
}
|