mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Fix GEMM+AR fusion on blackwell (#5563)
Signed-off-by: xsimmons <xsimmons@nvidia.com>
This commit is contained in:
parent
a79b73f577
commit
b6013da198
@ -45,6 +45,7 @@ option(ENABLE_MULTI_DEVICE
|
||||
option(ENABLE_UCX "Enable building with UCX (Uniform Communication X) support"
|
||||
ON)
|
||||
option(NVRTC_DYNAMIC_LINKING "Link against the dynamic NVRTC libraries" OFF)
|
||||
option(ENABLE_NVSHMEM "Enable building with NVSHMEM support" OFF)
|
||||
option(USING_OSS_CUTLASS_LOW_LATENCY_GEMM
|
||||
"Using open sourced Cutlass low latency gemm kernel" ON)
|
||||
option(USING_OSS_CUTLASS_FP4_GEMM "Using open sourced Cutlass fp4 gemm kernel"
|
||||
@ -54,6 +55,8 @@ option(USING_OSS_CUTLASS_MOE_GEMM "Using open sourced Cutlass moe gemm kernel"
|
||||
option(USING_OSS_CUTLASS_ALLREDUCE_GEMM
|
||||
"Using open sourced Cutlass AR gemm kernel" ON)
|
||||
|
||||
message(STATUS "ENABLE_NVSHMEM is ${ENABLE_NVSHMEM}")
|
||||
|
||||
if(NVTX_DISABLE)
|
||||
add_compile_definitions("NVTX_DISABLE")
|
||||
message(STATUS "NVTX is disabled")
|
||||
@ -171,6 +174,7 @@ message(STATUS "CUDA library status:")
|
||||
message(STATUS " version: ${CUDAToolkit_VERSION}")
|
||||
message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}")
|
||||
message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}")
|
||||
message(STATUS "CUDA_NVML_LIB: ${CUDA_NVML_LIB}")
|
||||
|
||||
# Prevent CMake from creating a response file for CUDA compiler, so clangd can
|
||||
# pick up on the includes
|
||||
@ -262,9 +266,21 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_SYSTEM=cmake_oss ")
|
||||
# note: cmake expr generation $<BOOL:${ENABLE_MULTI_DEVICE}> is a build time
|
||||
# evaluation so hard to debug at cmake time
|
||||
if(ENABLE_MULTI_DEVICE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_MULTI_DEVICE=1")
|
||||
# Add target definitions for both C++ and CUDA
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_MULTI_DEVICE=1>
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_MULTI_DEVICE=1>)
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_MULTI_DEVICE=0")
|
||||
# Add target definitions for both C++ and CUDA
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_MULTI_DEVICE=0>
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_MULTI_DEVICE=0>)
|
||||
endif()
|
||||
|
||||
if(ENABLE_NVSHMEM)
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_NVSHMEM=1>
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=1>)
|
||||
else()
|
||||
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:ENABLE_NVSHMEM=0>
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:ENABLE_NVSHMEM=0>)
|
||||
endif()
|
||||
|
||||
# Fix linking issue with TRT 10, the detailed description about `--mcmodel` can
|
||||
|
||||
@ -72,6 +72,12 @@ if(ENABLE_MULTI_DEVICE)
|
||||
include_directories(${MPI_C_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
if(ENABLE_NVSHMEM)
|
||||
# Add hints for aarch64
|
||||
find_package(NVSHMEM REQUIRED HINTS /usr/lib/sbsa-linux-gnu/cmake/nvshmem/)
|
||||
include_directories(/usr/include/nvshmem/)
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
set(DECODER_SHARED_TARGET_0 decoder_attention_0)
|
||||
set(DECODER_SHARED_TARGET_1 decoder_attention_1)
|
||||
@ -231,7 +237,10 @@ if(ENABLE_MULTI_DEVICE)
|
||||
set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} ${MPI_C_LIBRARIES} ${NCCL_LIB})
|
||||
endif()
|
||||
|
||||
message("TRTLLM_LINK_LIBS: ${TRTLLM_LINK_LIBS}")
|
||||
if(ENABLE_NVSHMEM)
|
||||
set(TRTLLM_LINK_LIBS ${TRTLLM_LINK_LIBS} nvshmem::nvshmem_host
|
||||
nvshmem::nvshmem_device)
|
||||
endif()
|
||||
|
||||
if(NOT WIN32) # Unix-like compilers
|
||||
set(UNDEFINED_FLAG "-Wl,--no-undefined")
|
||||
|
||||
@ -332,12 +332,12 @@ enum class ClusterShape
|
||||
ClusterShape_1x2x1,
|
||||
ClusterShape_2x2x1,
|
||||
ClusterShape_1x4x1,
|
||||
ClusterShape_4x1x1,
|
||||
ClusterShape_4x2x1,
|
||||
ClusterShape_2x4x1,
|
||||
ClusterShape_4x4x1,
|
||||
ClusterShape_1x8x1,
|
||||
ClusterShape_8x1x1,
|
||||
ClusterShape_4x1x1
|
||||
ClusterShape_8x1x1
|
||||
};
|
||||
|
||||
static auto get_cluster_shape_name(ClusterShape Shape_MNK)
|
||||
|
||||
@ -22,6 +22,8 @@
|
||||
|
||||
#include "cutlass/barrier.h"
|
||||
|
||||
#include <cuda/atomic>
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
|
||||
@ -43,7 +45,7 @@ __forceinline__ __device__ uint32_t atomicCAS_system_acq(uint32_t* p, uint32_t c
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class Sync, bool SafeBetweenPhases, bool UseMembarGPU>
|
||||
template <class Sync, bool SafeBetweenPhases>
|
||||
struct MulticastSystemBarrier : public GenericBarrier<Sync>
|
||||
{
|
||||
|
||||
@ -57,8 +59,8 @@ struct MulticastSystemBarrier : public GenericBarrier<Sync>
|
||||
|
||||
protected:
|
||||
/// Reduce into flag, with release pattern (int specialization)
|
||||
CUTLASS_DEVICE
|
||||
static void red_release(T* mc_ptr, int val)
|
||||
template <cuda::thread_scope Scope>
|
||||
CUTLASS_DEVICE static void red_release(T* mc_ptr, int val)
|
||||
{
|
||||
#if defined(CUTE_ARCH_MULTIMEM_SM90_ENABLED)
|
||||
// atomic reduction to all replicas
|
||||
@ -66,14 +68,18 @@ protected:
|
||||
// See
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-multimem-ld-reduce-multimem-st-multimem-red
|
||||
// for multimem PTX doc
|
||||
if constexpr (UseMembarGPU)
|
||||
if constexpr (Scope == cuda::thread_scope::thread_scope_device)
|
||||
{
|
||||
asm volatile("multimem.red.release.gpu.global.add.u32 [%0], %1;" ::"l"(mc_ptr), "r"(val) : "memory");
|
||||
}
|
||||
else
|
||||
else if constexpr (Scope == cuda::thread_scope::thread_scope_system)
|
||||
{
|
||||
asm volatile("multimem.red.release.sys.global.add.u32 [%0], %1;" ::"l"(mc_ptr), "r"(val) : "memory");
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTE_INVALID_CONTROL_PATH("Invalid thread scope for MulticastSystemBarrier.");
|
||||
}
|
||||
|
||||
// Need a fence between MC and UC access to the same memory:
|
||||
// - fence.proxy instructions establish an ordering between memory accesses that may happen through different
|
||||
@ -128,8 +134,8 @@ public:
|
||||
Sync::sync();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static T arrive_inc_get(T* mc_ptr, T* uc_ptr, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
template <cuda::thread_scope Scope>
|
||||
CUTLASS_DEVICE static T arrive_inc_get(T* mc_ptr, T* uc_ptr, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
{
|
||||
T* mc_barrier_ptr = mc_ptr + flag_idx;
|
||||
T* uc_barrier_ptr = uc_ptr + flag_idx;
|
||||
@ -156,13 +162,13 @@ public:
|
||||
// can be immediately reused.
|
||||
bool master = rank == 0;
|
||||
int val = master ? 0x80000000 - (world_size - 1) : 1;
|
||||
red_release(mc_barrier_ptr, val);
|
||||
red_release<Scope>(mc_barrier_ptr, val);
|
||||
}
|
||||
return old_arrive;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_inc(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
template <cuda::thread_scope Scope = cuda::thread_scope::thread_scope_system>
|
||||
CUTLASS_DEVICE static void arrive_inc(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
{
|
||||
T* mc_barrier = params.mc_barrier_ptr + flag_idx;
|
||||
|
||||
@ -170,23 +176,24 @@ public:
|
||||
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
red_release(mc_barrier, 1);
|
||||
red_release<Scope>(mc_barrier, 1);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void arrive_and_wait(Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
template <cuda::thread_scope Scope = cuda::thread_scope::thread_scope_system>
|
||||
CUTLASS_DEVICE static void arrive_and_wait(
|
||||
Params const& params, int thread_idx, int flag_idx, int rank, int world_size)
|
||||
{
|
||||
auto mc_ptr = params.mc_barrier_ptr;
|
||||
auto uc_ptr = params.uc_barrier_ptr;
|
||||
if constexpr (SafeBetweenPhases)
|
||||
{
|
||||
auto old_arrive = arrive_inc_get(mc_ptr, uc_ptr, thread_idx, flag_idx, rank, world_size);
|
||||
auto old_arrive = arrive_inc_get<Scope>(mc_ptr, uc_ptr, thread_idx, flag_idx, rank, world_size);
|
||||
wait(old_arrive, uc_ptr, thread_idx, flag_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
arrive_inc(params, thread_idx, flag_idx, rank, world_size);
|
||||
arrive_inc<Scope>(params, thread_idx, flag_idx, rank, world_size);
|
||||
wait_eq_reset(uc_ptr, thread_idx, flag_idx, world_size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,11 +249,14 @@ endif()
|
||||
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
|
||||
add_library(
|
||||
ar_gemm_src STATIC
|
||||
${ARGEMM_SRC_CU}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../runtime/ipcNvlsMemory.cpp)
|
||||
${ARGEMM_SRC_CU} ${CMAKE_CURRENT_SOURCE_DIR}/../../runtime/ipcNvlsMemory.cu)
|
||||
target_include_directories(
|
||||
ar_gemm_src
|
||||
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../internal_cutlass_kernels/include)
|
||||
if(ENABLE_NVSHMEM)
|
||||
target_link_libraries(ar_gemm_src PRIVATE nvshmem::nvshmem_host
|
||||
nvshmem::nvshmem_device)
|
||||
endif()
|
||||
set_cuda_architectures(ar_gemm_src 90 100f)
|
||||
endif()
|
||||
|
||||
|
||||
@ -138,7 +138,7 @@ public:
|
||||
// Epilogue
|
||||
////////////////
|
||||
using FusionCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD, float, void, float>;
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true, true>;
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true>;
|
||||
using EpilogueScheduleType = typename MmaAdapter<MmaType, IsFP4>::EpilogueSchedule;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
using FusionOp
|
||||
|
||||
@ -100,8 +100,7 @@ public:
|
||||
using RasterOrderOptions =
|
||||
typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions;
|
||||
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true /* Safe across phases */,
|
||||
true /* membar.gpu */>;
|
||||
using TileBarrierType = cutlass::MulticastSystemBarrier<cutlass::detail::SyncNoOp, true /* Safe across phases */>;
|
||||
|
||||
// 16B alignment for TMA
|
||||
static constexpr int AlignmentA = 16 / sizeof(ElementA);
|
||||
|
||||
@ -201,7 +201,7 @@ public:
|
||||
auto [M, N, K, L] = problem_shape;
|
||||
auto [m, n, k, l] = tile_coord;
|
||||
|
||||
if (!tile_valid(m, n) || params_ptr->world_size == 1)
|
||||
if (!tile_valid(m, n) || params_ptr->world_size <= 2)
|
||||
{
|
||||
return; // nothing to do
|
||||
}
|
||||
@ -212,7 +212,7 @@ public:
|
||||
|
||||
// Wait for all multicast writes to be visible to us.
|
||||
// This is safe between phases.
|
||||
SystemBarrier::arrive_and_wait(
|
||||
SystemBarrier::arrive_and_wait<cuda::thread_scope::thread_scope_system>(
|
||||
params_ptr->barrier_params_final_sync, thread_idx, tile_index, params_ptr->rank, params_ptr->world_size);
|
||||
}
|
||||
|
||||
@ -297,13 +297,20 @@ public:
|
||||
Tensor tGR_gD1_vec = zipped_divide(tGR_gD1(_, _, _, red_m, red_n), Vec);
|
||||
Tensor tRG_gOut_vec = zipped_divide(tRG_gOut(_, _, _, red_m, red_n), Vec);
|
||||
|
||||
auto pred_fn
|
||||
= [&](auto const&... coords) { return elem_less(tGR_pD_vec(_0{}, coords...), problem_shape); };
|
||||
// Create predicate tensor for bounds checking
|
||||
Tensor pred_tensor = make_tensor<bool>(make_shape(size(tGR_pD_vec)), Stride<_1>{});
|
||||
|
||||
// Set predicate values based on coordinate bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred_tensor); ++i)
|
||||
{
|
||||
pred_tensor(i) = elem_less(tGR_pD_vec(_0{}, i), problem_shape);
|
||||
}
|
||||
|
||||
// Read from self.
|
||||
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD0_vec, tGR_rD0_vec);
|
||||
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD0_vec, tGR_rD0_vec);
|
||||
// Read from remote.
|
||||
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD1_vec, tGR_rD1_vec);
|
||||
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD1_vec, tGR_rD1_vec);
|
||||
// Reduce
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(tGR_rD0_vec); i++)
|
||||
@ -311,7 +318,7 @@ public:
|
||||
tGR_rD0_vec(i) += tGR_rD1_vec(i);
|
||||
}
|
||||
// store to self.
|
||||
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_rD0_vec, tRG_gOut_vec);
|
||||
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_rD0_vec, tRG_gOut_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -386,13 +393,21 @@ public:
|
||||
Tensor tGR_gD_vec = zipped_divide(tGR_gD(_, _, _, red_m, red_n), Vec);
|
||||
Tensor tRG_gD_vec = zipped_divide(tRG_gD(_, _, _, red_m, red_n), Vec);
|
||||
Tensor tGR_pD_vec = zipped_divide(tGR_pD(_, _, _, red_m, red_n), Vec);
|
||||
// problem shape bounds check
|
||||
auto pred_fn
|
||||
= [&](auto const&... coords) { return elem_less(tGR_pD_vec(_0{}, coords...), problem_shape); };
|
||||
|
||||
// Create predicate tensor for bounds checking
|
||||
Tensor pred_tensor = make_tensor<bool>(make_shape(size(tGR_gD_vec)), Stride<_1>{});
|
||||
|
||||
// Set predicate values based on coordinate bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred_tensor); ++i)
|
||||
{
|
||||
pred_tensor(i) = elem_less(tGR_pD_vec(_0{}, i), problem_shape);
|
||||
}
|
||||
|
||||
// load-reduce in switch
|
||||
cute::copy_if(CopyAtomG2R{}, pred_fn, tGR_gD_vec, tGR_rD_vec);
|
||||
cute::copy_if(CopyAtomG2R{}, pred_tensor, tGR_gD_vec, tGR_rD_vec);
|
||||
// store switch multicast
|
||||
cute::copy_if(CopyAtomR2G{}, pred_fn, tGR_rD_vec, tRG_gD_vec);
|
||||
cute::copy_if(CopyAtomR2G{}, pred_tensor, tGR_rD_vec, tRG_gD_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,7 +171,7 @@ struct Sm100AllReduceArrive
|
||||
tma_store_wait<0>();
|
||||
|
||||
int tile_idx = params_ptr->tile_layout(m, n);
|
||||
SystemBarrier::arrive_inc(
|
||||
SystemBarrier::arrive_inc<cuda::thread_scope::thread_scope_device>(
|
||||
params_ptr->barrier_params, thread_idx, tile_idx, params_ptr->rank, params_ptr->world_size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -268,7 +268,7 @@ struct Sm90AuxAllReduce
|
||||
tma_store_wait<0>();
|
||||
|
||||
int tile_idx = params_ptr->tile_layout(m, n);
|
||||
SystemBarrier::arrive_inc(
|
||||
SystemBarrier::arrive_inc<cuda::thread_scope::thread_scope_device>(
|
||||
params_ptr->barrier_params, thread_idx, tile_idx, params_ptr->rank, params_ptr->world_size);
|
||||
}
|
||||
};
|
||||
|
||||
@ -969,8 +969,8 @@ public:
|
||||
{
|
||||
bool do_tail_store = false;
|
||||
|
||||
constexpr auto ARBarrier = (uint32_t) cutlass::arch::ReservedNamedBarriers::FirstUserBarrier;
|
||||
CollectiveAllReduce collective_allreduce(params.allreduce, ARBarrier);
|
||||
const uint32_t AR_barrier_id = 0;
|
||||
CollectiveAllReduce collective_allreduce(params.allreduce, AR_barrier_id);
|
||||
int thread_idx = threadIdx.x - (MaxThreadsPerBlock - NumARThreads);
|
||||
auto init_cta_coord_mnkl = cta_coord_mnkl;
|
||||
|
||||
|
||||
@ -31,6 +31,10 @@
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/workspace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/arch/grid_dependency_control.h"
|
||||
|
||||
#include "gemm_universal_allreduce.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
@ -42,6 +46,8 @@
|
||||
namespace cutlass::gemm::kernel
|
||||
{
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, class CollectiveAllReduce_,
|
||||
class TileScheduler_>
|
||||
class GemmARUniversal<ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, CollectiveAllReduce_, TileScheduler_,
|
||||
@ -55,6 +61,7 @@ public:
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled;
|
||||
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
@ -88,22 +95,46 @@ public:
|
||||
|
||||
static_assert(!cute::is_same_v<TileScheduler_, StreamKScheduler>,
|
||||
"Ping-pong kernel does not currently support stream-K scheduler.");
|
||||
static constexpr uint32_t TileSchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount;
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler =
|
||||
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler;
|
||||
using TileScheduler = typename detail::TileSchedulerSelector<TileSchedulerTag, ArchTag, TileShape, ClusterShape,
|
||||
TileSchedulerPipelineStageCount>::Scheduler;
|
||||
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
using TileSchedulerPipeline = typename TileScheduler::Pipeline;
|
||||
using TileSchedulerPipelineState = typename TileSchedulerPipeline::PipelineState;
|
||||
using TileSchedulerStorage = typename TileScheduler::SharedStorage;
|
||||
|
||||
static constexpr uint32_t NumAllReduceThreads = 2 * NumThreadsPerWarp;
|
||||
using TileSchedulerThrottlePipeline = typename TileScheduler::ThrottlePipeline;
|
||||
using TileSchedulerThrottlePipelineState = typename TileSchedulerThrottlePipeline::PipelineState;
|
||||
|
||||
static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent;
|
||||
static_assert(!IsSchedDynamicPersistent, "Tile scheduling order needs to be consistent across ranks.");
|
||||
|
||||
// Warp specialization thread count per threadblock
|
||||
static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp
|
||||
static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp
|
||||
static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = 2;
|
||||
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
|
||||
static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 4 warp
|
||||
static constexpr uint32_t MaxThreadsPerBlock
|
||||
= CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
|
||||
= NumMMAThreads * NumMmaWarpGroups + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
static constexpr bool IsMainloopAuxiliaryLoadNeeded
|
||||
= detail::HasAuxiliaryLoad_v<typename CollectiveMainloop::DispatchPolicy>;
|
||||
|
||||
static_assert(NumMMAThreads == 128, "Pingpong kernel must have TiledMMA operating using 128 threads.");
|
||||
static_assert(MaxThreadsPerBlock == 384, "Pingpong kernel must have 384 threads in total.");
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
static constexpr int RegsPerThread = (size<0>(TileShape{}) * size<1>(TileShape{}) * sizeof(ElementAccumulator))
|
||||
/ (NumMMAThreads * sizeof(uint32_t));
|
||||
static constexpr bool HeavyRegisterPressure = RegsPerThread >= 208;
|
||||
static constexpr uint32_t LoadRegisterRequirement = !HeavyRegisterPressure ? 40 : 24;
|
||||
static constexpr uint32_t MmaRegisterRequirement = !HeavyRegisterPressure ? 232 : 240;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
|
||||
@ -114,7 +145,6 @@ public:
|
||||
// Order Sequence barrier with two stages: one for Mainloop and one for Epilogue
|
||||
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||
using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||
|
||||
using MathWarpGroupOrderBarrierSharedStorage
|
||||
= cutlass::PipelineDetail::OrderedSequenceBarrierSharedStorage<MathWarpGroupOrderBarrier::SequenceDepth,
|
||||
MathWarpGroupOrderBarrier::SequenceLength>;
|
||||
@ -122,7 +152,7 @@ public:
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage
|
||||
{
|
||||
struct PipelineStorage : cute::aligned_struct<16>
|
||||
struct PipelineStorage : cute::aligned_struct<16, _1>
|
||||
{
|
||||
using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage;
|
||||
@ -135,7 +165,9 @@ public:
|
||||
alignas(16) typename AllReduceOrderBarrier::SharedStorage allreduce_order;
|
||||
} pipelines;
|
||||
|
||||
struct TensotStorage : cute::aligned_struct<128>
|
||||
alignas(16) TileSchedulerStorage scheduler;
|
||||
|
||||
struct TensorStorage : cute::aligned_struct<128, _1>
|
||||
{
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
@ -199,31 +231,48 @@ public:
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Get maximum number of clusters that could co-exist on the target device
|
||||
int max_active_clusters = args.hw_info.max_active_clusters;
|
||||
if (max_active_clusters <= 0)
|
||||
{
|
||||
max_active_clusters = 0;
|
||||
CUTLASS_TRACE_HOST(
|
||||
" WARNING: Arguments do not include a valid max cluster count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the "
|
||||
"max_active_clusters.");
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_TRACE_HOST(
|
||||
"to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters);
|
||||
}
|
||||
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void* scheduler_workspace = workspace_ptr;
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* scheduler_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void* mainloop_workspace = nullptr;
|
||||
constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
return {args.mode, problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace),
|
||||
CollectiveAllReduce::to_underlying_arguments(args.problem_shape, args.all_reduce), hw_info,
|
||||
TileScheduler::to_underlying_arguments(
|
||||
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)};
|
||||
TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
|
||||
args.scheduler, scheduler_workspace, NumEpilogueSubTiles)};
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args)
|
||||
@ -245,13 +294,14 @@ public:
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
size_t workspace_size = 0;
|
||||
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
@ -261,20 +311,23 @@ public:
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t* workspace_ptr = reinterpret_cast<uint8_t*>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
static constexpr uint32_t NumEpilogueSubTiles = 1;
|
||||
static constexpr uint32_t NumAccumulatorMtxs = 1;
|
||||
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, 1);
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue);
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape, ElementAccumulator>(args.scheduler,
|
||||
workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups,
|
||||
NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter);
|
||||
workspace_offset += TileScheduler::template get_workspace_size<ProblemShape, ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
@ -347,9 +400,9 @@ public:
|
||||
enum class ProducerWarpRole
|
||||
{
|
||||
Mainloop = 0,
|
||||
Epilogue = 1,
|
||||
Warp2 = 2,
|
||||
Warp3 = 3
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
MainloopAux = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
@ -372,10 +425,18 @@ public:
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// TileScheduler pipeline
|
||||
typename TileSchedulerPipeline::Params scheduler_pipeline_params;
|
||||
typename TileSchedulerThrottlePipeline::Params scheduler_throttle_pipeline_params;
|
||||
TileSchedulerPipeline scheduler_pipeline(shared_storage.scheduler.pipeline(), scheduler_pipeline_params);
|
||||
TileSchedulerPipelineState scheduler_pipe_consumer_state;
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop)
|
||||
if (warp_group_role == WarpGroupRole::Producer
|
||||
&& (producer_warp_role == ProducerWarpRole::Mainloop
|
||||
|| producer_warp_role == ProducerWarpRole::MainloopAux))
|
||||
{
|
||||
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
@ -385,6 +446,7 @@ public:
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
||||
mainloop_pipeline_params.num_producers = NumProducerThreads;
|
||||
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
|
||||
|
||||
@ -420,7 +482,6 @@ public:
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_allreduce_order_barrier;
|
||||
// DMA Load WG will not participate in these Ordered Barrier syncs
|
||||
params_allreduce_order_barrier.group_id
|
||||
= canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
|
||||
params_allreduce_order_barrier.group_size = NumThreadsPerWarpGroup;
|
||||
@ -493,8 +554,10 @@ public:
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer1)
|
||||
{
|
||||
|
||||
// Advance 2nd Math WG to the next work tile for the startup
|
||||
scheduler.advance_to_next_work();
|
||||
|
||||
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count);
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
@ -546,8 +609,40 @@ public:
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
else if (producer_warp_role == ProducerWarpRole::MainloopAux)
|
||||
{
|
||||
if constexpr (IsMainloopAuxiliaryLoadNeeded)
|
||||
{
|
||||
// Ensure that the prefetched kernel does not touch
|
||||
// unflushed global memory prior to this instruction
|
||||
cutlass::arch::wait_on_dependent_grids();
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
|
||||
collective_mainloop.load_auxiliary(params.mainloop, mainloop_pipeline,
|
||||
mainloop_pipe_producer_state, load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx,
|
||||
block_rank_in_cluster, shared_storage.tensors.mainloop);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(k_tile_count);
|
||||
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
|
||||
}
|
||||
}
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed())
|
||||
{
|
||||
@ -556,9 +651,15 @@ public:
|
||||
// unflushed global memory prior to this instruction
|
||||
cutlass::arch::wait_on_dependent_grids();
|
||||
|
||||
load_order_barrier.wait();
|
||||
bool do_load_order_wait = true;
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
if (do_load_order_wait)
|
||||
{
|
||||
load_order_barrier.wait();
|
||||
do_load_order_wait = false;
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
@ -592,18 +693,13 @@ public:
|
||||
// The timing of calling this function only influences performance,
|
||||
// not functional correctness.
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr uint32_t AllReduceBarrier0 = uint32_t(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier) + 1;
|
||||
constexpr uint32_t AllReduceBarrier1 = uint32_t(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier) + 2;
|
||||
auto const named_barrier
|
||||
= warp_group_role == WarpGroupRole::Consumer0 ? AllReduceBarrier0 : AllReduceBarrier1;
|
||||
CollectiveAllReduce collective_all_reduce(params.all_reduce, named_barrier);
|
||||
|
||||
cute::tuple<int32_t, int32_t, cute::Underscore, int32_t> prev_blk_coord;
|
||||
int tile_count = 0;
|
||||
const uint32_t AR_barrier_id = warp_group_role == WarpGroupRole::Consumer0 ? 0 : 1;
|
||||
CollectiveAllReduce collective_allreduce(params.all_reduce, AR_barrier_id);
|
||||
|
||||
while (work_tile_info.is_valid())
|
||||
{
|
||||
@ -621,6 +717,7 @@ public:
|
||||
|
||||
collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators, k_tile_count,
|
||||
warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop);
|
||||
|
||||
// Cue for next Math WG's MMA to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
@ -638,6 +735,7 @@ public:
|
||||
cutlass::arch::launch_dependent_grids();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Order two Math WG's Epilogue one after the other
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
@ -661,32 +759,28 @@ public:
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
|
||||
// Cue next WG's Epilogue
|
||||
// Cue for next Math WG's Epilogue to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Order two consumer WG's Allreduce one after the other
|
||||
allreduce_order_barrier.wait();
|
||||
|
||||
collective_all_reduce.gather_reduce_broadcast(problem_shape_MNKL, blk_coord, warp_group_thread_idx);
|
||||
collective_allreduce.gather_reduce_broadcast(problem_shape_MNKL, blk_coord, warp_group_thread_idx);
|
||||
|
||||
// Cue next WG's Allreduce
|
||||
allreduce_order_barrier.arrive();
|
||||
|
||||
prev_blk_coord = blk_coord;
|
||||
tile_count++;
|
||||
if (scheduler.is_last_tile(work_tile_info, NumMmaWarpGroups))
|
||||
{
|
||||
// Ensure broadcast from other ranks are visible to us.
|
||||
collective_allreduce.tile_global_sync(problem_shape_MNKL, blk_coord, warp_group_thread_idx);
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work(NumMmaWarpGroups);
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Last tile in CTA, flush and sync
|
||||
if (tile_count > 0)
|
||||
{
|
||||
collective_all_reduce.tile_global_sync(problem_shape_MNKL, prev_blk_coord, warp_group_thread_idx);
|
||||
}
|
||||
|
||||
} // Consumer Warp Groups End
|
||||
} // Consumer Warp Groups End
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -127,6 +127,53 @@ void GemmAllReducePlugin::allocatePersistentWorkspace()
|
||||
TLLM_CHECK(mWorkspace != nullptr);
|
||||
}
|
||||
|
||||
LaunchConfig GemmAllReducePlugin::getStaticHeuristicLaunchConfig(int M) const
|
||||
{
|
||||
using namespace tensorrt_llm::cutlass_extensions;
|
||||
// This is only applicable when we swap and transpose A & B.
|
||||
// When M is small we want to select tile that best fits it to maximize MMA efficiency.
|
||||
auto filterByM = [&](std::vector<LaunchConfig> candidateConfigs)
|
||||
{
|
||||
std::vector<LaunchConfig> result;
|
||||
if (M <= 16)
|
||||
{
|
||||
std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result),
|
||||
[](const LaunchConfig& config)
|
||||
{ return config.tile_shape == TileShape::TileShape_128x16x128 and config.transposed; });
|
||||
}
|
||||
else if (M <= 32)
|
||||
{
|
||||
std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result),
|
||||
[](const LaunchConfig& config)
|
||||
{ return config.tile_shape == TileShape::TileShape_128x32x128 and config.transposed; });
|
||||
}
|
||||
else if (M <= 64)
|
||||
{
|
||||
std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result),
|
||||
[](const LaunchConfig& config)
|
||||
{ return config.tile_shape == TileShape::TileShape_128x64x128 and config.transposed; });
|
||||
}
|
||||
else
|
||||
{
|
||||
std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result),
|
||||
[](const LaunchConfig& config)
|
||||
{ return config.tile_shape == TileShape::TileShape_128x128x128 and config.transposed; });
|
||||
}
|
||||
// If result empty then use any.
|
||||
if (result.empty())
|
||||
{
|
||||
result = candidateConfigs;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
auto bestLaunchConfigs = mGemm->getSupportedLaunchConfigs();
|
||||
bestLaunchConfigs = filterByM(bestLaunchConfigs);
|
||||
TLLM_CHECK(!bestLaunchConfigs.empty());
|
||||
// Return first one, because who knows which is best.
|
||||
return bestLaunchConfigs.front();
|
||||
}
|
||||
|
||||
static GemmAllReducePluginOptions deserializeOptions(void const*& data, size_t length)
|
||||
{
|
||||
char const* begin = reinterpret_cast<char const*>(data);
|
||||
@ -164,8 +211,10 @@ static GemmAllReducePluginOptions deserializeOptions(void const*& data, size_t l
|
||||
GemmAllReducePlugin::GemmAllReducePlugin(void const* data, size_t length)
|
||||
: GemmAllReducePlugin(deserializeOptions(std::ref(data), length))
|
||||
{
|
||||
// char const* end = reinterpret_cast<char const*>(data);
|
||||
mProfiler->deserializeFromOwnFile(mGemmId, mOptions.maxProblemShape);
|
||||
if (mProfiler->useProfiler())
|
||||
{
|
||||
mProfiler->deserializeFromOwnFile(mGemmId, mOptions.maxProblemShape);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////
|
||||
@ -351,7 +400,15 @@ int GemmAllReducePlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensor
|
||||
TLLM_CHECK_WITH_INFO(K > 0, "GemmAllReducePlugin K is 0.");
|
||||
TLLM_CHECK_WITH_INFO(mWorkspace != nullptr, "GemmAllReducePlugin workspace is null.");
|
||||
|
||||
auto bestLaunchConfig = mProfiler->getBestConfig(M, mGemmId).value();
|
||||
LaunchConfig bestLaunchConfig;
|
||||
if (mProfiler->useProfiler())
|
||||
{
|
||||
bestLaunchConfig = mProfiler->getBestConfig(M, mGemmId).value();
|
||||
}
|
||||
else
|
||||
{
|
||||
bestLaunchConfig = getStaticHeuristicLaunchConfig(M);
|
||||
}
|
||||
|
||||
void const* activation = inputs[mArgInvMap[TensorArg::IN_ACTIVATION]];
|
||||
void const* weight = inputs[mArgInvMap[TensorArg::IN_WEIGHT]];
|
||||
@ -435,7 +492,7 @@ int GemmAllReducePlugin::getNbOutputs() const noexcept
|
||||
|
||||
int GemmAllReducePlugin::initialize() noexcept
|
||||
{
|
||||
if (isBuilding())
|
||||
if (isBuilding() && mProfiler->useProfiler())
|
||||
{
|
||||
// TODO (xsimmons): interfaces between GemmPluginProfiler and Plugin
|
||||
// needs to be relooked at - current interface implicitly assigns runner to profiler
|
||||
@ -509,7 +566,7 @@ void GemmAllReducePlugin::serialize(void* buffer) const noexcept
|
||||
// Since by default each rank will generate and serialize its own profiler mapping
|
||||
// this can lead to different mappings between ranks which will result in fatal
|
||||
// error. Therefore only generate and use profiler mapping for single rank.
|
||||
if (COMM_SESSION.getRank() == 0)
|
||||
if (mProfiler->useProfiler() && COMM_SESSION.getRank() == 0)
|
||||
{
|
||||
mProfiler->serializeToOwnFile(mGemmId);
|
||||
}
|
||||
|
||||
@ -34,6 +34,9 @@ namespace cutlass_kernels = ::tensorrt_llm::kernels::opened_cutlass_kernels;
|
||||
#else
|
||||
namespace cutlass_kernels = ::tensorrt_llm::kernels::cutlass_kernels;
|
||||
#endif
|
||||
|
||||
using LaunchConfig = typename cutlass_kernels::GemmAllReduceImplInterface::LaunchConfig;
|
||||
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
struct GemmAllReducePluginOptions
|
||||
@ -125,6 +128,8 @@ private:
|
||||
|
||||
void allocatePersistentWorkspace();
|
||||
|
||||
LaunchConfig getStaticHeuristicLaunchConfig(int M) const;
|
||||
|
||||
// Params that are initialized during constructor
|
||||
using KeyType = std::tuple<DataType, DataType, DataType>;
|
||||
using ValueType = std::function<cutlass_kernels::GemmAllReduceImplInterface*()>;
|
||||
|
||||
@ -58,10 +58,18 @@ void GemmAllReducePluginProfiler::deserializeFromOwnFile(GemmIdCore gemmId, Gemm
|
||||
assert(end == begin + size);
|
||||
}
|
||||
|
||||
bool GemmAllReducePluginProfiler::useProfiler()
|
||||
{
|
||||
char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR");
|
||||
return envDir != nullptr;
|
||||
}
|
||||
|
||||
std::string GemmAllReducePluginProfiler::getCacheFileName(GemmIdCore gemmId)
|
||||
{
|
||||
std::stringstream fileName;
|
||||
fileName << "/tmp/gemm-AR";
|
||||
char const* envDir = getenv("GEMM_AR_PLUGIN_PROFILE_DIR");
|
||||
std::string directory = envDir ? std::string(envDir) : "/tmp/";
|
||||
fileName << directory + "/gemm-AR";
|
||||
fileName << "-n" << std::to_string(gemmId.n);
|
||||
fileName << "-k" << std::to_string(gemmId.k);
|
||||
fileName << "-" << tc::getDtypeString(gemmId.dtype);
|
||||
|
||||
@ -29,6 +29,8 @@ namespace tensorrt_llm::plugins
|
||||
* Used for tuning to find best GEMM configs for different problem shapes.
|
||||
* WARNING: Tuning GEMM+AR kernel may not be fully representable of real
|
||||
* multi-GPU workloads as tuning only runs on single-GPU.
|
||||
* IMPORTANT: TRT-LLM does not support deterministic tuning across ranks.
|
||||
* Because of this, we have to serialize/deserialize our own configuration file.
|
||||
*/
|
||||
|
||||
#if defined(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
|
||||
@ -45,6 +47,8 @@ public:
|
||||
|
||||
void deserializeFromOwnFile(GemmIdCore gemmId, GemmDims problemShape);
|
||||
|
||||
bool useProfiler();
|
||||
|
||||
protected:
|
||||
////////////////////////////////////
|
||||
// GemmPluginProfiler methods
|
||||
|
||||
@ -18,7 +18,7 @@ set(SRCS
|
||||
testing/modelSpecBinding.cpp
|
||||
runtime/moeBindings.cpp
|
||||
userbuffers/bindings.cpp
|
||||
../runtime/ipcNvlsMemory.cpp
|
||||
../runtime/ipcNvlsMemory.cu
|
||||
bindings.cpp)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
@ -30,10 +30,21 @@ set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE
|
||||
|
||||
target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC
|
||||
"${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
if(ENABLE_NVSHMEM)
|
||||
target_link_libraries(${TRTLLM_PYBIND_MODULE} PUBLIC nvshmem::nvshmem_host
|
||||
nvshmem::nvshmem_device)
|
||||
endif()
|
||||
|
||||
target_link_libraries(
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG} ${NO_AS_NEEDED_FLAG}
|
||||
${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python)
|
||||
PUBLIC ${SHARED_TARGET}
|
||||
${UNDEFINED_FLAG}
|
||||
${NO_AS_NEEDED_FLAG}
|
||||
${Python3_LIBRARIES}
|
||||
${TORCH_LIBRARIES}
|
||||
torch_python
|
||||
${CUDA_NVML_LIB})
|
||||
target_compile_definitions(
|
||||
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
|
||||
PYBIND11_DETAILED_ERROR_MESSAGES=1)
|
||||
@ -43,6 +54,6 @@ if(NOT WIN32)
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PROPERTIES
|
||||
LINK_FLAGS
|
||||
"-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}"
|
||||
"-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' -Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/lib/stubs' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}"
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -40,7 +40,7 @@ set(SRCS
|
||||
iTensor.cpp
|
||||
ipcUtils.cpp
|
||||
ipcSocket.cpp
|
||||
ipcNvlsMemory.cpp
|
||||
ipcNvlsMemory.cu
|
||||
mcastDeviceMemory.cpp
|
||||
memoryCounters.cpp
|
||||
moeLoadBalancer/gdrwrap.cpp
|
||||
@ -80,6 +80,7 @@ set_property(TARGET runtime_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
add_cuda_architectures(runtime_src 89)
|
||||
|
||||
target_include_directories(runtime_src PRIVATE ${MPI_C_INCLUDE_DIRS})
|
||||
target_link_libraries(runtime_src PUBLIC ${CUDA_NVML_LIB})
|
||||
|
||||
if(ENABLE_MULTI_DEVICE)
|
||||
target_link_libraries(runtime_src PUBLIC ${NCCL_LIB})
|
||||
|
||||
@ -1,359 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
|
||||
#include "ipcSocket.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#define CUCHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
CUresult retval = cmd; \
|
||||
if (retval != CUDA_SUCCESS) \
|
||||
{ \
|
||||
const char* error_string; \
|
||||
cuGetErrorString(retval, &error_string); \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define ALIGN_SIZE(x, align) x = ((x + align - 1) / align) * align;
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
using namespace tensorrt_llm::mpi;
|
||||
|
||||
void MPI_group_rank(std::set<int> group, int* groupRank)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
int rank = COMM_SESSION.getRank();
|
||||
auto it = std::find(group.begin(), group.end(), rank);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
it != group.end(), "Incorrect group specified - rank " + std::to_string(rank) + " not found in group");
|
||||
*groupRank = std::distance(group.begin(), it);
|
||||
#else
|
||||
TLLM_THROW("MPI_group_rank needs to be compiled with ENABLE_MULTI_DEVICE");
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief MPI_Barrier when subset of ranks present
|
||||
*/
|
||||
void MPI_group_barrier(std::set<int> group)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
std::vector<int> ranks(group.begin(), group.end());
|
||||
int size = group.size();
|
||||
int group_rank;
|
||||
MPI_group_rank(group, &group_rank);
|
||||
|
||||
int root = 0;
|
||||
|
||||
if (group_rank == root)
|
||||
{
|
||||
int dummy = 0;
|
||||
// Root receives messages from all other processes
|
||||
for (int i = 1; i < size; i++)
|
||||
{
|
||||
COMM_SESSION.recv(&dummy, 1, MpiType::kINT32, ranks[i], MpiTag::kDefault);
|
||||
}
|
||||
// Root sends messages back to all other processes
|
||||
for (int i = 1; i < size; i++)
|
||||
{
|
||||
COMM_SESSION.send(&dummy, 1, MpiType::kINT32, ranks[i], MpiTag::kDefault);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
int dummy = 0;
|
||||
// Non-root processes send a message to root
|
||||
COMM_SESSION.send(&dummy, 1, MpiType::kINT32, ranks[root], MpiTag::kDefault);
|
||||
// Non-root processes receive a message from root
|
||||
COMM_SESSION.recv(&dummy, 1, MpiType::kINT32, ranks[root], MpiTag::kDefault);
|
||||
}
|
||||
#else
|
||||
TLLM_THROW("MPI_group_barrier needs to be compiled with ENABLE_MULTI_DEVICE");
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief MPI_Bcast when subset of ranks present
|
||||
*/
|
||||
void MPI_group_bcast(std::set<int> group, void* buffer, int count, MpiType datatype, int root)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
int group_rank;
|
||||
MPI_group_rank(group, &group_rank);
|
||||
std::vector<int> ranks(group.begin(), group.end());
|
||||
|
||||
if (group_rank == root)
|
||||
{
|
||||
// Root sends message to all other processes
|
||||
for (size_t i = 1; i < ranks.size(); ++i)
|
||||
{
|
||||
COMM_SESSION.send(buffer, count, datatype, ranks[i], MpiTag::kDefault);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Non-root processes receive a message from root
|
||||
COMM_SESSION.recv(buffer, count, datatype, ranks[root], MpiTag::kDefault);
|
||||
}
|
||||
MPI_group_barrier(group);
|
||||
#else
|
||||
TLLM_THROW("MPI_group_bcast needs to be compiled with ENABLE_MULTI_DEVICE");
|
||||
#endif
|
||||
}
|
||||
|
||||
bool ipcNvlsSupported()
|
||||
{
|
||||
CUdevice current_dev;
|
||||
int cuda_dev = -1;
|
||||
int cuda_driver_version = -1;
|
||||
int dev_count = 0;
|
||||
|
||||
TLLM_CUDA_CHECK(cudaDriverGetVersion(&cuda_driver_version));
|
||||
if (cuda_driver_version < 12010)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceCount(&dev_count));
|
||||
for (int i = 0; i < dev_count; ++i)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&cuda_dev));
|
||||
CUCHECK(cuDeviceGet(¤t_dev, cuda_dev));
|
||||
|
||||
int mc_support = 0;
|
||||
CUCHECK(cuDeviceGetAttribute(
|
||||
&mc_support, static_cast<CUdevice_attribute>(CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED), current_dev));
|
||||
if (mc_support == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
IpcNvlsHandle* ipcNvlsAllocate(size_t size, std::set<int> group)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
TLLM_CHECK(size > 0);
|
||||
|
||||
std::vector<int> ranks(group.begin(), group.end());
|
||||
|
||||
int rank = COMM_SESSION.getRank();
|
||||
|
||||
int group_rank;
|
||||
MPI_group_rank(group, &group_rank);
|
||||
int device_id = ranks[group_rank];
|
||||
|
||||
cudaSetDevice(device_id);
|
||||
|
||||
CUmemAllocationProp ucprop;
|
||||
CUmulticastObjectProp mcprop;
|
||||
size_t uc_align = 0;
|
||||
size_t mc_align = 0;
|
||||
|
||||
CUmemAccessDesc uc_mc_access;
|
||||
memset(&uc_mc_access, 0, sizeof(CUmemAccessDesc));
|
||||
uc_mc_access.location.id = device_id;
|
||||
uc_mc_access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
uc_mc_access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
|
||||
memset(&ucprop, 0, sizeof(CUmemAllocationProp));
|
||||
ucprop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
ucprop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
ucprop.location.id = device_id;
|
||||
ucprop.allocFlags.gpuDirectRDMACapable = 1;
|
||||
ucprop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
CUCHECK(cuMemGetAllocationGranularity(&uc_align, &ucprop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
ALIGN_SIZE(size, uc_align);
|
||||
|
||||
memset(&mcprop, 0, sizeof(CUmulticastObjectProp));
|
||||
mcprop.numDevices = ranks.size();
|
||||
mcprop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
mcprop.flags = 0;
|
||||
mcprop.size = size;
|
||||
CUCHECK(cuMulticastGetGranularity(&mc_align, &mcprop, CU_MULTICAST_GRANULARITY_MINIMUM));
|
||||
ALIGN_SIZE(size, mc_align);
|
||||
mcprop.size = size;
|
||||
|
||||
// Init NVLS handle
|
||||
IpcNvlsHandle handle;
|
||||
handle.size = mcprop.size;
|
||||
|
||||
// Get time
|
||||
timespec ts;
|
||||
clock_gettime(CLOCK_MONOTONIC, &ts);
|
||||
// High res time down to nanosec
|
||||
unsigned long seed = ts.tv_sec * 1000000000L + ts.tv_nsec;
|
||||
// Initialize with rand seed.
|
||||
srand(seed);
|
||||
int root = 0;
|
||||
uint64_t unique_op_id = (uint64_t) (rand()) ^ ((uint64_t) (rand()) << 32);
|
||||
MPI_group_bcast(group, &unique_op_id, sizeof(unique_op_id), MpiType::kBYTE, root);
|
||||
|
||||
uint32_t volatile abort_flag = 0;
|
||||
std::shared_ptr<NcclIpcSocket> socket = ncclIpcSocketInit(rank, unique_op_id, &abort_flag);
|
||||
MPI_group_barrier(group);
|
||||
|
||||
int fd;
|
||||
if (group_rank == root)
|
||||
{
|
||||
CUCHECK(cuMulticastCreate(&handle.mc_handle, &mcprop));
|
||||
CUCHECK(
|
||||
cuMemExportToShareableHandle(&fd, handle.mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
|
||||
// Root to send fd to all other processes
|
||||
for (size_t i = 1; i < group.size(); ++i)
|
||||
{
|
||||
ncclIpcSocketSendFd(socket, fd, ranks[i], unique_op_id);
|
||||
}
|
||||
MPI_group_barrier(group);
|
||||
}
|
||||
else
|
||||
{
|
||||
MPI_group_barrier(group);
|
||||
fd = ncclIpcSocketRecvFd(socket);
|
||||
CUCHECK(cuMemImportFromShareableHandle(
|
||||
&handle.mc_handle, (void*) (uintptr_t) fd, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
|
||||
}
|
||||
|
||||
MPI_group_barrier(group);
|
||||
close(fd);
|
||||
|
||||
// Add device to multicast object
|
||||
CUdevice dev;
|
||||
CUCHECK(cuDeviceGet(&dev, device_id));
|
||||
CUCHECK(cuMulticastAddDevice(handle.mc_handle, dev));
|
||||
|
||||
// Create multicast VA
|
||||
CUCHECK(cuMemAddressReserve(&handle.mc_va, size, mc_align, 0U, 0));
|
||||
CUCHECK(cuMemMap(handle.mc_va, size, 0, handle.mc_handle, 0));
|
||||
CUCHECK(cuMemSetAccess(handle.mc_va, size, &uc_mc_access, 1 /* count */));
|
||||
|
||||
// Allocate unicast VA
|
||||
CUCHECK(cuMemCreate(&handle.uc_handle, size, &ucprop, 0));
|
||||
CUCHECK(cuMemAddressReserve(&handle.uc_va, size, uc_align, 0U, 0));
|
||||
CUCHECK(cuMemMap(handle.uc_va, size, 0, handle.uc_handle, 0));
|
||||
|
||||
// set access on UC address, for all GPUs so that UVA works
|
||||
for (int gpu_id : group)
|
||||
{
|
||||
uc_mc_access.location.id = gpu_id;
|
||||
CUCHECK(cuMemSetAccess(handle.uc_va, size, &uc_mc_access, 1 /* count */));
|
||||
}
|
||||
|
||||
// Bind unicast memory to multicast group
|
||||
CUCHECK(cuMulticastBindMem(handle.mc_handle, 0 /*mcOffset*/, handle.uc_handle, 0 /*memOffset*/, size, 0 /*flags*/));
|
||||
|
||||
handle.mc_ptr = reinterpret_cast<uintptr_t>((void*) handle.mc_va);
|
||||
handle.uc_ptr = reinterpret_cast<uintptr_t>((void*) handle.uc_va);
|
||||
|
||||
printf("Rank %d nvlsAllocated %zu bytes successfully %p %p\n", rank, size, (void*) handle.uc_ptr,
|
||||
(void*) handle.mc_ptr);
|
||||
|
||||
// Export to unicast VA to shareable handle
|
||||
int fd_uc;
|
||||
CUCHECK(cuMemExportToShareableHandle(
|
||||
(void*) &fd_uc, handle.uc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
|
||||
|
||||
handle.ipc_uc_ptrs.resize(ranks.size());
|
||||
handle.ipc_uc_vas.resize(ranks.size());
|
||||
handle.ipc_uc_handles.resize(ranks.size());
|
||||
|
||||
// Allgather unicast shareable handles
|
||||
std::vector<int> peer_fds_uc(ranks.size());
|
||||
peer_fds_uc[group_rank] = fd_uc;
|
||||
for (size_t i = 1; i < ranks.size(); ++i)
|
||||
{
|
||||
MPI_group_barrier(group);
|
||||
int send_rank = (group_rank + i) % ranks.size();
|
||||
int recv_rank = (group_rank + ranks.size() - i) % ranks.size();
|
||||
ncclIpcSocketSendFd(socket, fd_uc, ranks[send_rank], unique_op_id);
|
||||
peer_fds_uc[recv_rank] = ncclIpcSocketRecvFd(socket);
|
||||
}
|
||||
ncclIpcSocketClose(socket);
|
||||
|
||||
// Import unicast shareable handles
|
||||
for (size_t i = 0; i < ranks.size(); ++i)
|
||||
{
|
||||
if (ranks[i] == rank)
|
||||
{
|
||||
handle.ipc_uc_ptrs[i] = handle.uc_ptr;
|
||||
handle.ipc_uc_vas[i] = handle.uc_va;
|
||||
handle.ipc_uc_handles[i] = handle.uc_handle;
|
||||
}
|
||||
else
|
||||
{
|
||||
CUCHECK(cuMemImportFromShareableHandle(&handle.ipc_uc_handles[i], (void*) (uintptr_t) peer_fds_uc[i],
|
||||
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
|
||||
CUCHECK(cuMemAddressReserve(&handle.ipc_uc_vas[i], size, uc_align, 0U, 0));
|
||||
CUCHECK(cuMemMap(handle.ipc_uc_vas[i], size, 0, handle.ipc_uc_handles[i], 0));
|
||||
// set access on UC address, for all GPUs so that UVA works
|
||||
for (int gpu_id : group)
|
||||
{
|
||||
uc_mc_access.location.id = gpu_id;
|
||||
CUCHECK(cuMemSetAccess(handle.ipc_uc_vas[i], size, &uc_mc_access, 1 /* count */));
|
||||
}
|
||||
|
||||
handle.ipc_uc_ptrs[i] = reinterpret_cast<uintptr_t>((void*) handle.ipc_uc_vas[i]);
|
||||
}
|
||||
// close FD UC
|
||||
close(peer_fds_uc[i]);
|
||||
}
|
||||
|
||||
MPI_group_barrier(group);
|
||||
|
||||
printf("Rank %d imported IPC handles successfully\n", rank);
|
||||
|
||||
return new IpcNvlsHandle(std::move(handle));
|
||||
#else
|
||||
TLLM_THROW("ipcNvlsAllocate needs to be compiled with ENABLE_MULTI_DEVICE");
|
||||
#endif
|
||||
}
|
||||
|
||||
void ipcNvlsFree(IpcNvlsHandle* handle)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
if (handle == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Unmap and release MC VA
|
||||
CUCHECK(cuMemUnmap(handle->mc_va, handle->size));
|
||||
CUCHECK(cuMemRelease(handle->mc_handle));
|
||||
CUCHECK(cuMemAddressFree(handle->mc_va, handle->size));
|
||||
// Unmap and release UC VA
|
||||
for (size_t i = 0; i < handle->ipc_uc_vas.size(); ++i)
|
||||
{
|
||||
CUCHECK(cuMemUnmap(handle->ipc_uc_vas[i], handle->size));
|
||||
CUCHECK(cuMemRelease(handle->ipc_uc_handles[i]));
|
||||
CUCHECK(cuMemAddressFree(handle->ipc_uc_vas[i], handle->size));
|
||||
}
|
||||
|
||||
delete handle;
|
||||
#else
|
||||
TLLM_THROW("ipcNvlsFree needs to be compiled with ENABLE_MULTI_DEVICE");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
537
cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu
Normal file
537
cpp/tensorrt_llm/runtime/ipcNvlsMemory.cu
Normal file
@ -0,0 +1,537 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
|
||||
#include "tensorrt_llm/runtime/ipcSocket.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#if ENABLE_NVSHMEM
|
||||
#include <nvshmem/nvshmem.h>
|
||||
#include <nvshmem/nvshmemx.h>
|
||||
#endif
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include <nvml.h>
|
||||
#endif
|
||||
#include <unistd.h>
|
||||
|
||||
#define CUCHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
CUresult retval = cmd; \
|
||||
if (retval != CUDA_SUCCESS) \
|
||||
{ \
|
||||
const char* error_string; \
|
||||
cuGetErrorString(retval, &error_string); \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define NVMLCHECK(cmd) \
|
||||
do \
|
||||
{ \
|
||||
nvmlReturn_t retval = cmd; \
|
||||
if (retval != NVML_SUCCESS) \
|
||||
{ \
|
||||
printf("Failed: NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(retval)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// if n is already a multiple of "multiple", n is returned unchanged, otherwise round up to next multiple.
|
||||
#define ROUND_UP(n, multiple) (((n + multiple - 1) / multiple) * multiple)
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
using namespace tensorrt_llm::mpi;
|
||||
|
||||
#if ENABLE_MULTI_DEVICE && !ENABLE_NVSHMEM
|
||||
union IpcMemHandle
|
||||
{
|
||||
uint64_t fd;
|
||||
CUmemFabricHandle fh;
|
||||
};
|
||||
|
||||
class IpcCommunicator
|
||||
{
|
||||
public:
|
||||
virtual ~IpcCommunicator() = default;
|
||||
virtual void bcastMemHandle(IpcMemHandle* handle, int root) = 0;
|
||||
};
|
||||
|
||||
class IpcSocketCommunicator : public IpcCommunicator
|
||||
{
|
||||
public:
|
||||
IpcSocketCommunicator(int world_rank, int group_rank, std::vector<int> group_ranks, MPI_Comm group_comm)
|
||||
: mGroupRank(group_rank)
|
||||
, mGroupRanks(group_ranks)
|
||||
, mGroupComm(group_comm)
|
||||
{
|
||||
timespec ts;
|
||||
clock_gettime(CLOCK_MONOTONIC, &ts);
|
||||
unsigned long seed = ts.tv_sec * 1000000000L + ts.tv_nsec;
|
||||
srand(seed);
|
||||
uint64_t unique_op_id = (uint64_t) (rand()) ^ ((uint64_t) (rand()) << 32);
|
||||
MPI_Bcast(&unique_op_id, sizeof(unique_op_id), MPI_BYTE, 0, group_comm);
|
||||
|
||||
uint32_t volatile abort_flag = 0;
|
||||
mSocket = ncclIpcSocketInit(world_rank, unique_op_id, &abort_flag);
|
||||
MPI_Barrier(group_comm);
|
||||
}
|
||||
|
||||
~IpcSocketCommunicator()
|
||||
{
|
||||
ncclIpcSocketClose(mSocket);
|
||||
}
|
||||
|
||||
void bcastMemHandle(IpcMemHandle* handle, int root) override
|
||||
{
|
||||
if (mGroupRank == root)
|
||||
{
|
||||
for (size_t i = 0; i < mGroupRanks.size(); ++i)
|
||||
{
|
||||
if (i != root)
|
||||
{
|
||||
ncclIpcSocketSendFd(mSocket, handle->fd, mGroupRanks[i]);
|
||||
}
|
||||
}
|
||||
MPI_Barrier(mGroupComm);
|
||||
}
|
||||
else
|
||||
{
|
||||
MPI_Barrier(mGroupComm);
|
||||
handle->fd = ncclIpcSocketRecvFd(mSocket);
|
||||
}
|
||||
MPI_Barrier(mGroupComm);
|
||||
}
|
||||
|
||||
private:
|
||||
int mGroupRank;
|
||||
std::vector<int> mGroupRanks;
|
||||
MPI_Comm mGroupComm;
|
||||
std::shared_ptr<NcclIpcSocket> mSocket;
|
||||
};
|
||||
|
||||
class IpcFabricCommunicator : public IpcCommunicator
|
||||
{
|
||||
public:
|
||||
IpcFabricCommunicator(MPI_Comm group_comm)
|
||||
: mGroupComm(group_comm)
|
||||
{
|
||||
}
|
||||
|
||||
~IpcFabricCommunicator() = default;
|
||||
|
||||
void bcastMemHandle(IpcMemHandle* handle, int root) override
|
||||
{
|
||||
MPI_Bcast(handle, sizeof(CUmemFabricHandle), MPI_BYTE, root, mGroupComm);
|
||||
}
|
||||
|
||||
private:
|
||||
MPI_Comm mGroupComm;
|
||||
};
|
||||
|
||||
class NVLSCudaAllocator
|
||||
{
|
||||
public:
|
||||
static IpcNvlsHandle* allocate(size_t size, std::vector<int> ranks)
|
||||
{
|
||||
auto nvls_handle = new IpcNvlsHandle();
|
||||
|
||||
// Create a new communicator for the subset of ranks.
|
||||
MPI_Group world_group, new_group;
|
||||
MPI_Comm new_comm;
|
||||
// Get the group of the world communicator.
|
||||
MPI_Comm_group(COMM_SESSION, &world_group);
|
||||
// Create a new group containing only the ranks we want.
|
||||
MPI_Group_incl(world_group, ranks.size(), ranks.data(), &new_group);
|
||||
// Create a new communicator from the group.
|
||||
MPI_Comm_create_group(COMM_SESSION, new_group, 0, &new_comm);
|
||||
|
||||
// Get rank and group rank.
|
||||
int world_rank;
|
||||
int group_rank;
|
||||
MPI_Comm_rank(COMM_SESSION, &world_rank);
|
||||
MPI_Comm_rank(new_comm, &group_rank);
|
||||
|
||||
// Get runtime and driver device IDs.
|
||||
int device_id;
|
||||
int CU_dev;
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
|
||||
CUCHECK(cuDeviceGet(&CU_dev, device_id));
|
||||
|
||||
// Get handle type used to share memory handles between devices.
|
||||
auto handle_type = getMemHandleType();
|
||||
|
||||
// Define allocation access permissions (same for unicast and multicast).
|
||||
CUmemAccessDesc access_desc;
|
||||
memset(&access_desc, 0, sizeof(CUmemAccessDesc));
|
||||
access_desc.location.id = device_id;
|
||||
access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
|
||||
// Define unicast allocation properties.
|
||||
CUmemAllocationProp prop;
|
||||
memset(&prop, 0, sizeof(CUmemAllocationProp));
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device_id;
|
||||
prop.requestedHandleTypes = handle_type;
|
||||
|
||||
// Define multicast allocation properties.
|
||||
CUmulticastObjectProp mcprop;
|
||||
memset(&mcprop, 0, sizeof(CUmulticastObjectProp));
|
||||
mcprop.numDevices = ranks.size();
|
||||
mcprop.handleTypes = handle_type;
|
||||
mcprop.flags = 0;
|
||||
|
||||
// Round up allocation size to the nearest multiple of the unicast allocation granularity.
|
||||
size_t granularity = 0;
|
||||
CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
size = ROUND_UP(size, granularity);
|
||||
|
||||
// Round up allocation size to the nearest multiple of the multicast allocation granularity.
|
||||
size_t mc_granularity = 0;
|
||||
CUCHECK(cuMulticastGetGranularity(&mc_granularity, &mcprop, CU_MULTICAST_GRANULARITY_MINIMUM));
|
||||
size = ROUND_UP(size, mc_granularity);
|
||||
mcprop.size = size;
|
||||
nvls_handle->size = size;
|
||||
|
||||
// Allocate physical pages of memory on GPU.
|
||||
CUCHECK(cuMemCreate(&nvls_handle->uc_handle, size, &prop, 0));
|
||||
// Reserve unicast virtual address space for the memory.
|
||||
CUCHECK(cuMemAddressReserve(&nvls_handle->uc_va, size, granularity, 0U, 0));
|
||||
// Map the unicast virtual address space to the physical pages.
|
||||
CUCHECK(cuMemMap(nvls_handle->uc_va, size, 0, nvls_handle->uc_handle, 0));
|
||||
// Set the access permissions for the unicast memory.
|
||||
CUCHECK(cuMemSetAccess(nvls_handle->uc_va, size, &access_desc, 1));
|
||||
nvls_handle->uc_ptr = reinterpret_cast<uintptr_t>((void*) nvls_handle->uc_va);
|
||||
|
||||
// Setup communicator needed for multicast and unicast pointer exchange.
|
||||
std::shared_ptr<IpcCommunicator> ipc_communicator;
|
||||
if (handle_type == CU_MEM_HANDLE_TYPE_FABRIC)
|
||||
{
|
||||
ipc_communicator = std::make_shared<IpcFabricCommunicator>(new_comm);
|
||||
}
|
||||
else
|
||||
{
|
||||
ipc_communicator = std::make_shared<IpcSocketCommunicator>(world_rank, group_rank, ranks, new_comm);
|
||||
}
|
||||
|
||||
// Unicast pointer exchange between ranks.
|
||||
IpcMemHandle ipc_handle;
|
||||
CUCHECK(cuMemExportToShareableHandle((void*) &ipc_handle, nvls_handle->uc_handle, handle_type, 0 /*flags*/));
|
||||
|
||||
nvls_handle->ipc_uc_ptrs.resize(ranks.size());
|
||||
nvls_handle->ipc_uc_vas.resize(ranks.size());
|
||||
nvls_handle->ipc_uc_handles.resize(ranks.size());
|
||||
|
||||
for (int i = 0; i < ranks.size(); i++)
|
||||
{
|
||||
IpcMemHandle peer_ipc_handle = ipc_handle;
|
||||
ipc_communicator->bcastMemHandle(&peer_ipc_handle, i);
|
||||
if (i != group_rank)
|
||||
{
|
||||
void* os_handle
|
||||
= handle_type == CU_MEM_HANDLE_TYPE_FABRIC ? (void*) &peer_ipc_handle : (void*) peer_ipc_handle.fd;
|
||||
CUCHECK(cuMemImportFromShareableHandle(&nvls_handle->ipc_uc_handles[i], os_handle, handle_type));
|
||||
// Reserve peer unicast virtual address space for the memory.
|
||||
CUCHECK(cuMemAddressReserve(&nvls_handle->ipc_uc_vas[i], size, granularity, 0U, 0));
|
||||
// Map the peer unicast virtual address space to the physical pages.
|
||||
CUCHECK(cuMemMap(nvls_handle->ipc_uc_vas[i], size, 0, nvls_handle->ipc_uc_handles[i], 0));
|
||||
// Set the access permissions for the peer unicast memory.
|
||||
CUCHECK(cuMemSetAccess(nvls_handle->ipc_uc_vas[i], size, &access_desc, 1));
|
||||
nvls_handle->ipc_uc_ptrs[i] = reinterpret_cast<uintptr_t>((void*) nvls_handle->ipc_uc_vas[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
nvls_handle->ipc_uc_ptrs[i] = nvls_handle->uc_ptr;
|
||||
nvls_handle->ipc_uc_vas[i] = nvls_handle->uc_va;
|
||||
nvls_handle->ipc_uc_handles[i] = nvls_handle->uc_handle;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize multicast object for all ranks.
|
||||
if (group_rank == 0)
|
||||
{
|
||||
CUCHECK(cuMulticastCreate(&nvls_handle->mc_handle, &mcprop));
|
||||
// Export the allocation for the importing process.
|
||||
CUCHECK(cuMemExportToShareableHandle(&ipc_handle, nvls_handle->mc_handle, handle_type, 0 /*flags*/));
|
||||
ipc_communicator->bcastMemHandle(&ipc_handle, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
ipc_communicator->bcastMemHandle(&ipc_handle, 0);
|
||||
void* os_handle = handle_type == CU_MEM_HANDLE_TYPE_FABRIC ? (void*) &ipc_handle : (void*) ipc_handle.fd;
|
||||
CUCHECK(cuMemImportFromShareableHandle(&nvls_handle->mc_handle, os_handle, handle_type));
|
||||
}
|
||||
|
||||
// Add device to multicast object
|
||||
CUCHECK(cuMulticastAddDevice(nvls_handle->mc_handle, CU_dev));
|
||||
// Bind physical memory to the Multicast group.
|
||||
// Note: It will block until all ranks have been added to the group.
|
||||
CUCHECK(cuMulticastBindMem(nvls_handle->mc_handle, 0, nvls_handle->uc_handle, 0, size, 0));
|
||||
// Reserve multicast virtual address space for the memory.
|
||||
CUCHECK(cuMemAddressReserve(&nvls_handle->mc_va, size, mc_granularity, 0U, 0));
|
||||
// Map the multicast virtual address space to the physical pages.
|
||||
CUCHECK(cuMemMap(nvls_handle->mc_va, size, 0, nvls_handle->mc_handle, 0));
|
||||
// Set the access permissions for the multicast memory.
|
||||
CUCHECK(cuMemSetAccess(nvls_handle->mc_va, size, &access_desc, 1 /* count */));
|
||||
nvls_handle->mc_ptr = reinterpret_cast<uintptr_t>((void*) nvls_handle->mc_va);
|
||||
|
||||
// Clean up
|
||||
MPI_Group_free(&new_group);
|
||||
MPI_Group_free(&world_group);
|
||||
|
||||
return nvls_handle;
|
||||
}
|
||||
|
||||
static void free(IpcNvlsHandle* nvls_handle)
|
||||
{
|
||||
CUCHECK(cuMemUnmap(nvls_handle->mc_va, nvls_handle->size));
|
||||
CUCHECK(cuMemRelease(nvls_handle->mc_handle));
|
||||
CUCHECK(cuMemAddressFree(nvls_handle->mc_va, nvls_handle->size));
|
||||
for (size_t i = 0; i < nvls_handle->ipc_uc_vas.size(); ++i)
|
||||
{
|
||||
CUCHECK(cuMemUnmap(nvls_handle->ipc_uc_vas[i], nvls_handle->size));
|
||||
CUCHECK(cuMemRelease(nvls_handle->ipc_uc_handles[i]));
|
||||
CUCHECK(cuMemAddressFree(nvls_handle->ipc_uc_vas[i], nvls_handle->size));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static CUmemAllocationHandleType getMemHandleType()
|
||||
{
|
||||
int device_id;
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
|
||||
|
||||
// Check if fabric handle support is available.
|
||||
int fabric_supported = 0;
|
||||
CUCHECK(cuDeviceGetAttribute(&fabric_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device_id));
|
||||
if (!fabric_supported)
|
||||
{
|
||||
TLLM_LOG_TRACE(
|
||||
"checking fabric support... CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED not supported.");
|
||||
return CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
}
|
||||
|
||||
nvmlDevice_t nvml_device;
|
||||
nvmlGpuFabricInfo_t fabric_info;
|
||||
NVMLCHECK(nvmlInit_v2());
|
||||
NVMLCHECK(nvmlDeviceGetHandleByIndex(device_id, &nvml_device));
|
||||
NVMLCHECK(nvmlDeviceGetGpuFabricInfo(nvml_device, &fabric_info));
|
||||
NVMLCHECK(nvmlShutdown());
|
||||
|
||||
// Check if the fabric is fully initialized.
|
||||
if (fabric_info.state != NVML_GPU_FABRIC_STATE_COMPLETED || fabric_info.status != NVML_SUCCESS)
|
||||
{
|
||||
TLLM_LOG_TRACE("checking fabric support... fabric state is NOT COMPLETE: state=%u status=%u.",
|
||||
fabric_info.state, fabric_info.status);
|
||||
return CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
}
|
||||
|
||||
// Check that fabric handles can be created.
|
||||
CUmemAllocationProp prop;
|
||||
memset(&prop, 0, sizeof(CUmemAllocationProp));
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device_id;
|
||||
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
|
||||
|
||||
size_t alloc_size = 1024; // anything > 0
|
||||
size_t min_gran = 0;
|
||||
CUCHECK(cuMemGetAllocationGranularity(&min_gran, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
alloc_size = ROUND_UP(alloc_size, min_gran);
|
||||
|
||||
CUmemGenericAllocationHandle handle;
|
||||
CUresult err = cuMemCreate(&handle, alloc_size, &prop, 0);
|
||||
if (err == CUDA_ERROR_NOT_PERMITTED || err == CUDA_ERROR_NOT_SUPPORTED)
|
||||
{
|
||||
TLLM_LOG_TRACE("checking fabric support... cuMemCreate failed with not %s.",
|
||||
err == CUDA_ERROR_NOT_PERMITTED ? "permitted" : "supported");
|
||||
return CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
}
|
||||
else
|
||||
{
|
||||
CUCHECK(err);
|
||||
}
|
||||
|
||||
// Check if fabric handles can be exported & imported by IMEX (Internode Memory Exchange)
|
||||
CUmemFabricHandle fh;
|
||||
err = cuMemExportToShareableHandle(&fh, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0);
|
||||
if (err != CUDA_SUCCESS
|
||||
|| (err = cuMemImportFromShareableHandle(&handle, &fh, CU_MEM_HANDLE_TYPE_FABRIC)) != CUDA_SUCCESS)
|
||||
{
|
||||
TLLM_LOG_TRACE("checking fabric support... cuMemExport/cuMemImport failed.");
|
||||
CUCHECK(cuMemRelease(handle));
|
||||
return CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("fabric status: state=%u status=%u clique=%u", device_id, fabric_info.state, fabric_info.status,
|
||||
fabric_info.cliqueId);
|
||||
|
||||
CUCHECK(cuMemRelease(handle));
|
||||
// If we get here, fabric handles are supported.
|
||||
return CU_MEM_HANDLE_TYPE_FABRIC;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief MPI_Barrier when subset of ranks present
|
||||
*/
|
||||
void MPI_group_barrier(std::set<int> group)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
// Create a new communicator for the subset of ranks
|
||||
MPI_Group world_group, new_group;
|
||||
MPI_Comm new_comm;
|
||||
|
||||
// Get the group of the world communicator
|
||||
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
|
||||
|
||||
// Create a new group containing only the ranks we want
|
||||
std::vector<int> ranks(group.begin(), group.end());
|
||||
MPI_Group_incl(world_group, ranks.size(), ranks.data(), &new_group);
|
||||
|
||||
// Create a new communicator from the group
|
||||
MPI_Comm_create_group(MPI_COMM_WORLD, new_group, 0, &new_comm);
|
||||
|
||||
// Use the new communicator for the barrier
|
||||
MPI_Barrier(new_comm);
|
||||
|
||||
// Clean up
|
||||
MPI_Group_free(&new_group);
|
||||
MPI_Group_free(&world_group);
|
||||
MPI_Comm_free(&new_comm);
|
||||
#else
|
||||
TLLM_THROW("MPI_group_barrier needs to be compiled with ENABLE_MULTI_DEVICE");
|
||||
#endif
|
||||
}
|
||||
|
||||
bool ipcNvlsSupported()
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
CUdevice current_dev;
|
||||
int cuda_dev = -1;
|
||||
int cuda_driver_version = -1;
|
||||
int dev_count = 0;
|
||||
|
||||
TLLM_CUDA_CHECK(cudaDriverGetVersion(&cuda_driver_version));
|
||||
if (cuda_driver_version < 12010)
|
||||
{
|
||||
TLLM_LOG_ERROR("CUDA Driver version < 12010");
|
||||
return false;
|
||||
}
|
||||
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceCount(&dev_count));
|
||||
for (int i = 0; i < dev_count; ++i)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaGetDevice(&cuda_dev));
|
||||
CUCHECK(cuDeviceGet(¤t_dev, cuda_dev));
|
||||
|
||||
int multicast_supported = 0;
|
||||
CUCHECK(cuDeviceGetAttribute(&multicast_supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, current_dev));
|
||||
if (!multicast_supported)
|
||||
{
|
||||
TLLM_LOG_ERROR("CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED not supported on GPU%d.", cuda_dev);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
IpcNvlsHandle* ipcNvlsAllocate(size_t size, std::set<int> group)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
TLLM_CHECK_WITH_INFO(ipcNvlsSupported(), "Switch multicast is not supported on this system.");
|
||||
TLLM_CHECK(size > 0);
|
||||
TLLM_CHECK(group.size() >= 2);
|
||||
|
||||
std::vector<int> ranks(group.begin(), group.end());
|
||||
int group_size = ranks.size();
|
||||
|
||||
MPI_Comm mpi_comm = COMM_SESSION;
|
||||
|
||||
// Create a new communicator with only the ranks in the group
|
||||
MPI_Group world_group, new_group;
|
||||
MPI_Comm_group(mpi_comm, &world_group);
|
||||
MPI_Group_incl(world_group, group_size, ranks.data(), &new_group);
|
||||
|
||||
MPI_Comm new_comm;
|
||||
MPI_Comm_create_group(mpi_comm, new_group, 0, &new_comm);
|
||||
|
||||
#if ENABLE_NVSHMEM
|
||||
// Initialize NVSHMEM with the new communicator
|
||||
nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER;
|
||||
attr.mpi_comm = &new_comm;
|
||||
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr);
|
||||
|
||||
// Allocate NVSHMEM memory
|
||||
void* ptr = nvshmem_malloc(size);
|
||||
|
||||
// Create handle to return
|
||||
auto handle = new IpcNvlsHandle();
|
||||
|
||||
handle->size = size;
|
||||
handle->uc_ptr = reinterpret_cast<uintptr_t>(ptr);
|
||||
handle->mc_ptr = reinterpret_cast<uintptr_t>(nvshmemx_mc_ptr(NVSHMEM_TEAM_WORLD, ptr));
|
||||
for (int i = 0; i < ranks.size(); i++)
|
||||
{
|
||||
handle->ipc_uc_ptrs.push_back(reinterpret_cast<uintptr_t>(nvshmem_ptr(ptr, i)));
|
||||
}
|
||||
#else // !ENABLE_NVSHMEM
|
||||
auto handle = NVLSCudaAllocator::allocate(size, ranks);
|
||||
#endif
|
||||
|
||||
TLLM_LOG_INFO("Rank %d NVLS allocate %zu bytes, uc_ptr:%p mc_ptr:%p", COMM_SESSION.getRank(), size,
|
||||
(void*) handle->uc_ptr, (void*) handle->mc_ptr);
|
||||
|
||||
// Cleanup
|
||||
MPI_Group_free(&new_group);
|
||||
MPI_Group_free(&world_group);
|
||||
|
||||
MPI_Barrier(new_comm);
|
||||
|
||||
return handle;
|
||||
#else
|
||||
TLLM_THROW("ipcNvlsAllocate needs to be compiled with ENABLE_MULTI_DEVICE");
|
||||
#endif
|
||||
}
|
||||
|
||||
void ipcNvlsFree(IpcNvlsHandle* handle)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
if (handle == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
#if ENABLE_NVSHMEM
|
||||
nvshmem_free((void*) handle->uc_ptr);
|
||||
#else
|
||||
NVLSCudaAllocator::free(handle);
|
||||
#endif
|
||||
delete handle;
|
||||
#else
|
||||
TLLM_THROW("ipcNvlsFree needs to be compiled with ENABLE_MULTI_DEVICE");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -72,6 +72,7 @@ struct NcclIpcSocket
|
||||
int fd;
|
||||
char socketName[NCCL_IPC_SOCKNAME_LEN];
|
||||
uint32_t volatile* abortFlag;
|
||||
uint64_t hash;
|
||||
};
|
||||
|
||||
/*
|
||||
@ -89,6 +90,7 @@ std::shared_ptr<NcclIpcSocket> ncclIpcSocketInit(int rank, uint64_t hash, uint32
|
||||
TLLM_NCCL_CHECK(ncclInternalError); // throws
|
||||
}
|
||||
|
||||
handle->hash = hash;
|
||||
handle->fd = -1;
|
||||
handle->socketName[0] = '\0';
|
||||
if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0)
|
||||
@ -310,9 +312,9 @@ void ncclIpcSocketSendMsg(
|
||||
}
|
||||
}
|
||||
|
||||
void ncclIpcSocketSendFd(std::shared_ptr<NcclIpcSocket> handle, int sendFd, int rank, uint64_t hash)
|
||||
void ncclIpcSocketSendFd(std::shared_ptr<NcclIpcSocket> handle, int sendFd, int rank)
|
||||
{
|
||||
ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash);
|
||||
ncclIpcSocketSendMsg(handle, NULL, 0, sendFd, rank, handle->hash);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -33,7 +33,7 @@ void ncclIpcSocketClose(std::shared_ptr<NcclIpcSocket> handle);
|
||||
|
||||
int ncclIpcSocketRecvFd(std::shared_ptr<NcclIpcSocket> handle);
|
||||
|
||||
void ncclIpcSocketSendFd(std::shared_ptr<NcclIpcSocket> handle, int fd, int rank, uint64_t hash);
|
||||
void ncclIpcSocketSendFd(std::shared_ptr<NcclIpcSocket> handle, int fd, int rank);
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
|
||||
@ -803,7 +803,6 @@ public:
|
||||
if (mCapacity < newSize)
|
||||
{
|
||||
release();
|
||||
printf("MulticastBuffer resize: %d B\n", int(toBytes(newSize)));
|
||||
mHandle = ipcNvlsAllocate(toBytes(newSize), mRanks);
|
||||
|
||||
TLLM_CHECK(mHandle->size % BufferDataType(mType).getSize() == 0);
|
||||
|
||||
@ -45,10 +45,12 @@ add_gtest(mlaChunkedPrefillTest mlaChunkedPrefillTest.cu)
|
||||
if(NOT ENABLE_MULTI_DEVICE EQUAL 0)
|
||||
add_gtest(allReduceKernelTest allReduce/allReduceKernelTest.cu)
|
||||
add_gtest(allReduceFusionTest allReduce/allReduceFusionTest.cu)
|
||||
# add_gtest(gemmAllReduceTest allReduce/gemmAllReduceTest.cu)
|
||||
# if(USING_OSS_CUTLASS_ALLREDUCE_GEMM) target_link_libraries(gemmAllReduceTest
|
||||
# PRIVATE ar_gemm_src) target_compile_definitions(gemmAllReduceTest PRIVATE
|
||||
# USING_OSS_CUTLASS_ALLREDUCE_GEMM) endif()
|
||||
add_gtest(gemmAllReduceTest allReduce/gemmAllReduceTest.cu)
|
||||
if(USING_OSS_CUTLASS_ALLREDUCE_GEMM)
|
||||
target_link_libraries(gemmAllReduceTest PRIVATE ar_gemm_src)
|
||||
target_compile_definitions(gemmAllReduceTest
|
||||
PRIVATE USING_OSS_CUTLASS_ALLREDUCE_GEMM)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_gtest(
|
||||
|
||||
@ -914,10 +914,14 @@ int main(int argc, char** argv)
|
||||
notSupported = true;
|
||||
}
|
||||
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(COMM_SESSION.getRank()));
|
||||
int device_count;
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceCount(&device_count));
|
||||
|
||||
int device_id = COMM_SESSION.getRank() % device_count;
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(device_id));
|
||||
|
||||
cudaDeviceProp props;
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&props, COMM_SESSION.getRank()));
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&props, device_id));
|
||||
|
||||
if (props.major < 9)
|
||||
{
|
||||
|
||||
@ -2058,6 +2058,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
|
||||
multiGpuJobs = parallelJobs.findAll{(it.key.contains("4_GPUs") || it.key.contains("8_GPUs")) && !it.key.contains("Post-Merge")}
|
||||
println multiGpuJobs.keySet()
|
||||
multiGpuJobsPostMerge = parallelJobs.findAll{(it.key.contains("4_GPUs") || it.key.contains("8_GPUs")) && it.key.contains("Post-Merge")}
|
||||
|
||||
parallelJobs += docBuildJobs
|
||||
parallelJobs += sanityCheckJobs
|
||||
@ -2105,7 +2106,11 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
|
||||
// Check --only-multi-gpu-test, if true, only run multi-GPU test stages.
|
||||
if (testFilter[(ONLY_MULTI_GPU_TEST)]) {
|
||||
parallelJobsFiltered = multiGpuJobs
|
||||
if (testFilter[(IS_POST_MERGE)]) {
|
||||
parallelJobsFiltered = multiGpuJobsPostMerge
|
||||
} else {
|
||||
parallelJobsFiltered = multiGpuJobs
|
||||
}
|
||||
}
|
||||
|
||||
// Check --disable-multi-gpu-test, if true, remove multi-GPU test stages.
|
||||
|
||||
@ -681,7 +681,28 @@ def main(*,
|
||||
new_library_path = "/usr/local/cuda/compat:/usr/local/cuda/compat/lib:/usr/local/cuda/compat/lib.real"
|
||||
if 'LD_LIBRARY_PATH' in env_ld:
|
||||
new_library_path += f":{env_ld['LD_LIBRARY_PATH']}"
|
||||
|
||||
result = build_run("find /usr -name *libnvidia-ml.so*",
|
||||
capture_output=True,
|
||||
text=True)
|
||||
assert result.returncode == 0, f"Failed to run find *libnvidia-ml.so*: {result.stderr}"
|
||||
|
||||
# Build containers only contain stub version of libnvidia-ml.so and not the real version.
|
||||
# If real version not in system, we need to create symbolic link to stub version to prevent import errors.
|
||||
if "libnvidia-ml.so.1" not in result.stdout:
|
||||
if "libnvidia-ml.so" in result.stdout:
|
||||
line = result.stdout.splitlines()[0]
|
||||
path = os.path.dirname(line)
|
||||
new_library_path += f":{path}"
|
||||
build_run(f"ln -s {line} {path}/libnvidia-ml.so.1")
|
||||
else:
|
||||
print(
|
||||
f"Failed to find libnvidia-ml.so: {result.stderr}",
|
||||
file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
env_ld["LD_LIBRARY_PATH"] = new_library_path
|
||||
|
||||
build_run(
|
||||
f"\"{venv_python}\" -m pybind11_stubgen -o . bindings --exit-code",
|
||||
env=env_ld)
|
||||
|
||||
@ -765,8 +765,7 @@ class TestLlama3_1_8B(CliFlowAccuracyTestHarness):
|
||||
tp_size=4,
|
||||
extra_build_args=extra_build_args)
|
||||
|
||||
@skip_pre_ada
|
||||
@skip_post_blackwell
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize(
|
||||
"gemm_allreduce", [False, pytest.param(True, marks=skip_no_nvls)],
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import List, Optional
|
||||
|
||||
import defs.cpp.cpp_common as _cpp
|
||||
import pytest
|
||||
from defs.conftest import skip_no_nvls
|
||||
|
||||
|
||||
# Helper filter for disagg google tests
|
||||
@ -65,6 +66,28 @@ def run_mpi_utils_tests(build_dir, timeout=300):
|
||||
timeout=timeout)
|
||||
|
||||
|
||||
def run_gemm_allreduce_tests(build_dir, nprocs, timeout=300):
|
||||
|
||||
tests_dir = build_dir / "tests"
|
||||
mgpu_env = get_multi_gpu_env()
|
||||
|
||||
gemm_allreduce_test = [
|
||||
"mpirun",
|
||||
"-n",
|
||||
f"{nprocs}",
|
||||
"--allow-run-as-root",
|
||||
"unit_tests/kernels/gemmAllReduceTest",
|
||||
"--m=2032",
|
||||
"--n=8200",
|
||||
"--k=1024",
|
||||
"--iterations=1",
|
||||
]
|
||||
_cpp.run_command(gemm_allreduce_test,
|
||||
cwd=tests_dir,
|
||||
env=mgpu_env,
|
||||
timeout=timeout)
|
||||
|
||||
|
||||
def run_cache_transceiver_tests(build_dir: _pl.Path,
|
||||
nprocs=2,
|
||||
kv_cache_type=KVCacheType.MPI,
|
||||
@ -451,6 +474,15 @@ def test_mpi_utils(build_google_tests, build_dir):
|
||||
run_mpi_utils_tests(build_dir, timeout=300)
|
||||
|
||||
|
||||
@skip_no_nvls
|
||||
@pytest.mark.parametrize("build_google_tests", ["90", "100"], indirect=True)
|
||||
@pytest.mark.parametrize("nprocs", [2, 4], ids=["2proc", "4proc"])
|
||||
def test_fused_gemm_allreduce(build_google_tests, nprocs, build_dir):
|
||||
|
||||
if platform.system() != "Windows":
|
||||
run_gemm_allreduce_tests(build_dir, nprocs, timeout=300)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
|
||||
indirect=True)
|
||||
@pytest.mark.parametrize("kvcache_type", [KVCacheType.MPI, KVCacheType.UCX],
|
||||
|
||||
@ -3368,8 +3368,6 @@ def test_llm_llama_v3_2_smoothquant_1node_single_gpu(
|
||||
venv_check_call(llm_venv, summary_cmd)
|
||||
|
||||
|
||||
# TODO: remove skip after support fp8 rowwise gemm on B200
|
||||
@skip_post_blackwell
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("fp8_quant",
|
||||
|
||||
@ -118,6 +118,7 @@ l0_dgx_h100:
|
||||
tests:
|
||||
# ------------- CPP tests ---------------
|
||||
- cpp/test_multi_gpu.py::test_mpi_utils[90]
|
||||
- cpp/test_multi_gpu.py::test_fused_gemm_allreduce[4proc-90]
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-mpi_kvcache-90]
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[2proc-ucx_kvcache-90]
|
||||
- cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mpi_kvcache-90]
|
||||
|
||||
@ -335,7 +335,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5247271)
|
||||
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
|
||||
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5274229)
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin] SKIP (https://nvbugs/5247786)
|
||||
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen1.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
full:B200/examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2.5_7b_chat-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5247837)
|
||||
@ -390,7 +389,6 @@ full:B200/examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it
|
||||
full:B200/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5295470)
|
||||
examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization_long] SKIP (https://nvbugs/5324976)
|
||||
examples/test_prompt_lookup.py::test_llm_prompt_lookup_1gpu[no_streaming-gpt2-use_cpp_session-use_tokens-max_matching_ngram_size_2-prompt_lookup_num_tokens_8-float16-bs1] SKIP (https://nvbugs/5344070)
|
||||
examples/test_llama.py::test_llm_llama_v3_1_1node_multi_gpus[enable_gemm_allreduce_plugin-llama-3.1-70b-disable_fp8] SKIP (https://nvbugs/5343850)
|
||||
examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_py_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (https://nvbugs/5333849)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5333818)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5333818)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user