/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include "common.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/mpiUtils.h" #include "tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_runner.h" #include "tensorrt_llm/runtime/ipcNvlsMemory.h" #include #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/numeric_types.h" #include "cutlass/tensor_ref.h" #include "cutlass/util/command_line.h" #include "cutlass/util/distribution.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/host/tensor_fill.h" using namespace cutlass; using namespace nvinfer1; using namespace tensorrt_llm::mpi; using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::kernels::cutlass_kernels; /////////////////////////// // CLI Args /////////////////////////// struct Options { bool help = false; int seed = 0; // problem shape int M = 13; int N = 4096; int K = 8192; int K_tp; // K per GPU int rank = 0; int tp = 1; std::set tp_group; float alpha = 1.0f; float beta = 0.0f; int iterations = 100; bool valid() const { return M > 0 && N >= 16 && K_tp >= 16 && tp > 0 && iterations > 0; } // Parses the command line void parse(int argc, char** args) { cutlass::CommandLine cmd(argc, const_cast(args)); if (cmd.check_cmd_line_flag("help")) { help = true; } cmd.get_cmd_line_argument("m", M); cmd.get_cmd_line_argument("n", N); cmd.get_cmd_line_argument("k", K); cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); cmd.get_cmd_line_argument("iterations", iterations); rank = COMM_SESSION.getRank(); tp = COMM_SESSION.getSize(); for (int rank = 0; rank < tp; ++rank) { tp_group.insert(rank); } assert(K % tp == 0); K_tp = K / tp; // Should be same across all ranks srand(time(NULL)); seed = static_cast(rand()); COMM_SESSION.bcastValue(seed, 0); #if 1 printf("rank: %d, m: %d, n: %d, k: %d, tp: %d, seed: %d\n", rank, M, N, K, tp, seed); #endif } /// Prints the usage statement. std::ostream& print_usage(std::ostream& out) const { out << "\n" "Options:\n" " --help If specified, displays this usage statement.\n" " --m= GEMM M dimension (LLM batch size)\n" " --n= GEMM N dimension (needs to be >= 16)\n" " --k= GEMM K dimension (needs to be >= 16 * nranks)\n" " --alpha= GEMM alpha parameter\n" " --beta= GEMM beta parameter\n" " --iterations= Number of profiling iterations to perform.\n" "\n" "Examples:\n" "\n" "$ mpirun -np 8 ./test/gemmAllReduceTest --m=8192 --n=8192 --k=8192 --iterations=1000\n"; return out; } /// Compute performance in GFLOP/s double gflops(double runtime_s) const { // Two flops per multiply-add uint64_t flop = uint64_t(2) * M * N * K_tp; double gflop = double(flop) / double(1.0e9); return gflop / runtime_s; } double effective_bandwidth(double runtime_s, size_t bytes_a, size_t bytes_b, size_t bytes_c, size_t bytes_d) const { static double const kBytesPerGiB = double(1ull << 30); double bytes_in = (double) (M) * (double) (K_tp) * (double) (bytes_a) + // A (double) (N) * (double) (K_tp) * (double) (bytes_b) + // B (beta != 0.f ? (double) (M) * (double) (N) * (double) (bytes_c) : 0.f); // C double bytes_out = (double) (M) * (double) (N) * (double) (bytes_d); // D double gb_total = (bytes_in + bytes_out) / kBytesPerGiB; return gb_total / runtime_s; } double effective_allreduce_bandwidth(double runtime_s, size_t bytes_d) { static double const kBytesPerGiB = double(1ull << 30); double bytes = (double) (M) * (double) (N) * (double) (bytes_d); double gb_total = bytes / kBytesPerGiB; return gb_total / runtime_s; } }; struct Result { double avg_runtime_us; double gflops; double eff_bw; double eff_AR_bw; bool passed; Result(double avg_runtime_us = 0, double gflops = 0, double eff_bw = 0, double eff_AR_bw = 0) : avg_runtime_us(avg_runtime_us) , gflops(gflops) , eff_bw(eff_bw) , eff_AR_bw(eff_AR_bw) , passed(false) { } }; /////////////////////////// // NCCL types /////////////////////////// template struct ToType { }; template <> struct ToType { ncclDataType_t nccl_value = ncclBfloat16; char const* str_value = "bf16"; }; template <> struct ToType { ncclDataType_t nccl_value = ncclFloat16; char const* str_value = "fp16"; }; template <> struct ToType { char const* str_value = "fp8_e4m3"; }; class NcclCommunicator { public: static NcclCommunicator& instance() { static NcclCommunicator communicator; return communicator; } ncclComm_t comm; private: NcclCommunicator() { int rank = COMM_SESSION.getRank(); int world_size = COMM_SESSION.getSize(); ncclUniqueId id; if (rank == 0) ncclGetUniqueId(&id); COMM_SESSION.bcastValue(id, 0); TLLM_NCCL_CHECK(ncclCommInitRank(&comm, world_size, id, rank)); } }; ///////////////////////////////////// // Gemm+AR Functional test fixture ///////////////////////////////////// static Options options; template struct TestConfig { using ElementA = _ElementA; using ElementB = _ElementB; using ElementC = _ElementC; using ElementD = _ElementD; }; template class GemmAllReduceFixture : public ::testing::Test { protected: using ElementA = typename T::ElementA; using ElementB = typename T::ElementB; using ElementC = typename T::ElementC; using ElementD = typename T::ElementD; static bool isMultiGpu() { return COMM_SESSION.getSize() > 1; } void SetUp() override { using GemmTraits = GemmTypes; _gemm = std::make_shared>(); auto const M = options.M; auto const N = options.N; auto const K_tp = options.K_tp; _A.reset(M * K_tp); _B.reset(N * K_tp); _C.reset(M * N); if (isMultiGpu()) { _D_nvls.reset(M * N, options.tp_group); // Create workspace for max problem size GemmAllReduceImplInterface::LaunchConfig launch_config = {GemmAllReduceImpl::NVLS_2SHOT, MainloopScheduleType::PINGPONG, TileShape::TileShape_128x16x128, ClusterShape::ClusterShape_1x1x1}; GemmAllReduceImplInterface::ProblemArgs max_problem; max_problem.argProblemShape(M, N, K_tp, 1) .argRanks(options.rank, options.tp_group) .argLaunchConfig(launch_config); _workspace = _gemm->getPersistentWorkspace(max_problem); _workspace->allocate(); } else { _D.reset(M * N); } _D_ref.reset(M * N); initialize_block(_A, options.seed + options.rank + 2024); initialize_block(_B, options.seed + options.rank); initialize_block(_C, options.seed); } void TearDown() override { if (isMultiGpu()) { _workspace->free(); _D_nvls.free(); } } void run(cudaStream_t stream = NULL) { // Test GemmAllReduceImplInterface::ProblemArgs args; args.argProblemShape(options.M, options.N, options.K_tp, 1) .argA(_A.get()) .argB(_B.get()) .argC(_C.get()) .argAlpha(options.alpha) .argBeta(options.beta) .argRanks(options.rank, options.tp_group); if (isMultiGpu()) { args.argD(_D_nvls.getUnicastPointer(), _D_nvls.getMulticastPointer()).argWorkspace(_workspace.get()); } else { args.argD(_D.get()); } Result result; result.passed = true; // Ensure all configs pass auto launch_configs = _gemm->getSupportedLaunchConfigs(); for (auto launch_config : launch_configs) { args.argLaunchConfig(launch_config); _gemm->run(args, stream); TLLM_CUDA_CHECK(cudaStreamSynchronize(stream)); bool passed = verify(stream); if (!passed) { std::cout << "config failed: " << launch_config.str() << std::endl; } result.passed &= passed; } // Benchmark tensorrt_llm::testing::GpuTimer timer; int const warmup = 20; float best_elapsed_us = std::numeric_limits::max(); GemmAllReduceImplInterface::LaunchConfig best_launch_config; for (auto launch_config : launch_configs) { args.argLaunchConfig(launch_config); for (int i = 0; i < options.iterations + warmup; ++i) { if (i == warmup) { // Synchronize ranks TLLM_CUDA_CHECK(cudaStreamSynchronize(stream)); COMM_SESSION.barrier(); timer.start(stream); } _gemm->run(args, stream); } timer.stop(); float elapsed_us = timer.elapsed_millis() * 1000.f; if (options.rank == 0) { double avg_runtime_us = double(elapsed_us) / double(options.iterations); std::cout << launch_config.str() << std::endl; std::cout << " Avg runtime: " << avg_runtime_us << " us" << std::endl; } if (elapsed_us < best_elapsed_us) { best_elapsed_us = elapsed_us; best_launch_config = launch_config; } } result.avg_runtime_us = double(best_elapsed_us) / double(options.iterations); double avg_runtime_s = (double) (result.avg_runtime_us / 1000000.0); result.gflops = options.gflops(avg_runtime_s); result.eff_bw = options.effective_bandwidth( avg_runtime_s, sizeof(ElementA), sizeof(ElementB), sizeof(ElementC), sizeof(ElementD)); result.eff_AR_bw = options.effective_allreduce_bandwidth(avg_runtime_s, sizeof(ElementD)); if (options.rank == 0) { std::cout << std::endl; std::cout << " Precision: " << ToType{}.str_value << "x" << ToType{}.str_value << "=" << ToType{}.str_value << std::endl; std::cout << " Problem Size: " << options.M << 'x' << options.N << 'x' << options.K << std::endl; std::cout << " Local Problem Size: " << options.M << 'x' << options.N << 'x' << options.K_tp << std::endl; std::cout << " " << best_launch_config.str() << std::endl; std::cout << " Verify: " << (result.passed ? "Pass" : "Fail") << std::endl; std::cout << " Avg runtime: " << result.avg_runtime_us << " us" << std::endl; std::cout << " GFLOPS: " << result.gflops << std::endl; std::cout << " Effective GEMM bandwidth: " << result.eff_bw << " GB/s" << std::endl; std::cout << " Effective AR bandwidth: " << result.eff_AR_bw << " GB/s" << std::endl; } ASSERT_TRUE(result.passed); } private: template void print_tensor(std::string name, ElementT* data, int const H, int const W) { std::vector host(H * W); cutlass::device_memory::copy_to_host(host.data(), data, H * W); auto host_tensor = cute::make_tensor(host.data(), cute::make_shape(H, W), cute::make_stride(W, _1{})); cute::print_tensor(host_tensor); } template auto find_relative_differences(ElementT const* d_ptr_A, ElementT const* d_ptr_B, size_t capacity, ElementT epsilon, ElementT nonzero_floor, size_t max_count = 5) { std::vector h_ptr_A(capacity); std::vector h_ptr_B(capacity); cutlass::device_memory::copy_to_host(h_ptr_A.data(), d_ptr_A, capacity); cutlass::device_memory::copy_to_host(h_ptr_B.data(), d_ptr_B, capacity); std::vector> differences; for (size_t i = 0; i < capacity; ++i) { auto a = h_ptr_A[i]; auto b = h_ptr_B[i]; if (!cutlass::relatively_equal(a, b, epsilon, nonzero_floor)) { differences.push_back(std::make_tuple(a, b, i)); if (differences.size() >= max_count) { break; } } } return differences; } bool verify(cudaStream_t stream) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; using ElementScalar = float; using ElementAccumulator = float; auto const M = options.M; auto const N = options.N; auto const K = options.K_tp; cutlass::TensorRef ref_A(_A.get(), LayoutA::packed({M, K})); cutlass::TensorRef ref_B(_B.get(), LayoutB::packed({K, N})); cutlass::TensorRef ref_C(_C.get(), LayoutC::packed({M, N})); cutlass::TensorRef ref_D(_D_ref.get(), LayoutD::packed({M, N})); // Reference device GEMM implementation type using DeviceGemmReference = cutlass::reference::device::Gemm; // Create instantiation for device reference gemm kernel DeviceGemmReference gemm_reference; // Launch device reference gemm kernel gemm_reference( {M, N, K}, ElementAccumulator(options.alpha), ref_A, ref_B, ElementAccumulator(options.beta), ref_C, ref_D); TLLM_CUDA_CHECK(cudaDeviceSynchronize()); // AllReduce across ranks ncclComm_t comm = NcclCommunicator::instance().comm; auto dtype = ToType{}.nccl_value; TLLM_NCCL_CHECK(ncclAllReduce(_D_ref.get(), _D_ref.get(), _D_ref.size(), dtype, ncclSum, comm, stream)); TLLM_CUDA_CHECK(cudaStreamSynchronize(stream)); // Compare results const ElementC epsilon(1e-2f); const ElementC nonzero_floor(1e-4f); auto D_ptr = isMultiGpu() ? _D_nvls.getUnicastPointer() : _D.get(); int local_failed = 0; if (!cutlass::reference::device::BlockCompareRelativelyEqual( _D_ref.get(), D_ptr, _D_ref.size(), epsilon, nonzero_floor)) { if (options.rank == 0) { #if 1 auto differences = find_relative_differences(_D_ref.get(), D_ptr, _D_ref.size(), epsilon, nonzero_floor); std::cerr << "Differences:" << std::endl; for (auto [exp, act, pos] : differences) { std::cerr << "expected: " << std::setprecision(3) << std::setw(5) << exp << ", actual: " << std::setprecision(3) << std::setw(5) << act << ", at pos: " << pos << std::endl; } #endif #if 0 print_tensor("Actual", D_ptr, M, N); print_tensor("Ref ", _D_ref.get(), M, N); #endif } local_failed = 1; } // Aggregate results - if 1 rank fails, then all ranks fail. int global_failed; COMM_SESSION.allreduce(&local_failed, &global_failed, 1, MpiType::kINT32, MpiOp::SUM); return global_failed == 0; } template static void initialize_block(cutlass::DeviceAllocation& block, int seed) { double scope_max, scope_min; int bits_input = cutlass::sizeof_bits::value; int bits_output = cutlass::sizeof_bits::value; if (bits_input == 1) { scope_max = 2; scope_min = 0; } else if (bits_input <= 8) { scope_max = 2; scope_min = -2; } else if (bits_output == 16) { scope_max = 5; scope_min = -5; } else { scope_max = 8; scope_min = -8; } using Real = typename cutlass::RealType::Type; cutlass::reference::device::BlockFillRandomUniform( block.get(), block.size(), seed, static_cast(scope_max), static_cast(scope_min), 0); } cutlass::DeviceAllocation _A; cutlass::DeviceAllocation _B; cutlass::DeviceAllocation _C; cutlass::DeviceAllocation _D; cutlass::DeviceAllocation _D_ref; DeviceAllocationNvls _D_nvls; std::shared_ptr _workspace; std::shared_ptr _gemm; }; using MyTypes = testing::Types, TestConfig>; TYPED_TEST_SUITE(GemmAllReduceFixture, MyTypes); ///////////////////////////////////////////////////////////////////// // ATTENTION: run test with mpi `mpi -np ./gemmAllReduceTest' ///////////////////////////////////////////////////////////////////// TYPED_TEST(GemmAllReduceFixture, RunnerTest) { cudaStream_t stream; cudaStreamCreate(&stream); this->run(stream); cudaStreamDestroy(stream); } int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); bool notSupported = false; // CUDA 12 minimum required if (__CUDACC_VER_MAJOR__ < 12) { std::cerr << "This example requires CUDA Toolkit version 12 or later.\n"; notSupported = true; } TLLM_CUDA_CHECK(cudaSetDevice(COMM_SESSION.getRank())); cudaDeviceProp props; TLLM_CUDA_CHECK(cudaGetDeviceProperties(&props, COMM_SESSION.getRank())); if (props.major < 9) { std::cerr << "This example requires a device with compute capability 90 or higher.\n"; notSupported = true; } if (!ipcNvlsSupported()) { std::cerr << "NVLS not supported on this system.\n"; notSupported = true; } if (notSupported) { return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems } options.parse(argc, argv); if (options.help) { options.print_usage(std::cout) << "\n"; return EXIT_SUCCESS; } if (!options.valid()) { std::cerr << "Invalid arguments." << "\n"; return EXIT_FAILURE; } return RUN_ALL_TESTS(); }