TensorRT-LLMs/cpp/tests/unit_tests/kernels/allReduce/gemmAllReduceTest.cu
Yuan Tong 4b6c19737b
feat: support add internal cutlass kernels as subproject (#3658)
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
2025-05-06 11:35:07 +08:00

908 lines
32 KiB
Plaintext

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <memory>
#include <nccl.h>
#include <vector>
#include "allreduce_gemm_runner.h"
#include "common.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <NvInferRuntime.h>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include "cutlass/util/reference/host/gett.hpp"
#include "cutlass/util/reference/host/tensor_fill.h"
using namespace cutlass;
using namespace nvinfer1;
using namespace tensorrt_llm::mpi;
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels::cutlass_kernels;
///////////////////////////
// CLI Args
///////////////////////////
struct Options
{
bool help = false;
bool verify = true;
int seed = 0;
// problem shape
int M = 13;
int N = 4096;
int K = 8192;
int K_tp; // K per GPU
int rank = 0;
int tp = 1;
std::set<int> tp_group;
float alpha = 1.0f;
float beta = 0.0f;
int iterations = 100;
bool valid() const
{
return M > 0 && N >= 16 && K_tp >= 16 && tp > 0 && iterations > 0;
}
// Parses the command line
void parse(int argc, char** args)
{
cutlass::CommandLine cmd(argc, const_cast<char const**>(args));
if (cmd.check_cmd_line_flag("help"))
{
help = true;
}
if (cmd.check_cmd_line_flag("skip_check"))
{
verify = false;
}
cmd.get_cmd_line_argument("m", M);
cmd.get_cmd_line_argument("n", N);
cmd.get_cmd_line_argument("k", K);
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
rank = COMM_SESSION.getRank();
tp = COMM_SESSION.getSize();
for (int i = 0; i < tp; ++i)
{
tp_group.insert(i);
}
assert(K % tp == 0);
K_tp = K / tp;
// Should be same across all ranks
srand(time(NULL));
seed = static_cast<int>(rand());
COMM_SESSION.bcastValue(seed, 0);
#if 1
printf("rank: %d, m: %d, n: %d, k: %d, tp: %d, seed: %d\n", this->rank, M, N, K, tp, seed);
#endif
}
/// Prints the usage statement.
std::ostream& print_usage(std::ostream& out) const
{
out << "\n"
"Options:\n"
" --help If specified, displays this usage statement.\n"
" --m=<int> GEMM M dimension (LLM batch size)\n"
" --n=<int> GEMM N dimension (needs to be >= 16)\n"
" --k=<int> GEMM K dimension (needs to be >= 16 * nranks)\n"
" --alpha=<float> GEMM alpha parameter\n"
" --beta=<float> GEMM beta parameter\n"
" --iterations=<int> Number of profiling iterations to perform.\n"
" --skip_check Skips verification (verification is slow for large shapes)"
"\n"
"Examples:\n"
"\n"
"$ mpirun -np 8 ./test/gemmAllReduceTest --m=8192 --n=8192 --k=8192 --iterations=1000\n";
return out;
}
/// Compute performance in GFLOP/s
double tflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * M * N * K_tp;
double gflop = double(flop) / double(1.0e9);
double tflops = gflop / 1e3;
return tflops / runtime_s;
}
double effective_bandwidth(double runtime_s, size_t bytes_a, size_t bytes_b, size_t bytes_c, size_t bytes_d) const
{
static double const kBytesPerGiB = double(1ull << 30);
double bytes_in = (double) (M) * (double) (K_tp) * (double) (bytes_a) + // A
(double) (N) * (double) (K_tp) * (double) (bytes_b) + // B
(beta != 0.f ? (double) (M) * (double) (N) * (double) (bytes_c) : 0.f); // C
double bytes_out = (double) (M) * (double) (N) * (double) (bytes_d); // D
double gb_total = (bytes_in + bytes_out) / kBytesPerGiB;
return gb_total / runtime_s;
}
double effective_allreduce_bandwidth(double runtime_s, size_t bytes_d)
{
static double const kBytesPerGiB = double(1ull << 30);
double bytes = (double) (M) * (double) (N) * (double) (bytes_d);
double gb_total = bytes / kBytesPerGiB;
return gb_total / runtime_s;
}
};
struct Result
{
double avg_runtime_us;
double avg_runtime_AR_us;
double tflops;
double eff_bw;
double eff_AR_bw;
bool passed;
GemmAllReduceImplInterface::LaunchConfig best_config;
Result(double avg_runtime_us = 0, double avg_runtime_AR_us = 0, double tflops = 0, double eff_bw = 0,
double eff_AR_bw = 0)
: avg_runtime_us(avg_runtime_us)
, avg_runtime_AR_us(avg_runtime_AR_us)
, tflops(tflops)
, eff_bw(eff_bw)
, eff_AR_bw(eff_AR_bw)
, passed(false)
{
}
};
///////////////////////////
// NCCL types
///////////////////////////
template <typename CutlassType>
struct ToType
{
};
template <>
struct ToType<cutlass::bfloat16_t>
{
ncclDataType_t nccl_value = ncclBfloat16;
char const* str_value = "bf16";
};
template <>
struct ToType<cutlass::half_t>
{
ncclDataType_t nccl_value = ncclFloat16;
char const* str_value = "fp16";
};
template <>
struct ToType<cutlass::float_e4m3_t>
{
ncclDataType_t nccl_value = ncclFloat8e4m3;
char const* str_value = "fp8_e4m3";
};
template <>
struct ToType<cutlass::float_e2m1_t>
{
char const* str_value = "fp4_e2m1";
};
class NcclCommunicator
{
public:
static NcclCommunicator& instance()
{
static NcclCommunicator communicator;
return communicator;
}
ncclComm_t comm;
private:
NcclCommunicator()
{
int rank = COMM_SESSION.getRank();
int world_size = COMM_SESSION.getSize();
ncclUniqueId id;
if (rank == 0)
ncclGetUniqueId(&id);
COMM_SESSION.bcastValue(id, 0);
TLLM_NCCL_CHECK(ncclCommInitRank(&comm, world_size, id, rank));
}
};
// Required for subbyte reference GEMM (i.e FP4)
template <typename T>
auto make_iterator(T* ptr)
{
using namespace cute;
if constexpr (cute::is_subbyte_v<T>)
{
return subbyte_iterator<T>(ptr);
}
else
{
return ptr;
}
}
/////////////////////////////////////
// Gemm+AR Functional test fixture
/////////////////////////////////////
static Options options;
template <typename _ElementA, typename _ElementB, typename _ElementC, typename _ElementD, typename _ElementSFA = void,
typename _ElementSFB = void>
struct TestConfig
{
using ElementA = _ElementA;
using ElementB = _ElementB;
using ElementC = _ElementC;
using ElementD = _ElementD;
using ElementSFB = _ElementSFA;
using ElementSFA = _ElementSFB;
};
template <typename T>
class GemmAllReduceFixture : public ::testing::Test
{
protected:
using ElementA = typename T::ElementA;
using ElementB = typename T::ElementB;
using ElementC = typename T::ElementC;
using ElementD = typename T::ElementD;
using ElementSFA = typename T::ElementSFA;
using ElementSFB = typename T::ElementSFB;
static_assert(std::is_same_v<ElementA, ElementB> && "A & B types must be same");
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutD>;
// Only currently supported for FP4 GEMM
static constexpr bool IsInputScalingNeeded = std::is_same_v<ElementSFA, cutlass::float_ue4m3_t>;
static constexpr bool IsFP4 = std::is_same_v<ElementA, cutlass::float_e2m1_t>;
static bool isMultiGpu()
{
return COMM_SESSION.getSize() > 1;
}
static void SetUpTestSuite()
{
// Blackwell skip FP4 GEMMs
if (getSMVersion() >= 100 && !IsFP4)
{
GTEST_SKIP() << "Skipping non-FP4 GEMM";
}
// Hopper skip FP4 GEMMs
else if (getSMVersion() < 100 && IsFP4)
{
GTEST_SKIP() << "Skipping FP4 GEMM";
}
}
void SetUp() override
{
using GemmTraits = GemmTypes<ElementA, ElementB, ElementC, ElementD, ElementSFA, ElementSFB, LayoutA, LayoutB,
LayoutC, LayoutD>;
_gemm = std::make_shared<GemmAllReduceImplRunner<GemmTraits>>();
auto const M = options.M;
auto const N = options.N;
auto const K_tp = options.K_tp;
_A.reset(cutlass::make_Coord(M * K_tp));
_B.reset(cutlass::make_Coord(N * K_tp));
_C.reset(cutlass::make_Coord(M * N));
if (isMultiGpu())
{
_D_nvls.reset(M * N, options.tp_group);
// Create workspace for max problem size
GemmAllReduceImplInterface::LaunchConfig launch_config = _gemm->getSupportedLaunchConfigs()[0];
GemmAllReduceImplInterface::ProblemArgs max_problem;
max_problem.argProblemShape(M, N, K_tp, 1)
.argRanks(options.rank, options.tp_group)
.argLaunchConfig(launch_config);
_workspace = _gemm->getPersistentWorkspace(max_problem);
_workspace->allocate();
}
else
{
_D.reset(cutlass::make_Coord(M * N));
}
_D_ref.reset(cutlass::make_Coord(M * N));
_alpha_vec.resize(cutlass::make_Coord(1));
if constexpr (IsInputScalingNeeded)
{
auto [layout_SFA, layout_SFB] = getLayoutSF_AB(M, N, K_tp);
auto size_SFA = size(filter_zeros(layout_SFA));
auto size_SFB = size(filter_zeros(layout_SFB));
_SFA.reset(cutlass::make_Coord(size_SFA));
_SFB.reset(cutlass::make_Coord(size_SFB));
}
initializeTensor(_A.host_view(), options.seed + options.rank + 2024);
initializeTensor(_B.host_view(), options.seed + options.rank);
initializeTensor(_C.host_view(), options.seed);
_A.sync_device();
_B.sync_device();
_C.sync_device();
if constexpr (IsInputScalingNeeded)
{
initializeTensor(_SFA.host_view(), options.seed + options.rank + 2023);
initializeTensor(_SFB.host_view(), options.seed + options.rank + 2022);
_SFA.sync_device();
_SFB.sync_device();
}
_alpha_vec.host_data()[0] = options.alpha;
_alpha_vec.sync_device();
}
void TearDown() override
{
if (isMultiGpu())
{
_workspace->free();
_D_nvls.free();
}
}
/*
* Benchmarks each config.
* Benchmarks no fusion for comparison.
*/
void bench(cudaStream_t stream)
{
GemmAllReduceImplInterface::ProblemArgs args = get_arguments();
int const warmup = 20;
auto sweep_configs = [&]()
{
Result result;
tensorrt_llm::testing::GpuTimer timer;
float best_elapsed_us = std::numeric_limits<float>::max();
GemmAllReduceImplInterface::LaunchConfig best_launch_config;
auto launch_configs = _gemm->getSupportedLaunchConfigs();
for (auto launch_config : launch_configs)
{
args.argLaunchConfig(launch_config);
for (int i = 0; i < options.iterations + warmup; ++i)
{
if (i == warmup)
{
// Synchronize ranks
TLLM_CUDA_CHECK(cudaStreamSynchronize(stream));
COMM_SESSION.barrier();
timer.start(stream);
}
_gemm->run(args, stream);
}
timer.stop();
float elapsed_us = timer.elapsed_millis() * 1000.f;
if (options.rank == 0)
{
double avg_runtime_us = double(elapsed_us) / double(options.iterations);
std::cout << launch_config.str() << std::endl;
std::cout << " Avg runtime: " << avg_runtime_us << " us" << std::endl;
}
if (elapsed_us < best_elapsed_us)
{
best_elapsed_us = elapsed_us;
best_launch_config = launch_config;
}
}
result.best_config = best_launch_config;
result.avg_runtime_us = double(best_elapsed_us) / double(options.iterations);
double avg_runtime_s = (double) (result.avg_runtime_us / 1000000.0);
result.tflops = options.tflops(avg_runtime_s);
result.eff_bw = options.effective_bandwidth(
avg_runtime_s, sizeof(ElementA), sizeof(ElementB), sizeof(ElementC), sizeof(ElementD));
result.eff_AR_bw = options.effective_allreduce_bandwidth(avg_runtime_s, sizeof(ElementD));
return result;
};
// Benchmark each config.
auto result = sweep_configs();
// set to single device
args.argRanks(0, {0});
// Benchmark GEMM with no fusion.
auto result_no_fusion = sweep_configs();
result_no_fusion.eff_AR_bw = 0;
result_no_fusion.avg_runtime_AR_us = 0;
// Benchmark AR with no fusion
if (isMultiGpu())
{
tensorrt_llm::testing::GpuTimer timer;
for (int i = 0; i < options.iterations + warmup; ++i)
{
if (i == warmup)
{
// Synchronize ranks
TLLM_CUDA_CHECK(cudaStreamSynchronize(stream));
COMM_SESSION.barrier();
timer.start(stream);
}
ncclComm_t comm = NcclCommunicator::instance().comm;
auto dtype = ToType<ElementD>{}.nccl_value;
TLLM_NCCL_CHECK(ncclAllReduce(
_D_ref.device_data(), _D_ref.device_data(), _D_ref.size(), dtype, ncclSum, comm, stream));
}
timer.stop();
float elapsed_us = timer.elapsed_millis() * 1000.f;
result_no_fusion.avg_runtime_AR_us = double(elapsed_us) / double(options.iterations);
double avg_runtime_AR_s = (double) (result_no_fusion.avg_runtime_AR_us / 1000000.0);
result_no_fusion.eff_AR_bw = options.effective_allreduce_bandwidth(avg_runtime_AR_s, sizeof(ElementD));
}
if (options.rank == 0)
{
std::cout << std::endl;
std::cout << " Precision: " << ToType<ElementA>{}.str_value << "x" << ToType<ElementB>{}.str_value << "="
<< ToType<ElementD>{}.str_value << std::endl;
std::cout << " Problem Size: " << options.M << 'x' << options.N << 'x' << options.K << std::endl;
std::cout << " Local Problem Size: " << options.M << 'x' << options.N << 'x' << options.K_tp << std::endl;
std::cout << "\n GEMM->AR\n" << std::endl;
std::cout << " " << result.best_config.str() << std::endl;
std::cout << " GEMM runtime: " << result_no_fusion.avg_runtime_us << " us" << std::endl;
std::cout << " GEMM TFLOPS: " << result_no_fusion.tflops << std::endl;
std::cout << " GEMM effective bandwidth: " << result_no_fusion.eff_bw << " GB/s" << std::endl;
std::cout << " AR runtime: " << result_no_fusion.avg_runtime_AR_us << " us" << std::endl;
std::cout << " AR algo bandwidth: " << result_no_fusion.eff_AR_bw << " GB/s" << std::endl;
std::cout << "\n GEMM+AR fusion\n" << std::endl;
std::cout << " " << result.best_config.str() << std::endl;
std::cout << " GEMM runtime: " << result.avg_runtime_us << " us" << std::endl;
std::cout << " GEMM TFLOPS: " << result.tflops << std::endl;
std::cout << " GEMM effective bandwidth: " << result.eff_bw << " GB/s" << std::endl;
std::cout << " AR algo bandwidth: " << result.eff_AR_bw << " GB/s" << std::endl;
float speedup
= (result_no_fusion.avg_runtime_us + result_no_fusion.avg_runtime_AR_us) / result.avg_runtime_us;
std::cout << "\n Speedup: " << speedup << std::endl;
std::cout << std::endl;
};
}
/**
* Run each config to ensure each one passes numerical check.
*/
void run(cudaStream_t stream = NULL)
{
GemmAllReduceImplInterface::ProblemArgs args = get_arguments();
Result result;
result.passed = true;
// Ensure all configs pass
auto launch_configs = _gemm->getSupportedLaunchConfigs();
for (auto launch_config : launch_configs)
{
args.argLaunchConfig(launch_config);
_gemm->run(args, stream);
TLLM_CUDA_CHECK(cudaStreamSynchronize(stream));
bool passed = verify(stream);
std::cout << launch_config.str() << std::endl;
std::cout << " Verify: " << (passed ? "Pass" : "Fail") << std::endl;
result.passed &= passed;
}
ASSERT_TRUE(result.passed);
}
private:
GemmAllReduceImplInterface::ProblemArgs get_arguments()
{
GemmAllReduceImplInterface::ProblemArgs args;
args.argProblemShape(options.M, options.N, options.K_tp, 1)
.argA(_A.device_data())
.argB(_B.device_data())
.argC(_C.device_data())
.argAlphaPtr(_alpha_vec.device_data())
.argBeta(options.beta)
.argRanks(options.rank, options.tp_group)
.argWorkspace(_workspace.get());
if constexpr (IsInputScalingNeeded)
{
args.argAScale(_SFA.device_data());
args.argBScale(_SFB.device_data());
}
if (isMultiGpu())
{
args.argD(
_D_nvls.getUnicastPointer(), _D_nvls.getMulticastPointer(), (void**) _D_nvls.getIpcUnicastPointers());
}
else
{
args.argD(_D.device_data());
}
return args;
}
template <typename ElementT>
void print_tensor(std::string name, ElementT* data, int const H, int const W)
{
std::vector<ElementT> host(H * W);
cutlass::device_memory::copy_to_host(host.data(), data, H * W);
auto host_tensor = cute::make_tensor(host.data(), cute::make_shape(H, W), cute::make_stride(W, _1{}));
cute::print_tensor(host_tensor);
}
template <typename ElementT>
auto findRelativeDifferences(ElementT const* d_ptr_A, ElementT const* d_ptr_B, size_t capacity, ElementT epsilon,
ElementT nonzero_floor, size_t max_count = 5)
{
std::vector<ElementT> h_ptr_A(capacity);
std::vector<ElementT> h_ptr_B(capacity);
cutlass::device_memory::copy_to_host(h_ptr_A.data(), d_ptr_A, capacity);
cutlass::device_memory::copy_to_host(h_ptr_B.data(), d_ptr_B, capacity);
std::vector<std::tuple<ElementT, ElementT, size_t>> differences;
for (size_t i = 0; i < capacity; ++i)
{
auto a = h_ptr_A[i];
auto b = h_ptr_B[i];
if (!cutlass::relatively_equal(a, b, epsilon, nonzero_floor))
{
differences.push_back(std::make_tuple(a, b, i));
if (differences.size() >= max_count)
{
break;
}
}
}
return differences;
}
template <typename ElementT>
bool compareRelativelyEqual(ElementT* expected, ElementT* actual, size_t size, float epsilon, float nonzero_floor)
{
ElementT eps = static_cast<ElementT>(epsilon);
ElementT floor = static_cast<ElementT>(nonzero_floor);
if (!cutlass::reference::device::BlockCompareRelativelyEqual(expected, actual, size, eps, floor))
{
if (options.rank == 0)
{
#if 1
auto differences = findRelativeDifferences(expected, actual, size, eps, floor);
std::cerr << "Differences:" << std::endl;
for (auto [exp, act, pos] : differences)
{
std::cerr << "expected: " << std::setprecision(3) << std::setw(5) << float(exp)
<< ", actual: " << std::setprecision(3) << std::setw(5) << float(act)
<< ", at pos: " << pos << std::endl;
}
#endif
#if 0
// print_tensor("Actual", D_actual, M, N);
// print_tensor("Ref ", D_expect, M, N);
#endif
}
return false;
}
return true;
}
bool verify(cudaStream_t stream)
{
auto const M = options.M;
auto const N = options.N;
auto const K = options.K_tp;
// Prepare arguments for reference GEMM
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
auto layout_A = cute::make_layout(cute::make_shape(M, K, 1), stride_A);
auto tensor_A = cute::make_tensor(make_iterator(_A.host_data()), layout_A);
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1));
auto layout_B = cute::make_layout(cute::make_shape(N, K, 1), stride_B);
auto tensor_B = cute::make_tensor(make_iterator(_B.host_data()), layout_B);
auto get_mainloop_params = [&]()
{
if constexpr (IsInputScalingNeeded)
{
auto [layout_SFA, layout_SFB] = getLayoutSF_AB(M, N, K);
auto tensor_SFA = cute::make_tensor(_SFA.host_data(), layout_SFA);
auto tensor_SFB = cute::make_tensor(_SFB.host_data(), layout_SFB);
return cutlass::reference::host::GettMainloopParams<float, decltype(tensor_A), decltype(tensor_B),
decltype(tensor_SFA), decltype(tensor_SFB)>{tensor_A, tensor_SFA, tensor_B, tensor_SFB};
}
else
{
return cutlass::reference::host::GettMainloopParams<float, decltype(tensor_A), decltype(tensor_B)>{
tensor_A, tensor_B};
}
};
auto mainloop_params = get_mainloop_params();
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1));
auto layout_C = make_layout(make_shape(M, N, 1), stride_C);
auto tensor_C = cute::make_tensor(make_iterator(_C.host_data()), layout_C);
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1));
auto layout_D = make_layout(make_shape(M, N, 1), stride_D);
auto tensor_D = cute::make_tensor(make_iterator(_D_ref.host_data()), layout_D);
cutlass::reference::host::GettEpilogueParams<float, float, float, float, decltype(tensor_C), decltype(tensor_D)>
epilogue_params{};
epilogue_params.C = tensor_C;
epilogue_params.D = tensor_D;
epilogue_params.alpha = options.alpha;
epilogue_params.beta = options.beta;
// Run reference gemm
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
// Reference is run on host, so copy results to device
_D_ref.sync_device();
// run reference allreduce
ncclComm_t comm = NcclCommunicator::instance().comm;
auto dtype = ToType<ElementD>{}.nccl_value;
TLLM_NCCL_CHECK(
ncclAllReduce(_D_ref.device_data(), _D_ref.device_data(), _D_ref.size(), dtype, ncclSum, comm, stream));
TLLM_CUDA_CHECK(cudaStreamSynchronize(stream));
// Compare results
float const epsilon(0.1f);
float const nonzero_floor(std::numeric_limits<float>::min());
int local_passed = 1;
ElementD* ptr = isMultiGpu() ? _D_nvls.getUnicastPointer() : _D.device_data();
ElementD* ptr_ref = _D_ref.device_data();
// Compare D output
local_passed &= compareRelativelyEqual(ptr_ref, ptr, M * N, epsilon, nonzero_floor);
// Aggregate results - if 1 rank fails, then all ranks fail.
int ranks_passed = 0;
COMM_SESSION.allreduce(&local_passed, &ranks_passed, 1, MpiType::kINT32, MpiOp::SUM);
return ranks_passed == options.tp;
}
template <typename Element, typename Layout>
bool initializeTensor(cutlass::TensorView<Element, Layout> view, uint64_t seed)
{
double scope_max, scope_min;
constexpr int bits_input = cutlass::sizeof_bits<Element>::value;
if constexpr (bits_input == 1)
{
scope_max = 2;
scope_min = 0;
}
else if constexpr (bits_input <= 6)
{
scope_max = 2;
scope_min = -2;
}
else if constexpr (bits_input <= 8)
{
if constexpr (cute::is_same_v<Element, cutlass::float_ue8m0_t>)
{
scope_max = 4;
scope_min = 1;
}
else
{
scope_max = 1;
scope_min = -1;
}
}
else
{
scope_max = 4;
scope_min = -4;
}
cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0);
return true;
}
// Return scale-factor A & B tensor layouts
auto getLayoutSF_AB(int M, int N, int K)
{
switch (getSMVersion())
{
case 100: // blackwell
{
// Unfortunately have to construct mainloop in order to extract SFA/SFB layouts
using MainloopElementA = cute::tuple<ElementA, ElementSFA>;
using MainloopElementB = cute::tuple<ElementB, ElementSFB>;
constexpr static int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
constexpr static int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<cutlass::arch::Sm100,
cutlass::arch::OpClassBlockScaledTensorOp, MainloopElementA, LayoutA, AlignmentA, MainloopElementB,
LayoutB, AlignmentB, float, Shape<_128, _128, _128>, Shape<_1, _1, _1>,
cutlass::gemm::collective::StageCount<1>,
cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100>::CollectiveOp;
using Sm1xxBlkScaledConfig = typename CollectiveMainloop::Sm1xxBlkScaledConfig;
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1));
return std::make_tuple(layout_SFA, layout_SFB);
}
case 90: // hopper
TLLM_THROW("A/B tensor scaling not supported on Sm90 yet");
default: TLLM_THROW("SM version not supported");
}
}
cutlass::HostTensor<ElementA, cutlass::layout::PackedVectorLayout> _A;
cutlass::HostTensor<ElementB, cutlass::layout::PackedVectorLayout> _B;
cutlass::HostTensor<ElementC, cutlass::layout::PackedVectorLayout> _C;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> _D;
cutlass::HostTensor<ElementD, cutlass::layout::PackedVectorLayout> _D_ref;
cutlass::HostTensor<float, cutlass::layout::PackedVectorLayout> _alpha_vec;
// Requires conditional because cannot have HostTensor with type void
// which is the case when we have no scale-factors.
typename std::conditional<IsInputScalingNeeded,
cutlass::HostTensor<ElementSFA, cutlass::layout::PackedVectorLayout>, void*>::type _SFA;
typename std::conditional<IsInputScalingNeeded,
cutlass::HostTensor<ElementSFB, cutlass::layout::PackedVectorLayout>, void*>::type _SFB;
DeviceAllocationNvls<ElementD> _D_nvls;
std::shared_ptr<PersistentWorkspaceInterface> _workspace;
std::shared_ptr<GemmAllReduceImplInterface> _gemm;
};
using MyTypes = testing::Types<
// fp4xfp4=fp16
TestConfig<cutlass::float_e2m1_t, cutlass::float_e2m1_t, cutlass::half_t, cutlass::half_t, cutlass::float_ue4m3_t,
cutlass::float_ue4m3_t>,
// fp8xfp8=fp16
TestConfig<cutlass::float_e4m3_t, cutlass::float_e4m3_t, cutlass::half_t, cutlass::half_t>,
// fp16xfp16=fp16
>;
// TestConfig<cutlass::half_t, cutlass::half_t, cutlass::half_t, cutlass::half_t>>;
TYPED_TEST_SUITE(GemmAllReduceFixture, MyTypes);
/////////////////////////////////////////////////////////////////////
// ATTENTION: run test with mpi `mpi -np <NP> ./gemmAllReduceTest'
/////////////////////////////////////////////////////////////////////
TYPED_TEST(GemmAllReduceFixture, RunnerTest)
{
cudaStream_t stream;
cudaStreamCreate(&stream);
if (!options.verify)
{
TLLM_LOG_WARNING("Skipping verify - return success");
}
else
{
this->run(stream);
}
this->bench(stream);
cudaStreamDestroy(stream);
}
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
bool notSupported = false;
// CUDA 12 minimum required
if (__CUDACC_VER_MAJOR__ < 12)
{
std::cerr << "This example requires CUDA Toolkit version 12 or later.\n";
notSupported = true;
}
TLLM_CUDA_CHECK(cudaSetDevice(COMM_SESSION.getRank()));
cudaDeviceProp props;
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&props, COMM_SESSION.getRank()));
if (props.major < 9)
{
std::cerr << "This example requires a device with compute capability 90 or higher.\n";
notSupported = true;
}
if (!ipcNvlsSupported())
{
std::cerr << "NVLS not supported on this system.\n";
notSupported = true;
}
if (notSupported)
{
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
}
options.parse(argc, argv);
if (options.help)
{
options.print_usage(std::cout) << "\n";
return EXIT_SUCCESS;
}
if (!options.valid())
{
std::cerr << "Invalid arguments."
<< "\n";
return EXIT_FAILURE;
}
// Ensure only 1 rank prints
::testing::TestEventListeners& listeners = ::testing::UnitTest::GetInstance()->listeners();
if (COMM_SESSION.getRank() != 0)
{
delete listeners.Release(listeners.default_result_printer());
}
return RUN_ALL_TESTS();
}