TensorRT-LLMs/cpp/kernels/xqa/pairedF32Op.cuh
Ming Wei ed887940d4
infra: open source XQA kernels (#3762)
Replace libtensorrt_llm_nvrtc_wrapper.so with its source code, which
consists of two parts:

1. NVRTC glue code
2. XQA kernel code

During TensorRT-LLM build, XQA kernel code is embedded as C++ arries via
gen_cpp_header.py and passed to NVRTC for JIT compilation.

Signed-off-by: Ming Wei <2345434+ming-wei@users.noreply.github.com>
2025-04-30 18:05:15 +08:00

57 lines
2.7 KiB
Plaintext

#include <vector_types.h>
extern "C"
{
/*
* Rounding mode modifiers:
* _rn : round to nearest even (default)
* _rm : round towards negative infinity
* _rp : round towards positive infinity
* _rz : round towards zero
*
* _ftz : flush denormalized values to zero
*/
/*
* FFMA2 - fused multiply-add
*/
__device__ float2 __nv_ptx_builtin_ocg_ffma2(float2 a, float2 b, float2 c);
__device__ float2 __nv_ptx_builtin_ocg_ffma2_rn(float2 a, float2 b, float2 c);
__device__ float2 __nv_ptx_builtin_ocg_ffma2_rm(float2 a, float2 b, float2 c);
__device__ float2 __nv_ptx_builtin_ocg_ffma2_rp(float2 a, float2 b, float2 c);
__device__ float2 __nv_ptx_builtin_ocg_ffma2_rz(float2 a, float2 b, float2 c);
__device__ float2 __nv_ptx_builtin_ocg_ffma2_ftz(float2 a, float2 b, float2 c);
__device__ float2 __nv_ptx_builtin_ocg_ffma2_ftz_rn(float2 a, float2 b, float2 c);
__device__ float2 __nv_ptx_builtin_ocg_ffma2_ftz_rm(float2 a, float2 b, float2 c);
__device__ float2 __nv_ptx_builtin_ocg_ffma2_ftz_rp(float2 a, float2 b, float2 c);
__device__ float2 __nv_ptx_builtin_ocg_ffma2_ftz_rz(float2 a, float2 b, float2 c);
/*
* FADD2 - add
*/
__device__ float2 __nv_ptx_builtin_ocg_fadd2(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fadd2_rn(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fadd2_rm(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fadd2_rp(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fadd2_rz(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fadd2_ftz(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fadd2_ftz_rn(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fadd2_ftz_rm(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fadd2_ftz_rp(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fadd2_ftz_rz(float2 a, float2 b);
/*
* FMUL2 - multiply
*/
__device__ float2 __nv_ptx_builtin_ocg_fmul2(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fmul2_rn(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fmul2_rm(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fmul2_rp(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fmul2_rz(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fmul2_ftz(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fmul2_ftz_rn(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fmul2_ftz_rm(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fmul2_ftz_rp(float2 a, float2 b);
__device__ float2 __nv_ptx_builtin_ocg_fmul2_ftz_rz(float2 a, float2 b);
}