TensorRT-LLMs/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu
石晓伟 548b5b7310
Update TensorRT-LLM (#2532)
* blossom-ci.yml: run vulnerability scan on blossom

* open source efb18c1256f8c9c3d47b7d0c740b83e5d5ebe0ec

---------

Co-authored-by: niukuo <6831097+niukuo@users.noreply.github.com>
Co-authored-by: pei0033 <59505847+pei0033@users.noreply.github.com>
Co-authored-by: Kyungmin Lee <30465912+lkm2835@users.noreply.github.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2024-12-04 21:16:56 +08:00

1224 lines
58 KiB
Plaintext

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "userbuffers.h"
#include "utils.h"
#include <assert.h>
#include <stdio.h>
namespace tensorrt_llm::kernels::ub
{
using namespace tensorrt_llm::runtime::ub;
#define MAX_THREADS 1024
#define TIMEOUT 200000000000ull
template <typename DType, int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rw(int const op, int const flagoffset, int const firstrank, int const myrank,
int const gpustep, size_t const lineoffset, int const numlines, void** commbuff, int const handleridx)
{
__shared__ int4* userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
int const blockflagoffset = MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int*>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - MAX_OPS;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id)
{
if (clock64() - s > TIMEOUT)
{
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
}
reduce_id++;
}
__syncthreads();
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads();
for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines;
line += blockDim.x * gridDim.x * RANKS)
{
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
{
val[i] = userptr[dest[i]][lineoffset + line];
}
int4 sum = val[0];
DType* s = reinterpret_cast<DType*>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++)
{
DType* x = reinterpret_cast<DType*>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++)
s[j] += x[j];
}
#pragma unroll
for (int i = 0; i < RANKS; i++)
{
userptr[dest[i]][lineoffset + line] = sum;
}
}
__syncthreads();
if (threadIdx.x == 0)
__threadfence_system();
__syncthreads();
if (threadIdx.x < RANKS)
{
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id)
{
if (clock64() - s > TIMEOUT)
{
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
}
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Volta,Hopper)
template <typename DType, int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr(int const op, int const flagoffset, int const firstrank, int const myrank,
int const gpustep, size_t const lineoffset, int const numlines, void** commbuff, int const handleridx)
{
__shared__ int4* userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
int const blockflagoffset = MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int*>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - MAX_OPS;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id)
{
if (clock64() - s > TIMEOUT)
{
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
}
reduce_id++;
}
__syncthreads();
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads();
for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines;
line += blockDim.x * gridDim.x * RANKS)
{
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++)
{
val[i] = userptr[dest[i]][lineoffset + line];
}
int4 sum = val[0];
DType* s = reinterpret_cast<DType*>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++)
{
DType* x = reinterpret_cast<DType*>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++)
s[j] += x[j];
}
userptr[myrank][lineoffset + line] = sum;
}
#ifdef ALLREDUCEONLYRS
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
return;
#endif
__syncthreads();
if (threadIdx.x == 0)
__threadfence();
__syncthreads();
if (threadIdx.x < RANKS)
{
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id)
{
if (clock64() - s > TIMEOUT)
{
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
}
}
int skipmy = 0;
#pragma unroll
for (int i = 0; i < RANKS; i++)
{
int dst = (i + warp + myrank) & (RANKS - 1);
if (dst == myrank)
{
skipmy++;
continue;
}
dest[i - skipmy] = dst;
}
__syncthreads();
for (int line = threadIdx.x + blockDim.x * RANKS * blockIdx.x; line < numlines;
line += blockDim.x * gridDim.x * RANKS)
{
int4 val[RANKS - 1];
#pragma unroll
for (int i = 0; i < RANKS - 1; i++)
{
val[i] = userptr[dest[i]][lineoffset + line + blockDim.x * dest[i]];
}
#pragma unroll
for (int i = 0; i < RANKS - 1; i++)
{
userptr[myrank][lineoffset + line + blockDim.x * dest[i]] = val[i];
}
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Ampere)
#define ATOMIC_CONSUMER(chunk) \
if (counters) \
{ \
if (threadIdx.x == 0 && blockIdx.x == 0) \
{ \
int old_val; \
while (0 != (old_val = atomicCAS(((unsigned int*) counters) + chunk, 0, 0))) \
{ \
} \
((unsigned int*) counters)[chunk] = 1; \
asm volatile("fence.sc.gpu;\n"); \
} \
if (blockIdx.x == 0) \
__syncthreads(); \
}
#define ATOMIC_PRODUCER(chunk) \
if (counters) \
{ \
((unsigned int*) counters)[chunk] = 0; \
asm volatile("fence.sc.gpu;\n"); \
}
#if __CUDA_ARCH__ >= 900
template <typename ValType, typename PtrType>
__device__ __forceinline__ void MULTIMEM_ST(ValType val, PtrType ptr)
{
asm volatile(
"multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)
: "memory");
}
template <typename ValType, typename PtrType>
__device__ __forceinline__ void MULTIMEM_ST2(ValType& val, PtrType ptr)
{
asm volatile("multimem.st.global.v2.f32 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y) : "memory");
}
template <typename DType, typename ValType, typename PtrType>
__device__ __forceinline__ void MULTIMEM_LD(ValType& val, PtrType ptr)
{
if constexpr (std::is_same_v<DType, half>)
{
asm("multimem.ld_reduce.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
}
#ifdef ENABLE_BF16
if constexpr (std::is_same_v<DType, __nv_bfloat16>)
{
asm("multimem.ld_reduce.global.add.v4.bf16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
}
#endif
}
// All MC kernels here
template <typename DType, int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(int const op, int const flagoffset,
int const firstrank, int const myrank, int const gpustep, size_t const lineoffset, int const numlines,
void** commbuff, int const handleridx, float4* mc_ptr)
{
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
int const blockflagoffset = MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int*>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - MAX_OPS;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &(myptr[targetgpu]);
#ifdef UB_TIMEOUT_ENABLED
clock_t s = clock64();
#endif
while (*flag < reduce_id)
{
#ifdef UB_TIMEOUT_ENABLED
if (clock64() - s > TIMEOUT)
{
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
#endif
}
reduce_id++;
}
__syncthreads();
#define UNROLL_MC 8
int const loop_step0 = blockDim.x * gridDim.x * RANKS;
int const loop_step = loop_step0 * UNROLL_MC;
int const start_elem = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x);
int const end_elem = max(start_elem, numlines);
int const aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step;
int const end_aligned = start_elem + aligned_elem;
for (int line = start_elem; line < end_aligned; line += loop_step)
{
uint4 val[UNROLL_MC];
#pragma unroll
for (int i = 0; i < UNROLL_MC; i++)
MULTIMEM_LD<DType>(val[i], mc_ptr + (lineoffset + line + i * loop_step0));
#pragma unroll
for (int i = 0; i < UNROLL_MC; i++)
MULTIMEM_ST(val[i], mc_ptr + (lineoffset + line + i * loop_step0));
}
for (int line = end_aligned; line < end_elem; line += loop_step0)
{
uint4 val;
MULTIMEM_LD<DType>(val, mc_ptr + (lineoffset + line));
MULTIMEM_ST(val, mc_ptr + (lineoffset + line));
}
__syncthreads();
if (threadIdx.x == 0)
__threadfence_system();
__syncthreads();
if (threadIdx.x < RANKS)
{
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &myptr[targetgpu];
#ifdef UB_TIMEOUT_ENABLED
clock_t s = clock64();
#endif
while (*flag < reduce_id)
{
#ifdef UB_TIMEOUT_ENABLED
if (clock64() - s > 2ull * TIMEOUT)
{
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
#endif
}
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Hopper) MC
#else
template <typename DType, int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(int const op, int const flagoffset,
int const firstrank, int const myrank, int const gpustep, size_t const lineoffset, int const numlines,
void** commbuff, int const handleridx, float4* mc_ptr)
{
printf("userbuffer based kernels not implemented when SM < 90\n");
asm volatile("brkpt;\n");
}
#endif
#define callranks(x) \
if (ar_nvsize == x) \
{ \
int arg1 = op - MAX_OPS, \
arg2 = REG0_OFFSET(comm) - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * REG0_SINGLENODE + MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step; \
size_t arg6 = offset / 8; \
int arg7 = elements / 8; \
void** arg8 = (void**) (comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void* kernelArgs[] \
= {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), reinterpret_cast<void*>(&arg3), \
reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), reinterpret_cast<void*>(&arg6), \
reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), reinterpret_cast<void*>(&arg9)}; \
TLLM_CUDA_CHECK(cudaLaunchKernelExC(&cfg, \
(void*) (comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr<DType, x> \
: userbuffers_fp16_sum_inplace_gpu_rw<DType, x>), \
kernelArgs)); \
}
#define callranksMC(x) \
if (ar_nvsize == x) \
{ \
int arg1 = op - MAX_OPS, \
arg2 = REG0_OFFSET(comm) - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * REG0_SINGLENODE + MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step; \
size_t arg6 = offset / 8; \
int arg7 = elements / 8; \
void** arg8 = (void**) (comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void* arg10 = comm->mc_ptr[handler]; \
void* kernelArgs[] = {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), \
reinterpret_cast<void*>(&arg3), reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), \
reinterpret_cast<void*>(&arg6), reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), \
reinterpret_cast<void*>(&arg9), reinterpret_cast<void*>(&arg10)}; \
TLLM_CUDA_CHECK( \
cudaLaunchKernelExC(&cfg, (void*) (userbuffers_fp16_sum_inplace_gpu_mc<DType, x>), kernelArgs)); \
}
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[3]; \
attribute_ub[2].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[2].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[2].val.clusterDim.y = 1; \
attribute_ub[2].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
attribute_ub[1].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
attribute_ub[1].val.programmaticStreamSerializationAllowed = comm->pdl_launch; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = comm->sm_arch >= 9 ? 3 : 1;
template <typename DType>
__inline__ __device__ float compute_rmsnorm2(float val, float s_variance, DType const* gamma, DType const* beta, int i)
{
float ret = val * s_variance * (float) (gamma[i]);
if (beta != nullptr)
{
ret = ret + (float) (beta[i]);
}
return ret;
}
#define shard_tokens(ntokens, nranks, myrank) \
int first_token = 0, my_tokens; \
{ \
int remapped_rank = myrank; \
my_tokens = ntokens / nranks; \
int extra_tokens = ntokens % nranks; \
first_token = remapped_rank * my_tokens; \
first_token += remapped_rank < extra_tokens ? remapped_rank : extra_tokens; \
if (remapped_rank < extra_tokens) \
my_tokens++; \
}
#if __CUDA_ARCH__ >= 900
template <typename DType, int UNROLL_NLINES>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant(int const op,
int const flagoffset, int const firstrank, int const myrank, int const gpustep, size_t const lineoffset,
int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma,
float const eps, int const RANKS, float2* mc_ptr_out, size_t const out_lineoffset, float const* scale,
uint4* residual_in, uint4* residual_out, int res_offset)
{
float const sf = 1.f / (*scale);
__shared__ float s_variance;
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
int const blockflagoffset = MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int*>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - MAX_OPS;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &(myptr[targetgpu]);
#ifdef UB_TIMEOUT_ENABLED
clock_t s = clock64();
#endif
while (*flag < reduce_id)
{
#ifdef UB_TIMEOUT_ENABLED
if (clock64() - s > TIMEOUT)
{
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
#endif
}
reduce_id++;
}
__syncthreads();
int const loop_step0 = blockDim.x;
int const loop_step = loop_step0 * UNROLL_NLINES * gridDim.x;
int const start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL_NLINES;
int const end_elem = max(start_elem, numlines);
for (int line = start_elem; line < end_elem; line += loop_step)
{
uint4 val[UNROLL_NLINES];
DType* x = reinterpret_cast<DType*>(&val[0]);
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
MULTIMEM_LD<DType>(val[i], mc_ptr + (lineoffset + line + i * loop_step0));
if (residual_in != nullptr)
{
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
{
uint4 resval = residual_in[res_offset + line + i * loop_step0];
DType* y = reinterpret_cast<DType*>(&resval);
#pragma unroll
for (int j = 0; j < 8; j++)
x[i * 8 + j] += y[j];
residual_out[res_offset + line + i * loop_step0] = val[i];
}
}
float local_var_sum = 0.0f;
for (int j = 0; j < UNROLL_NLINES * sizeof(int4) / sizeof(DType); j++)
local_var_sum += (float) (x[j]) * (float) (x[j]);
float packed[1] = {local_var_sum};
blockReduceSumV2<float, 1>(packed);
float variance = packed[0];
if (threadIdx.x == 0)
{
variance = (variance / hidden_dim); // Var[x] = E[x²]
s_variance = rsqrtf(variance + eps);
}
__syncthreads();
int i = 0;
uint2 valout;
__nv_fp8_e4m3* y = reinterpret_cast<__nv_fp8_e4m3*>(&valout);
#pragma unroll
for (int g = 0; g < UNROLL_NLINES; g++)
{
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(DType); j++)
{
y[j] = cuda_cast<__nv_fp8_e4m3>(sf
* compute_rmsnorm2<DType>((float) x[i], s_variance, gamma, beta,
(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(DType) + j));
i++;
}
MULTIMEM_ST2(valout, mc_ptr_out + (out_lineoffset + line + g * loop_step0));
}
}
__syncthreads();
if (threadIdx.x == 0)
__threadfence_system();
__syncthreads();
if (threadIdx.x < RANKS)
{
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &myptr[targetgpu];
#ifdef UB_TIMEOUT_ENABLED
clock_t s = clock64();
#endif
while (*flag < reduce_id)
{
#ifdef UB_TIMEOUT_ENABLED
if (clock64() - s > 2ull * TIMEOUT)
{
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
#endif
}
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // quant kernel fp16->fp8 twoshot
template <typename DType, int UNROLL_NLINES>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant_oneshot(int const op,
int const flagoffset, int const firstrank, int const myrank, int const gpustep, size_t const lineoffset,
int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma,
float const eps, int const RANKS, uint2* mc_ptr_out, size_t const out_lineoffset, float const* scale,
uint4* residual_in, uint4* residual_out, int res_offset)
{
float const sf = 1.f / (*scale);
__shared__ float s_variance;
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
int const blockflagoffset = MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int*>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - MAX_OPS;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &(myptr[targetgpu]);
#ifdef UB_TIMEOUT_ENABLED
clock_t s = clock64();
#endif
while (*flag < reduce_id)
{
#ifdef UB_TIMEOUT_ENABLED
if (clock64() - s > TIMEOUT)
{
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
#endif
}
}
__syncthreads();
int const loop_step0 = blockDim.x;
int const loop_step = loop_step0 * UNROLL_NLINES * gridDim.x;
int const start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL_NLINES;
int const end_elem = max(start_elem, numlines);
for (int line = start_elem; line < end_elem; line += loop_step)
{
uint4 val[UNROLL_NLINES];
DType* x = reinterpret_cast<DType*>(&val[0]);
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
MULTIMEM_LD<DType>(val[i], mc_ptr + (lineoffset + line + i * loop_step0));
if (residual_in != nullptr)
{
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
{
uint4 resval = residual_in[res_offset + line + i * loop_step0];
DType* y = reinterpret_cast<DType*>(&resval);
#pragma unroll
for (int j = 0; j < 8; j++)
x[i * 8 + j] += y[j];
residual_out[res_offset + line + i * loop_step0] = val[i];
}
}
float local_var_sum = 0.0f;
for (int j = 0; j < UNROLL_NLINES * sizeof(int4) / sizeof(DType); j++)
local_var_sum += (float) (x[j]) * (float) (x[j]);
float packed[1] = {local_var_sum};
blockReduceSumV2<float, 1>(packed);
float variance = packed[0];
if (threadIdx.x == 0)
{
variance = (variance / hidden_dim); // Var[x] = E[x²]
s_variance = rsqrtf(variance + eps);
}
__syncthreads();
int i = 0;
uint2 valout;
__nv_fp8_e4m3* y = reinterpret_cast<__nv_fp8_e4m3*>(&valout);
#pragma unroll
for (int g = 0; g < UNROLL_NLINES; g++)
{
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(DType); j++)
{
y[j] = cuda_cast<__nv_fp8_e4m3>(sf
* compute_rmsnorm2<DType>((float) x[i], s_variance, gamma, beta,
(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(DType) + j));
i++;
}
mc_ptr_out[out_lineoffset + line + g * loop_step0] = valout;
}
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // quant kernel fp16->fp8 oneshot
inline __device__ void load128(uint4 const* ptr, uint4& val)
{
uint64_t* v = (uint64_t*) &val;
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v[0]), "=l"(v[1]) : "l"(ptr));
}
template <typename DType, int UNROLL_NLINES>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant_oneshot_lamport(
int const myrank, uint4* ptr_in, int const numlines, int* reduceidptr, uint4* buff_ptr, float4* mc_ptr,
DType const* beta, DType const* gamma, float const eps, int const RANKS, uint2* ptr_out,
size_t const out_lineoffset, float const* scale, uint4* residual_in, uint4* residual_out)
{
__shared__ int reduce_id;
if (threadIdx.x == 0)
{
reduce_id = (*reduceidptr) + 1;
if (blockIdx.x == 0)
*reduceidptr = reduce_id;
}
__syncthreads();
int const loop_step0 = blockDim.x;
int const loop_step = loop_step0 * UNROLL_NLINES * gridDim.x;
int const start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL_NLINES;
int const end_elem = max(start_elem, numlines);
for (int line = start_elem; line < end_elem; line += loop_step)
{
uint4 val[UNROLL_NLINES];
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
{
val[i].x = reduce_id;
val[i].y = reduce_id;
val[i].z = reduce_id;
val[i].w = reduce_id;
}
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
MULTIMEM_ST(val[i], mc_ptr + (line + i * loop_step0 + myrank * numlines));
}
for (int line = start_elem; line < end_elem; line += loop_step)
{
uint4 val[UNROLL_NLINES];
{
bool readAgain;
do
{
readAgain = false;
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
{
load128(buff_ptr + (line + i * loop_step0), val[i]);
readAgain |= ((threadIdx.x % 8) == 7) && (val[i].w != reduce_id);
}
} while (__any_sync(0xffffffff, readAgain));
}
}
} // quant kernel fp16->fp8 oneshot(LL style)
template <typename DType, int UNROLL_NLINES>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc_res_allgather(int const op, int const flagoffset, int const firstrank,
int const myrank, int const gpustep, size_t const lineoffset, int const numlines, void** commbuff,
int const handleridx, float4* mc_ptr, int const RANKS, uint4* residual_in, int res_offset)
{
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
int const blockflagoffset = MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int*>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - MAX_OPS;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
}
__syncthreads();
int const loop_step0 = blockDim.x;
int const loop_step = loop_step0 * UNROLL_NLINES * gridDim.x;
int const start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL_NLINES;
int const end_elem = max(start_elem, numlines);
for (int line = start_elem; line < end_elem; line += loop_step)
{
uint4 val[UNROLL_NLINES];
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
val[i] = residual_in[res_offset + line + i * loop_step0];
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
MULTIMEM_ST(val[i], mc_ptr + (lineoffset + line + i * loop_step0));
}
__syncthreads();
if (threadIdx.x == 0)
__threadfence_system();
__syncthreads();
if (threadIdx.x < RANKS)
{
flagptr[physgpu] = reduce_id;
int volatile* flag = (int volatile*) &myptr[targetgpu];
#ifdef UB_TIMEOUT_ENABLED
clock_t s = clock64();
#endif
while (*flag < reduce_id)
{
#ifdef UB_TIMEOUT_ENABLED
if (clock64() - s > 2ull * TIMEOUT)
{
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag);
break;
}
#endif
}
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // residual allgather kernel
#else
template <typename DType, int UNROLL_NLINES>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant(int const op,
int const flagoffset, int const firstrank, int const myrank, int const gpustep, size_t const lineoffset,
int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma,
float const eps, int const RANKS, float2* mc_ptr_out, size_t const out_lineoffset, float const* scale,
uint4* residual_in, uint4* residual_out, int res_offset)
{
printf("userbuffer based kernels not implemented when SM < 90\n");
asm volatile("brkpt;\n");
}
template <typename DType, int UNROLL_NLINES>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_mc_res_allgather(int const op, int const flagoffset, int const firstrank,
int const myrank, int const gpustep, size_t const lineoffset, int const numlines, void** commbuff,
int const handleridx, float4* mc_ptr, int const RANKS, uint4* residual_in, int res_offset)
{
printf("userbuffer based kernels not implemented when SM < 90\n");
asm volatile("brkpt;\n");
}
template <typename DType, int UNROLL_NLINES>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant_oneshot(int const op,
int const flagoffset, int const firstrank, int const myrank, int const gpustep, size_t const lineoffset,
int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma,
float const eps, int const RANKS, uint2* ptr_out, size_t const out_lineoffset, float const* scale,
uint4* residual_in, uint4* residual_out, int res_offset)
{
printf("userbuffer based kernels not implemented when SM < 90\n");
asm volatile("brkpt;\n");
}
template <typename DType, int UNROLL_NLINES>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant_oneshot_lamport(
int const myrank, uint4* ptr_in, int const numlines, int* reduceidptr, uint4* buff_ptr, float4* mc_ptr,
DType const* beta, DType const* gamma, float const eps, int const RANKS, uint2* ptr_out,
size_t const out_lineoffset, float const* scale, uint4* residual_in, uint4* residual_out)
{
printf("userbuffer based kernels not implemented when SM < 90\n");
asm volatile("brkpt;\n");
}
#endif
#define callranksMC_RMSNORM_QUANT(x) \
if (nlines == x) \
{ \
int arg1 = op - MAX_OPS, \
arg2 = REG0_OFFSET(comm) - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * REG0_SINGLENODE + MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step; \
size_t arg6 = offset / 8 + first_token * hidden_lines; \
int arg7 = hidden_lines * my_tokens; \
void** arg8 = (void**) (comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void* arg10 = comm->mc_ptr[handler]; \
DType* arg11 = (DType*) beta; \
DType* arg12 = (DType*) gamma; \
float arg13 = eps; \
int arg14 = ar_nvsize; \
void* arg15 = comm->mc_ptr[out_handler]; \
size_t arg16 = out_offset / 8 + first_token * hidden_lines; \
float* arg17 = scalefactor; \
void* arg18 = residual_in; \
void* arg19 = residual_out; \
int arg20 = first_token * hidden_lines; \
void* kernelArgs[] = {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), \
reinterpret_cast<void*>(&arg3), reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), \
reinterpret_cast<void*>(&arg6), reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), \
reinterpret_cast<void*>(&arg9), reinterpret_cast<void*>(&arg10), reinterpret_cast<void*>(&arg11), \
reinterpret_cast<void*>(&arg12), reinterpret_cast<void*>(&arg13), reinterpret_cast<void*>(&arg14), \
reinterpret_cast<void*>(&arg15), reinterpret_cast<void*>(&arg16), reinterpret_cast<void*>(&arg17), \
reinterpret_cast<void*>(&arg18), reinterpret_cast<void*>(&arg19), reinterpret_cast<void*>(&arg20)}; \
TLLM_CUDA_CHECK(cudaLaunchKernelExC( \
&cfg, (void*) (userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant<DType, x>), kernelArgs)); \
}
#define callranksMC_RMSNORM_QUANT_ONESHOT(x) \
if (nlines == x) \
{ \
int arg1 = op - MAX_OPS, \
arg2 = REG0_OFFSET(comm) - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * REG0_SINGLENODE + MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step; \
size_t arg6 = offset / 8; \
int arg7 = elements / 8; \
void** arg8 = (void**) (comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void* arg10 = comm->mc_ptr[handler]; \
DType* arg11 = (DType*) beta; \
DType* arg12 = (DType*) gamma; \
float arg13 = eps; \
int arg14 = ar_nvsize; \
void* arg15 = comm->mem_ptr[out_handler]; \
size_t arg16 = out_offset / 8; \
float* arg17 = scalefactor; \
void* arg18 = residual_in; \
void* arg19 = residual_out; \
int arg20 = 0; \
void* kernelArgs[] = {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), \
reinterpret_cast<void*>(&arg3), reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), \
reinterpret_cast<void*>(&arg6), reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), \
reinterpret_cast<void*>(&arg9), reinterpret_cast<void*>(&arg10), reinterpret_cast<void*>(&arg11), \
reinterpret_cast<void*>(&arg12), reinterpret_cast<void*>(&arg13), reinterpret_cast<void*>(&arg14), \
reinterpret_cast<void*>(&arg15), reinterpret_cast<void*>(&arg16), reinterpret_cast<void*>(&arg17), \
reinterpret_cast<void*>(&arg18), reinterpret_cast<void*>(&arg19), reinterpret_cast<void*>(&arg20)}; \
TLLM_CUDA_CHECK(cudaLaunchKernelExC( \
&cfg, (void*) (userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant_oneshot<DType, x>), kernelArgs)); \
}
#define callranksMC_RMSNORM_QUANT_ONESHOT_LL(x) \
if (nlines == x) \
{ \
int arg1 = ar_nvrank; \
void* arg2 = reinterpret_cast<uint8_t*>(comm->mem_ptr[handler]) + (offset * 2); \
int arg3 = elements / 8; \
void* arg4 \
= reinterpret_cast<uint8_t*>(comm->mem_ptr[0]) + (REG0_OFFSET(comm) - REG0_SINGLENODE) * sizeof(int); \
void* arg5 = reinterpret_cast<uint8_t*>(comm->mem_ptr[0]) + sizeof(int) * (REG0_OFFSET(comm) + REG0_FLAGS); \
void* arg6 = reinterpret_cast<uint8_t*>(comm->mc_ptr[0]) + sizeof(int) * (REG0_OFFSET(comm) + REG0_FLAGS); \
DType* arg7 = (DType*) beta; \
DType* arg8 = (DType*) gamma; \
float arg9 = eps; \
int arg10 = ar_nvsize; \
void* arg11 = comm->mem_ptr[out_handler]; \
size_t arg12 = out_offset / 8; \
float* arg13 = scalefactor; \
void* arg14 = residual_in; \
void* arg15 = residual_out; \
void* kernelArgs[] \
= {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), reinterpret_cast<void*>(&arg3), \
reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), reinterpret_cast<void*>(&arg6), \
reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), reinterpret_cast<void*>(&arg9), \
reinterpret_cast<void*>(&arg10), reinterpret_cast<void*>(&arg11), reinterpret_cast<void*>(&arg12), \
reinterpret_cast<void*>(&arg13), reinterpret_cast<void*>(&arg14), reinterpret_cast<void*>(&arg15)}; \
TLLM_CUDA_CHECK(cudaLaunchKernelExC( \
&cfg, (void*) (userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant_oneshot_lamport<DType, x>), kernelArgs)); \
}
#define callranksMC_RES_AG(x) \
if (nlines == x) \
{ \
int arg1 = op - MAX_OPS, \
arg2 = REG0_OFFSET(comm) - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * REG0_SINGLENODE + MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step; \
size_t arg6 = offset / 8 + first_token * hidden_lines; \
int arg7 = hidden_lines * my_tokens; \
void** arg8 = (void**) (comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void* arg10 = comm->mc_ptr[handler]; \
int arg11 = ar_nvsize; \
uint4* arg12 = (uint4*) residual_in; \
int arg13 = first_token * hidden_lines; \
void* kernelArgs[] = {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), \
reinterpret_cast<void*>(&arg3), reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), \
reinterpret_cast<void*>(&arg6), reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), \
reinterpret_cast<void*>(&arg9), reinterpret_cast<void*>(&arg10), reinterpret_cast<void*>(&arg11), \
reinterpret_cast<void*>(&arg12), reinterpret_cast<void*>(&arg13)}; \
TLLM_CUDA_CHECK(cudaLaunchKernelExC( \
&cfg, (void*) (userbuffers_fp16_sum_inplace_gpu_mc_res_allgather<DType, x>), kernelArgs)); \
}
template <typename DType>
int allreduce2_userbuff_inplace_gpu(int const maxcredit, int const handler, size_t const offset, size_t const elements,
int const blocksize, communicator* comm, cudaStream_t stream, int op)
{
// schedule GPU kernel only
// CPU/SHARP part is responsibility of caller
int const ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
int const ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
int const ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int const ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 8)
return 0;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize)
warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (op == userbuffers_allreduceop_nonsharp2 && comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED))
{
callranksMC(2) callranksMC(4) callranksMC(8)
#ifdef MNNVL
callranksMC(16) callranksMC(32)
#endif
}
else
{
callranks(2) callranks(4) callranks(8)
#ifdef MNNVL
callranks(16) callranks(32)
#endif
}
return sms;
}
template <typename DType>
void allreduce_nonsharp_inplace(
int const handler, size_t const offset, size_t const elements, communicator* comm, cudaStream_t stream, int op)
{
if (elements < 64)
return;
int blocksize = elements * 2;
int maxcredit = 0;
int sms = allreduce2_userbuff_inplace_gpu<DType>(maxcredit, handler, offset, elements, blocksize, comm, stream, op);
}
template <typename DType>
void allreduce2_userbuff_inplace(
int const handler, size_t const offset, size_t const elements, communicator* comm, cudaStream_t stream)
{
allreduce_nonsharp_inplace<DType>(handler, offset, elements, comm, stream, userbuffers_allreduceop_nonsharp2);
}
template <typename DType>
int allreduce2_userbuff_inplace_rmsnorm_quant(int const handler, size_t const offset, int const out_handler,
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
float* scalefactor, void* residual_in, void* residual_out, communicator* comm, cudaStream_t stream)
{
// schedule GPU kernel only
// CPU/SHARP part is not supported yet;
int op = userbuffers_allreduceop_nonsharp2;
int const ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
int const ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
int const ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int const ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements % hidden_size)
return 0;
assert(hidden_size % 8 == 0);
int hidden_lines = hidden_size / 8;
shard_tokens(elements / hidden_size, ar_nvsize, ar_nvrank);
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int nthreads = hidden_size / 8;
int nlines = 1;
while (nthreads > 1024)
{
nlines++;
assert(nlines <= 4);
if ((hidden_size / 8) % nlines == 0)
nthreads = ((hidden_size / 8)) / nlines;
}
SETUP_LAUNCH_CONFIG(sms, nthreads, stream);
if (op == userbuffers_allreduceop_nonsharp2 && comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED))
{
if (comm->oneshot > 1 || (comm->oneshot == 1 && (elements * ar_nvsize <= 131072)))
{
if (comm->oneshot < 3)
{
callranksMC_RMSNORM_QUANT_ONESHOT(1) callranksMC_RMSNORM_QUANT_ONESHOT(2)
callranksMC_RMSNORM_QUANT_ONESHOT(3) callranksMC_RMSNORM_QUANT_ONESHOT(4)
}
else
{
sms = 1;
callranksMC_RMSNORM_QUANT_ONESHOT_LL(1) callranksMC_RMSNORM_QUANT_ONESHOT_LL(2)
}
}
else
{
callranksMC_RMSNORM_QUANT(1) callranksMC_RMSNORM_QUANT(2) callranksMC_RMSNORM_QUANT(3)
callranksMC_RMSNORM_QUANT(4)
}
}
else
{
assert(0);
}
return sms;
}
template <typename DType>
int allgather2_userbuff_residual(int const handler, size_t const offset, size_t const elements, int const hidden_size,
void* residual_in, communicator* comm, cudaStream_t stream)
{
// schedule GPU kernel only
// CPU/SHARP part is not supported yet;
if (comm->oneshot == 2 || (comm->oneshot == 1 && (elements * comm->ar2_nvsize <= 131072)))
{
TLLM_CUDA_CHECK(cudaMemcpyAsync(reinterpret_cast<uint8_t*>(comm->mem_ptr[handler]) + (offset * 2), residual_in,
elements * 2, cudaMemcpyDeviceToDevice, stream));
return 0;
}
int op = userbuffers_allreduceop_nonsharp2;
int const ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
int const ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
int const ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int const ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements % hidden_size)
return 0;
assert(hidden_size % 8 == 0);
int hidden_lines = hidden_size / 8;
shard_tokens(elements / hidden_size, ar_nvsize, ar_nvrank);
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int nthreads = hidden_size / 8;
int nlines = 1;
while (nthreads > 1024)
{
nlines++;
assert(nlines <= 4);
if ((hidden_size / 8) % nlines == 0)
nthreads = ((hidden_size / 8)) / nlines;
}
SETUP_LAUNCH_CONFIG(sms, nthreads, stream);
if (op == userbuffers_allreduceop_nonsharp2 && comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED))
{
callranksMC_RES_AG(1) callranksMC_RES_AG(2) callranksMC_RES_AG(3) callranksMC_RES_AG(4)
}
else
{
assert(0);
}
return sms;
}
void allreduce2_userbuff_inplace_impl(int const handler, size_t const offset, size_t const elements,
nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
{
switch (dataType)
{
case nvinfer1::DataType::kHALF: allreduce2_userbuff_inplace<half>(handler, offset, elements, comm, stream); break;
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16:
allreduce2_userbuff_inplace<__nv_bfloat16>(handler, offset, elements, comm, stream);
break;
#endif
default: TLLM_THROW("Unsupported dataType for allreduce2_userbuff_inplace_impl");
}
}
int allgather2_userbuff_residual_impl(int const handler, size_t const offset, size_t const elements,
int const hidden_size, void* residual, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
{
switch (dataType)
{
case nvinfer1::DataType::kHALF:
return allgather2_userbuff_residual<half>(handler, offset, elements, hidden_size, residual, comm, stream);
break;
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16:
return allgather2_userbuff_residual<__nv_bfloat16>(
handler, offset, elements, hidden_size, residual, comm, stream);
break;
#endif
default: TLLM_THROW("Unsupported dataType for allgather2_userbuff_residual_impl");
}
}
int allreduce2_userbuff_inplace_rmsnorm_quant_impl(int const handler, size_t const offset, int const out_handler,
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
float* scalefactor, void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm,
cudaStream_t stream)
{
switch (dataType)
{
case nvinfer1::DataType::kHALF:
return allreduce2_userbuff_inplace_rmsnorm_quant<half>(handler, offset, out_handler, out_offset, elements,
hidden_size, beta, gamma, eps, scalefactor, residual_in, residual_out, comm, stream);
break;
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16:
return allreduce2_userbuff_inplace_rmsnorm_quant<__nv_bfloat16>(handler, offset, out_handler, out_offset,
elements, hidden_size, beta, gamma, eps, scalefactor, residual_in, residual_out, comm, stream);
break;
#endif
default: TLLM_THROW("Unsupported dataType for allreduce2_userbuff_inplace_rmsnorm_quant_impl");
}
}
} // namespace tensorrt_llm::kernels::ub