TensorRT-LLMs/cpp/tests/unit_tests/kernels/allReduce/allReduceFusionTest.cu
2025-03-11 21:13:42 +08:00

447 lines
15 KiB
Plaintext

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <nccl.h>
#include <cstdint>
#include <functional>
#include <iostream>
#include <random>
#include <vector>
#include "tensorrt_llm/kernels/allReduceFusionKernels.h"
#include "tensorrt_llm/kernels/quantization.h"
#include "tensorrt_llm/kernels/rmsnormKernels.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/multiDeviceUtils.h"
namespace mpi = tensorrt_llm::mpi;
namespace tr = tensorrt_llm::runtime;
using namespace tensorrt_llm::kernels;
template <typename DType>
__global__ void residual_add_kernel(DType* data, DType* residual, int size)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= size)
return;
data[idx] = data[idx] + residual[idx];
}
template <typename DType>
void residual_add(DType* data, DType* residual, int size, cudaStream_t stream)
{
residual_add_kernel<<<size / 128, 128, 0, stream>>>(data, residual, size);
}
template <typename DType>
__global__ void cast_to_fp32_kernel(DType* in, float* out, int size)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= size)
return;
out[idx] = static_cast<float>(in[idx]);
}
template <typename DType>
void cast_to_fp32(DType* in, float* out, int size, cudaStream_t stream)
{
cast_to_fp32_kernel<<<size / 128, 128, 0, stream>>>(in, out, size);
}
template <typename T>
void print(int rank, void* _pa, int size)
{
auto pa = reinterpret_cast<T*>(_pa);
if (rank == 0)
{
printf("print: [");
for (int n = 0; n < 20; ++n)
{
float v = static_cast<float>(pa[n]);
printf("%f, ", v);
}
printf("...]\n");
}
}
template <typename T>
float compare(int rank, void* _pa, void* _pb, int size, float scale, std::string const& cmp_info = "")
{
auto pa = reinterpret_cast<T*>(_pa);
auto pb = reinterpret_cast<T*>(_pb);
float max_diff = 0.f, tot_diff = 0.f;
float max_val = 0.f;
int diff_cnt = 0;
float threshold = 1e-7;
static char* ar_debug = std::getenv("AR_DEBUG");
if (ar_debug && rank == 0)
{
printf("TensorA: [");
for (int n = 0; n < 20; ++n)
{
float v = static_cast<float>(pa[n]);
printf("%f, ", v);
}
printf("...]\n");
printf("TensorB: [");
for (int n = 0; n < 20; ++n)
{
float v = static_cast<float>(pb[n]);
printf("%f, ", v);
}
printf("...]\n");
}
int print_cnt = 0;
for (int n = 0; n < size; ++n)
{
float va = static_cast<float>(pa[n]);
float vb = static_cast<float>(pb[n]);
max_val = std::max(max_val, vb);
float diff = std::abs(va - vb);
if (diff > threshold)
{
max_diff = std::max(max_diff, diff);
tot_diff += diff;
++diff_cnt;
}
if (rank == 0 && print_cnt < 20 && ar_debug && diff / (std::abs(vb) + 1e-7) > 0.1)
{
++print_cnt;
printf("idx %d, va %f, vb %f\n", n, va, vb);
}
}
float diff_thres = max_val * scale;
if (rank == 0)
{
TLLM_LOG_INFO("[%s] rank %d, max diff %f (diff threshold %f), avg diff %f, diff cnt %d/%d", cmp_info.c_str(),
rank, max_diff, diff_thres, tot_diff / std::max(diff_cnt, 1), diff_cnt, size);
}
return max_diff <= diff_thres;
}
template <typename T1, typename T2>
void random_fill(T1* data, int size, T2 minv, T2 maxv)
{
static int rseed = 20250227;
std::mt19937 gen(rseed++);
std::uniform_real_distribution<float> dis(static_cast<float>(minv), static_cast<float>(maxv));
for (int i = 0; i < size; ++i)
{
data[i] = static_cast<T1>(dis(gen));
}
}
struct CudaBuffer
{
void* m_d_data;
void* m_h_data;
int m_size;
CudaBuffer(int size_in_bytes = 0)
: m_size(size_in_bytes)
, m_d_data(nullptr)
, m_h_data(nullptr)
{
allocate(size_in_bytes);
}
void allocate(int size_in_bytes)
{
if (size_in_bytes == 0)
return;
TLLM_CHECK(m_d_data == nullptr && m_h_data == nullptr);
m_size = size_in_bytes;
TLLM_CUDA_CHECK(cudaMalloc(&m_d_data, m_size));
TLLM_CUDA_CHECK(cudaMemset(m_d_data, 0, m_size));
m_h_data = malloc(m_size);
}
template <typename T = void>
T* device_data()
{
TLLM_CHECK(m_d_data != nullptr);
return reinterpret_cast<T*>(m_d_data);
}
template <typename T = void>
T* host_data()
{
TLLM_CHECK(m_h_data != nullptr);
d2h();
return reinterpret_cast<T*>(m_h_data);
}
template <typename DType, typename VType>
void random(VType minv, VType maxv)
{
random_fill(reinterpret_cast<DType*>(m_h_data), m_size / sizeof(DType), minv, maxv);
h2d();
}
void h2d()
{
TLLM_CUDA_CHECK(cudaMemcpy(m_d_data, m_h_data, m_size, cudaMemcpyHostToDevice));
}
void d2h()
{
TLLM_CUDA_CHECK(cudaMemcpy(m_h_data, m_d_data, m_size, cudaMemcpyDeviceToHost));
}
~CudaBuffer()
{
if (m_d_data)
{
TLLM_CUDA_CHECK(cudaFree(m_d_data));
}
if (m_h_data)
{
free(m_h_data);
}
}
};
template <typename DType>
class TestRunner
{
static_assert(std::is_same_v<DType, half> || std::is_same_v<DType, __nv_bfloat16>);
static constexpr ncclDataType_t kNCCLDataType = std::is_same_v<DType, half> ? ncclFloat16 : ncclBfloat16;
static constexpr nvinfer1::DataType kTRTDataType
= std::is_same_v<DType, half> ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kBF16;
public:
TestRunner(int max_token_num, int hidden_dim)
: m_mpi_comm(mpi::MpiComm::world())
{
m_message_size = max_token_num * hidden_dim;
m_world_size = m_mpi_comm.getSize();
m_rank = m_mpi_comm.getRank();
TLLM_CUDA_CHECK(cudaSetDevice(m_rank));
ncclUniqueId id;
if (m_rank == 0)
{
TLLM_NCCL_CHECK(ncclGetUniqueId(&id));
}
m_mpi_comm.bcast(&id, sizeof(id), mpi::MpiType::kBYTE, 0);
TLLM_NCCL_CHECK(ncclCommInitRank(&m_nccl_comm, m_world_size, id, m_rank));
m_allreduce_in.allocate(m_message_size * sizeof(DType));
m_residual_in.allocate(m_message_size * sizeof(DType));
m_residual_out.allocate(m_message_size * sizeof(DType));
m_norm_out.allocate(m_message_size * sizeof(DType));
m_quant_out.allocate(m_message_size * sizeof(DType));
m_scale_out.allocate(m_message_size * sizeof(DType));
m_rms_gamma.allocate(hidden_dim * sizeof(DType));
m_scale_factor.allocate(sizeof(float));
m_stream = std::make_shared<tr::CudaStream>();
m_workspace = std::make_shared<ar_fusion::Workspace>(m_rank, m_world_size, max_token_num, hidden_dim, m_stream);
m_params.nranks = m_world_size;
m_params.rank = m_rank;
m_params.dtype = kTRTDataType;
m_params.workspace = m_workspace->get_workspace();
m_params.allreduce_in = m_allreduce_in.device_data();
m_params.residual_in = m_residual_in.device_data();
m_params.residual_out = m_residual_out.device_data();
m_params.norm_out = m_norm_out.device_data();
m_params.quant_out = m_quant_out.device_data();
m_params.scale_out = m_scale_out.device_data();
m_params.rms_gamma = m_rms_gamma.device_data();
m_params.scale_factor = m_scale_factor.device_data<float>();
m_params.rms_eps = 1e-3;
m_params.stream = m_stream->get();
}
void random_input()
{
m_allreduce_in.random<DType>(-100.f, 100.f);
m_residual_in.random<DType>(-100.f, 100.f);
m_rms_gamma.random<DType>(-1.f, 1.f);
m_scale_factor.random<float>(5.f, 5.f);
}
template <typename Func>
float benchmark(Func func, int warmup, int iter, int token_num, int hidden_dim)
{
m_params.size = token_num * hidden_dim;
m_params.hidden_dim = hidden_dim;
cudaEvent_t begin, end;
cudaEventCreate(&begin);
cudaEventCreate(&end);
random_input();
m_mpi_comm.barrier();
for (int i = 0; i < warmup; ++i)
{
(this->*func)(token_num, hidden_dim);
}
cudaEventRecord(begin, m_stream->get());
for (int i = 0; i < iter; ++i)
{
(this->*func)(token_num, hidden_dim);
}
cudaEventRecord(end, m_stream->get());
cudaEventSynchronize(end);
float time;
cudaEventElapsedTime(&time, begin, end);
time /= iter;
m_mpi_comm.barrier();
cudaEventDestroy(begin);
cudaEventDestroy(end);
return time * 1000;
}
int get_sm_count()
{
static int sm_count = 0;
if (sm_count == 0)
{
int device_id;
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
sm_count = device_prop.multiProcessorCount;
}
return sm_count;
}
void verify(int token_num, int hidden_dim)
{
int message_size = token_num * hidden_dim;
CudaBuffer ref_output(message_size * sizeof(DType)), ref_scale(message_size * sizeof(DType));
TLLM_NCCL_CHECK(ncclAllReduce(m_allreduce_in.device_data(), ref_output.device_data(), message_size,
kNCCLDataType, ncclSum, m_nccl_comm, 0));
residual_add(ref_output.device_data<DType>(), m_residual_in.device_data<DType>(), message_size, 0);
invokeGeneralRmsNorm<DType, int8_t>(ref_output.device_data<DType>(), ref_output.device_data<DType>(),
m_rms_gamma.device_data<DType>(), nullptr, m_params.rms_eps, token_num, hidden_dim,
tensorrt_llm::common::QuantMode(), 0);
compare<DType>(m_rank, m_norm_out.host_data(), ref_output.host_data(), message_size, 1e-3, "norm out");
invokeFP4Quantization(token_num, hidden_dim, m_norm_out.device_data<DType>(),
m_scale_factor.device_data<float>(), ref_output.device_data<int64_t>(), ref_scale.device_data<int32_t>(),
false, 128, 0);
compare<int8_t>(m_rank, m_quant_out.host_data(), ref_output.host_data(), message_size / 2, 1e-3, "quant out");
compare<int8_t>(m_rank, m_scale_out.host_data(), ref_scale.host_data(), message_size / 16, 1e-3, "scale out");
}
void run_nccl_allreduce(int token_num, int hidden_dim)
{
TLLM_NCCL_CHECK(ncclAllReduce(m_allreduce_in.device_data(), m_residual_out.device_data(),
token_num * hidden_dim, kNCCLDataType, ncclSum, m_nccl_comm, m_stream->get()));
}
void run_residual_add(int token_num, int hidden_dim)
{
residual_add(m_residual_out.device_data<DType>(), m_residual_in.device_data<DType>(), token_num * hidden_dim,
m_stream->get());
}
void run_rms_norm(int token_num, int hidden_dim)
{
invokeGeneralRmsNorm<DType, int8_t>(m_residual_out.device_data<DType>(), m_norm_out.device_data<DType>(),
m_rms_gamma.device_data<DType>(), nullptr, m_params.rms_eps, token_num, hidden_dim,
tensorrt_llm::common::QuantMode(), m_stream->get());
}
void run_fp4_quant(int token_num, int hidden_dim)
{
invokeFP4Quantization(token_num, hidden_dim, m_norm_out.device_data<DType>(),
m_scale_factor.device_data<float>(), m_quant_out.device_data<int64_t>(), m_scale_out.device_data<int32_t>(),
false, 128, m_stream->get());
}
void run_kernel(int token_num, int hidden_dim)
{
ar_fusion::allreduce_fusion_op(m_params);
}
~TestRunner()
{
TLLM_NCCL_CHECK(ncclCommDestroy(m_nccl_comm));
}
private:
int m_rank;
int m_world_size;
int m_message_size;
mpi::MpiComm const& m_mpi_comm;
ncclComm_t m_nccl_comm;
CudaBuffer m_allreduce_in;
CudaBuffer m_residual_in;
CudaBuffer m_residual_out;
CudaBuffer m_norm_out;
CudaBuffer m_quant_out;
CudaBuffer m_scale_out;
CudaBuffer m_rms_gamma;
CudaBuffer m_scale_factor;
std::shared_ptr<ar_fusion::Workspace> m_workspace;
ar_fusion::AllReduceFusionParams m_params;
std::shared_ptr<tr::CudaStream> m_stream;
};
TEST(Kernel, allReduceFusion)
{
auto& comm = mpi::MpiComm::world();
auto world_size = comm.getSize();
auto rank = comm.getRank();
if (world_size % 2)
{
TLLM_LOG_WARNING("world size is not a multiple of 2, return");
return;
}
int warmup = 100, iter = 100;
int hidden_dim = 7168;
std::vector<int> candidate_token_num{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048};
int max_token_num = 2048;
TestRunner<half> runner(max_token_num, hidden_dim);
for (auto token_num : candidate_token_num)
{
auto latency = runner.benchmark(&TestRunner<half>::run_kernel, warmup, iter, token_num, hidden_dim);
runner.verify(token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("token_num %d, hidden_dim %d, latency %fus", token_num, hidden_dim, latency);
}
auto nccl_latency
= runner.benchmark(&TestRunner<half>::run_nccl_allreduce, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("nccl allreduce latency %fus", token_num, hidden_dim, nccl_latency);
}
auto residual_latency
= runner.benchmark(&TestRunner<half>::run_residual_add, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("residual add latency %fus", token_num, hidden_dim, residual_latency);
}
auto rms_latency = runner.benchmark(&TestRunner<half>::run_rms_norm, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("rms norm latency %fus", token_num, hidden_dim, rms_latency);
}
auto quant_latency = runner.benchmark(&TestRunner<half>::run_fp4_quant, warmup, iter, token_num, hidden_dim);
if (rank == 0)
{
TLLM_LOG_INFO("fp4 quant latency %fus", token_num, hidden_dim, quant_latency);
auto tot_latency = nccl_latency + residual_latency + rms_latency + quant_latency;
TLLM_LOG_INFO("fusion kernel latency %fus, nccl + ops latency %fus, total speedup %fx", latency,
tot_latency, tot_latency / latency);
}
}
}