TensorRT-LLMs/cpp/tensorrt_llm/thop/deepseekAllreduceFusionOp.cpp
Zongfei Jing c7548ad72c
perf: Add optimizations for deepseek in min latency mode (#3093)
* Add optimizations for deepseek min latency

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Fix compile error

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Update internal cutlass kernel libs

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Format code

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

* Resolve conflicts

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>

---------

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
2025-04-02 09:05:24 +08:00

273 lines
11 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/customAllReduceUtils.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/allReduceFusionKernels.h"
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
#include "tensorrt_llm/kernels/internal_cutlass_kernels/include/fp4_gemm.h"
#include "tensorrt_llm/kernels/quantization.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#if ENABLE_MULTI_DEVICE
#include <ATen/cuda/EmptyTensor.h>
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE
#include <nvml.h>
#include <torch/extension.h>
#include <cstddef>
#include <cstdint>
#include <vector>
namespace torch_ext
{
#if ENABLE_MULTI_DEVICE
using tensorrt_llm::kernels::AllReduceFusionOp;
namespace
{
class DeepseekAllreduceOp
{
public:
DeepseekAllreduceOp() {}
~DeepseekAllreduceOp() = default;
std::vector<torch::Tensor> run(torch::Tensor input, torch::optional<torch::Tensor> workspace,
torch::TensorList reduce_fusion_inputs, int64_t rank, int64_t nranks, double eps, int64_t fusion_op) noexcept
{
auto const fusion_op_type = static_cast<AllReduceFusionOp>(int8_t(fusion_op));
torch::Tensor residual_out;
torch::Tensor norm_out;
torch::Tensor quant_out;
torch::Tensor scale_out;
tensorrt_llm::kernels::ar_fusion::AllReduceFusionParams allreduce_fusion_params;
allreduce_fusion_params.quant_out = nullptr;
allreduce_fusion_params.scale_out = nullptr;
allreduce_fusion_params.residual_out = nullptr;
allreduce_fusion_params.norm_out = nullptr;
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4
|| fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_AND_QUANT_NVFP4)
{
TORCH_CHECK(reduce_fusion_inputs.size() == 3, "Pre-MLP fusion should have 3 inputs.");
int64_t sfVecSize = 16;
int64_t m = 1;
auto const& inputShape = input.sizes();
auto const& r = inputShape.size();
TORCH_CHECK(r >= 2, "Input should be >=2D tensor.");
for (size_t i = 0; i < r - 1; i++)
{
m *= inputShape[i];
}
auto const k = inputShape[r - 1];
TORCH_CHECK(k % sfVecSize == 0, "Input should be divisible by sfVecSize.");
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
outputShape[r - 1] = k / 2;
quant_out = at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, input.device(), std::nullopt);
scale_out = at::detail::empty_cuda(
{tensorrt_llm::computeSFSize(m, k / sfVecSize)}, SF_DTYPE, input.device(), std::nullopt);
residual_out = torch::empty_like(reduce_fusion_inputs[0]);
allreduce_fusion_params.quant_out = quant_out.mutable_data_ptr();
allreduce_fusion_params.scale_out = scale_out.mutable_data_ptr();
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_AND_QUANT_NVFP4)
{
norm_out = torch::empty_like(input);
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
}
}
else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM)
{
norm_out = torch::empty_like(input);
residual_out = torch::empty_like(reduce_fusion_inputs[0]);
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
}
else
{
return std::vector<torch::Tensor>();
}
allreduce_fusion_params.nranks = static_cast<int>(nranks);
allreduce_fusion_params.rank = static_cast<int>(rank);
allreduce_fusion_params.dtype = tensorrt_llm::runtime::TorchUtils::dataType(input.scalar_type());
allreduce_fusion_params.size = static_cast<int>(input.numel());
allreduce_fusion_params.hidden_dim = static_cast<int>(input.size(-1));
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr());
allreduce_fusion_params.allreduce_in = input.data_ptr();
allreduce_fusion_params.residual_in = reduce_fusion_inputs[0].data_ptr();
allreduce_fusion_params.rms_gamma = reduce_fusion_inputs[1].data_ptr();
allreduce_fusion_params.rms_eps = static_cast<float>(eps);
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4
|| fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_AND_QUANT_NVFP4)
{
allreduce_fusion_params.scale_factor = static_cast<float*>(reduce_fusion_inputs[2].data_ptr());
}
else
{
allreduce_fusion_params.scale_factor = nullptr;
}
allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
tensorrt_llm::kernels::ar_fusion::allreduce_fusion_op(allreduce_fusion_params);
if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4)
{
return std::vector<torch::Tensor>({quant_out, scale_out, residual_out});
}
else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM)
{
return std::vector<torch::Tensor>({norm_out, residual_out});
}
else if (fusion_op_type == AllReduceFusionOp::RESIDUAL_RMS_NORM_AND_QUANT_NVFP4)
{
return std::vector<torch::Tensor>({norm_out, quant_out, scale_out, residual_out});
}
else
{
return std::vector<torch::Tensor>();
}
}
std::vector<torch::Tensor> run_moe_allreduce(torch::optional<torch::Tensor> workspace,
torch::TensorList reduce_fusion_inputs, int64_t rank, int64_t nranks, double eps, int64_t fusion_op) noexcept
{
auto const fusion_op_type = static_cast<AllReduceFusionOp>(int8_t(fusion_op));
TORCH_CHECK(fusion_op_type == AllReduceFusionOp::MOE_ALLREDUCE_RESIDUAL_RMS_NORM,
"Only support MOE_ALLREDUCE_RESIDUAL_RMS_NORM");
auto allreduce_fusion_params = tensorrt_llm::kernels::ar_fusion::MoeReductionAllReduceFusionParams();
allreduce_fusion_params.quant_out = nullptr;
allreduce_fusion_params.scale_out = nullptr;
allreduce_fusion_params.residual_out = nullptr;
allreduce_fusion_params.norm_out = nullptr;
allreduce_fusion_params.nranks = static_cast<int>(nranks);
allreduce_fusion_params.rank = static_cast<int>(rank);
allreduce_fusion_params.dtype
= tensorrt_llm::runtime::TorchUtils::dataType(reduce_fusion_inputs[5].scalar_type());
// size: num_token * hidden_dim
allreduce_fusion_params.size = static_cast<int>(reduce_fusion_inputs[5].numel());
allreduce_fusion_params.hidden_dim = static_cast<int>(reduce_fusion_inputs[4].size(-1));
// workspace: AR scratch space
allreduce_fusion_params.workspace = reinterpret_cast<void**>(workspace.value().mutable_data_ptr());
allreduce_fusion_params.rms_gamma = reduce_fusion_inputs[1].data_ptr();
allreduce_fusion_params.rms_eps = static_cast<float>(eps);
allreduce_fusion_params.stream = at::cuda::getCurrentCUDAStream(reduce_fusion_inputs[1].get_device());
allreduce_fusion_params.residual_in = reduce_fusion_inputs[0].data_ptr();
// MOE Reduction specific params
// reduce_fusion_inputs[0]: residual
// reduce_fusion_inputs[1]: gamma
// reduce_fusion_inputs[2]: moe_reduction_device_num_experts
// reduce_fusion_inputs[3]: moe_reduction_scale_input [device_num_experts, m]
// reduce_fusion_inputs[4]: moe_reduction_active_experts_token_input [device_num_experts, m, 7168]
// reduce_fusion_inputs[5]: moe_reduction_token_input [m, 7168]
allreduce_fusion_params.allreduce_in = nullptr; // for safety, set nullptr
allreduce_fusion_params.moe_reduction_device_num_experts
= static_cast<int*>(reduce_fusion_inputs[2].data_ptr());
allreduce_fusion_params.moe_reduction_scale_input = static_cast<float*>(reduce_fusion_inputs[3].data_ptr());
allreduce_fusion_params.moe_reduction_active_experts_token_input = reduce_fusion_inputs[4].data_ptr();
allreduce_fusion_params.moe_reduction_token_input = reduce_fusion_inputs[5].data_ptr();
// output tensors
torch::Tensor norm_out = torch::empty_like(reduce_fusion_inputs[5]);
torch::Tensor residual_out = torch::empty_like(reduce_fusion_inputs[0]);
allreduce_fusion_params.norm_out = norm_out.mutable_data_ptr();
allreduce_fusion_params.residual_out = residual_out.mutable_data_ptr();
tensorrt_llm::kernels::ar_fusion::moereduction_allreduce_fusion_op(allreduce_fusion_params);
return std::vector<torch::Tensor>({norm_out, residual_out});
}
};
} // namespace
#endif // ENABLE_MULTI_DEVICE
std::vector<torch::Tensor> deepseekAllreduceFusion(torch::Tensor input, torch::optional<torch::Tensor> workspace,
torch::TensorList reduce_fusion_inputs, int64_t const rank, int64_t const nranks, double const eps,
int64_t const fusion_op)
{
#if ENABLE_MULTI_DEVICE
DeepseekAllreduceOp op;
auto fusion_op_type = static_cast<AllReduceFusionOp>(int8_t(fusion_op));
if (fusion_op_type == AllReduceFusionOp::MOE_ALLREDUCE_RESIDUAL_RMS_NORM)
{
return op.run_moe_allreduce(workspace, reduce_fusion_inputs, rank, nranks, eps, fusion_op);
}
else
{
return op.run(input, workspace, reduce_fusion_inputs, rank, nranks, eps, fusion_op);
}
#else
return std::vector<torch::Tensor>();
#endif // ENABLE_MULTI_DEVICE
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
// reduce_fusion_inputs includes
// 0: residual
// 1. gamma
// 2. scale_factor: only when fusion_op == RESIDUAL_RMS_NORM_QUANT_NVFP4
// for moe allreduce
// 0: residual
// 1: gamma
// 2: moe_reduction_device_num_experts [1]
// 3: moe_reduction_scale_input [global_num_experts, m]
// 4: moe_reduction_active_experts_token_input [device_num_experts, m, 7168]
// 5: moe_reduction_token_input [m, 7168]
m.def(
"deepseek_allreduce_fusion(Tensor input, Tensor? workspace, Tensor[] reduce_fusion_inputs, "
"int rank, int nranks, float eps, int fusion_op) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("deepseek_allreduce_fusion", &torch_ext::deepseekAllreduceFusion);
}