TensorRT-LLMs/cpp/kernels/xqa/hostUtils.h
Ming Wei ed887940d4
infra: open source XQA kernels (#3762)
Replace libtensorrt_llm_nvrtc_wrapper.so with its source code, which
consists of two parts:

1. NVRTC glue code
2. XQA kernel code

During TensorRT-LLM build, XQA kernel code is embedded as C++ arries via
gen_cpp_header.py and passed to NVRTC for JIT compilation.

Signed-off-by: Ming Wei <2345434+ming-wei@users.noreply.github.com>
2025-04-30 18:05:15 +08:00

14 lines
467 B
C

#pragma once
#include <cuda_runtime.h>
inline cudaLaunchConfig_t makeLaunchConfig(
dim3 const& gridDim, dim3 const& ctaDim, size_t dynShmBytes, cudaStream_t stream, bool useFDL)
{
static cudaLaunchAttribute fdlAttr;
fdlAttr.id = cudaLaunchAttributeProgrammaticStreamSerialization;
fdlAttr.val.programmaticStreamSerializationAllowed = (useFDL ? 1 : 0);
cudaLaunchConfig_t cfg{gridDim, ctaDim, dynShmBytes, stream, &fdlAttr, 1};
return cfg;
}