mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
827 lines
34 KiB
Plaintext
827 lines
34 KiB
Plaintext
/*
|
|
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include "tensorrt_llm/common/config.h"
|
|
#include "tensorrt_llm/common/envUtils.h"
|
|
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
|
#include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
|
|
#include "tensorrt_llm/kernels/quantization.cuh"
|
|
#include <cooperative_groups.h>
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
|
|
namespace kernels::ar_fusion
|
|
{
|
|
template <int NRanks>
|
|
struct SyncComm
|
|
{
|
|
__device__ __forceinline__ SyncComm(void** workspace)
|
|
{
|
|
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
|
|
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[1];
|
|
flag_value = *flag_ptr;
|
|
for (int r = 0; r < NRanks; ++r)
|
|
{
|
|
comm_bufs[r] = workspace[r];
|
|
barrier_flags[r] = workspace[NRanks + r];
|
|
}
|
|
__syncthreads();
|
|
if (threadIdx.x == 0)
|
|
{
|
|
atomicAdd(counter_ptr, 1);
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void update(int new_flag_value)
|
|
{
|
|
if (blockIdx.x == 0 && threadIdx.x == 0)
|
|
{
|
|
while (*reinterpret_cast<int volatile*>(counter_ptr) != gridDim.x)
|
|
{
|
|
}
|
|
*flag_ptr = new_flag_value;
|
|
*counter_ptr = 0;
|
|
}
|
|
}
|
|
|
|
int* counter_ptr;
|
|
int* flag_ptr;
|
|
void* comm_bufs[NRanks];
|
|
void* barrier_flags[NRanks];
|
|
int flag_value;
|
|
};
|
|
|
|
template <int NRanks>
|
|
struct LamportComm
|
|
{
|
|
__device__ __forceinline__ LamportComm(void** workspace, int rank)
|
|
{
|
|
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
|
|
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
|
|
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
|
|
flag_value = *flag_ptr;
|
|
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
|
|
clear_size = *clear_ptr;
|
|
int data_offset = flag_value % 3;
|
|
int clear_offset = (flag_value + 2) % 3;
|
|
for (int r = 0; r < NRanks; ++r)
|
|
{
|
|
data_bufs[r] = reinterpret_cast<uint8_t*>(workspace[2 * NRanks + r]) + data_offset * comm_size;
|
|
}
|
|
clear_buf = reinterpret_cast<uint8_t*>(workspace[2 * NRanks + rank]) + clear_offset * comm_size;
|
|
__syncthreads();
|
|
if (threadIdx.x == 0)
|
|
{
|
|
atomicAdd(counter_ptr, 1);
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void update(int new_clear_size)
|
|
{
|
|
if (blockIdx.x == 0 && threadIdx.x == 0)
|
|
{
|
|
while (*reinterpret_cast<int volatile*>(counter_ptr) != gridDim.x)
|
|
{
|
|
}
|
|
*flag_ptr = (flag_value + 1) % 3;
|
|
*clear_ptr = new_clear_size;
|
|
*counter_ptr = 0;
|
|
}
|
|
}
|
|
|
|
int* counter_ptr;
|
|
int* flag_ptr;
|
|
int* clear_ptr;
|
|
uint8_t* data_bufs[NRanks];
|
|
uint8_t* clear_buf;
|
|
int clear_size;
|
|
int flag_value;
|
|
};
|
|
|
|
template <int NRanks>
|
|
class Barrier
|
|
{
|
|
public:
|
|
__device__ __forceinline__ Barrier(int rank, SyncComm<NRanks> const& comm)
|
|
{
|
|
if (threadIdx.x < NRanks)
|
|
{
|
|
m_flag_value = comm.flag_value;
|
|
int current_rank = rank;
|
|
int target_rank = threadIdx.x;
|
|
m_target_flag = reinterpret_cast<int*>(comm.barrier_flags[target_rank]) + current_rank;
|
|
m_current_flag
|
|
= reinterpret_cast<int*>(comm.barrier_flags[current_rank]) + blockIdx.x * NRanks + target_rank;
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void sync()
|
|
{
|
|
__syncthreads();
|
|
if (threadIdx.x < NRanks)
|
|
{
|
|
m_flag_value = next_flag(m_flag_value);
|
|
// To avoid the ABA problem, we need to synchronize the correct flag value to all barrier_flags, even if the
|
|
// corresponding CTA has not been launched.
|
|
for (int flag_idx = blockIdx.x; flag_idx < kBarrierFlagCount; flag_idx += gridDim.x)
|
|
{
|
|
asm volatile(
|
|
"st.global.relaxed.sys.b32 [%1], %0;" ::"r"(m_flag_value), "l"(m_target_flag + flag_idx * NRanks));
|
|
}
|
|
// Single release fence
|
|
asm volatile("fence.release.sys;");
|
|
|
|
while (ld_flag(m_current_flag) == prev_flag(m_flag_value))
|
|
{
|
|
}
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
protected:
|
|
__device__ __forceinline__ void st_flag(int* addr, int flag)
|
|
{
|
|
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(addr));
|
|
}
|
|
|
|
__device__ __forceinline__ int ld_flag(int* addr)
|
|
{
|
|
int flag;
|
|
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(addr));
|
|
return flag;
|
|
}
|
|
|
|
__device__ __forceinline__ int next_flag(int flag)
|
|
{
|
|
return flag == 2 ? 0 : flag + 1;
|
|
}
|
|
|
|
__device__ __forceinline__ int prev_flag(int flag)
|
|
{
|
|
return flag == 0 ? 2 : flag - 1;
|
|
}
|
|
|
|
public:
|
|
int m_flag_value;
|
|
|
|
private:
|
|
int* m_target_flag;
|
|
int* m_current_flag;
|
|
};
|
|
|
|
template <typename DType, typename PackedType>
|
|
__device__ __forceinline__ PackedType add128(PackedType const& a, PackedType const& b)
|
|
{
|
|
static constexpr int kMathCount = sizeof(PackedType) / sizeof(DType);
|
|
PackedType c;
|
|
#pragma unroll
|
|
for (int i = 0; i < kMathCount; ++i)
|
|
{
|
|
reinterpret_cast<DType*>(&c)[i] = reinterpret_cast<DType const*>(&a)[i] + reinterpret_cast<DType const*>(&b)[i];
|
|
}
|
|
return c;
|
|
}
|
|
|
|
template <AllReduceFusionPattern Pattern, typename DType>
|
|
class FusedOp
|
|
{
|
|
static constexpr int kMathCount = sizeof(float4) / sizeof(DType);
|
|
|
|
public:
|
|
__device__ __forceinline__ FusedOp(AllReduceFusionParams const& params, int access_id, int access_id_in_token)
|
|
: m_params(params)
|
|
, m_access_id(access_id)
|
|
, m_access_id_in_token(access_id_in_token)
|
|
{
|
|
if constexpr (HasRMSNorm<Pattern>)
|
|
{
|
|
m_gamma_val = reinterpret_cast<float4*>(params.rms_gamma)[m_access_id_in_token];
|
|
}
|
|
if constexpr (HasResidual<Pattern>)
|
|
{
|
|
m_residual_val = reinterpret_cast<float4*>(params.residual_in)[m_access_id];
|
|
}
|
|
if constexpr (GetQuantType<Pattern> == QuantType::kFP8)
|
|
{
|
|
m_scale_factor = 1.f / *params.scale_factor;
|
|
}
|
|
else if constexpr (GetQuantType<Pattern> == QuantType::kFP4)
|
|
{
|
|
m_scale_factor = *params.scale_factor;
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void update(int access_id)
|
|
{
|
|
|
|
if (m_access_id != access_id)
|
|
{
|
|
m_access_id = access_id;
|
|
if constexpr (HasResidual<Pattern>)
|
|
{
|
|
m_residual_val = reinterpret_cast<float4*>(m_params.residual_in)[m_access_id];
|
|
}
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ void operator()(float4 val, int token_id)
|
|
{
|
|
if constexpr (HasAllReduceOut<Pattern>)
|
|
{
|
|
reinterpret_cast<float4*>(m_params.allreduce_out)[m_access_id] = val;
|
|
}
|
|
if constexpr (HasResidual<Pattern>)
|
|
{
|
|
val = add128<DType>(val, m_residual_val);
|
|
if constexpr (HasResidualOut<Pattern>)
|
|
{
|
|
reinterpret_cast<float4*>(m_params.residual_out)[m_access_id] = val;
|
|
}
|
|
}
|
|
if constexpr (HasRMSNorm<Pattern>)
|
|
{
|
|
val = rms_norm(val, m_gamma_val);
|
|
if constexpr (HasNormOut<Pattern>)
|
|
{
|
|
reinterpret_cast<float4*>(m_params.norm_out)[m_access_id] = val;
|
|
}
|
|
}
|
|
if constexpr (GetQuantType<Pattern> == QuantType::kFP4)
|
|
{
|
|
constexpr int SF_VEC_SIZE = 16;
|
|
using PackedVec = PackedVec<DType>;
|
|
PackedVec pack_val = *reinterpret_cast<PackedVec const*>(&val);
|
|
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, 2>(std::nullopt, token_id, m_access_id_in_token,
|
|
std::nullopt, m_params.hidden_dim / SF_VEC_SIZE, reinterpret_cast<uint32_t*>(m_params.scale_out),
|
|
m_params.layout);
|
|
reinterpret_cast<uint32_t*>(m_params.quant_out)[m_access_id]
|
|
= cvt_warp_fp16_to_fp4<DType, SF_VEC_SIZE, false>(pack_val, m_scale_factor, sf_out);
|
|
}
|
|
else if constexpr (GetQuantType<Pattern> == QuantType::kFP8)
|
|
{
|
|
using PackedQuantizedType = std::conditional_t<std::is_same_v<DType, float>, float, float2>;
|
|
PackedQuantizedType ret;
|
|
#pragma unroll
|
|
for (int i = 0; i < kMathCount; ++i)
|
|
{
|
|
reinterpret_cast<__nv_fp8_e4m3*>(&ret)[i] = static_cast<__nv_fp8_e4m3>(
|
|
static_cast<float>(reinterpret_cast<DType*>(&val)[i]) * m_scale_factor);
|
|
}
|
|
reinterpret_cast<PackedQuantizedType*>(m_params.quant_out)[m_access_id] = ret;
|
|
}
|
|
else
|
|
{
|
|
static_assert(GetQuantType<Pattern> == QuantType::kNone, "Invalid quant type");
|
|
}
|
|
}
|
|
|
|
protected:
|
|
__device__ __forceinline__ float4 rms_norm(float4 const& residual, float4 const& gamma)
|
|
{
|
|
__shared__ float s_val;
|
|
float4 norm_out;
|
|
float acc = 0.f;
|
|
#pragma unroll
|
|
for (int i = 0; i < kMathCount; ++i)
|
|
{
|
|
float v = static_cast<float>(reinterpret_cast<DType const*>(&residual)[i]);
|
|
acc += v * v;
|
|
}
|
|
tensorrt_llm::common::blockReduceSumV2<float, 1>(&acc);
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cg::cluster_group cluster = cg::this_cluster();
|
|
if (cluster.num_blocks() > 1)
|
|
{
|
|
if (threadIdx.x == 0)
|
|
{
|
|
s_val = acc;
|
|
acc = 0.f;
|
|
}
|
|
cluster.sync();
|
|
if (threadIdx.x == 0)
|
|
{
|
|
for (int i = 0; i < cluster.num_blocks(); ++i)
|
|
{
|
|
acc += *cluster.map_shared_rank(&s_val, i);
|
|
}
|
|
}
|
|
cluster.sync();
|
|
}
|
|
#endif
|
|
if (threadIdx.x == 0)
|
|
{
|
|
s_val = rsqrtf(acc / m_params.hidden_dim + m_params.rms_eps);
|
|
}
|
|
__syncthreads();
|
|
#pragma unroll
|
|
for (int i = 0; i < kMathCount; ++i)
|
|
{
|
|
reinterpret_cast<DType*>(&norm_out)[i]
|
|
= static_cast<DType>(static_cast<float>(reinterpret_cast<DType const*>(&residual)[i]) * s_val
|
|
* static_cast<float>(reinterpret_cast<DType const*>(&gamma)[i]));
|
|
}
|
|
return norm_out;
|
|
}
|
|
|
|
private:
|
|
AllReduceFusionParams const& m_params;
|
|
int m_access_id;
|
|
int m_access_id_in_token;
|
|
float m_scale_factor;
|
|
float4 m_residual_val;
|
|
float4 m_gamma_val;
|
|
};
|
|
|
|
__device__ __forceinline__ bool is_neg_zero(float v)
|
|
{
|
|
return *reinterpret_cast<uint32_t*>(&v) == 0x80000000;
|
|
}
|
|
|
|
__device__ __forceinline__ bool is_neg_zero(float4 v)
|
|
{
|
|
return is_neg_zero(v.x) || is_neg_zero(v.y) || is_neg_zero(v.z) || is_neg_zero(v.w);
|
|
}
|
|
|
|
__device__ __forceinline__ float4 get_neg_zero()
|
|
{
|
|
float4 vec;
|
|
#pragma unroll
|
|
for (int i = 0; i < 4; ++i)
|
|
{
|
|
reinterpret_cast<uint32_t*>(&vec)[i] = 0x80000000;
|
|
}
|
|
return vec;
|
|
}
|
|
|
|
__device__ __forceinline__ float4 ld_global_volatile(float4* addr)
|
|
{
|
|
float4 val;
|
|
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];"
|
|
: "=f"(val.x), "=f"(val.y), "=f"(val.z), "=f"(val.w)
|
|
: "l"(addr));
|
|
return val;
|
|
}
|
|
|
|
template <typename DType, int NRanks, bool Fp32Acc>
|
|
__device__ __forceinline__ float4 allreduce_sum(float4* vals)
|
|
{
|
|
if constexpr (Fp32Acc)
|
|
{
|
|
static_assert(!std::is_same_v<DType, float>);
|
|
float acc_f32[kElemsPerAccess<DType>];
|
|
#pragma unroll
|
|
for (int i = 0; i < kElemsPerAccess<DType>; ++i)
|
|
{
|
|
acc_f32[i] = static_cast<float>(reinterpret_cast<DType*>(&vals[0])[i]);
|
|
}
|
|
#pragma unroll
|
|
for (int r = 1; r < NRanks; ++r)
|
|
{
|
|
#pragma unroll
|
|
for (int i = 0; i < kElemsPerAccess<DType>; ++i)
|
|
{
|
|
acc_f32[i] += static_cast<float>(reinterpret_cast<DType*>(&vals[r])[i]);
|
|
}
|
|
}
|
|
float4 acc;
|
|
#pragma unroll
|
|
for (int i = 0; i < kElemsPerAccess<DType>; ++i)
|
|
{
|
|
reinterpret_cast<DType*>(&acc)[i] = static_cast<DType>(acc_f32[i]);
|
|
}
|
|
return acc;
|
|
}
|
|
else
|
|
{
|
|
float4 acc = vals[0];
|
|
#pragma unroll
|
|
for (int r = 1; r < NRanks; ++r)
|
|
{
|
|
acc = add128<DType>(acc, vals[r]);
|
|
}
|
|
return acc;
|
|
}
|
|
}
|
|
|
|
template <typename DType>
|
|
class IndexHelper
|
|
{
|
|
public:
|
|
__device__ __forceinline__ IndexHelper(AllReduceFusionParams const& params)
|
|
{
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
namespace cg = cooperative_groups;
|
|
cg::cluster_group cluster = cg::this_cluster();
|
|
cg::grid_group grid = cg::this_grid();
|
|
token_id = grid.cluster_rank();
|
|
access_id_in_token = cluster.thread_rank();
|
|
token_stride = grid.num_clusters();
|
|
#else
|
|
token_id = blockIdx.x;
|
|
access_id_in_token = threadIdx.x;
|
|
token_stride = gridDim.x;
|
|
#endif
|
|
access_id = token_id * params.hidden_dim / kElemsPerAccess<DType> + access_id_in_token;
|
|
access_stride = token_stride * params.hidden_dim / kElemsPerAccess<DType>;
|
|
tot_access = params.size / kElemsPerAccess<DType>;
|
|
}
|
|
|
|
int token_id;
|
|
int access_id_in_token;
|
|
int token_stride;
|
|
int access_id;
|
|
int access_stride;
|
|
int tot_access;
|
|
};
|
|
|
|
template <AllReduceFusionPattern Pattern, typename DType, int NRanks, bool Fp32Acc, bool TriggerCompletionAtEnd = true>
|
|
__global__ void __launch_bounds__(1024) allreduce_fusion_kernel_oneshot_lamport(AllReduceFusionParams params)
|
|
{
|
|
IndexHelper<DType> index_helper(params);
|
|
int token_id = index_helper.token_id;
|
|
int access_id_in_token = index_helper.access_id_in_token;
|
|
int token_stride = index_helper.token_stride;
|
|
int access_id = index_helper.access_id;
|
|
int access_stride = index_helper.access_stride;
|
|
int tot_access = index_helper.tot_access;
|
|
float4 clear_vec = get_neg_zero();
|
|
FusedOp<Pattern, DType> fused_op(params, access_id, access_id_in_token);
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaGridDependencySynchronize();
|
|
if constexpr (!TriggerCompletionAtEnd)
|
|
{
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
}
|
|
#endif
|
|
LamportComm<NRanks> comm(params.workspace, params.rank);
|
|
int clear_access = comm.clear_size / kElemsPerAccess<DType>;
|
|
|
|
for (int idx = access_id; idx < tot_access; idx += access_stride)
|
|
{
|
|
alignas(16) float val[4];
|
|
*reinterpret_cast<float4*>(val) = reinterpret_cast<float4*>(params.allreduce_in)[idx];
|
|
#pragma unroll
|
|
for (int i = 0; i < 4; ++i)
|
|
{
|
|
if (is_neg_zero(val[i]))
|
|
{
|
|
val[i] = 0.f;
|
|
}
|
|
}
|
|
#pragma unroll
|
|
for (int r = 0; r < NRanks; ++r)
|
|
{
|
|
// Push data to other ranks
|
|
reinterpret_cast<float4*>(comm.data_bufs[r])[params.rank * tot_access + idx]
|
|
= *reinterpret_cast<float4*>(val);
|
|
}
|
|
}
|
|
for (int idx = access_id; idx < clear_access; idx += access_stride)
|
|
{
|
|
// Clear comm buffer that previous kernel used
|
|
reinterpret_cast<float4*>(comm.clear_buf)[idx] = clear_vec;
|
|
}
|
|
|
|
for (int idx = access_id, tidx = token_id; idx < tot_access; idx += access_stride, tidx += token_stride)
|
|
{
|
|
fused_op.update(idx);
|
|
float4 vals[NRanks];
|
|
bool done = false;
|
|
while (!done)
|
|
{
|
|
done = true;
|
|
#pragma unroll
|
|
for (int r = 0; r < NRanks; ++r)
|
|
{
|
|
// LDG.128 from local rank
|
|
vals[r]
|
|
= ld_global_volatile(&reinterpret_cast<float4*>(comm.data_bufs[params.rank])[r * tot_access + idx]);
|
|
done &= !is_neg_zero(vals[r]);
|
|
}
|
|
}
|
|
float4 sum_val = allreduce_sum<DType, NRanks, Fp32Acc>(vals);
|
|
fused_op(sum_val, tidx);
|
|
}
|
|
|
|
comm.update(params.size * NRanks);
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
if constexpr (TriggerCompletionAtEnd)
|
|
{
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template <AllReduceFusionPattern Pattern, typename DType, int NRanks, bool Fp32Acc>
|
|
__global__ void __launch_bounds__(1024) allreduce_fusion_kernel_twoshot_sync(
|
|
AllReduceFusionParams params, std::array<int, NRanks> begin_tokens, std::array<int, NRanks> token_num_per_ranks)
|
|
{
|
|
IndexHelper<DType> index_helper(params);
|
|
int token_id = index_helper.token_id;
|
|
int access_id_in_token = index_helper.access_id_in_token;
|
|
int token_stride = index_helper.token_stride;
|
|
int access_id = index_helper.access_id;
|
|
int access_stride = index_helper.access_stride;
|
|
int tot_access = index_helper.tot_access;
|
|
FusedOp<Pattern, DType> fused_op(params, access_id, access_id_in_token);
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaGridDependencySynchronize();
|
|
#endif
|
|
SyncComm<NRanks> comm(params.workspace);
|
|
#pragma unroll
|
|
for (int r = 0; r < NRanks; ++r)
|
|
{
|
|
int comm_access_id = access_id + begin_tokens[r] * params.hidden_dim / kElemsPerAccess<DType>;
|
|
int comm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * params.hidden_dim / kElemsPerAccess<DType>;
|
|
for (int idx = comm_access_id; idx < comm_tot_access; idx += access_stride)
|
|
{
|
|
reinterpret_cast<float4*>(comm.comm_bufs[params.rank])[idx]
|
|
= reinterpret_cast<float4*>(params.allreduce_in)[idx];
|
|
}
|
|
}
|
|
Barrier<NRanks> barrier(params.rank, comm);
|
|
barrier.sync();
|
|
int comm_access_id = access_id + begin_tokens[params.rank] * params.hidden_dim / kElemsPerAccess<DType>;
|
|
int comm_tot_access
|
|
= (begin_tokens[params.rank] + token_num_per_ranks[params.rank]) * params.hidden_dim / kElemsPerAccess<DType>;
|
|
for (int idx = comm_access_id; idx < comm_tot_access; idx += access_stride)
|
|
{
|
|
float4 vals[NRanks];
|
|
#pragma unroll
|
|
for (int r = 0; r < NRanks; ++r)
|
|
{
|
|
vals[r] = reinterpret_cast<float4*>(comm.comm_bufs[r])[idx];
|
|
}
|
|
float4 sum_val = allreduce_sum<DType, NRanks, Fp32Acc>(vals);
|
|
#pragma unroll
|
|
for (int r = 0; r < NRanks; ++r)
|
|
{
|
|
reinterpret_cast<float4*>(comm.comm_bufs[r])[tot_access + idx] = sum_val;
|
|
}
|
|
}
|
|
barrier.sync();
|
|
#pragma unroll
|
|
for (int r = 0; r < NRanks; ++r)
|
|
{
|
|
int comm_access_id = access_id + begin_tokens[r] * params.hidden_dim / kElemsPerAccess<DType>;
|
|
int comm_token_id = token_id + begin_tokens[r];
|
|
int comm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * params.hidden_dim / kElemsPerAccess<DType>;
|
|
for (int idx = comm_access_id, tidx = comm_token_id; idx < comm_tot_access;
|
|
idx += access_stride, tidx += token_stride)
|
|
{
|
|
fused_op.update(idx);
|
|
float4 sum_val = reinterpret_cast<float4*>(comm.comm_bufs[params.rank])[tot_access + idx];
|
|
fused_op(sum_val, tidx);
|
|
}
|
|
}
|
|
comm.update(barrier.m_flag_value);
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
#endif
|
|
}
|
|
|
|
int get_sm_count()
|
|
{
|
|
static int sm_count = 0;
|
|
if (sm_count == 0)
|
|
{
|
|
int device_id;
|
|
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
|
|
cudaDeviceProp device_prop;
|
|
cudaGetDeviceProperties(&device_prop, device_id);
|
|
sm_count = device_prop.multiProcessorCount;
|
|
}
|
|
return sm_count;
|
|
}
|
|
|
|
template <AllReduceFusionPattern Pattern, typename DType, int NRanks, bool Fp32Acc, bool TriggerCompletionAtEnd = true>
|
|
void launch_oneshot_lamport(AllReduceFusionParams const& params, cudaLaunchConfig_t& cfg)
|
|
{
|
|
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg,
|
|
allreduce_fusion_kernel_oneshot_lamport<Pattern, DType, NRanks, Fp32Acc, TriggerCompletionAtEnd>, params));
|
|
}
|
|
|
|
template <AllReduceFusionPattern Pattern, typename DType, int NRanks, bool Fp32Acc>
|
|
void launch_twoshot_sync(AllReduceFusionParams const& params, cudaLaunchConfig_t& cfg,
|
|
std::array<int, NRanks> begin_tokens, std::array<int, NRanks> token_num_per_ranks)
|
|
{
|
|
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&cfg, allreduce_fusion_kernel_twoshot_sync<Pattern, DType, NRanks, Fp32Acc>,
|
|
params, begin_tokens, token_num_per_ranks));
|
|
}
|
|
|
|
bool use_oneshot(int token_num)
|
|
{
|
|
return token_num <= kOneShotMaxToken;
|
|
}
|
|
|
|
template <AllReduceFusionPattern Pattern, typename DType, int NRanks, bool Fp32Acc>
|
|
void allreduce_fusion_kernel_launcher(AllReduceFusionParams const& params)
|
|
{
|
|
TLLM_CHECK(params.size % params.hidden_dim == 0);
|
|
TLLM_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
|
|
static int SM = tensorrt_llm::common::getSMVersion();
|
|
int token_num = params.size / params.hidden_dim;
|
|
bool oneshot = params.use_oneshot;
|
|
int cluster_num = token_num;
|
|
std::array<int, NRanks> begin_tokens, token_num_per_ranks;
|
|
if (!oneshot)
|
|
{
|
|
int remaining_token = token_num % NRanks;
|
|
int token_num_per_rank = token_num / NRanks;
|
|
cluster_num = token_num_per_rank;
|
|
if (remaining_token)
|
|
{
|
|
cluster_num++;
|
|
}
|
|
for (int r = 0; r < NRanks; ++r)
|
|
{
|
|
begin_tokens[r] = r * token_num_per_rank + (remaining_token > r ? r : remaining_token);
|
|
token_num_per_ranks[r] = token_num_per_rank + (remaining_token > r ? 1 : 0);
|
|
}
|
|
}
|
|
int threads_per_token = params.hidden_dim / kElemsPerAccess<DType>;
|
|
int cluster_size;
|
|
if (SM >= 90)
|
|
{
|
|
cluster_size = 8;
|
|
}
|
|
else
|
|
{
|
|
cluster_size = 1;
|
|
}
|
|
while (threads_per_token % cluster_size != 0 && cluster_size > 1)
|
|
{
|
|
cluster_size /= 2;
|
|
}
|
|
int threads_per_block = threads_per_token / cluster_size;
|
|
while (threads_per_block < 128 && cluster_size >= 2)
|
|
{
|
|
threads_per_block *= 2;
|
|
cluster_size /= 2;
|
|
}
|
|
int sm_count = get_sm_count();
|
|
while (cluster_num * cluster_size > sm_count && cluster_size > 1 && threads_per_block <= 512)
|
|
{
|
|
threads_per_block *= 2;
|
|
cluster_size /= 2;
|
|
}
|
|
TLLM_CHECK(oneshot || threads_per_block >= params.nranks);
|
|
int block_size = threads_per_block;
|
|
TLLM_CHECK(block_size <= 1024 && cluster_size > 0);
|
|
|
|
int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size;
|
|
cudaLaunchConfig_t cfg;
|
|
cudaLaunchAttribute attribute[2];
|
|
cfg.gridDim = grid_size;
|
|
cfg.blockDim = block_size;
|
|
cfg.dynamicSmemBytes = 0;
|
|
cfg.stream = params.stream;
|
|
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
attribute[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL() ? 1 : 0;
|
|
attribute[1].id = cudaLaunchAttributeClusterDimension;
|
|
attribute[1].val.clusterDim.x = cluster_size;
|
|
attribute[1].val.clusterDim.y = 1;
|
|
attribute[1].val.clusterDim.z = 1;
|
|
cfg.attrs = attribute;
|
|
cfg.numAttrs = SM >= 90 ? 2 : 0;
|
|
if (oneshot)
|
|
{
|
|
bool trigger_completion_at_end = params.trigger_completion_at_end;
|
|
if (trigger_completion_at_end)
|
|
{
|
|
launch_oneshot_lamport<Pattern, DType, NRanks, Fp32Acc, true>(params, cfg);
|
|
}
|
|
else
|
|
{
|
|
launch_oneshot_lamport<Pattern, DType, NRanks, Fp32Acc, false>(params, cfg);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
launch_twoshot_sync<Pattern, DType, NRanks, Fp32Acc>(params, cfg, begin_tokens, token_num_per_ranks);
|
|
}
|
|
}
|
|
|
|
bool use_fp32_acc()
|
|
{
|
|
// we use fp16 acc type by default due to keep align with nccl
|
|
static char* fp32_acc = std::getenv("ALL_REDUCE_FUSION_KERNEL_ACC_FP32");
|
|
return fp32_acc != nullptr;
|
|
}
|
|
|
|
void allreduce_fusion_op(AllReduceFusionParams const& params)
|
|
{
|
|
#define DISPATCH_ACC_TYPE(DType, Pattern, NRanks) \
|
|
if constexpr (std::is_same_v<DType, float>) \
|
|
{ \
|
|
return allreduce_fusion_kernel_launcher<Pattern, DType, NRanks, false>(params); \
|
|
} \
|
|
else \
|
|
{ \
|
|
if (fp32_acc) \
|
|
{ \
|
|
return allreduce_fusion_kernel_launcher<Pattern, DType, NRanks, true>(params); \
|
|
} \
|
|
else \
|
|
{ \
|
|
return allreduce_fusion_kernel_launcher<Pattern, DType, NRanks, false>(params); \
|
|
} \
|
|
}
|
|
|
|
#define DISPATCH_PATTERN(DType, NRanks) \
|
|
if (params.pattern == AllReduceFusionPattern::kAllReduce) \
|
|
{ \
|
|
DISPATCH_ACC_TYPE(DType, AllReduceFusionPattern::kAllReduce, NRanks); \
|
|
} \
|
|
else if (params.pattern == AllReduceFusionPattern::kARResidualRMSNorm) \
|
|
{ \
|
|
DISPATCH_ACC_TYPE(DType, AllReduceFusionPattern::kARResidualRMSNorm, NRanks); \
|
|
} \
|
|
else if (params.pattern == AllReduceFusionPattern::kARResidualRMSNormFP8Quant) \
|
|
{ \
|
|
DISPATCH_ACC_TYPE(DType, AllReduceFusionPattern::kARResidualRMSNormFP8Quant, NRanks); \
|
|
} \
|
|
else if (params.pattern == AllReduceFusionPattern::kARResidualRMSNormFP4Quant) \
|
|
{ \
|
|
if constexpr (!std::is_same_v<DType, float>) \
|
|
{ \
|
|
DISPATCH_ACC_TYPE(DType, AllReduceFusionPattern::kARResidualRMSNormFP4Quant, NRanks); \
|
|
} \
|
|
else \
|
|
{ \
|
|
TLLM_CHECK_WITH_INFO(false, \
|
|
"allreduce_fusion_kernel: AllReduceFusionPattern=kARResidualRMSNormFP4Quant can not work with " \
|
|
"DType=float!"); \
|
|
} \
|
|
} \
|
|
else if (params.pattern == AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant) \
|
|
{ \
|
|
DISPATCH_ACC_TYPE(DType, AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant, NRanks); \
|
|
} \
|
|
else if (params.pattern == AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant) \
|
|
{ \
|
|
if constexpr (!std::is_same_v<DType, float>) \
|
|
{ \
|
|
DISPATCH_ACC_TYPE(DType, AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant, NRanks); \
|
|
} \
|
|
else \
|
|
{ \
|
|
TLLM_CHECK_WITH_INFO(false, \
|
|
"allreduce_fusion_kernel: AllReduceFusionPattern=kARResidualRMSNormOutFP4Quant can not work with " \
|
|
"DType=float!"); \
|
|
} \
|
|
} \
|
|
else \
|
|
{ \
|
|
TLLM_CHECK_WITH_INFO(false, "allreduce_fusion_kernel: unsupported pattern!"); \
|
|
}
|
|
|
|
#define DISPATCH_DTYPE(NRanks) \
|
|
if (params.dtype == nvinfer1::DataType::kHALF) \
|
|
{ \
|
|
DISPATCH_PATTERN(half, NRanks); \
|
|
} \
|
|
else if (params.dtype == nvinfer1::DataType::kBF16) \
|
|
{ \
|
|
DISPATCH_PATTERN(__nv_bfloat16, NRanks); \
|
|
} \
|
|
else if (params.dtype == nvinfer1::DataType::kFLOAT) \
|
|
{ \
|
|
DISPATCH_PATTERN(float, NRanks); \
|
|
} \
|
|
else \
|
|
{ \
|
|
TLLM_CHECK_WITH_INFO(false, "allreduce_fusion_kernel: unsupported dtype!"); \
|
|
}
|
|
|
|
#define DISPATCH_RANKS(NRanks) \
|
|
if (params.nranks == NRanks) \
|
|
{ \
|
|
DISPATCH_DTYPE(NRanks); \
|
|
}
|
|
|
|
bool fp32_acc = use_fp32_acc();
|
|
DISPATCH_RANKS(2);
|
|
DISPATCH_RANKS(4);
|
|
DISPATCH_RANKS(8);
|
|
DISPATCH_RANKS(16);
|
|
TLLM_CHECK_WITH_INFO(false, "allreduce_fusion_kernel: unsupported ranks number!");
|
|
}
|
|
}; // namespace kernels::ar_fusion
|
|
|
|
TRTLLM_NAMESPACE_END
|