/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include "tensorrt_llm/kernels/allReduceFusionKernels.h" #include "tensorrt_llm/kernels/quantization.h" #include "tensorrt_llm/kernels/rmsnormKernels.h" #include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "tensorrt_llm/runtime/utils/multiDeviceUtils.h" namespace mpi = tensorrt_llm::mpi; namespace tr = tensorrt_llm::runtime; using namespace tensorrt_llm::kernels; template __global__ void residual_add_kernel(DType* data, DType* residual, int size) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= size) return; data[idx] = data[idx] + residual[idx]; } template void residual_add(DType* data, DType* residual, int size, cudaStream_t stream) { residual_add_kernel<<>>(data, residual, size); } template __global__ void cast_to_fp32_kernel(DType* in, float* out, int size) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= size) return; out[idx] = static_cast(in[idx]); } template void cast_to_fp32(DType* in, float* out, int size, cudaStream_t stream) { cast_to_fp32_kernel<<>>(in, out, size); } template void print(int rank, void* _pa, int size) { auto pa = reinterpret_cast(_pa); if (rank == 0) { printf("print: ["); for (int n = 0; n < 20; ++n) { float v = static_cast(pa[n]); printf("%f, ", v); } printf("...]\n"); } } template float compare(int rank, void* _pa, void* _pb, int size, float scale, std::string const& cmp_info = "") { auto pa = reinterpret_cast(_pa); auto pb = reinterpret_cast(_pb); float max_diff = 0.f, tot_diff = 0.f; float max_val = 0.f; int diff_cnt = 0; float threshold = 1e-7; static char* ar_debug = std::getenv("AR_DEBUG"); if (ar_debug && rank == 0) { printf("TensorA: ["); for (int n = 0; n < 20; ++n) { float v = static_cast(pa[n]); printf("%f, ", v); } printf("...]\n"); printf("TensorB: ["); for (int n = 0; n < 20; ++n) { float v = static_cast(pb[n]); printf("%f, ", v); } printf("...]\n"); } int print_cnt = 0; for (int n = 0; n < size; ++n) { float va = static_cast(pa[n]); float vb = static_cast(pb[n]); max_val = std::max(max_val, vb); float diff = std::abs(va - vb); if (diff > threshold) { max_diff = std::max(max_diff, diff); tot_diff += diff; ++diff_cnt; } if (rank == 0 && print_cnt < 20 && ar_debug && diff / (std::abs(vb) + 1e-7) > 0.1) { ++print_cnt; printf("idx %d, va %f, vb %f\n", n, va, vb); } } float diff_thres = max_val * scale; if (rank == 0) { TLLM_LOG_INFO("[%s] rank %d, max diff %f (diff threshold %f), avg diff %f, diff cnt %d/%d", cmp_info.c_str(), rank, max_diff, diff_thres, tot_diff / std::max(diff_cnt, 1), diff_cnt, size); } return max_diff <= diff_thres; } template void random_fill(T1* data, int size, T2 minv, T2 maxv) { static int rseed = 20250227; std::mt19937 gen(rseed++); std::uniform_real_distribution dis(static_cast(minv), static_cast(maxv)); for (int i = 0; i < size; ++i) { data[i] = static_cast(dis(gen)); } } struct CudaBuffer { void* m_d_data; void* m_h_data; int m_size; CudaBuffer(int size_in_bytes = 0) : m_size(size_in_bytes) , m_d_data(nullptr) , m_h_data(nullptr) { allocate(size_in_bytes); } void allocate(int size_in_bytes) { if (size_in_bytes == 0) return; TLLM_CHECK(m_d_data == nullptr && m_h_data == nullptr); m_size = size_in_bytes; TLLM_CUDA_CHECK(cudaMalloc(&m_d_data, m_size)); TLLM_CUDA_CHECK(cudaMemset(m_d_data, 0, m_size)); m_h_data = malloc(m_size); } template T* device_data() { TLLM_CHECK(m_d_data != nullptr); return reinterpret_cast(m_d_data); } template T* host_data() { TLLM_CHECK(m_h_data != nullptr); d2h(); return reinterpret_cast(m_h_data); } template void random(VType minv, VType maxv) { random_fill(reinterpret_cast(m_h_data), m_size / sizeof(DType), minv, maxv); h2d(); } void h2d() { TLLM_CUDA_CHECK(cudaMemcpy(m_d_data, m_h_data, m_size, cudaMemcpyHostToDevice)); } void d2h() { TLLM_CUDA_CHECK(cudaMemcpy(m_h_data, m_d_data, m_size, cudaMemcpyDeviceToHost)); } ~CudaBuffer() { if (m_d_data) { TLLM_CUDA_CHECK(cudaFree(m_d_data)); } if (m_h_data) { free(m_h_data); } } }; template class TestRunner { static_assert(std::is_same_v || std::is_same_v); static constexpr ncclDataType_t kNCCLDataType = std::is_same_v ? ncclFloat16 : ncclBfloat16; static constexpr nvinfer1::DataType kTRTDataType = std::is_same_v ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kBF16; public: TestRunner(int max_token_num, int hidden_dim) : m_mpi_comm(mpi::MpiComm::world()) { m_message_size = max_token_num * hidden_dim; m_world_size = m_mpi_comm.getSize(); m_rank = m_mpi_comm.getRank(); TLLM_CUDA_CHECK(cudaSetDevice(m_rank)); ncclUniqueId id; if (m_rank == 0) { TLLM_NCCL_CHECK(ncclGetUniqueId(&id)); } m_mpi_comm.bcast(&id, sizeof(id), mpi::MpiType::kBYTE, 0); TLLM_NCCL_CHECK(ncclCommInitRank(&m_nccl_comm, m_world_size, id, m_rank)); m_allreduce_in.allocate(m_message_size * sizeof(DType)); m_residual_in.allocate(m_message_size * sizeof(DType)); m_residual_out.allocate(m_message_size * sizeof(DType)); m_norm_out.allocate(m_message_size * sizeof(DType)); m_quant_out.allocate(m_message_size * sizeof(DType)); m_scale_out.allocate(m_message_size * sizeof(DType)); m_rms_gamma.allocate(hidden_dim * sizeof(DType)); m_scale_factor.allocate(sizeof(float)); m_stream = std::make_shared(); m_workspace = std::make_shared(m_rank, m_world_size, max_token_num, hidden_dim, m_stream); m_params.nranks = m_world_size; m_params.rank = m_rank; m_params.dtype = kTRTDataType; m_params.workspace = m_workspace->get_workspace(); m_params.allreduce_in = m_allreduce_in.device_data(); m_params.residual_in = m_residual_in.device_data(); m_params.residual_out = m_residual_out.device_data(); m_params.norm_out = m_norm_out.device_data(); m_params.quant_out = m_quant_out.device_data(); m_params.scale_out = m_scale_out.device_data(); m_params.rms_gamma = m_rms_gamma.device_data(); m_params.scale_factor = m_scale_factor.device_data(); m_params.rms_eps = 1e-3; m_params.stream = m_stream->get(); } void random_input() { m_allreduce_in.random(-100.f, 100.f); m_residual_in.random(-100.f, 100.f); m_rms_gamma.random(-1.f, 1.f); m_scale_factor.random(5.f, 5.f); } template float benchmark(Func func, int warmup, int iter, int token_num, int hidden_dim) { m_params.size = token_num * hidden_dim; m_params.hidden_dim = hidden_dim; cudaEvent_t begin, end; cudaEventCreate(&begin); cudaEventCreate(&end); random_input(); m_mpi_comm.barrier(); for (int i = 0; i < warmup; ++i) { (this->*func)(token_num, hidden_dim); } cudaEventRecord(begin, m_stream->get()); for (int i = 0; i < iter; ++i) { (this->*func)(token_num, hidden_dim); } cudaEventRecord(end, m_stream->get()); cudaEventSynchronize(end); float time; cudaEventElapsedTime(&time, begin, end); time /= iter; m_mpi_comm.barrier(); cudaEventDestroy(begin); cudaEventDestroy(end); return time * 1000; } int get_sm_count() { static int sm_count = 0; if (sm_count == 0) { int device_id; TLLM_CUDA_CHECK(cudaGetDevice(&device_id)); cudaDeviceProp device_prop; cudaGetDeviceProperties(&device_prop, device_id); sm_count = device_prop.multiProcessorCount; } return sm_count; } void verify(int token_num, int hidden_dim) { int message_size = token_num * hidden_dim; CudaBuffer ref_output(message_size * sizeof(DType)), ref_scale(message_size * sizeof(DType)); TLLM_NCCL_CHECK(ncclAllReduce(m_allreduce_in.device_data(), ref_output.device_data(), message_size, kNCCLDataType, ncclSum, m_nccl_comm, 0)); residual_add(ref_output.device_data(), m_residual_in.device_data(), message_size, 0); invokeGeneralRmsNorm(ref_output.device_data(), ref_output.device_data(), m_rms_gamma.device_data(), nullptr, m_params.rms_eps, token_num, hidden_dim, tensorrt_llm::common::QuantMode(), 0); compare(m_rank, m_norm_out.host_data(), ref_output.host_data(), message_size, 1e-3, "norm out"); invokeFP4Quantization(token_num, hidden_dim, m_norm_out.device_data(), m_scale_factor.device_data(), ref_output.device_data(), ref_scale.device_data(), false, 128, 0); compare(m_rank, m_quant_out.host_data(), ref_output.host_data(), message_size / 2, 1e-3, "quant out"); compare(m_rank, m_scale_out.host_data(), ref_scale.host_data(), message_size / 16, 1e-3, "scale out"); } void run_nccl_allreduce(int token_num, int hidden_dim) { TLLM_NCCL_CHECK(ncclAllReduce(m_allreduce_in.device_data(), m_residual_out.device_data(), token_num * hidden_dim, kNCCLDataType, ncclSum, m_nccl_comm, m_stream->get())); } void run_residual_add(int token_num, int hidden_dim) { residual_add(m_residual_out.device_data(), m_residual_in.device_data(), token_num * hidden_dim, m_stream->get()); } void run_rms_norm(int token_num, int hidden_dim) { invokeGeneralRmsNorm(m_residual_out.device_data(), m_norm_out.device_data(), m_rms_gamma.device_data(), nullptr, m_params.rms_eps, token_num, hidden_dim, tensorrt_llm::common::QuantMode(), m_stream->get()); } void run_fp4_quant(int token_num, int hidden_dim) { invokeFP4Quantization(token_num, hidden_dim, m_norm_out.device_data(), m_scale_factor.device_data(), m_quant_out.device_data(), m_scale_out.device_data(), false, 128, m_stream->get()); } void run_kernel(int token_num, int hidden_dim) { ar_fusion::allreduce_fusion_op(m_params); } ~TestRunner() { TLLM_NCCL_CHECK(ncclCommDestroy(m_nccl_comm)); } private: int m_rank; int m_world_size; int m_message_size; mpi::MpiComm const& m_mpi_comm; ncclComm_t m_nccl_comm; CudaBuffer m_allreduce_in; CudaBuffer m_residual_in; CudaBuffer m_residual_out; CudaBuffer m_norm_out; CudaBuffer m_quant_out; CudaBuffer m_scale_out; CudaBuffer m_rms_gamma; CudaBuffer m_scale_factor; std::shared_ptr m_workspace; ar_fusion::AllReduceFusionParams m_params; std::shared_ptr m_stream; }; TEST(Kernel, allReduceFusion) { auto& comm = mpi::MpiComm::world(); auto world_size = comm.getSize(); auto rank = comm.getRank(); if (world_size % 2) { TLLM_LOG_WARNING("world size is not a multiple of 2, return"); return; } int warmup = 100, iter = 100; int hidden_dim = 7168; std::vector candidate_token_num{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}; int max_token_num = 2048; TestRunner runner(max_token_num, hidden_dim); for (auto token_num : candidate_token_num) { auto latency = runner.benchmark(&TestRunner::run_kernel, warmup, iter, token_num, hidden_dim); runner.verify(token_num, hidden_dim); if (rank == 0) { TLLM_LOG_INFO("token_num %d, hidden_dim %d, latency %fus", token_num, hidden_dim, latency); } auto nccl_latency = runner.benchmark(&TestRunner::run_nccl_allreduce, warmup, iter, token_num, hidden_dim); if (rank == 0) { TLLM_LOG_INFO("nccl allreduce latency %fus", token_num, hidden_dim, nccl_latency); } auto residual_latency = runner.benchmark(&TestRunner::run_residual_add, warmup, iter, token_num, hidden_dim); if (rank == 0) { TLLM_LOG_INFO("residual add latency %fus", token_num, hidden_dim, residual_latency); } auto rms_latency = runner.benchmark(&TestRunner::run_rms_norm, warmup, iter, token_num, hidden_dim); if (rank == 0) { TLLM_LOG_INFO("rms norm latency %fus", token_num, hidden_dim, rms_latency); } auto quant_latency = runner.benchmark(&TestRunner::run_fp4_quant, warmup, iter, token_num, hidden_dim); if (rank == 0) { TLLM_LOG_INFO("fp4 quant latency %fus", token_num, hidden_dim, quant_latency); auto tot_latency = nccl_latency + residual_latency + rms_latency + quant_latency; TLLM_LOG_INFO("fusion kernel latency %fus, nccl + ops latency %fus, total speedup %fx", latency, tot_latency, tot_latency / latency); } } }