None - Add one-shot version for UB AR NORM FP16/BF16 (#2995)

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
liji-nv 2025-03-31 11:16:03 +08:00 committed by GitHub
parent 794f61c997
commit e0d0dde058
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 190 additions and 46 deletions

View File

@ -89,12 +89,12 @@ int allgather2_userbuff_residual_launcher(int const handler, size_t const offset
handler, offset, elements, hidden_size, residual, dataType, comm, stream, force_enable);
}
int allreduce2_userbuff_inplace_rmsnorm_launcher(int const handler, size_t const offset, size_t const elements,
int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out,
nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
int allreduce2_userbuff_rmsnorm_launcher(int const handler, size_t const offset, int const out_handler,
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
{
return allreduce2_userbuff_inplace_rmsnorm_impl(
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, dataType, comm, stream);
return allreduce2_userbuff_rmsnorm_impl(handler, offset, out_handler, out_offset, elements, hidden_size, beta,
gamma, eps, residual_in, residual_out, dataType, comm, stream);
}
int allreduce2_userbuff_inplace_rmsnorm_quant_launcher(int const handler, size_t const offset, int const out_handler,

View File

@ -41,9 +41,9 @@ int allgather2_userbuff_residual_launcher(int const handler, size_t const offset
int const hidden_size, void* residual, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream,
bool force_enable = false);
int allreduce2_userbuff_inplace_rmsnorm_launcher(int const handler, size_t const offset, size_t const elements,
int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out,
nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream);
int allreduce2_userbuff_rmsnorm_launcher(int const handler, size_t const offset, int const out_handler,
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream);
int allreduce2_userbuff_inplace_rmsnorm_quant_launcher(int const handler, size_t const offset, int const out_handler,
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,

View File

@ -773,10 +773,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
#if __CUDA_ARCH__ >= 900
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm(int const op,
int const flagoffset, int const firstrank, int const myrank, int const gpustep, const size_t lineoffset,
int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma,
float const eps, int const RANKS, uint4* residual_in, uint4* residual_out, int res_offset)
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_gpu_mc_rmsnorm(int const op, int const flagoffset, int const firstrank, int const myrank,
int const gpustep, const size_t lineoffset, int const numlines, void** commbuff, int const handleridx,
float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, float4* mc_ptr_out,
size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
{
cudaTriggerProgrammaticLaunchCompletion();
__shared__ float s_variance;
@ -853,7 +854,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(DType) + j));
i++;
}
MULTIMEM_ST(val[g], mc_ptr + (lineoffset + line + g * loop_step0));
MULTIMEM_ST(val[g], mc_ptr_out + (out_lineoffset + line + g * loop_step0));
}
}
__syncthreads();
@ -870,6 +871,94 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
*reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_gpu_mc_rmsnorm_oneshot(int const op, int const flagoffset, int const firstrank,
int const myrank, int const gpustep, const size_t lineoffset, int const numlines, void** commbuff,
int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS,
uint4* uc_ptr_out, size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
{
cudaTriggerProgrammaticLaunchCompletion();
__shared__ float s_variance;
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
int const blockflagoffset = MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int*>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - MAX_OPS; //+op;
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
cudaGridDependencySynchronize();
flagptr[physgpu] = reduce_id;
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
}
__syncthreads();
int const loop_step0 = blockDim.x;
int const loop_step = loop_step0 * UNROLL_NLINES * gridDim.x;
int const start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL_NLINES;
int const end_elem = max(start_elem, numlines);
for (int line = start_elem; line < end_elem; line += loop_step)
{
uint4 val[UNROLL_NLINES];
DType* x = reinterpret_cast<DType*>(&val[0]);
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
MULTIMEM_LD<DType, DISABLE_FP32_ACC>(val[i], mc_ptr + (lineoffset + line + i * loop_step0));
if (residual_in != nullptr)
{
#pragma unroll
for (int i = 0; i < UNROLL_NLINES; i++)
{
uint4 resval = residual_in[res_offset + line + i * loop_step0];
DType* y = reinterpret_cast<DType*>(&resval);
#pragma unroll
for (int j = 0; j < 8; j++)
x[i * 8 + j] += y[j];
residual_out[res_offset + line + i * loop_step0] = val[i];
}
}
float local_var_sum = 0.0f;
for (int j = 0; j < UNROLL_NLINES * sizeof(int4) / sizeof(DType); j++)
local_var_sum += (float) (x[j]) * (float) (x[j]);
float packed[1] = {local_var_sum};
blockReduceSumV2<float, 1>(packed);
float variance = packed[0];
if (threadIdx.x == 0)
{
variance = (variance / hidden_dim); // Var[x] = E[x²]
s_variance = rsqrtf(variance + eps);
}
__syncthreads();
int i = 0;
#pragma unroll
for (int g = 0; g < UNROLL_NLINES; g++)
{
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(DType); j++)
{
x[i] = cuda_cast<DType>(compute_rmsnorm2<DType>((float) (x[i]), s_variance, gamma, beta,
(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(DType) + j));
i++;
}
uc_ptr_out[out_lineoffset + line + g * loop_step0] = val[g];
}
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused oneshot
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant(int const op,
int const flagoffset, int const firstrank, int const myrank, int const gpustep, size_t const lineoffset,
@ -1127,10 +1216,22 @@ __global__ void __launch_bounds__(MAX_THREADS)
#else
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm(int const op,
int const flagoffset, int const firstrank, int const myrank, int const gpustep, const size_t lineoffset,
int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma,
float const eps, int const RANKS, uint4* residual_in, uint4* residual_out, int res_offset)
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_gpu_mc_rmsnorm(int const op, int const flagoffset, int const firstrank, int const myrank,
int const gpustep, const size_t lineoffset, int const numlines, void** commbuff, int const handleridx,
float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, float4* uc_ptr_out,
size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
{
printf("userbuffer based kernels not implemented when SM < 90\n");
asm volatile("brkpt;\n");
}
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_gpu_mc_rmsnorm_oneshot(int const op, int const flagoffset, int const firstrank,
int const myrank, int const gpustep, const size_t lineoffset, int const numlines, void** commbuff,
int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS,
uint4* uc_ptr_out, size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
{
printf("userbuffer based kernels not implemented when SM < 90\n");
asm volatile("brkpt;\n");
@ -1339,17 +1440,50 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
DType* arg12 = (DType*) gamma; \
float arg13 = eps; \
int arg14 = ar_nvsize; \
void* arg15 = residual_in; \
void* arg16 = residual_out; \
int arg17 = first_token * hidden_lines; \
void* arg15 = comm->mc_ptr[out_handler]; \
size_t arg16 = out_offset / 8 + first_token * hidden_lines; \
void* arg17 = residual_in; \
void* arg18 = residual_out; \
int arg19 = first_token * hidden_lines; \
void* kernelArgs[] = {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), \
reinterpret_cast<void*>(&arg3), reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), \
reinterpret_cast<void*>(&arg6), reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), \
reinterpret_cast<void*>(&arg9), reinterpret_cast<void*>(&arg10), reinterpret_cast<void*>(&arg11), \
reinterpret_cast<void*>(&arg12), reinterpret_cast<void*>(&arg13), reinterpret_cast<void*>(&arg14), \
reinterpret_cast<void*>(&arg15), reinterpret_cast<void*>(&arg16), reinterpret_cast<void*>(&arg17)}; \
reinterpret_cast<void*>(&arg15), reinterpret_cast<void*>(&arg16), reinterpret_cast<void*>(&arg17), \
reinterpret_cast<void*>(&arg18), reinterpret_cast<void*>(&arg19)}; \
TLLM_CUDA_CHECK(cudaLaunchKernelExC( \
&cfg, (void*) (userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm<DType, x, DISABLE_FP32_ACC>), kernelArgs)); \
&cfg, (void*) (userbuffers_fp16_sum_gpu_mc_rmsnorm<DType, x, DISABLE_FP32_ACC>), kernelArgs)); \
}
#define callranksMC_RMSNORM_ONESHOT(x) \
if (nlines == x) \
{ \
int arg1 = userbuffers_allreduceop_nonsharp2 - MAX_OPS, arg2 = REG0_OFFSET(comm) - REG0_SINGLENODE + MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step; \
size_t arg6 = offset / 8; \
int arg7 = elements / 8; \
void** arg8 = (void**) (comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void* arg10 = comm->mc_ptr[handler]; \
DType* arg11 = (DType*) beta; \
DType* arg12 = (DType*) gamma; \
float arg13 = eps; \
int arg14 = ar_nvsize; \
void* arg15 = comm->mem_ptr[out_handler]; \
size_t arg16 = out_offset / 8; \
void* arg17 = residual_in; \
void* arg18 = residual_out; \
int arg19 = 0; \
void* kernelArgs[] = {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), \
reinterpret_cast<void*>(&arg3), reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), \
reinterpret_cast<void*>(&arg6), reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), \
reinterpret_cast<void*>(&arg9), reinterpret_cast<void*>(&arg10), reinterpret_cast<void*>(&arg11), \
reinterpret_cast<void*>(&arg12), reinterpret_cast<void*>(&arg13), reinterpret_cast<void*>(&arg14), \
reinterpret_cast<void*>(&arg15), reinterpret_cast<void*>(&arg16), reinterpret_cast<void*>(&arg17), \
reinterpret_cast<void*>(&arg18), reinterpret_cast<void*>(&arg19)}; \
TLLM_CUDA_CHECK(cudaLaunchKernelExC( \
&cfg, (void*) (userbuffers_fp16_sum_gpu_mc_rmsnorm_oneshot<DType, x, DISABLE_FP32_ACC>), kernelArgs)); \
}
template <typename DType, bool DISABLE_FP32_ACC>
@ -1436,8 +1570,9 @@ bool use_oneshot_kernel(communicator* comm, size_t elements, int hidden_size)
}
template <typename DType, bool DISABLE_FP32_ACC>
int allreduce2_userbuff_inplace_rmsnorm(int const handler, int const offset, int const elements, int const hidden_size,
void* beta, void* gamma, float eps, void* residual_in, void* residual_out, communicator* comm, cudaStream_t stream)
int allreduce2_userbuff_rmsnorm(int const handler, int const offset, int const out_handler, size_t const out_offset,
int const elements, int const hidden_size, void* beta, void* gamma, float eps, void* residual_in,
void* residual_out, communicator* comm, cudaStream_t stream)
{
int const ar_firstgpu = comm->tp_first_rank;
int const ar_step = 1;
@ -1465,7 +1600,15 @@ int allreduce2_userbuff_inplace_rmsnorm(int const handler, int const offset, int
auto& cfg = launch_config.get();
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED))
{
callranksMC_RMSNORM(1) callranksMC_RMSNORM(2) callranksMC_RMSNORM(3) callranksMC_RMSNORM(4)
if (use_oneshot_kernel(comm, elements, hidden_size))
{
callranksMC_RMSNORM_ONESHOT(1) callranksMC_RMSNORM_ONESHOT(2) callranksMC_RMSNORM_ONESHOT(3)
callranksMC_RMSNORM_ONESHOT(4)
}
else
{
callranksMC_RMSNORM(1) callranksMC_RMSNORM(2) callranksMC_RMSNORM(3) callranksMC_RMSNORM(4)
}
}
else
{
@ -1678,9 +1821,9 @@ int allgather2_userbuff_residual_impl(int const handler, size_t const offset, si
}
}
int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const offset, size_t const elements,
int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out,
nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
int allreduce2_userbuff_rmsnorm_impl(int const handler, size_t const offset, int const out_handler,
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
{
switch (dataType)
{
@ -1688,13 +1831,13 @@ int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const off
{
if (kDISABLE_FP32_ACCUMULATION)
{
return allreduce2_userbuff_inplace_rmsnorm<half, true>(
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
return allreduce2_userbuff_rmsnorm<half, true>(handler, offset, out_handler, out_offset, elements,
hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
}
else
{
return allreduce2_userbuff_inplace_rmsnorm<half, false>(
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
return allreduce2_userbuff_rmsnorm<half, false>(handler, offset, out_handler, out_offset, elements,
hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
}
break;
}
@ -1703,18 +1846,18 @@ int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const off
{
if (kDISABLE_FP32_ACCUMULATION)
{
return allreduce2_userbuff_inplace_rmsnorm<__nv_bfloat16, true>(
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
return allreduce2_userbuff_rmsnorm<__nv_bfloat16, true>(handler, offset, out_handler, out_offset, elements,
hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
}
else
{
return allreduce2_userbuff_inplace_rmsnorm<__nv_bfloat16, false>(
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
return allreduce2_userbuff_rmsnorm<__nv_bfloat16, false>(handler, offset, out_handler, out_offset, elements,
hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
}
break;
}
#endif
default: TLLM_THROW("Unsupported dataType for allreduce2_userbuff_inplace_rmsnorm_impl");
default: TLLM_THROW("Unsupported dataType for allreduce2_userbuff_rmsnorm_impl");
}
}

View File

@ -125,9 +125,9 @@ int allgather2_userbuff_residual_impl(int const handler, size_t const offset, si
int const hidden_size, void* residual, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream,
bool force_enable);
int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const offset, size_t const elements,
int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out,
nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream);
int allreduce2_userbuff_rmsnorm_impl(int const handler, size_t const offset, int const out_handler,
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream);
int allreduce2_userbuff_inplace_rmsnorm_quant_impl(int const handler, size_t const offset, int const out_handler,
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,

View File

@ -261,11 +261,11 @@ public:
TLLM_CHECK(mAffine);
TLLM_CHECK(!mBias);
TLLM_CHECK(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16);
tensorrt_llm::kernels::ub::allreduce2_userbuff_inplace_rmsnorm_launcher(ub_buffer0.handle, 0, size,
hidden_size, nullptr, gamma, mEps, residual, output.data_ptr(), mType, ub_comm, stream);
tensorrt_llm::kernels::ub::allreduce2_userbuff_rmsnorm_launcher(ub_buffer0.handle, 0, ub_buffer1.handle,
0, size, hidden_size, nullptr, gamma, mEps, residual, output.data_ptr(), mType, ub_comm, stream);
auto dt = input.scalar_type();
finalOutput = torch::from_blob(
ub_buffer0.addr, input.sizes(), input.strides(), torch::dtype(dt).device(torch::kCUDA));
ub_buffer1.addr, input.sizes(), input.strides(), torch::dtype(dt).device(torch::kCUDA));
}
}
else if (runtimeStrategy == AllReduceStrategyType::NCCL)

View File

@ -132,7 +132,7 @@ def register_ub_allreduce_finalize(custom_pass: PatternMatcherPass):
)
trtllm_userbuffers_allreduce_finalize_default = CallFunction(
torch.ops.trtllm.userbuffers_allreduce_finalize.default,
KeywordArg("sharded_residual"), Ignored())
KeywordArg("sharded_residual"), False)
trtllm_ub_scaled_mm_allreduce_quant_scaled_mm_op_default = CallFunction(
torch.ops.trtllm.ub_scaled_mm_allreduce_quant_scaled_mm_op.default,
KeywordArg("mm0_a"),

View File

@ -142,8 +142,9 @@ def run_single_rank_ar_rms_norm(tensor_parallel_size, a, b, c, gamma):
residual=c,
norm_weight=gamma,
eps=eps)
res, residual = ar.forward(hidden, all_reduce_params=ar_params)
residual = userbuffers_allreduce_finalize(residual, True)
res_ub, residual = ar.forward(hidden, all_reduce_params=ar_params)
res = res_ub.clone()
residual = userbuffers_allreduce_finalize(residual, False)
torch.cuda.synchronize()
if rank == 0: