mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
None - Add one-shot version for UB AR NORM FP16/BF16 (#2995)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
parent
794f61c997
commit
e0d0dde058
@ -89,12 +89,12 @@ int allgather2_userbuff_residual_launcher(int const handler, size_t const offset
|
||||
handler, offset, elements, hidden_size, residual, dataType, comm, stream, force_enable);
|
||||
}
|
||||
|
||||
int allreduce2_userbuff_inplace_rmsnorm_launcher(int const handler, size_t const offset, size_t const elements,
|
||||
int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out,
|
||||
nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
|
||||
int allreduce2_userbuff_rmsnorm_launcher(int const handler, size_t const offset, int const out_handler,
|
||||
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
|
||||
void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
|
||||
{
|
||||
return allreduce2_userbuff_inplace_rmsnorm_impl(
|
||||
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, dataType, comm, stream);
|
||||
return allreduce2_userbuff_rmsnorm_impl(handler, offset, out_handler, out_offset, elements, hidden_size, beta,
|
||||
gamma, eps, residual_in, residual_out, dataType, comm, stream);
|
||||
}
|
||||
|
||||
int allreduce2_userbuff_inplace_rmsnorm_quant_launcher(int const handler, size_t const offset, int const out_handler,
|
||||
|
||||
@ -41,9 +41,9 @@ int allgather2_userbuff_residual_launcher(int const handler, size_t const offset
|
||||
int const hidden_size, void* residual, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream,
|
||||
bool force_enable = false);
|
||||
|
||||
int allreduce2_userbuff_inplace_rmsnorm_launcher(int const handler, size_t const offset, size_t const elements,
|
||||
int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out,
|
||||
nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream);
|
||||
int allreduce2_userbuff_rmsnorm_launcher(int const handler, size_t const offset, int const out_handler,
|
||||
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
|
||||
void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream);
|
||||
|
||||
int allreduce2_userbuff_inplace_rmsnorm_quant_launcher(int const handler, size_t const offset, int const out_handler,
|
||||
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
|
||||
|
||||
@ -773,10 +773,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
|
||||
#if __CUDA_ARCH__ >= 900
|
||||
|
||||
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
|
||||
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm(int const op,
|
||||
int const flagoffset, int const firstrank, int const myrank, int const gpustep, const size_t lineoffset,
|
||||
int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma,
|
||||
float const eps, int const RANKS, uint4* residual_in, uint4* residual_out, int res_offset)
|
||||
__global__ void __launch_bounds__(MAX_THREADS)
|
||||
userbuffers_fp16_sum_gpu_mc_rmsnorm(int const op, int const flagoffset, int const firstrank, int const myrank,
|
||||
int const gpustep, const size_t lineoffset, int const numlines, void** commbuff, int const handleridx,
|
||||
float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, float4* mc_ptr_out,
|
||||
size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
__shared__ float s_variance;
|
||||
@ -853,7 +854,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
|
||||
(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(DType) + j));
|
||||
i++;
|
||||
}
|
||||
MULTIMEM_ST(val[g], mc_ptr + (lineoffset + line + g * loop_step0));
|
||||
MULTIMEM_ST(val[g], mc_ptr_out + (out_lineoffset + line + g * loop_step0));
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@ -870,6 +871,94 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
|
||||
*reduceidptr = reduce_id;
|
||||
} // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused
|
||||
|
||||
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
|
||||
__global__ void __launch_bounds__(MAX_THREADS)
|
||||
userbuffers_fp16_sum_gpu_mc_rmsnorm_oneshot(int const op, int const flagoffset, int const firstrank,
|
||||
int const myrank, int const gpustep, const size_t lineoffset, int const numlines, void** commbuff,
|
||||
int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS,
|
||||
uint4* uc_ptr_out, size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
__shared__ float s_variance;
|
||||
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
|
||||
int *flagptr, physgpu, targetgpu, *myptr;
|
||||
int *reduceidptr, reduce_id;
|
||||
if (threadIdx.x < RANKS)
|
||||
{
|
||||
physgpu = myrank * gpustep + firstrank;
|
||||
targetgpu = threadIdx.x * gpustep + firstrank;
|
||||
int const blockflagoffset = MAX_NVLINK * 2 * blockIdx.x;
|
||||
myptr = (reinterpret_cast<int*>(commbuff[physgpu])) + flagoffset;
|
||||
reduceidptr = myptr - MAX_OPS; //+op;
|
||||
reduce_id = next_flag(*reduceidptr);
|
||||
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
|
||||
myptr += blockflagoffset;
|
||||
cudaGridDependencySynchronize();
|
||||
flagptr[physgpu] = reduce_id;
|
||||
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int const loop_step0 = blockDim.x;
|
||||
int const loop_step = loop_step0 * UNROLL_NLINES * gridDim.x;
|
||||
int const start_elem = threadIdx.x + blockDim.x * blockIdx.x * UNROLL_NLINES;
|
||||
int const end_elem = max(start_elem, numlines);
|
||||
|
||||
for (int line = start_elem; line < end_elem; line += loop_step)
|
||||
{
|
||||
uint4 val[UNROLL_NLINES];
|
||||
DType* x = reinterpret_cast<DType*>(&val[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < UNROLL_NLINES; i++)
|
||||
MULTIMEM_LD<DType, DISABLE_FP32_ACC>(val[i], mc_ptr + (lineoffset + line + i * loop_step0));
|
||||
|
||||
if (residual_in != nullptr)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < UNROLL_NLINES; i++)
|
||||
{
|
||||
uint4 resval = residual_in[res_offset + line + i * loop_step0];
|
||||
DType* y = reinterpret_cast<DType*>(&resval);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++)
|
||||
x[i * 8 + j] += y[j];
|
||||
residual_out[res_offset + line + i * loop_step0] = val[i];
|
||||
}
|
||||
}
|
||||
|
||||
float local_var_sum = 0.0f;
|
||||
for (int j = 0; j < UNROLL_NLINES * sizeof(int4) / sizeof(DType); j++)
|
||||
local_var_sum += (float) (x[j]) * (float) (x[j]);
|
||||
|
||||
float packed[1] = {local_var_sum};
|
||||
blockReduceSumV2<float, 1>(packed);
|
||||
float variance = packed[0];
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
variance = (variance / hidden_dim); // Var[x] = E[x²]
|
||||
s_variance = rsqrtf(variance + eps);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int i = 0;
|
||||
#pragma unroll
|
||||
for (int g = 0; g < UNROLL_NLINES; g++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < sizeof(int4) / sizeof(DType); j++)
|
||||
{
|
||||
x[i] = cuda_cast<DType>(compute_rmsnorm2<DType>((float) (x[i]), s_variance, gamma, beta,
|
||||
(threadIdx.x + g * loop_step0) * sizeof(int4) / sizeof(DType) + j));
|
||||
i++;
|
||||
}
|
||||
uc_ptr_out[out_lineoffset + line + g * loop_step0] = val[g];
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
*reduceidptr = reduce_id;
|
||||
} // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused oneshot
|
||||
|
||||
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
|
||||
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm_quant(int const op,
|
||||
int const flagoffset, int const firstrank, int const myrank, int const gpustep, size_t const lineoffset,
|
||||
@ -1127,10 +1216,22 @@ __global__ void __launch_bounds__(MAX_THREADS)
|
||||
|
||||
#else
|
||||
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
|
||||
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm(int const op,
|
||||
int const flagoffset, int const firstrank, int const myrank, int const gpustep, const size_t lineoffset,
|
||||
int const numlines, void** commbuff, int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma,
|
||||
float const eps, int const RANKS, uint4* residual_in, uint4* residual_out, int res_offset)
|
||||
__global__ void __launch_bounds__(MAX_THREADS)
|
||||
userbuffers_fp16_sum_gpu_mc_rmsnorm(int const op, int const flagoffset, int const firstrank, int const myrank,
|
||||
int const gpustep, const size_t lineoffset, int const numlines, void** commbuff, int const handleridx,
|
||||
float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, float4* uc_ptr_out,
|
||||
size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
|
||||
{
|
||||
printf("userbuffer based kernels not implemented when SM < 90\n");
|
||||
asm volatile("brkpt;\n");
|
||||
}
|
||||
|
||||
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
|
||||
__global__ void __launch_bounds__(MAX_THREADS)
|
||||
userbuffers_fp16_sum_gpu_mc_rmsnorm_oneshot(int const op, int const flagoffset, int const firstrank,
|
||||
int const myrank, int const gpustep, const size_t lineoffset, int const numlines, void** commbuff,
|
||||
int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS,
|
||||
uint4* uc_ptr_out, size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
|
||||
{
|
||||
printf("userbuffer based kernels not implemented when SM < 90\n");
|
||||
asm volatile("brkpt;\n");
|
||||
@ -1339,17 +1440,50 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
|
||||
DType* arg12 = (DType*) gamma; \
|
||||
float arg13 = eps; \
|
||||
int arg14 = ar_nvsize; \
|
||||
void* arg15 = residual_in; \
|
||||
void* arg16 = residual_out; \
|
||||
int arg17 = first_token * hidden_lines; \
|
||||
void* arg15 = comm->mc_ptr[out_handler]; \
|
||||
size_t arg16 = out_offset / 8 + first_token * hidden_lines; \
|
||||
void* arg17 = residual_in; \
|
||||
void* arg18 = residual_out; \
|
||||
int arg19 = first_token * hidden_lines; \
|
||||
void* kernelArgs[] = {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), \
|
||||
reinterpret_cast<void*>(&arg3), reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), \
|
||||
reinterpret_cast<void*>(&arg6), reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), \
|
||||
reinterpret_cast<void*>(&arg9), reinterpret_cast<void*>(&arg10), reinterpret_cast<void*>(&arg11), \
|
||||
reinterpret_cast<void*>(&arg12), reinterpret_cast<void*>(&arg13), reinterpret_cast<void*>(&arg14), \
|
||||
reinterpret_cast<void*>(&arg15), reinterpret_cast<void*>(&arg16), reinterpret_cast<void*>(&arg17)}; \
|
||||
reinterpret_cast<void*>(&arg15), reinterpret_cast<void*>(&arg16), reinterpret_cast<void*>(&arg17), \
|
||||
reinterpret_cast<void*>(&arg18), reinterpret_cast<void*>(&arg19)}; \
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelExC( \
|
||||
&cfg, (void*) (userbuffers_fp16_sum_inplace_gpu_mc_rmsnorm<DType, x, DISABLE_FP32_ACC>), kernelArgs)); \
|
||||
&cfg, (void*) (userbuffers_fp16_sum_gpu_mc_rmsnorm<DType, x, DISABLE_FP32_ACC>), kernelArgs)); \
|
||||
}
|
||||
|
||||
#define callranksMC_RMSNORM_ONESHOT(x) \
|
||||
if (nlines == x) \
|
||||
{ \
|
||||
int arg1 = userbuffers_allreduceop_nonsharp2 - MAX_OPS, arg2 = REG0_OFFSET(comm) - REG0_SINGLENODE + MAX_OPS, \
|
||||
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step; \
|
||||
size_t arg6 = offset / 8; \
|
||||
int arg7 = elements / 8; \
|
||||
void** arg8 = (void**) (comm->gpu_ptrs); \
|
||||
int arg9 = handler * comm->nvsize; \
|
||||
void* arg10 = comm->mc_ptr[handler]; \
|
||||
DType* arg11 = (DType*) beta; \
|
||||
DType* arg12 = (DType*) gamma; \
|
||||
float arg13 = eps; \
|
||||
int arg14 = ar_nvsize; \
|
||||
void* arg15 = comm->mem_ptr[out_handler]; \
|
||||
size_t arg16 = out_offset / 8; \
|
||||
void* arg17 = residual_in; \
|
||||
void* arg18 = residual_out; \
|
||||
int arg19 = 0; \
|
||||
void* kernelArgs[] = {reinterpret_cast<void*>(&arg1), reinterpret_cast<void*>(&arg2), \
|
||||
reinterpret_cast<void*>(&arg3), reinterpret_cast<void*>(&arg4), reinterpret_cast<void*>(&arg5), \
|
||||
reinterpret_cast<void*>(&arg6), reinterpret_cast<void*>(&arg7), reinterpret_cast<void*>(&arg8), \
|
||||
reinterpret_cast<void*>(&arg9), reinterpret_cast<void*>(&arg10), reinterpret_cast<void*>(&arg11), \
|
||||
reinterpret_cast<void*>(&arg12), reinterpret_cast<void*>(&arg13), reinterpret_cast<void*>(&arg14), \
|
||||
reinterpret_cast<void*>(&arg15), reinterpret_cast<void*>(&arg16), reinterpret_cast<void*>(&arg17), \
|
||||
reinterpret_cast<void*>(&arg18), reinterpret_cast<void*>(&arg19)}; \
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelExC( \
|
||||
&cfg, (void*) (userbuffers_fp16_sum_gpu_mc_rmsnorm_oneshot<DType, x, DISABLE_FP32_ACC>), kernelArgs)); \
|
||||
}
|
||||
|
||||
template <typename DType, bool DISABLE_FP32_ACC>
|
||||
@ -1436,8 +1570,9 @@ bool use_oneshot_kernel(communicator* comm, size_t elements, int hidden_size)
|
||||
}
|
||||
|
||||
template <typename DType, bool DISABLE_FP32_ACC>
|
||||
int allreduce2_userbuff_inplace_rmsnorm(int const handler, int const offset, int const elements, int const hidden_size,
|
||||
void* beta, void* gamma, float eps, void* residual_in, void* residual_out, communicator* comm, cudaStream_t stream)
|
||||
int allreduce2_userbuff_rmsnorm(int const handler, int const offset, int const out_handler, size_t const out_offset,
|
||||
int const elements, int const hidden_size, void* beta, void* gamma, float eps, void* residual_in,
|
||||
void* residual_out, communicator* comm, cudaStream_t stream)
|
||||
{
|
||||
int const ar_firstgpu = comm->tp_first_rank;
|
||||
int const ar_step = 1;
|
||||
@ -1465,7 +1600,15 @@ int allreduce2_userbuff_inplace_rmsnorm(int const handler, int const offset, int
|
||||
auto& cfg = launch_config.get();
|
||||
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED))
|
||||
{
|
||||
callranksMC_RMSNORM(1) callranksMC_RMSNORM(2) callranksMC_RMSNORM(3) callranksMC_RMSNORM(4)
|
||||
if (use_oneshot_kernel(comm, elements, hidden_size))
|
||||
{
|
||||
callranksMC_RMSNORM_ONESHOT(1) callranksMC_RMSNORM_ONESHOT(2) callranksMC_RMSNORM_ONESHOT(3)
|
||||
callranksMC_RMSNORM_ONESHOT(4)
|
||||
}
|
||||
else
|
||||
{
|
||||
callranksMC_RMSNORM(1) callranksMC_RMSNORM(2) callranksMC_RMSNORM(3) callranksMC_RMSNORM(4)
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1678,9 +1821,9 @@ int allgather2_userbuff_residual_impl(int const handler, size_t const offset, si
|
||||
}
|
||||
}
|
||||
|
||||
int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const offset, size_t const elements,
|
||||
int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out,
|
||||
nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
|
||||
int allreduce2_userbuff_rmsnorm_impl(int const handler, size_t const offset, int const out_handler,
|
||||
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
|
||||
void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream)
|
||||
{
|
||||
switch (dataType)
|
||||
{
|
||||
@ -1688,13 +1831,13 @@ int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const off
|
||||
{
|
||||
if (kDISABLE_FP32_ACCUMULATION)
|
||||
{
|
||||
return allreduce2_userbuff_inplace_rmsnorm<half, true>(
|
||||
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
|
||||
return allreduce2_userbuff_rmsnorm<half, true>(handler, offset, out_handler, out_offset, elements,
|
||||
hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
return allreduce2_userbuff_inplace_rmsnorm<half, false>(
|
||||
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
|
||||
return allreduce2_userbuff_rmsnorm<half, false>(handler, offset, out_handler, out_offset, elements,
|
||||
hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -1703,18 +1846,18 @@ int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const off
|
||||
{
|
||||
if (kDISABLE_FP32_ACCUMULATION)
|
||||
{
|
||||
return allreduce2_userbuff_inplace_rmsnorm<__nv_bfloat16, true>(
|
||||
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
|
||||
return allreduce2_userbuff_rmsnorm<__nv_bfloat16, true>(handler, offset, out_handler, out_offset, elements,
|
||||
hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
return allreduce2_userbuff_inplace_rmsnorm<__nv_bfloat16, false>(
|
||||
handler, offset, elements, hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
|
||||
return allreduce2_userbuff_rmsnorm<__nv_bfloat16, false>(handler, offset, out_handler, out_offset, elements,
|
||||
hidden_size, beta, gamma, eps, residual_in, residual_out, comm, stream);
|
||||
}
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default: TLLM_THROW("Unsupported dataType for allreduce2_userbuff_inplace_rmsnorm_impl");
|
||||
default: TLLM_THROW("Unsupported dataType for allreduce2_userbuff_rmsnorm_impl");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -125,9 +125,9 @@ int allgather2_userbuff_residual_impl(int const handler, size_t const offset, si
|
||||
int const hidden_size, void* residual, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream,
|
||||
bool force_enable);
|
||||
|
||||
int allreduce2_userbuff_inplace_rmsnorm_impl(int const handler, size_t const offset, size_t const elements,
|
||||
int const hidden_size, void* beta, void* gamma, float eps, void* residual_in, void* residual_out,
|
||||
nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream);
|
||||
int allreduce2_userbuff_rmsnorm_impl(int const handler, size_t const offset, int const out_handler,
|
||||
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
|
||||
void* residual_in, void* residual_out, nvinfer1::DataType dataType, communicator* comm, cudaStream_t stream);
|
||||
|
||||
int allreduce2_userbuff_inplace_rmsnorm_quant_impl(int const handler, size_t const offset, int const out_handler,
|
||||
size_t const out_offset, size_t const elements, int const hidden_size, void* beta, void* gamma, float eps,
|
||||
|
||||
@ -261,11 +261,11 @@ public:
|
||||
TLLM_CHECK(mAffine);
|
||||
TLLM_CHECK(!mBias);
|
||||
TLLM_CHECK(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16);
|
||||
tensorrt_llm::kernels::ub::allreduce2_userbuff_inplace_rmsnorm_launcher(ub_buffer0.handle, 0, size,
|
||||
hidden_size, nullptr, gamma, mEps, residual, output.data_ptr(), mType, ub_comm, stream);
|
||||
tensorrt_llm::kernels::ub::allreduce2_userbuff_rmsnorm_launcher(ub_buffer0.handle, 0, ub_buffer1.handle,
|
||||
0, size, hidden_size, nullptr, gamma, mEps, residual, output.data_ptr(), mType, ub_comm, stream);
|
||||
auto dt = input.scalar_type();
|
||||
finalOutput = torch::from_blob(
|
||||
ub_buffer0.addr, input.sizes(), input.strides(), torch::dtype(dt).device(torch::kCUDA));
|
||||
ub_buffer1.addr, input.sizes(), input.strides(), torch::dtype(dt).device(torch::kCUDA));
|
||||
}
|
||||
}
|
||||
else if (runtimeStrategy == AllReduceStrategyType::NCCL)
|
||||
|
||||
@ -132,7 +132,7 @@ def register_ub_allreduce_finalize(custom_pass: PatternMatcherPass):
|
||||
)
|
||||
trtllm_userbuffers_allreduce_finalize_default = CallFunction(
|
||||
torch.ops.trtllm.userbuffers_allreduce_finalize.default,
|
||||
KeywordArg("sharded_residual"), Ignored())
|
||||
KeywordArg("sharded_residual"), False)
|
||||
trtllm_ub_scaled_mm_allreduce_quant_scaled_mm_op_default = CallFunction(
|
||||
torch.ops.trtllm.ub_scaled_mm_allreduce_quant_scaled_mm_op.default,
|
||||
KeywordArg("mm0_a"),
|
||||
|
||||
@ -142,8 +142,9 @@ def run_single_rank_ar_rms_norm(tensor_parallel_size, a, b, c, gamma):
|
||||
residual=c,
|
||||
norm_weight=gamma,
|
||||
eps=eps)
|
||||
res, residual = ar.forward(hidden, all_reduce_params=ar_params)
|
||||
residual = userbuffers_allreduce_finalize(residual, True)
|
||||
res_ub, residual = ar.forward(hidden, all_reduce_params=ar_params)
|
||||
res = res_ub.clone()
|
||||
residual = userbuffers_allreduce_finalize(residual, False)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if rank == 0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user