Fix GEMM+AR fusion on blackwell (#5563)

Signed-off-by: xsimmons <xsimmons@nvidia.com>
This commit is contained in:
xavier-nvidia 2025-07-08 17:48:47 -07:00 committed by GitHub
parent a79b73f577
commit b6013da198
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 950 additions and 482 deletions

View File

@ -45,6 +45,7 @@ option(ENABLE_MULTI_DEVICE
option(ENABLE_UCX "Enable building with UCX (Uniform Communication X) support"
ON)
option(NVRTC_DYNAMIC_LINKING "Link against the dynamic NVRTC libraries" OFF)
option(ENABLE_NVSHMEM "Enable building with NVSHMEM support" OFF)
option(USING_OSS_CUTLASS_LOW_LATENCY_GEMM
"Using open sourced Cutlass low latency gemm kernel" ON)
option(USING_OSS_CUTLASS_FP4_GEMM "Using open sourced Cutlass fp4 gemm kernel"
@ -54,6 +55,8 @@ option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
"Using open sourced Cutlass AR gemm kernel" ON)
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
if(NVTX_DISABLE)
add_compile_definitions("NVTX_DISABLE")
message(STATUS "NVTX is disabled")
@ -171,6 +174,7 @@ message(STATUS "CUDA library status:")
message(STATUS " version: ${CUDAToolkit_VERSION}")
message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}")
message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}")
message(STATUS "CUDA_NVML_LIB: ${CUDA_NVML_LIB}")
# Prevent CMake from creating a response file for CUDA compiler, so clangd can
# pick up on the includes
@ -262,9 +266,21 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss ")
# note: cmake expr generation $<BOOL:${ENABLE_MULTI_DEVICE}> is a build time
# evaluation so hard to debug at cmake time
if(ENABLE_MULTI_DEVICE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_MULTI_DEVICE=1")
# Add target definitions for both C++ and CUDA
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_MULTI_DEVICE=1>
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_MULTI_DEVICE=1>)
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_MULTI_DEVICE=0")
# Add target definitions for both C++ and CUDA
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_MULTI_DEVICE=0>
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_MULTI_DEVICE=0>)
endif()
if(ENABLE_NVSHMEM)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_NVSHMEM=1>
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=1>)
else()
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_NVSHMEM=0>
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
endif()
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can

View File

@ -72,6 +72,12 @@ if(ENABLE_MULTI_DEVICE)
include_directories(${MPI_C_INCLUDE_DIRS})
endif()
if(ENABLE_NVSHMEM)
# Add hints for aarch64
find_package(NVSHMEM REQUIRED HINTS /usr/lib/sbsa-linux-gnu/cmake/nvshmem/)
include_directories(/usr/include/nvshmem/)
endif()
if(NOT WIN32)
set(DECODER_SHARED_TARGET_0 decoder_attention_0)
set(DECODER_SHARED_TARGET_1 decoder_attention_1)
@ -231,7 +237,10 @@ if(ENABLE_MULTI_DEVICE)
set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} ${MPI_C_LIBRARIES} ${NCCL_LIB})
endif()
message("TRTLLM_LINK_LIBS: ${TRTLLM_LINK_LIBS}")
if(ENABLE_NVSHMEM)
set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} nvshmem::nvshmem_host
nvshmem::nvshmem_device)
endif()
if(NOT WIN32) # Unix-like compilers
set(UNDEFINED_FLAG "-Wl,--no-undefined")

View File

@ -332,12 +332,12 @@ enum class ClusterShape
ClusterShape_1x2x1,
ClusterShape_2x2x1,
ClusterShape_1x4x1,
ClusterShape_4x1x1,
ClusterShape_4x2x1,
ClusterShape_2x4x1,
ClusterShape_4x4x1,
ClusterShape_1x8x1,
ClusterShape_8x1x1,
ClusterShape_4x1x1
ClusterShape_8x1x1
};
static auto get_cluster_shape_name(ClusterShape Shape_MNK)

View File

@ -22,6 +22,8 @@
#include "cutlass/barrier.h"
#include <cuda/atomic>
namespace cutlass
{
@ -43,7 +45,7 @@ __forceinline__ __device__ uint32_t atomicCAS_system_acq(uint32_t* p, uint32_t c
} // namespace detail
template <class Sync, bool SafeBetweenPhases, bool UseMembarGPU>
template <class Sync, bool SafeBetweenPhases>
struct MulticastSystemBarrier : public GenericBarrier<Sync>
{
@ -57,8 +59,8 @@ struct MulticastSystemBarrier : public GenericBarrier<Sync>
protected:
/// Reduce into flag, with release pattern (int specialization)
CUTLASS_DEVICE
static void red_release(T* mc_ptr, int val)
template <cuda::thread_scope Scope>
CUTLASS_DEVICE static void red_release(T* mc_ptr, int val)
{
#if defined(CUTE_ARCH_MULTIMEM_SM90_ENABLED)
// atomic reduction to all replicas
@ -66,14 +68,18 @@ protected:
// See
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-multimem-ld-reduce-multimem-st-multimem-red
// for multimem PTX doc
if constexpr (UseMembarGPU)
if constexpr (Scope == cuda::thread_scope::thread_scope_device)
{
asm volatile("multimem.red.release.gpu.global.add.u32 [%0], %1;" ::"l"(mc_ptr), "r"(val) : "memory");
}
else
else if constexpr (Scope == cuda::thread_scope::thread_scope_system)
{
asm volatile("multimem.red.release.sys.global.add.u32 [%0], %1;" ::"l"(mc_ptr), "r"(val) : "memory");
}
else
{
CUTE_INVALID_CONTROL_PATH("Invalid thread scope for MulticastSystemBarrier.");
}
// Need a fence between MC and UC access to the same memory:
// - fence.proxy instructions establish an ordering between memory accesses that may happen through different
@ -128,8 +134,8 @@ public:
Sync::sync();
}
CUTLASS_DEVICE
static T arrive_inc_get(T* mc_ptr, T* uc_ptr, int thread_idx, int flag_idx, int rank, int world_size)
template <cuda::thread_scope Scope>
CUTLASS_DEVICE static T arrive_inc_get(T* mc_ptr, T* uc_ptr, int thread_idx, int flag_idx, int rank, int world_size)
{
T* mc_barrier_ptr = mc_ptr + flag_idx;
T* uc_barrier_ptr = uc_ptr + flag_idx;
@ -156,13 +162,13 @@ public:
// can be immediately reused.
bool master = rank == 0;
int val = master ? 0x80000000 - (world_size - 1) : 1;
red_release(mc_barrier_ptr, val);
red_release<Scope>(mc_barrier_ptr, val);
}
return old_arrive;
}
CUTLASS_DEVICE
static void arrive_inc(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
template <cuda::thread_scope Scope = cuda::thread_scope::thread_scope_system>
CUTLASS_DEVICE static void arrive_inc(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
{
T* mc_barrier = params.mc_barrier_ptr + flag_idx;
@ -170,23 +176,24 @@ public:
if (thread_idx == 0)
{
red_release(mc_barrier, 1);
red_release<Scope>(mc_barrier, 1);
}
}
CUTLASS_DEVICE
static void arrive_and_wait(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
template <cuda::thread_scope Scope = cuda::thread_scope::thread_scope_system>
CUTLASS_DEVICE static void arrive_and_wait(
Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
{
auto mc_ptr = params.mc_barrier_ptr;
auto uc_ptr = params.uc_barrier_ptr;
if constexpr (SafeBetweenPhases)
{
auto old_arrive = arrive_inc_get(mc_ptr, uc_ptr, thread_idx, flag_idx, rank, world_size);
auto old_arrive = arrive_inc_get<Scope>(mc_ptr, uc_ptr, thread_idx, flag_idx, rank, world_size);
wait(old_arrive, uc_ptr, thread_idx, flag_idx);
}
else
{
arrive_inc(params, thread_idx, flag_idx, rank, world_size);
arrive_inc<Scope>(params, thread_idx, flag_idx, rank, world_size);
wait_eq_reset(uc_ptr, thread_idx, flag_idx, world_size);
}
}

View File

@ -249,11 +249,14 @@ endif()
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
add_library(
ar_gemm_src STATIC
${ARGEMM_SRC_CU}
${CMAKE_CURRENT_SOURCE_DIR}/../../runtime/ipcNvlsMemory.cpp)
${ARGEMM_SRC_CU} ${CMAKE_CURRENT_SOURCE_DIR}/../../runtime/ipcNvlsMemory.cu)
target_include_directories(
ar_gemm_src
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../internal_cutlass_kernels/include)
if(ENABLE_NVSHMEM)
target_link_libraries(ar_gemm_src PRIVATE nvshmem::nvshmem_host
nvshmem::nvshmem_device)
endif()
set_cuda_architectures(ar_gemm_src 90 100f)
endif()

View File

@ -138,7 +138,7 @@ public:
// Epilogue
////////////////
using FusionCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>;
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true, true>;
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true>;
using EpilogueScheduleType = typename MmaAdapter<MmaType, IsFP4>::EpilogueSchedule;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOp

View File

@ -100,8 +100,7 @@ public:
using RasterOrderOptions =
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true /* Safe across phases */,
true /* membar.gpu */>;
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true /* Safe across phases */>;
// 16B alignment for TMA
static constexpr int AlignmentA = 16 / sizeof(ElementA);

View File

@ -201,7 +201,7 @@ public:
auto [M, N, K, L] = problem_shape;
auto [m, n, k, l] = tile_coord;
if (!tile_valid(m, n) || params_ptr->world_size == 1)
if (!tile_valid(m, n) || params_ptr->world_size <= 2)
{
return; // nothing to do
}
@ -212,7 +212,7 @@ public:
// Wait for all multicast writes to be visible to us.
// This is safe between phases.
SystemBarrier::arrive_and_wait(
SystemBarrier::arrive_and_wait<cuda::thread_scope::thread_scope_system>(
params_ptr->barrier_params_final_sync, thread_idx, tile_index, params_ptr->rank, params_ptr->world_size);
}
@ -297,13 +297,20 @@ public:
Tensor tGR_gD1_vec = zipped_divide(tGR_gD1(_, _, _, red_m, red_n), Vec);
Tensor tRG_gOut_vec = zipped_divide(tRG_gOut(_, _, _, red_m, red_n), Vec);
auto pred_fn
= [&](auto const&... coords) { return elem_less(tGR_pD_vec(_0{}, coords...), problem_shape); };
// Create predicate tensor for bounds checking
Tensor pred_tensor = make_tensor<bool>(make_shape(size(tGR_pD_vec)), Stride<_1>{});
// Set predicate values based on coordinate bounds
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred_tensor); ++i)
{
pred_tensor(i) = elem_less(tGR_pD_vec(_0{}, i), problem_shape);
}
// Read from self.
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD0_vec, tGR_rD0_vec);
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD0_vec, tGR_rD0_vec);
// Read from remote.
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD1_vec, tGR_rD1_vec);
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD1_vec, tGR_rD1_vec);
// Reduce
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tGR_rD0_vec); i++)
@ -311,7 +318,7 @@ public:
tGR_rD0_vec(i) += tGR_rD1_vec(i);
}
// store to self.
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_rD0_vec, tRG_gOut_vec);
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_rD0_vec, tRG_gOut_vec);
}
}
}
@ -386,13 +393,21 @@ public:
Tensor tGR_gD_vec = zipped_divide(tGR_gD(_, _, _, red_m, red_n), Vec);
Tensor tRG_gD_vec = zipped_divide(tRG_gD(_, _, _, red_m, red_n), Vec);
Tensor tGR_pD_vec = zipped_divide(tGR_pD(_, _, _, red_m, red_n), Vec);
// problem shape bounds check
auto pred_fn
= [&](auto const&... coords) { return elem_less(tGR_pD_vec(_0{}, coords...), problem_shape); };
// Create predicate tensor for bounds checking
Tensor pred_tensor = make_tensor<bool>(make_shape(size(tGR_gD_vec)), Stride<_1>{});
// Set predicate values based on coordinate bounds
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred_tensor); ++i)
{
pred_tensor(i) = elem_less(tGR_pD_vec(_0{}, i), problem_shape);
}
// load-reduce in switch
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD_vec, tGR_rD_vec);
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD_vec, tGR_rD_vec);
// store switch multicast
cute::copy_if(CopyAtomR2G{}, pred_fn, tGR_rD_vec, tRG_gD_vec);
cute::copy_if(CopyAtomR2G{}, pred_tensor, tGR_rD_vec, tRG_gD_vec);
}
}
}

View File

@ -171,7 +171,7 @@ struct Sm100AllReduceArrive
tma_store_wait<0>();
int tile_idx = params_ptr->tile_layout(m, n);
SystemBarrier::arrive_inc(
SystemBarrier::arrive_inc<cuda::thread_scope::thread_scope_device>(
params_ptr->barrier_params, thread_idx, tile_idx, params_ptr->rank, params_ptr->world_size);
}
}

View File

@ -268,7 +268,7 @@ struct Sm90AuxAllReduce
tma_store_wait<0>();
int tile_idx = params_ptr->tile_layout(m, n);
SystemBarrier::arrive_inc(
SystemBarrier::arrive_inc<cuda::thread_scope::thread_scope_device>(
params_ptr->barrier_params, thread_idx, tile_idx, params_ptr->rank, params_ptr->world_size);
}
};

View File

@ -969,8 +969,8 @@ public:
{
bool do_tail_store = false;
constexpr auto ARBarrier = (uint32_t) cutlass::arch::ReservedNamedBarriers::FirstUserBarrier;
CollectiveAllReduce collective_allreduce(params.allreduce, ARBarrier);
const uint32_t AR_barrier_id = 0;
CollectiveAllReduce collective_allreduce(params.allreduce, AR_barrier_id);
int thread_idx = threadIdx.x - (MaxThreadsPerBlock - NumARThreads);
auto init_cta_coord_mnkl = cta_coord_mnkl;

View File

@ -31,6 +31,10 @@
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
#include "cute/tensor.hpp"
#include "cutlass/arch/grid_dependency_control.h"
#include "gemm_universal_allreduce.hpp"
#include "cute/tensor.hpp"
@ -42,6 +46,8 @@
namespace cutlass::gemm::kernel
{
///////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class CollectiveAllReduce_,
class TileScheduler_>
class GemmARUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, CollectiveAllReduce_, TileScheduler_,
@ -55,6 +61,7 @@ public:
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled;
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
@ -88,22 +95,46 @@ public:
static_assert(!cute::is_same_v<TileScheduler_, StreamKScheduler>,
"Ping-pong kernel does not currently support stream-K scheduler.");
static constexpr uint32_t TileSchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount;
using TileSchedulerTag = TileScheduler_;
using TileScheduler =
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
using TileScheduler = typename detail::TileSchedulerSelector<TileSchedulerTag, ArchTag, TileShape, ClusterShape,
TileSchedulerPipelineStageCount>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
using TileSchedulerPipeline = typename TileScheduler::Pipeline;
using TileSchedulerPipelineState = typename TileSchedulerPipeline::PipelineState;
using TileSchedulerStorage = typename TileScheduler::SharedStorage;
static constexpr uint32_t NumAllReduceThreads = 2 * NumThreadsPerWarp;
using TileSchedulerThrottlePipeline = typename TileScheduler::ThrottlePipeline;
using TileSchedulerThrottlePipelineState = typename TileSchedulerThrottlePipeline::PipelineState;
static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent;
static_assert(!IsSchedDynamicPersistent, "Tile scheduling order needs to be consistent across ranks.");
// Warp specialization thread count per threadblock
static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp
static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp
static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = 2;
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 4 warp
static constexpr uint32_t MaxThreadsPerBlock
= CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
= NumMMAThreads * NumMmaWarpGroups + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
static constexpr bool IsMainloopAuxiliaryLoadNeeded
= detail::HasAuxiliaryLoad_v<typename CollectiveMainloop::DispatchPolicy>;
static_assert(NumMMAThreads == 128, "Pingpong kernel must have TiledMMA operating using 128 threads.");
static_assert(MaxThreadsPerBlock == 384, "Pingpong kernel must have 384 threads in total.");
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
static constexpr int RegsPerThread = (size<0>(TileShape{}) * size<1>(TileShape{}) * sizeof(ElementAccumulator))
/ (NumMMAThreads * sizeof(uint32_t));
static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208;
static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24;
static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240;
// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
@ -114,7 +145,6 @@ public:
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
static constexpr uint32_t StagesPerMathWarpGroup = 2;
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
using MathWarpGroupOrderBarrierSharedStorage
= cutlass::PipelineDetail::OrderedSequenceBarrierSharedStorage<MathWarpGroupOrderBarrier::SequenceDepth,
MathWarpGroupOrderBarrier::SequenceLength>;
@ -122,7 +152,7 @@ public:
// Kernel level shared memory storage
struct SharedStorage
{
struct PipelineStorage : cute::aligned_struct<16>
struct PipelineStorage : cute::aligned_struct<16, _1>
{
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
@ -135,7 +165,9 @@ public:
alignas(16) typename AllReduceOrderBarrier::SharedStorage allreduce_order;
} pipelines;
struct TensotStorage : cute::aligned_struct<128>
alignas(16) TileSchedulerStorage scheduler;
struct TensorStorage : cute::aligned_struct<128, _1>
{
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
@ -199,31 +231,48 @@ public:
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Get maximum number of clusters that could co-exist on the target device
int max_active_clusters = args.hw_info.max_active_clusters;
if (max_active_clusters <= 0)
{
max_active_clusters = 0;
CUTLASS_TRACE_HOST(
" WARNING: Arguments do not include a valid max cluster count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the "
"max_active_clusters.");
}
else
{
CUTLASS_TRACE_HOST(
"to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters);
}
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters};
// Calculate workspace pointers
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
void* scheduler_workspace = workspace_ptr;
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* scheduler_workspace = workspace_ptr + workspace_offset;
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void* mainloop_workspace = nullptr;
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
return {args.mode, problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace),
CollectiveAllReduce::to_underlying_arguments(args.problem_shape, args.all_reduce), hw_info,
TileScheduler::to_underlying_arguments(
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)};
TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
args.scheduler, scheduler_workspace, NumEpilogueSubTiles)};
}
static bool can_implement(Arguments const& args)
@ -245,13 +294,14 @@ public:
static size_t get_workspace_size(Arguments const& args)
{
size_t workspace_size = 0;
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
@ -261,20 +311,23 @@ public:
Status status = Status::kSuccess;
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
size_t workspace_offset = 0;
static constexpr uint32_t NumEpilogueSubTiles = 1;
static constexpr uint32_t NumAccumulatorMtxs = 1;
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(
args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, 1);
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
return status;
}
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups,
NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter);
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess)
{
@ -347,9 +400,9 @@ public:
enum class ProducerWarpRole
{
Mainloop = 0,
Epilogue = 1,
Warp2 = 2,
Warp3 = 3
Warp1 = 1,
Epilogue = 2,
MainloopAux = 3
};
// Kernel level shared memory storage
@ -372,10 +425,18 @@ public:
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// TileScheduler pipeline
typename TileSchedulerPipeline::Params scheduler_pipeline_params;
typename TileSchedulerThrottlePipeline::Params scheduler_throttle_pipeline_params;
TileSchedulerPipeline scheduler_pipeline(shared_storage.scheduler.pipeline(), scheduler_pipeline_params);
TileSchedulerPipelineState scheduler_pipe_consumer_state;
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
if (warp_group_role == WarpGroupRole::Producer
&& (producer_warp_role == ProducerWarpRole::Mainloop
|| producer_warp_role == ProducerWarpRole::MainloopAux))
{
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
@ -385,6 +446,7 @@ public:
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
mainloop_pipeline_params.num_producers = NumProducerThreads;
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
@ -420,7 +482,6 @@ public:
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
typename LoadWarpOrderBarrier::Params params_allreduce_order_barrier;
// DMA Load WG will not participate in these Ordered Barrier syncs
params_allreduce_order_barrier.group_id
= canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
params_allreduce_order_barrier.group_size = NumThreadsPerWarpGroup;
@ -493,8 +554,10 @@ public:
if (warp_group_role == WarpGroupRole::Consumer1)
{
// Advance 2nd Math WG to the next work tile for the startup
scheduler.advance_to_next_work();
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
mainloop_pipe_consumer_state.advance(k_tile_count);
epi_load_pipe_consumer_state.advance(c_tile_count);
@ -546,8 +609,40 @@ public:
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
else if (producer_warp_role == ProducerWarpRole::MainloopAux)
{
if constexpr (IsMainloopAuxiliaryLoadNeeded)
{
// Ensure that the prefetched kernel does not touch
// unflushed global memory prior to this instruction
cutlass::arch::wait_on_dependent_grids();
while (work_tile_info.is_valid())
{
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
collective_mainloop.load_auxiliary(params.mainloop, mainloop_pipeline,
mainloop_pipe_producer_state, load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx,
block_rank_in_cluster, shared_storage.tensors.mainloop);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(k_tile_count);
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
}
}
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
{
@ -556,9 +651,15 @@ public:
// unflushed global memory prior to this instruction
cutlass::arch::wait_on_dependent_grids();
load_order_barrier.wait();
bool do_load_order_wait = true;
while (work_tile_info.is_valid())
{
if (do_load_order_wait)
{
load_order_barrier.wait();
do_load_order_wait = false;
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
@ -592,18 +693,13 @@ public:
// The timing of calling this function only influences performance,
// not functional correctness.
cutlass::arch::launch_dependent_grids();
return;
}
#endif
constexpr uint32_t AllReduceBarrier0 = uint32_t(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier) + 1;
constexpr uint32_t AllReduceBarrier1 = uint32_t(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier) + 2;
auto const named_barrier
= warp_group_role == WarpGroupRole::Consumer0 ? AllReduceBarrier0 : AllReduceBarrier1;
CollectiveAllReduce collective_all_reduce(params.all_reduce, named_barrier);
cute::tuple<int32_t, int32_t, cute::Underscore, int32_t> prev_blk_coord;
int tile_count = 0;
const uint32_t AR_barrier_id = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1;
CollectiveAllReduce collective_allreduce(params.all_reduce, AR_barrier_id);
while (work_tile_info.is_valid())
{
@ -621,6 +717,7 @@ public:
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators, k_tile_count,
warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop);
// Cue for next Math WG's MMA to start
math_wg_order_barrier.arrive();
@ -638,6 +735,7 @@ public:
cutlass::arch::launch_dependent_grids();
}
#endif
// Order two Math WG's Epilogue one after the other
math_wg_order_barrier.wait();
@ -661,32 +759,28 @@ public:
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
// Cue next WG's Epilogue
// Cue for next Math WG's Epilogue to start
math_wg_order_barrier.arrive();
// Order two consumer WG's Allreduce one after the other
allreduce_order_barrier.wait();
collective_all_reduce.gather_reduce_broadcast(problem_shape_MNKL, blk_coord, warp_group_thread_idx);
collective_allreduce.gather_reduce_broadcast(problem_shape_MNKL, blk_coord, warp_group_thread_idx);
// Cue next WG's Allreduce
allreduce_order_barrier.arrive();
prev_blk_coord = blk_coord;
tile_count++;
if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups))
{
// Ensure broadcast from other ranks are visible to us.
collective_allreduce.tile_global_sync(problem_shape_MNKL, blk_coord, warp_group_thread_idx);
}
// Get next work tile
scheduler.advance_to_next_work(NumMmaWarpGroups);
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Last tile in CTA, flush and sync
if (tile_count > 0)
{
collective_all_reduce.tile_global_sync(problem_shape_MNKL, prev_blk_coord, warp_group_thread_idx);
}
} // Consumer Warp Groups End
} // Consumer Warp Groups End
}
};

View File

@ -127,6 +127,53 @@ void GemmAllReducePlugin::allocatePersistentWorkspace()
TLLM_CHECK(mWorkspace != nullptr);
}
LaunchConfig GemmAllReducePlugin::getStaticHeuristicLaunchConfig(int M) const
{
using namespace tensorrt_llm::cutlass_extensions;
// This is only applicable when we swap and transpose A & B.
// When M is small we want to select tile that best fits it to maximize MMA efficiency.
auto filterByM = [&](std::vector<LaunchConfig> candidateConfigs)
{
std::vector<LaunchConfig> result;
if (M <= 16)
{
std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result),
[](const LaunchConfig& config)
{ return config.tile_shape == TileShape::TileShape_128x16x128 and config.transposed; });
}
else if (M <= 32)
{
std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result),
[](const LaunchConfig& config)
{ return config.tile_shape == TileShape::TileShape_128x32x128 and config.transposed; });
}
else if (M <= 64)
{
std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result),
[](const LaunchConfig& config)
{ return config.tile_shape == TileShape::TileShape_128x64x128 and config.transposed; });
}
else
{
std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result),
[](const LaunchConfig& config)
{ return config.tile_shape == TileShape::TileShape_128x128x128 and config.transposed; });
}
// If result empty then use any.
if (result.empty())
{
result = candidateConfigs;
}
return result;
};
auto bestLaunchConfigs = mGemm->getSupportedLaunchConfigs();
bestLaunchConfigs = filterByM(bestLaunchConfigs);
TLLM_CHECK(!bestLaunchConfigs.empty());
// Return first one, because who knows which is best.
return bestLaunchConfigs.front();
}
static GemmAllReducePluginOptions deserializeOptions(void const*& data, size_t length)
{
char const* begin = reinterpret_cast<char const*>(data);
@ -164,8 +211,10 @@ static GemmAllReducePluginOptions deserializeOptions(void const*& data, size_t l
GemmAllReducePlugin::GemmAllReducePlugin(void const* data, size_t length)
: GemmAllReducePlugin(deserializeOptions(std::ref(data), length))
{
// char const* end = reinterpret_cast<char const*>(data);
mProfiler->deserializeFromOwnFile(mGemmId, mOptions.maxProblemShape);
if (mProfiler->useProfiler())
{
mProfiler->deserializeFromOwnFile(mGemmId, mOptions.maxProblemShape);
}
}
//////////////////////////////////
@ -351,7 +400,15 @@ int GemmAllReducePlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensor
TLLM_CHECK_WITH_INFO(K > 0, "GemmAllReducePlugin K is 0.");
TLLM_CHECK_WITH_INFO(mWorkspace != nullptr, "GemmAllReducePlugin workspace is null.");
auto bestLaunchConfig = mProfiler->getBestConfig(M, mGemmId).value();
LaunchConfig bestLaunchConfig;
if (mProfiler->useProfiler())
{
bestLaunchConfig = mProfiler->getBestConfig(M, mGemmId).value();
}
else
{
bestLaunchConfig = getStaticHeuristicLaunchConfig(M);
}
void const* activation = inputs[mArgInvMap[TensorArg::IN_ACTIVATION]];
void const* weight = inputs[mArgInvMap[TensorArg::IN_WEIGHT]];
@ -435,7 +492,7 @@ int GemmAllReducePlugin::getNbOutputs() const noexcept
int GemmAllReducePlugin::initialize() noexcept
{
if (isBuilding())
if (isBuilding() && mProfiler->useProfiler())
{
// TODO (xsimmons): interfaces between GemmPluginProfiler and Plugin
// needs to be relooked at - current interface implicitly assigns runner to profiler
@ -509,7 +566,7 @@ void GemmAllReducePlugin::serialize(void* buffer) const noexcept
// Since by default each rank will generate and serialize its own profiler mapping
// this can lead to different mappings between ranks which will result in fatal
// error. Therefore only generate and use profiler mapping for single rank.
if (COMM_SESSION.getRank() == 0)
if (mProfiler->useProfiler() && COMM_SESSION.getRank() == 0)
{
mProfiler->serializeToOwnFile(mGemmId);
}

View File

@ -34,6 +34,9 @@ namespace cutlass_kernels = ::tensorrt_llm::kernels::opened_cutlass_kernels;
#else
namespace cutlass_kernels = ::tensorrt_llm::kernels::cutlass_kernels;
#endif
using LaunchConfig = typename cutlass_kernels::GemmAllReduceImplInterface::LaunchConfig;
namespace tensorrt_llm::plugins
{
struct GemmAllReducePluginOptions
@ -125,6 +128,8 @@ private:
void allocatePersistentWorkspace();
LaunchConfig getStaticHeuristicLaunchConfig(int M) const;
// Params that are initialized during constructor
using KeyType = std::tuple<DataType, DataType, DataType>;
using ValueType = std::function<cutlass_kernels::GemmAllReduceImplInterface*()>;

View File

@ -58,10 +58,18 @@ void GemmAllReducePluginProfiler::deserializeFromOwnFile(GemmIdCore gemmId, Gemm
assert(end == begin + size);
}
bool GemmAllReducePluginProfiler::useProfiler()
{
char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR");
return envDir != nullptr;
}
std::string GemmAllReducePluginProfiler::getCacheFileName(GemmIdCore gemmId)
{
std::stringstream fileName;
fileName << "/tmp/gemm-AR";
char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR");
std::string directory = envDir ? std::string(envDir) : "/tmp/";
fileName << directory + "/gemm-AR";
fileName << "-n" << std::to_string(gemmId.n);
fileName << "-k" << std::to_string(gemmId.k);
fileName << "-" << tc::getDtypeString(gemmId.dtype);

View File

@ -29,6 +29,8 @@ namespace tensorrt_llm::plugins
* Used for tuning to find best GEMM configs for different problem shapes.
* WARNING: Tuning GEMM+AR kernel may not be fully representable of real
* multi-GPU workloads as tuning only runs on single-GPU.
* IMPORTANT: TRT-LLM does not support deterministic tuning across ranks.
* Because of this, we have to serialize/deserialize our own configuration file.
*/
#if defined(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
@ -45,6 +47,8 @@ public:
void deserializeFromOwnFile(GemmIdCore gemmId, GemmDims problemShape);
bool useProfiler();
protected:
////////////////////////////////////
// GemmPluginProfiler methods

View File

@ -18,7 +18,7 @@ set(SRCS
testing/modelSpecBinding.cpp
runtime/moeBindings.cpp
userbuffers/bindings.cpp
../runtime/ipcNvlsMemory.cpp
../runtime/ipcNvlsMemory.cu
bindings.cpp)
include_directories(${PROJECT_SOURCE_DIR}/include)
@ -30,10 +30,21 @@ set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE
target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC
"${TORCH_INSTALL_PREFIX}/lib")
if(ENABLE_NVSHMEM)
target_link_libraries(${TRTLLM_PYBIND_MODULE} PUBLIC nvshmem::nvshmem_host
nvshmem::nvshmem_device)
endif()
target_link_libraries(
${TRTLLM_PYBIND_MODULE}
PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG} ${NO_AS_NEEDED_FLAG}
${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python)
PUBLIC ${SHARED_TARGET}
${UNDEFINED_FLAG}
${NO_AS_NEEDED_FLAG}
${Python3_LIBRARIES}
${TORCH_LIBRARIES}
torch_python
${CUDA_NVML_LIB})
target_compile_definitions(
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
PYBIND11_DETAILED_ERROR_MESSAGES=1)
@ -43,6 +54,6 @@ if(NOT WIN32)
${TRTLLM_PYBIND_MODULE}
PROPERTIES
LINK_FLAGS
"-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}"
"-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' -Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib/stubs' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}"
)
endif()

View File

@ -40,7 +40,7 @@ set(SRCS
iTensor.cpp
ipcUtils.cpp
ipcSocket.cpp
ipcNvlsMemory.cpp
ipcNvlsMemory.cu
mcastDeviceMemory.cpp
memoryCounters.cpp
moeLoadBalancer/gdrwrap.cpp
@ -80,6 +80,7 @@ set_property(TARGET runtime_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_cuda_architectures(runtime_src 89)
target_include_directories(runtime_src PRIVATE ${MPI_C_INCLUDE_DIRS})
target_link_libraries(runtime_src PUBLIC ${CUDA_NVML_LIB})
if(ENABLE_MULTI_DEVICE)
target_link_libraries(runtime_src PUBLIC ${NCCL_LIB})

View File

@ -1,359 +0,0 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
#include "ipcSocket.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <unistd.h>
#define CUCHECK(cmd) \
do \
{ \
CUresult retval = cmd; \
if (retval != CUDA_SUCCESS) \
{ \
const char* error_string; \
cuGetErrorString(retval, &error_string); \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define ALIGN_SIZE(x, align) x = ((x + align - 1) / align) * align;
namespace tensorrt_llm::runtime
{
using namespace tensorrt_llm::mpi;
void MPI_group_rank(std::set<int> group, int* groupRank)
{
#if ENABLE_MULTI_DEVICE
int rank = COMM_SESSION.getRank();
auto it = std::find(group.begin(), group.end(), rank);
TLLM_CHECK_WITH_INFO(
it != group.end(), "Incorrect group specified - rank " + std::to_string(rank) + " not found in group");
*groupRank = std::distance(group.begin(), it);
#else
TLLM_THROW("MPI_group_rank needs to be compiled with ENABLE_MULTI_DEVICE");
#endif
}
/**
* @brief MPI_Barrier when subset of ranks present
*/
void MPI_group_barrier(std::set<int> group)
{
#if ENABLE_MULTI_DEVICE
std::vector<int> ranks(group.begin(), group.end());
int size = group.size();
int group_rank;
MPI_group_rank(group, &group_rank);
int root = 0;
if (group_rank == root)
{
int dummy = 0;
// Root receives messages from all other processes
for (int i = 1; i < size; i++)
{
COMM_SESSION.recv(&dummy, 1, MpiType::kINT32, ranks[i], MpiTag::kDefault);
}
// Root sends messages back to all other processes
for (int i = 1; i < size; i++)
{
COMM_SESSION.send(&dummy, 1, MpiType::kINT32, ranks[i], MpiTag::kDefault);
}
}
else
{
int dummy = 0;
// Non-root processes send a message to root
COMM_SESSION.send(&dummy, 1, MpiType::kINT32, ranks[root], MpiTag::kDefault);
// Non-root processes receive a message from root
COMM_SESSION.recv(&dummy, 1, MpiType::kINT32, ranks[root], MpiTag::kDefault);
}
#else
TLLM_THROW("MPI_group_barrier needs to be compiled with ENABLE_MULTI_DEVICE");
#endif
}
/**
* @brief MPI_Bcast when subset of ranks present
*/
void MPI_group_bcast(std::set<int> group, void* buffer, int count, MpiType datatype, int root)
{
#if ENABLE_MULTI_DEVICE
int group_rank;
MPI_group_rank(group, &group_rank);
std::vector<int> ranks(group.begin(), group.end());
if (group_rank == root)
{
// Root sends message to all other processes
for (size_t i = 1; i < ranks.size(); ++i)
{
COMM_SESSION.send(buffer, count, datatype, ranks[i], MpiTag::kDefault);
}
}
else
{
// Non-root processes receive a message from root
COMM_SESSION.recv(buffer, count, datatype, ranks[root], MpiTag::kDefault);
}
MPI_group_barrier(group);
#else
TLLM_THROW("MPI_group_bcast needs to be compiled with ENABLE_MULTI_DEVICE");
#endif
}
bool ipcNvlsSupported()
{
CUdevice current_dev;
int cuda_dev = -1;
int cuda_driver_version = -1;
int dev_count = 0;
TLLM_CUDA_CHECK(cudaDriverGetVersion(&cuda_driver_version));
if (cuda_driver_version < 12010)
{
return false;
}
TLLM_CUDA_CHECK(cudaGetDeviceCount(&dev_count));
for (int i = 0; i < dev_count; ++i)
{
TLLM_CUDA_CHECK(cudaGetDevice(&cuda_dev));
CUCHECK(cuDeviceGet(&current_dev, cuda_dev));
int mc_support = 0;
CUCHECK(cuDeviceGetAttribute(
&mc_support, static_cast<CUdevice_attribute>(CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED), current_dev));
if (mc_support == 0)
{
return false;
}
}
return true;
}
IpcNvlsHandle* ipcNvlsAllocate(size_t size, std::set<int> group)
{
#if ENABLE_MULTI_DEVICE
TLLM_CHECK(size > 0);
std::vector<int> ranks(group.begin(), group.end());
int rank = COMM_SESSION.getRank();
int group_rank;
MPI_group_rank(group, &group_rank);
int device_id = ranks[group_rank];
cudaSetDevice(device_id);
CUmemAllocationProp ucprop;
CUmulticastObjectProp mcprop;
size_t uc_align = 0;
size_t mc_align = 0;
CUmemAccessDesc uc_mc_access;
memset(&uc_mc_access, 0, sizeof(CUmemAccessDesc));
uc_mc_access.location.id = device_id;
uc_mc_access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
uc_mc_access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
memset(&ucprop, 0, sizeof(CUmemAllocationProp));
ucprop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
ucprop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
ucprop.location.id = device_id;
ucprop.allocFlags.gpuDirectRDMACapable = 1;
ucprop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
CUCHECK(cuMemGetAllocationGranularity(&uc_align, &ucprop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
ALIGN_SIZE(size, uc_align);
memset(&mcprop, 0, sizeof(CUmulticastObjectProp));
mcprop.numDevices = ranks.size();
mcprop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
mcprop.flags = 0;
mcprop.size = size;
CUCHECK(cuMulticastGetGranularity(&mc_align, &mcprop, CU_MULTICAST_GRANULARITY_MINIMUM));
ALIGN_SIZE(size, mc_align);
mcprop.size = size;
// Init NVLS handle
IpcNvlsHandle handle;
handle.size = mcprop.size;
// Get time
timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
// High res time down to nanosec
unsigned long seed = ts.tv_sec * 1000000000L + ts.tv_nsec;
// Initialize with rand seed.
srand(seed);
int root = 0;
uint64_t unique_op_id = (uint64_t) (rand()) ^ ((uint64_t) (rand()) << 32);
MPI_group_bcast(group, &unique_op_id, sizeof(unique_op_id), MpiType::kBYTE, root);
uint32_t volatile abort_flag = 0;
std::shared_ptr<NcclIpcSocket> socket = ncclIpcSocketInit(rank, unique_op_id, &abort_flag);
MPI_group_barrier(group);
int fd;
if (group_rank == root)
{
CUCHECK(cuMulticastCreate(&handle.mc_handle, &mcprop));
CUCHECK(
cuMemExportToShareableHandle(&fd, handle.mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
// Root to send fd to all other processes
for (size_t i = 1; i < group.size(); ++i)
{
ncclIpcSocketSendFd(socket, fd, ranks[i], unique_op_id);
}
MPI_group_barrier(group);
}
else
{
MPI_group_barrier(group);
fd = ncclIpcSocketRecvFd(socket);
CUCHECK(cuMemImportFromShareableHandle(
&handle.mc_handle, (void*) (uintptr_t) fd, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
}
MPI_group_barrier(group);
close(fd);
// Add device to multicast object
CUdevice dev;
CUCHECK(cuDeviceGet(&dev, device_id));
CUCHECK(cuMulticastAddDevice(handle.mc_handle, dev));
// Create multicast VA
CUCHECK(cuMemAddressReserve(&handle.mc_va, size, mc_align, 0U, 0));
CUCHECK(cuMemMap(handle.mc_va, size, 0, handle.mc_handle, 0));
CUCHECK(cuMemSetAccess(handle.mc_va, size, &uc_mc_access, 1 /* count */));
// Allocate unicast VA
CUCHECK(cuMemCreate(&handle.uc_handle, size, &ucprop, 0));
CUCHECK(cuMemAddressReserve(&handle.uc_va, size, uc_align, 0U, 0));
CUCHECK(cuMemMap(handle.uc_va, size, 0, handle.uc_handle, 0));
// set access on UC address, for all GPUs so that UVA works
for (int gpu_id : group)
{
uc_mc_access.location.id = gpu_id;
CUCHECK(cuMemSetAccess(handle.uc_va, size, &uc_mc_access, 1 /* count */));
}
// Bind unicast memory to multicast group
CUCHECK(cuMulticastBindMem(handle.mc_handle, 0 /*mcOffset*/, handle.uc_handle, 0 /*memOffset*/, size, 0 /*flags*/));
handle.mc_ptr = reinterpret_cast<uintptr_t>((void*) handle.mc_va);
handle.uc_ptr = reinterpret_cast<uintptr_t>((void*) handle.uc_va);
printf("Rank %d nvlsAllocated %zu bytes successfully %p %p\n", rank, size, (void*) handle.uc_ptr,
(void*) handle.mc_ptr);
// Export to unicast VA to shareable handle
int fd_uc;
CUCHECK(cuMemExportToShareableHandle(
(void*) &fd_uc, handle.uc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
handle.ipc_uc_ptrs.resize(ranks.size());
handle.ipc_uc_vas.resize(ranks.size());
handle.ipc_uc_handles.resize(ranks.size());
// Allgather unicast shareable handles
std::vector<int> peer_fds_uc(ranks.size());
peer_fds_uc[group_rank] = fd_uc;
for (size_t i = 1; i < ranks.size(); ++i)
{
MPI_group_barrier(group);
int send_rank = (group_rank + i) % ranks.size();
int recv_rank = (group_rank + ranks.size() - i) % ranks.size();
ncclIpcSocketSendFd(socket, fd_uc, ranks[send_rank], unique_op_id);
peer_fds_uc[recv_rank] = ncclIpcSocketRecvFd(socket);
}
ncclIpcSocketClose(socket);
// Import unicast shareable handles
for (size_t i = 0; i < ranks.size(); ++i)
{
if (ranks[i] == rank)
{
handle.ipc_uc_ptrs[i] = handle.uc_ptr;
handle.ipc_uc_vas[i] = handle.uc_va;
handle.ipc_uc_handles[i] = handle.uc_handle;
}
else
{
CUCHECK(cuMemImportFromShareableHandle(&handle.ipc_uc_handles[i], (void*) (uintptr_t) peer_fds_uc[i],
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
CUCHECK(cuMemAddressReserve(&handle.ipc_uc_vas[i], size, uc_align, 0U, 0));
CUCHECK(cuMemMap(handle.ipc_uc_vas[i], size, 0, handle.ipc_uc_handles[i], 0));
// set access on UC address, for all GPUs so that UVA works
for (int gpu_id : group)
{
uc_mc_access.location.id = gpu_id;
CUCHECK(cuMemSetAccess(handle.ipc_uc_vas[i], size, &uc_mc_access, 1 /* count */));
}
handle.ipc_uc_ptrs[i] = reinterpret_cast<uintptr_t>((void*) handle.ipc_uc_vas[i]);
}
// close FD UC
close(peer_fds_uc[i]);
}
MPI_group_barrier(group);
printf("Rank %d imported IPC handles successfully\n", rank);
return new IpcNvlsHandle(std::move(handle));
#else
TLLM_THROW("ipcNvlsAllocate needs to be compiled with ENABLE_MULTI_DEVICE");
#endif
}
void ipcNvlsFree(IpcNvlsHandle* handle)
{
#if ENABLE_MULTI_DEVICE
if (handle == nullptr)
{
return;
}
// Unmap and release MC VA
CUCHECK(cuMemUnmap(handle->mc_va, handle->size));
CUCHECK(cuMemRelease(handle->mc_handle));
CUCHECK(cuMemAddressFree(handle->mc_va, handle->size));
// Unmap and release UC VA
for (size_t i = 0; i < handle->ipc_uc_vas.size(); ++i)
{
CUCHECK(cuMemUnmap(handle->ipc_uc_vas[i], handle->size));
CUCHECK(cuMemRelease(handle->ipc_uc_handles[i]));
CUCHECK(cuMemAddressFree(handle->ipc_uc_vas[i], handle->size));
}
delete handle;
#else
TLLM_THROW("ipcNvlsFree needs to be compiled with ENABLE_MULTI_DEVICE");
#endif
}
} // namespace tensorrt_llm::runtime

View File

@ -0,0 +1,537 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
#include "tensorrt_llm/runtime/ipcSocket.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#if ENABLE_NVSHMEM
#include <nvshmem/nvshmem.h>
#include <nvshmem/nvshmemx.h>
#endif
#if ENABLE_MULTI_DEVICE
#include <nvml.h>
#endif
#include <unistd.h>
#define CUCHECK(cmd) \
do \
{ \
CUresult retval = cmd; \
if (retval != CUDA_SUCCESS) \
{ \
const char* error_string; \
cuGetErrorString(retval, &error_string); \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define NVMLCHECK(cmd) \
do \
{ \
nvmlReturn_t retval = cmd; \
if (retval != NVML_SUCCESS) \
{ \
printf("Failed: NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(retval)); \
exit(EXIT_FAILURE); \
} \
} while (0)
// if n is already a multiple of "multiple", n is returned unchanged, otherwise round up to next multiple.
#define ROUND_UP(n, multiple) (((n + multiple - 1) / multiple) * multiple)
namespace tensorrt_llm::runtime
{
using namespace tensorrt_llm::mpi;
#if ENABLE_MULTI_DEVICE && !ENABLE_NVSHMEM
union IpcMemHandle
{
uint64_t fd;
CUmemFabricHandle fh;
};
class IpcCommunicator
{
public:
virtual ~IpcCommunicator() = default;
virtual void bcastMemHandle(IpcMemHandle* handle, int root) = 0;
};
class IpcSocketCommunicator : public IpcCommunicator
{
public:
IpcSocketCommunicator(int world_rank, int group_rank, std::vector<int> group_ranks, MPI_Comm group_comm)
: mGroupRank(group_rank)
, mGroupRanks(group_ranks)
, mGroupComm(group_comm)
{
timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
unsigned long seed = ts.tv_sec * 1000000000L + ts.tv_nsec;
srand(seed);
uint64_t unique_op_id = (uint64_t) (rand()) ^ ((uint64_t) (rand()) << 32);
MPI_Bcast(&unique_op_id, sizeof(unique_op_id), MPI_BYTE, 0, group_comm);
uint32_t volatile abort_flag = 0;
mSocket = ncclIpcSocketInit(world_rank, unique_op_id, &abort_flag);
MPI_Barrier(group_comm);
}
~IpcSocketCommunicator()
{
ncclIpcSocketClose(mSocket);
}
void bcastMemHandle(IpcMemHandle* handle, int root) override
{
if (mGroupRank == root)
{
for (size_t i = 0; i < mGroupRanks.size(); ++i)
{
if (i != root)
{
ncclIpcSocketSendFd(mSocket, handle->fd, mGroupRanks[i]);
}
}
MPI_Barrier(mGroupComm);
}
else
{
MPI_Barrier(mGroupComm);
handle->fd = ncclIpcSocketRecvFd(mSocket);
}
MPI_Barrier(mGroupComm);
}
private:
int mGroupRank;
std::vector<int> mGroupRanks;
MPI_Comm mGroupComm;
std::shared_ptr<NcclIpcSocket> mSocket;
};
class IpcFabricCommunicator : public IpcCommunicator
{
public:
IpcFabricCommunicator(MPI_Comm group_comm)
: mGroupComm(group_comm)
{
}
~IpcFabricCommunicator() = default;
void bcastMemHandle(IpcMemHandle* handle, int root) override
{
MPI_Bcast(handle, sizeof(CUmemFabricHandle), MPI_BYTE, root, mGroupComm);
}
private:
MPI_Comm mGroupComm;
};
class NVLSCudaAllocator
{
public:
static IpcNvlsHandle* allocate(size_t size, std::vector<int> ranks)
{
auto nvls_handle = new IpcNvlsHandle();
// Create a new communicator for the subset of ranks.
MPI_Group world_group, new_group;
MPI_Comm new_comm;
// Get the group of the world communicator.
MPI_Comm_group(COMM_SESSION, &world_group);
// Create a new group containing only the ranks we want.
MPI_Group_incl(world_group, ranks.size(), ranks.data(), &new_group);
// Create a new communicator from the group.
MPI_Comm_create_group(COMM_SESSION, new_group, 0, &new_comm);
// Get rank and group rank.
int world_rank;
int group_rank;
MPI_Comm_rank(COMM_SESSION, &world_rank);
MPI_Comm_rank(new_comm, &group_rank);
// Get runtime and driver device IDs.
int device_id;
int CU_dev;
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
CUCHECK(cuDeviceGet(&CU_dev, device_id));
// Get handle type used to share memory handles between devices.
auto handle_type = getMemHandleType();
// Define allocation access permissions (same for unicast and multicast).
CUmemAccessDesc access_desc;
memset(&access_desc, 0, sizeof(CUmemAccessDesc));
access_desc.location.id = device_id;
access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
// Define unicast allocation properties.
CUmemAllocationProp prop;
memset(&prop, 0, sizeof(CUmemAllocationProp));
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device_id;
prop.requestedHandleTypes = handle_type;
// Define multicast allocation properties.
CUmulticastObjectProp mcprop;
memset(&mcprop, 0, sizeof(CUmulticastObjectProp));
mcprop.numDevices = ranks.size();
mcprop.handleTypes = handle_type;
mcprop.flags = 0;
// Round up allocation size to the nearest multiple of the unicast allocation granularity.
size_t granularity = 0;
CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
size = ROUND_UP(size, granularity);
// Round up allocation size to the nearest multiple of the multicast allocation granularity.
size_t mc_granularity = 0;
CUCHECK(cuMulticastGetGranularity(&mc_granularity, &mcprop, CU_MULTICAST_GRANULARITY_MINIMUM));
size = ROUND_UP(size, mc_granularity);
mcprop.size = size;
nvls_handle->size = size;
// Allocate physical pages of memory on GPU.
CUCHECK(cuMemCreate(&nvls_handle->uc_handle, size, &prop, 0));
// Reserve unicast virtual address space for the memory.
CUCHECK(cuMemAddressReserve(&nvls_handle->uc_va, size, granularity, 0U, 0));
// Map the unicast virtual address space to the physical pages.
CUCHECK(cuMemMap(nvls_handle->uc_va, size, 0, nvls_handle->uc_handle, 0));
// Set the access permissions for the unicast memory.
CUCHECK(cuMemSetAccess(nvls_handle->uc_va, size, &access_desc, 1));
nvls_handle->uc_ptr = reinterpret_cast<uintptr_t>((void*) nvls_handle->uc_va);
// Setup communicator needed for multicast and unicast pointer exchange.
std::shared_ptr<IpcCommunicator> ipc_communicator;
if (handle_type == CU_MEM_HANDLE_TYPE_FABRIC)
{
ipc_communicator = std::make_shared<IpcFabricCommunicator>(new_comm);
}
else
{
ipc_communicator = std::make_shared<IpcSocketCommunicator>(world_rank, group_rank, ranks, new_comm);
}
// Unicast pointer exchange between ranks.
IpcMemHandle ipc_handle;
CUCHECK(cuMemExportToShareableHandle((void*) &ipc_handle, nvls_handle->uc_handle, handle_type, 0 /*flags*/));
nvls_handle->ipc_uc_ptrs.resize(ranks.size());
nvls_handle->ipc_uc_vas.resize(ranks.size());
nvls_handle->ipc_uc_handles.resize(ranks.size());
for (int i = 0; i < ranks.size(); i++)
{
IpcMemHandle peer_ipc_handle = ipc_handle;
ipc_communicator->bcastMemHandle(&peer_ipc_handle, i);
if (i != group_rank)
{
void* os_handle
= handle_type == CU_MEM_HANDLE_TYPE_FABRIC ? (void*) &peer_ipc_handle : (void*) peer_ipc_handle.fd;
CUCHECK(cuMemImportFromShareableHandle(&nvls_handle->ipc_uc_handles[i], os_handle, handle_type));
// Reserve peer unicast virtual address space for the memory.
CUCHECK(cuMemAddressReserve(&nvls_handle->ipc_uc_vas[i], size, granularity, 0U, 0));
// Map the peer unicast virtual address space to the physical pages.
CUCHECK(cuMemMap(nvls_handle->ipc_uc_vas[i], size, 0, nvls_handle->ipc_uc_handles[i], 0));
// Set the access permissions for the peer unicast memory.
CUCHECK(cuMemSetAccess(nvls_handle->ipc_uc_vas[i], size, &access_desc, 1));
nvls_handle->ipc_uc_ptrs[i] = reinterpret_cast<uintptr_t>((void*) nvls_handle->ipc_uc_vas[i]);
}
else
{
nvls_handle->ipc_uc_ptrs[i] = nvls_handle->uc_ptr;
nvls_handle->ipc_uc_vas[i] = nvls_handle->uc_va;
nvls_handle->ipc_uc_handles[i] = nvls_handle->uc_handle;
}
}
// Initialize multicast object for all ranks.
if (group_rank == 0)
{
CUCHECK(cuMulticastCreate(&nvls_handle->mc_handle, &mcprop));
// Export the allocation for the importing process.
CUCHECK(cuMemExportToShareableHandle(&ipc_handle, nvls_handle->mc_handle, handle_type, 0 /*flags*/));
ipc_communicator->bcastMemHandle(&ipc_handle, 0);
}
else
{
ipc_communicator->bcastMemHandle(&ipc_handle, 0);
void* os_handle = handle_type == CU_MEM_HANDLE_TYPE_FABRIC ? (void*) &ipc_handle : (void*) ipc_handle.fd;
CUCHECK(cuMemImportFromShareableHandle(&nvls_handle->mc_handle, os_handle, handle_type));
}
// Add device to multicast object
CUCHECK(cuMulticastAddDevice(nvls_handle->mc_handle, CU_dev));
// Bind physical memory to the Multicast group.
// Note: It will block until all ranks have been added to the group.
CUCHECK(cuMulticastBindMem(nvls_handle->mc_handle, 0, nvls_handle->uc_handle, 0, size, 0));
// Reserve multicast virtual address space for the memory.
CUCHECK(cuMemAddressReserve(&nvls_handle->mc_va, size, mc_granularity, 0U, 0));
// Map the multicast virtual address space to the physical pages.
CUCHECK(cuMemMap(nvls_handle->mc_va, size, 0, nvls_handle->mc_handle, 0));
// Set the access permissions for the multicast memory.
CUCHECK(cuMemSetAccess(nvls_handle->mc_va, size, &access_desc, 1 /* count */));
nvls_handle->mc_ptr = reinterpret_cast<uintptr_t>((void*) nvls_handle->mc_va);
// Clean up
MPI_Group_free(&new_group);
MPI_Group_free(&world_group);
return nvls_handle;
}
static void free(IpcNvlsHandle* nvls_handle)
{
CUCHECK(cuMemUnmap(nvls_handle->mc_va, nvls_handle->size));
CUCHECK(cuMemRelease(nvls_handle->mc_handle));
CUCHECK(cuMemAddressFree(nvls_handle->mc_va, nvls_handle->size));
for (size_t i = 0; i < nvls_handle->ipc_uc_vas.size(); ++i)
{
CUCHECK(cuMemUnmap(nvls_handle->ipc_uc_vas[i], nvls_handle->size));
CUCHECK(cuMemRelease(nvls_handle->ipc_uc_handles[i]));
CUCHECK(cuMemAddressFree(nvls_handle->ipc_uc_vas[i], nvls_handle->size));
}
}
private:
static CUmemAllocationHandleType getMemHandleType()
{
int device_id;
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
// Check if fabric handle support is available.
int fabric_supported = 0;
CUCHECK(cuDeviceGetAttribute(&fabric_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device_id));
if (!fabric_supported)
{
TLLM_LOG_TRACE(
"checking fabric support... CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED not supported.");
return CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
}
nvmlDevice_t nvml_device;
nvmlGpuFabricInfo_t fabric_info;
NVMLCHECK(nvmlInit_v2());
NVMLCHECK(nvmlDeviceGetHandleByIndex(device_id, &nvml_device));
NVMLCHECK(nvmlDeviceGetGpuFabricInfo(nvml_device, &fabric_info));
NVMLCHECK(nvmlShutdown());
// Check if the fabric is fully initialized.
if (fabric_info.state != NVML_GPU_FABRIC_STATE_COMPLETED || fabric_info.status != NVML_SUCCESS)
{
TLLM_LOG_TRACE("checking fabric support... fabric state is NOT COMPLETE: state=%u status=%u.",
fabric_info.state, fabric_info.status);
return CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
}
// Check that fabric handles can be created.
CUmemAllocationProp prop;
memset(&prop, 0, sizeof(CUmemAllocationProp));
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = device_id;
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
size_t alloc_size = 1024; // anything > 0
size_t min_gran = 0;
CUCHECK(cuMemGetAllocationGranularity(&min_gran, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
alloc_size = ROUND_UP(alloc_size, min_gran);
CUmemGenericAllocationHandle handle;
CUresult err = cuMemCreate(&handle, alloc_size, &prop, 0);
if (err == CUDA_ERROR_NOT_PERMITTED || err == CUDA_ERROR_NOT_SUPPORTED)
{
TLLM_LOG_TRACE("checking fabric support... cuMemCreate failed with not %s.",
err == CUDA_ERROR_NOT_PERMITTED ? "permitted" : "supported");
return CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
}
else
{
CUCHECK(err);
}
// Check if fabric handles can be exported & imported by IMEX (Internode Memory Exchange)
CUmemFabricHandle fh;
err = cuMemExportToShareableHandle(&fh, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0);
if (err != CUDA_SUCCESS
|| (err = cuMemImportFromShareableHandle(&handle, &fh, CU_MEM_HANDLE_TYPE_FABRIC)) != CUDA_SUCCESS)
{
TLLM_LOG_TRACE("checking fabric support... cuMemExport/cuMemImport failed.");
CUCHECK(cuMemRelease(handle));
return CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
}
TLLM_LOG_TRACE("fabric status: state=%u status=%u clique=%u", device_id, fabric_info.state, fabric_info.status,
fabric_info.cliqueId);
CUCHECK(cuMemRelease(handle));
// If we get here, fabric handles are supported.
return CU_MEM_HANDLE_TYPE_FABRIC;
}
};
#endif
/**
* @brief MPI_Barrier when subset of ranks present
*/
void MPI_group_barrier(std::set<int> group)
{
#if ENABLE_MULTI_DEVICE
// Create a new communicator for the subset of ranks
MPI_Group world_group, new_group;
MPI_Comm new_comm;
// Get the group of the world communicator
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
// Create a new group containing only the ranks we want
std::vector<int> ranks(group.begin(), group.end());
MPI_Group_incl(world_group, ranks.size(), ranks.data(), &new_group);
// Create a new communicator from the group
MPI_Comm_create_group(MPI_COMM_WORLD, new_group, 0, &new_comm);
// Use the new communicator for the barrier
MPI_Barrier(new_comm);
// Clean up
MPI_Group_free(&new_group);
MPI_Group_free(&world_group);
MPI_Comm_free(&new_comm);
#else
TLLM_THROW("MPI_group_barrier needs to be compiled with ENABLE_MULTI_DEVICE");
#endif
}
bool ipcNvlsSupported()
{
#if ENABLE_MULTI_DEVICE
CUdevice current_dev;
int cuda_dev = -1;
int cuda_driver_version = -1;
int dev_count = 0;
TLLM_CUDA_CHECK(cudaDriverGetVersion(&cuda_driver_version));
if (cuda_driver_version < 12010)
{
TLLM_LOG_ERROR("CUDA Driver version < 12010");
return false;
}
TLLM_CUDA_CHECK(cudaGetDeviceCount(&dev_count));
for (int i = 0; i < dev_count; ++i)
{
TLLM_CUDA_CHECK(cudaGetDevice(&cuda_dev));
CUCHECK(cuDeviceGet(&current_dev, cuda_dev));
int multicast_supported = 0;
CUCHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, current_dev));
if (!multicast_supported)
{
TLLM_LOG_ERROR("CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED not supported on GPU%d.", cuda_dev);
return false;
}
}
return true;
#else
return false;
#endif
}
IpcNvlsHandle* ipcNvlsAllocate(size_t size, std::set<int> group)
{
#if ENABLE_MULTI_DEVICE
TLLM_CHECK_WITH_INFO(ipcNvlsSupported(), "Switch multicast is not supported on this system.");
TLLM_CHECK(size > 0);
TLLM_CHECK(group.size() >= 2);
std::vector<int> ranks(group.begin(), group.end());
int group_size = ranks.size();
MPI_Comm mpi_comm = COMM_SESSION;
// Create a new communicator with only the ranks in the group
MPI_Group world_group, new_group;
MPI_Comm_group(mpi_comm, &world_group);
MPI_Group_incl(world_group, group_size, ranks.data(), &new_group);
MPI_Comm new_comm;
MPI_Comm_create_group(mpi_comm, new_group, 0, &new_comm);
#if ENABLE_NVSHMEM
// Initialize NVSHMEM with the new communicator
nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER;
attr.mpi_comm = &new_comm;
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr);
// Allocate NVSHMEM memory
void* ptr = nvshmem_malloc(size);
// Create handle to return
auto handle = new IpcNvlsHandle();
handle->size = size;
handle->uc_ptr = reinterpret_cast<uintptr_t>(ptr);
handle->mc_ptr = reinterpret_cast<uintptr_t>(nvshmemx_mc_ptr(NVSHMEM_TEAM_WORLD, ptr));
for (int i = 0; i < ranks.size(); i++)
{
handle->ipc_uc_ptrs.push_back(reinterpret_cast<uintptr_t>(nvshmem_ptr(ptr, i)));
}
#else // !ENABLE_NVSHMEM
auto handle = NVLSCudaAllocator::allocate(size, ranks);
#endif
TLLM_LOG_INFO("Rank %d NVLS allocate %zu bytes, uc_ptr:%p mc_ptr:%p", COMM_SESSION.getRank(), size,
(void*) handle->uc_ptr, (void*) handle->mc_ptr);
// Cleanup
MPI_Group_free(&new_group);
MPI_Group_free(&world_group);
MPI_Barrier(new_comm);
return handle;
#else
TLLM_THROW("ipcNvlsAllocate needs to be compiled with ENABLE_MULTI_DEVICE");
#endif
}
void ipcNvlsFree(IpcNvlsHandle* handle)
{
#if ENABLE_MULTI_DEVICE
if (handle == nullptr)
{
return;
}
#if ENABLE_NVSHMEM
nvshmem_free((void*) handle->uc_ptr);
#else
NVLSCudaAllocator::free(handle);
#endif
delete handle;
#else
TLLM_THROW("ipcNvlsFree needs to be compiled with ENABLE_MULTI_DEVICE");
#endif
}
} // namespace tensorrt_llm::runtime

View File

@ -72,6 +72,7 @@ struct NcclIpcSocket
int fd;
char socketName[NCCL_IPC_SOCKNAME_LEN];
uint32_t volatile* abortFlag;
uint64_t hash;
};
/*
@ -89,6 +90,7 @@ std::shared_ptr<NcclIpcSocket> ncclIpcSocketInit(int rank, uint64_t hash, uint32
TLLM_NCCL_CHECK(ncclInternalError); // throws
}
handle->hash = hash;
handle->fd = -1;
handle->socketName[0] = '\0';
if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0)
@ -310,9 +312,9 @@ void ncclIpcSocketSendMsg(
}
}
void ncclIpcSocketSendFd(std::shared_ptr<NcclIpcSocket> handle, int sendFd, int rank, uint64_t hash)
void ncclIpcSocketSendFd(std::shared_ptr<NcclIpcSocket> handle, int sendFd, int rank)
{
ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash);
ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, handle->hash);
}
} // namespace tensorrt_llm::runtime

View File

@ -33,7 +33,7 @@ void ncclIpcSocketClose(std::shared_ptr<NcclIpcSocket> handle);
int ncclIpcSocketRecvFd(std::shared_ptr<NcclIpcSocket> handle);
void ncclIpcSocketSendFd(std::shared_ptr<NcclIpcSocket> handle, int fd, int rank, uint64_t hash);
void ncclIpcSocketSendFd(std::shared_ptr<NcclIpcSocket> handle, int fd, int rank);
} // namespace tensorrt_llm::runtime

View File

@ -803,7 +803,6 @@ public:
if (mCapacity < newSize)
{
release();
printf("MulticastBuffer resize: %d B\n", int(toBytes(newSize)));
mHandle = ipcNvlsAllocate(toBytes(newSize), mRanks);
TLLM_CHECK(mHandle->size % BufferDataType(mType).getSize() == 0);

View File

@ -45,10 +45,12 @@ add_gtest(mlaChunkedPrefillTest mlaChunkedPrefillTest.cu)
if(NOT ENABLE_MULTI_DEVICE EQUAL 0)
add_gtest(allReduceKernelTest allReduce/allReduceKernelTest.cu)
add_gtest(allReduceFusionTest allReduce/allReduceFusionTest.cu)
# add_gtest(gemmAllReduceTest allReduce/gemmAllReduceTest.cu)
# if(USING_OSS_CUTLASS_ALLREDUCE_GEMM) target_link_libraries(gemmAllReduceTest
# PRIVATE ar_gemm_src) target_compile_definitions(gemmAllReduceTest PRIVATE
# USING_OSS_CUTLASS_ALLREDUCE_GEMM) endif()
add_gtest(gemmAllReduceTest allReduce/gemmAllReduceTest.cu)
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
target_link_libraries(gemmAllReduceTest PRIVATE ar_gemm_src)
target_compile_definitions(gemmAllReduceTest
PRIVATE USING_OSS_CUTLASS_ALLREDUCE_GEMM)
endif()
endif()
add_gtest(

View File

@ -914,10 +914,14 @@ int main(int argc, char** argv)
notSupported = true;
}
TLLM_CUDA_CHECK(cudaSetDevice(COMM_SESSION.getRank()));
int device_count;
TLLM_CUDA_CHECK(cudaGetDeviceCount(&device_count));
int device_id = COMM_SESSION.getRank() % device_count;
TLLM_CUDA_CHECK(cudaSetDevice(device_id));
cudaDeviceProp props;
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&props, COMM_SESSION.getRank()));
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&props, device_id));
if (props.major < 9)
{

View File

@ -2058,6 +2058,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
multiGpuJobs = parallelJobs.findAll{(it.key.contains("4_GPUs") || it.key.contains("8_GPUs")) && !it.key.contains("Post-Merge")}
println multiGpuJobs.keySet()
multiGpuJobsPostMerge = parallelJobs.findAll{(it.key.contains("4_GPUs") || it.key.contains("8_GPUs")) && it.key.contains("Post-Merge")}
parallelJobs += docBuildJobs
parallelJobs += sanityCheckJobs
@ -2105,7 +2106,11 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
// Check --only-multi-gpu-test, if true, only run multi-GPU test stages.
if (testFilter[(ONLY_MULTI_GPU_TEST)]) {
parallelJobsFiltered = multiGpuJobs
if (testFilter[(IS_POST_MERGE)]) {
parallelJobsFiltered = multiGpuJobsPostMerge
} else {
parallelJobsFiltered = multiGpuJobs
}
}
// Check --disable-multi-gpu-test, if true, remove multi-GPU test stages.

View File

@ -681,7 +681,28 @@ def main(*,
new_library_path = "/usr/local/cuda/compat:/usr/local/cuda/compat/lib:/usr/local/cuda/compat/lib.real"
if 'LD_LIBRARY_PATH' in env_ld:
new_library_path += f":{env_ld['LD_LIBRARY_PATH']}"
result = build_run("find /usr -name *libnvidia-ml.so*",
capture_output=True,
text=True)
assert result.returncode == 0, f"Failed to run find *libnvidia-ml.so*: {result.stderr}"
# Build containers only contain stub version of libnvidia-ml.so and not the real version.
# If real version not in system, we need to create symbolic link to stub version to prevent import errors.
if "libnvidia-ml.so.1" not in result.stdout:
if "libnvidia-ml.so" in result.stdout:
line = result.stdout.splitlines()[0]
path = os.path.dirname(line)
new_library_path += f":{path}"
build_run(f"ln -s {line} {path}/libnvidia-ml.so.1")
else:
print(
f"Failed to find libnvidia-ml.so: {result.stderr}",
file=sys.stderr)
exit(1)
env_ld["LD_LIBRARY_PATH"] = new_library_path
build_run(
f"\"{venv_python}\" -m pybind11_stubgen -o . bindings --exit-code",
env=env_ld)

View File

@ -765,8 +765,7 @@ class TestLlama3_1_8B(CliFlowAccuracyTestHarness):
tp_size=4,
extra_build_args=extra_build_args)
@skip_pre_ada
@skip_post_blackwell
@skip_pre_hopper
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize(
"gemm_allreduce", [False, pytest.param(True, marks=skip_no_nvls)],

View File

@ -7,6 +7,7 @@ from typing import List, Optional
import defs.cpp.cpp_common as _cpp
import pytest
from defs.conftest import skip_no_nvls
# Helper filter for disagg google tests
@ -65,6 +66,28 @@ def run_mpi_utils_tests(build_dir, timeout=300):
timeout=timeout)
def run_gemm_allreduce_tests(build_dir, nprocs, timeout=300):
tests_dir = build_dir / "tests"
mgpu_env = get_multi_gpu_env()
gemm_allreduce_test = [
"mpirun",
"-n",
f"{nprocs}",
"--allow-run-as-root",
"unit_tests/kernels/gemmAllReduceTest",
"--m=2032",
"--n=8200",
"--k=1024",
"--iterations=1",
]
_cpp.run_command(gemm_allreduce_test,
cwd=tests_dir,
env=mgpu_env,
timeout=timeout)
def run_cache_transceiver_tests(build_dir: _pl.Path,
nprocs=2,
kv_cache_type=KVCacheType.MPI,
@ -451,6 +474,15 @@ def test_mpi_utils(build_google_tests, build_dir):
run_mpi_utils_tests(build_dir, timeout=300)
@skip_no_nvls
@pytest.mark.parametrize("build_google_tests", ["90", "100"], indirect=True)
@pytest.mark.parametrize("nprocs", [2, 4], ids=["2proc", "4proc"])
def test_fused_gemm_allreduce(build_google_tests, nprocs, build_dir):
if platform.system() != "Windows":
run_gemm_allreduce_tests(build_dir, nprocs, timeout=300)
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
indirect=True)
@pytest.mark.parametrize("kvcache_type", [KVCacheType.MPI, KVCacheType.UCX],

View File

@ -3368,8 +3368,6 @@ def test_llm_llama_v3_2_smoothquant_1node_single_gpu(
venv_check_call(llm_venv, summary_cmd)
# TODO: remove skip after support fp8 rowwise gemm on B200
@skip_post_blackwell
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("fp8_quant",

View File

@ -118,6 +118,7 @@ l0_dgx_h100:
tests:
# ------------- CPP tests ---------------
- cpp/test_multi_gpu.py::test_mpi_utils[90]
- cpp/test_multi_gpu.py::test_fused_gemm_allreduce[4proc-90]
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-mpi_kvcache-90]
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-ucx_kvcache-90]
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mpi_kvcache-90]

View File

@ -335,7 +335,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5247271)
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin] SKIP (https://nvbugs/5247786)
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
@ -390,7 +389,6 @@ full:B200/examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it
full:B200/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5295470)
examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization_long] SKIP (https://nvbugs/5324976)
examples/test_prompt_lookup.py::test_llm_prompt_lookup_1gpu[no_streaming-gpt2-use_cpp_session-use_tokens-max_matching_ngram_size_2-prompt_lookup_num_tokens_8-float16-bs1] SKIP (https://nvbugs/5344070)
examples/test_llama.py::test_llm_llama_v3_1_1node_multi_gpus[enable_gemm_allreduce_plugin-llama-3.1-70b-disable_fp8] SKIP (https://nvbugs/5343850)
examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_py_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (https://nvbugs/5333849)
examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5333818)
examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5333818)