From e0d0dde0580c7dd0f90dff180bd1f3004a511d96 Mon Sep 17 00:00:00 2001 From: liji-nv <59594262+liji-nv@users.noreply.github.com> Date: Mon, 31 Mar 2025 11:16:03 +0800 Subject: [PATCH] None - Add one-shot version for UB AR NORM FP16/BF16 (#2995) Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- .../kernels/userbuffers/ub_interface.cpp | 10 +- .../kernels/userbuffers/ub_interface.h | 6 +- .../kernels/userbuffers/userbuffers.cu | 201 +++++++++++++++--- .../kernels/userbuffers/userbuffers.h | 6 +- cpp/tensorrt_llm/thop/allreduceOp.cpp | 6 +- .../compilation/patterns/ub_allreduce.py | 2 +- .../_torch/multi_gpu/test_user_buffers.py | 5 +- 7 files changed, 190 insertions(+), 46 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp b/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp index d0aa53e28e..6415454014 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp @@ -89,12 +89,12 @@ int allgather2_userbuff_residual_launcher(int const handler, size_t const offset handler, offset, elements, hidden_size, residual, dataType, comm, stream, force_enable); } -int allreduce2_userbuff_inplace_rmsnorm_launcher(int const handler, size_t const offset, size_t const elements, - int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out, - nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream) +int allreduce2_userbuff_rmsnorm_launcher(int const handler, size_t const offset, int const out_handler, + size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps, + void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream) { - return allreduce2_userbuff_inplace_rmsnorm_impl( - handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, dataType, comm, stream); + return allreduce2_userbuff_rmsnorm_impl(handler, offset, out_handler, out_offset, elements, hidden_size, beta, + gamma, eps, residual_in, residual_out, dataType, comm, stream); } int allreduce2_userbuff_inplace_rmsnorm_quant_launcher(int const handler, size_t const offset, int const out_handler, diff --git a/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.h b/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.h index 077124e4a5..64f8153f13 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.h +++ b/cpp/tensorrt_llm/kernels/userbuffers/ub_interface.h @@ -41,9 +41,9 @@ int allgather2_userbuff_residual_launcher(int const handler, size_t const offset int const hidden_size, void* residual, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream, bool force_enable = false); -int allreduce2_userbuff_inplace_rmsnorm_launcher(int const handler, size_t const offset, size_t const elements, - int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out, - nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream); +int allreduce2_userbuff_rmsnorm_launcher(int const handler, size_t const offset, int const out_handler, + size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps, + void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream); int allreduce2_userbuff_inplace_rmsnorm_quant_launcher(int const handler, size_t const offset, int const out_handler, size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps, diff --git a/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu b/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu index cb16eeb284..9eeba5ab9e 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu +++ b/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu @@ -773,10 +773,11 @@ __global__ void __launch_bounds__(MAX_THREADS) #if __CUDA_ARCH__ >= 900 template -__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm(int const op, - int const flagoffset, int const firstrank, int const myrank, int const gpustep, const size_t lineoffset, - int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, - float const eps, int const RANKS, uint4* residual_in, uint4* residual_out, int res_offset) +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_gpu_mc_rmsnorm(int const op, int const flagoffset, int const firstrank, int const myrank, + int const gpustep, const size_t lineoffset, int const numlines, void** commbuff, int const handleridx, + float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, float4* mc_ptr_out, + size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset) { cudaTriggerProgrammaticLaunchCompletion(); __shared__ float s_variance; @@ -853,7 +854,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ (threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(DType) + j)); i++; } - MULTIMEM_ST(val[g], mc_ptr + (lineoffset + line + g * loop_step0)); + MULTIMEM_ST(val[g], mc_ptr_out + (out_lineoffset + line + g * loop_step0)); } } __syncthreads(); @@ -870,6 +871,94 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ *reduceidptr = reduce_id; } // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_gpu_mc_rmsnorm_oneshot(int const op, int const flagoffset, int const firstrank, + int const myrank, int const gpustep, const size_t lineoffset, int const numlines, void** commbuff, + int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, + uint4* uc_ptr_out, size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset) +{ + cudaTriggerProgrammaticLaunchCompletion(); + __shared__ float s_variance; + int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType); + int *flagptr, physgpu, targetgpu, *myptr; + int *reduceidptr, reduce_id; + if (threadIdx.x < RANKS) + { + physgpu = myrank * gpustep + firstrank; + targetgpu = threadIdx.x * gpustep + firstrank; + int const blockflagoffset = MAX_NVLINK * 2 * blockIdx.x; + myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; + reduceidptr = myptr - MAX_OPS; //+op; + reduce_id = next_flag(*reduceidptr); + flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset + blockflagoffset; + myptr += blockflagoffset; + cudaGridDependencySynchronize(); + flagptr[physgpu] = reduce_id; + multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]); + } + __syncthreads(); + + int const loop_step0 = blockDim.x; + int const loop_step = loop_step0 * UNROLL_NLINES * gridDim.x; + int const start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL_NLINES; + int const end_elem = max(start_elem, numlines); + + for (int line = start_elem; line < end_elem; line += loop_step) + { + uint4 val[UNROLL_NLINES]; + DType* x = reinterpret_cast(&val[0]); +#pragma unroll + for (int i = 0; i < UNROLL_NLINES; i++) + MULTIMEM_LD(val[i], mc_ptr + (lineoffset + line + i * loop_step0)); + + if (residual_in != nullptr) + { +#pragma unroll + for (int i = 0; i < UNROLL_NLINES; i++) + { + uint4 resval = residual_in[res_offset + line + i * loop_step0]; + DType* y = reinterpret_cast(&resval); +#pragma unroll + for (int j = 0; j < 8; j++) + x[i * 8 + j] += y[j]; + residual_out[res_offset + line + i * loop_step0] = val[i]; + } + } + + float local_var_sum = 0.0f; + for (int j = 0; j < UNROLL_NLINES * sizeof(int4) / sizeof(DType); j++) + local_var_sum += (float) (x[j]) * (float) (x[j]); + + float packed[1] = {local_var_sum}; + blockReduceSumV2(packed); + float variance = packed[0]; + + if (threadIdx.x == 0) + { + variance = (variance / hidden_dim); // Var[x] = E[x²] + s_variance = rsqrtf(variance + eps); + } + __syncthreads(); + + int i = 0; +#pragma unroll + for (int g = 0; g < UNROLL_NLINES; g++) + { +#pragma unroll + for (int j = 0; j < sizeof(int4) / sizeof(DType); j++) + { + x[i] = cuda_cast(compute_rmsnorm2((float) (x[i]), s_variance, gamma, beta, + (threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(DType) + j)); + i++; + } + uc_ptr_out[out_lineoffset + line + g * loop_step0] = val[g]; + } + } + if (threadIdx.x == 0 && blockIdx.x == 0) + *reduceidptr = reduce_id; +} // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused oneshot + template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant(int const op, int const flagoffset, int const firstrank, int const myrank, int const gpustep, size_t const lineoffset, @@ -1127,10 +1216,22 @@ __global__ void __launch_bounds__(MAX_THREADS) #else template -__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm(int const op, - int const flagoffset, int const firstrank, int const myrank, int const gpustep, const size_t lineoffset, - int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, - float const eps, int const RANKS, uint4* residual_in, uint4* residual_out, int res_offset) +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_gpu_mc_rmsnorm(int const op, int const flagoffset, int const firstrank, int const myrank, + int const gpustep, const size_t lineoffset, int const numlines, void** commbuff, int const handleridx, + float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, float4* uc_ptr_out, + size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset) +{ + printf("userbuffer based kernels not implemented when SM < 90\n"); + asm volatile("brkpt;\n"); +} + +template +__global__ void __launch_bounds__(MAX_THREADS) + userbuffers_fp16_sum_gpu_mc_rmsnorm_oneshot(int const op, int const flagoffset, int const firstrank, + int const myrank, int const gpustep, const size_t lineoffset, int const numlines, void** commbuff, + int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, + uint4* uc_ptr_out, size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset) { printf("userbuffer based kernels not implemented when SM < 90\n"); asm volatile("brkpt;\n"); @@ -1339,17 +1440,50 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ DType* arg12 = (DType*) gamma; \ float arg13 = eps; \ int arg14 = ar_nvsize; \ - void* arg15 = residual_in; \ - void* arg16 = residual_out; \ - int arg17 = first_token * hidden_lines; \ + void* arg15 = comm->mc_ptr[out_handler]; \ + size_t arg16 = out_offset / 8 + first_token * hidden_lines; \ + void* arg17 = residual_in; \ + void* arg18 = residual_out; \ + int arg19 = first_token * hidden_lines; \ void* kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), reinterpret_cast(&arg5), \ reinterpret_cast(&arg6), reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10), reinterpret_cast(&arg11), \ reinterpret_cast(&arg12), reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ - reinterpret_cast(&arg15), reinterpret_cast(&arg16), reinterpret_cast(&arg17)}; \ + reinterpret_cast(&arg15), reinterpret_cast(&arg16), reinterpret_cast(&arg17), \ + reinterpret_cast(&arg18), reinterpret_cast(&arg19)}; \ TLLM_CUDA_CHECK(cudaLaunchKernelExC( \ - &cfg, (void*) (userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm), kernelArgs)); \ + &cfg, (void*) (userbuffers_fp16_sum_gpu_mc_rmsnorm), kernelArgs)); \ + } + +#define callranksMC_RMSNORM_ONESHOT(x) \ + if (nlines == x) \ + { \ + int arg1 = userbuffers_allreduceop_nonsharp2 - MAX_OPS, arg2 = REG0_OFFSET(comm) - REG0_SINGLENODE + MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step; \ + size_t arg6 = offset / 8; \ + int arg7 = elements / 8; \ + void** arg8 = (void**) (comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + void* arg10 = comm->mc_ptr[handler]; \ + DType* arg11 = (DType*) beta; \ + DType* arg12 = (DType*) gamma; \ + float arg13 = eps; \ + int arg14 = ar_nvsize; \ + void* arg15 = comm->mem_ptr[out_handler]; \ + size_t arg16 = out_offset / 8; \ + void* arg17 = residual_in; \ + void* arg18 = residual_out; \ + int arg19 = 0; \ + void* kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), reinterpret_cast(&arg5), \ + reinterpret_cast(&arg6), reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), reinterpret_cast(&arg11), \ + reinterpret_cast(&arg12), reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15), reinterpret_cast(&arg16), reinterpret_cast(&arg17), \ + reinterpret_cast(&arg18), reinterpret_cast(&arg19)}; \ + TLLM_CUDA_CHECK(cudaLaunchKernelExC( \ + &cfg, (void*) (userbuffers_fp16_sum_gpu_mc_rmsnorm_oneshot), kernelArgs)); \ } template @@ -1436,8 +1570,9 @@ bool use_oneshot_kernel(communicator* comm, size_t elements, int hidden_size) } template -int allreduce2_userbuff_inplace_rmsnorm(int const handler, int const offset, int const elements, int const hidden_size, - void* beta, void* gamma, float eps, void* residual_in, void* residual_out, communicator* comm, cudaStream_t stream) +int allreduce2_userbuff_rmsnorm(int const handler, int const offset, int const out_handler, size_t const out_offset, + int const elements, int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, + void* residual_out, communicator* comm, cudaStream_t stream) { int const ar_firstgpu = comm->tp_first_rank; int const ar_step = 1; @@ -1465,7 +1600,15 @@ int allreduce2_userbuff_inplace_rmsnorm(int const handler, int const offset, int auto& cfg = launch_config.get(); if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranksMC_RMSNORM(1) callranksMC_RMSNORM(2) callranksMC_RMSNORM(3) callranksMC_RMSNORM(4) + if (use_oneshot_kernel(comm, elements, hidden_size)) + { + callranksMC_RMSNORM_ONESHOT(1) callranksMC_RMSNORM_ONESHOT(2) callranksMC_RMSNORM_ONESHOT(3) + callranksMC_RMSNORM_ONESHOT(4) + } + else + { + callranksMC_RMSNORM(1) callranksMC_RMSNORM(2) callranksMC_RMSNORM(3) callranksMC_RMSNORM(4) + } } else { @@ -1678,9 +1821,9 @@ int allgather2_userbuff_residual_impl(int const handler, size_t const offset, si } } -int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const offset, size_t const elements, - int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out, - nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream) +int allreduce2_userbuff_rmsnorm_impl(int const handler, size_t const offset, int const out_handler, + size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps, + void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream) { switch (dataType) { @@ -1688,13 +1831,13 @@ int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const off { if (kDISABLE_FP32_ACCUMULATION) { - return allreduce2_userbuff_inplace_rmsnorm( - handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream); + return allreduce2_userbuff_rmsnorm(handler, offset, out_handler, out_offset, elements, + hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream); } else { - return allreduce2_userbuff_inplace_rmsnorm( - handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream); + return allreduce2_userbuff_rmsnorm(handler, offset, out_handler, out_offset, elements, + hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream); } break; } @@ -1703,18 +1846,18 @@ int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const off { if (kDISABLE_FP32_ACCUMULATION) { - return allreduce2_userbuff_inplace_rmsnorm<__nv_bfloat16, true>( - handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream); + return allreduce2_userbuff_rmsnorm<__nv_bfloat16, true>(handler, offset, out_handler, out_offset, elements, + hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream); } else { - return allreduce2_userbuff_inplace_rmsnorm<__nv_bfloat16, false>( - handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream); + return allreduce2_userbuff_rmsnorm<__nv_bfloat16, false>(handler, offset, out_handler, out_offset, elements, + hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream); } break; } #endif - default: TLLM_THROW("Unsupported dataType for allreduce2_userbuff_inplace_rmsnorm_impl"); + default: TLLM_THROW("Unsupported dataType for allreduce2_userbuff_rmsnorm_impl"); } } diff --git a/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.h b/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.h index a4e43a9062..9751f969d5 100644 --- a/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.h +++ b/cpp/tensorrt_llm/kernels/userbuffers/userbuffers.h @@ -125,9 +125,9 @@ int allgather2_userbuff_residual_impl(int const handler, size_t const offset, si int const hidden_size, void* residual, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream, bool force_enable); -int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const offset, size_t const elements, - int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out, - nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream); +int allreduce2_userbuff_rmsnorm_impl(int const handler, size_t const offset, int const out_handler, + size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps, + void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream); int allreduce2_userbuff_inplace_rmsnorm_quant_impl(int const handler, size_t const offset, int const out_handler, size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps, diff --git a/cpp/tensorrt_llm/thop/allreduceOp.cpp b/cpp/tensorrt_llm/thop/allreduceOp.cpp index 4a69ea044a..ab70055114 100644 --- a/cpp/tensorrt_llm/thop/allreduceOp.cpp +++ b/cpp/tensorrt_llm/thop/allreduceOp.cpp @@ -261,11 +261,11 @@ public: TLLM_CHECK(mAffine); TLLM_CHECK(!mBias); TLLM_CHECK(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16); - tensorrt_llm::kernels::ub::allreduce2_userbuff_inplace_rmsnorm_launcher(ub_buffer0.handle, 0, size, - hidden_size, nullptr, gamma, mEps, residual, output.data_ptr(), mType, ub_comm, stream); + tensorrt_llm::kernels::ub::allreduce2_userbuff_rmsnorm_launcher(ub_buffer0.handle, 0, ub_buffer1.handle, + 0, size, hidden_size, nullptr, gamma, mEps, residual, output.data_ptr(), mType, ub_comm, stream); auto dt = input.scalar_type(); finalOutput = torch::from_blob( - ub_buffer0.addr, input.sizes(), input.strides(), torch::dtype(dt).device(torch::kCUDA)); + ub_buffer1.addr, input.sizes(), input.strides(), torch::dtype(dt).device(torch::kCUDA)); } } else if (runtimeStrategy == AllReduceStrategyType::NCCL) diff --git a/tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py b/tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py index 35c6268b83..e8cfbea312 100644 --- a/tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py +++ b/tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py @@ -132,7 +132,7 @@ def register_ub_allreduce_finalize(custom_pass: PatternMatcherPass): ) trtllm_userbuffers_allreduce_finalize_default = CallFunction( torch.ops.trtllm.userbuffers_allreduce_finalize.default, - KeywordArg("sharded_residual"), Ignored()) + KeywordArg("sharded_residual"), False) trtllm_ub_scaled_mm_allreduce_quant_scaled_mm_op_default = CallFunction( torch.ops.trtllm.ub_scaled_mm_allreduce_quant_scaled_mm_op.default, KeywordArg("mm0_a"), diff --git a/tests/unittest/_torch/multi_gpu/test_user_buffers.py b/tests/unittest/_torch/multi_gpu/test_user_buffers.py index ffa2e13dae..33e025c53b 100644 --- a/tests/unittest/_torch/multi_gpu/test_user_buffers.py +++ b/tests/unittest/_torch/multi_gpu/test_user_buffers.py @@ -142,8 +142,9 @@ def run_single_rank_ar_rms_norm(tensor_parallel_size, a, b, c, gamma): residual=c, norm_weight=gamma, eps=eps) - res, residual = ar.forward(hidden, all_reduce_params=ar_params) - residual = userbuffers_allreduce_finalize(residual, True) + res_ub, residual = ar.forward(hidden, all_reduce_params=ar_params) + res = res_ub.clone() + residual = userbuffers_allreduce_finalize(residual, False) torch.cuda.synchronize() if rank == 0: