TensorRT-LLMs/cpp/kernels/xqa/test/warmup.cu
Jinyang Yuan 20d0649f19
[feat] Support XQA-based MLA on SM120 (#4858)
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Co-authored-by: Yao Yao <lowsfer@users.noreply.github.com>
Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com>
2025-06-06 22:32:49 +08:00

19 lines
464 B
Plaintext

#include "../utils.h"
#include <cstdint>
#include <cuda_runtime.h>
__global__ void kernel_warmup(uint64_t cycles)
{
uint64_t const tic = clock64();
while (tic + cycles < clock64())
{
}
}
void warmup(cudaDeviceProp const& prop, float ms, cudaStream_t stream = nullptr)
{
uint64_t const nbCycles = std::round(prop.clockRate * ms); // clockRate is in kHz
kernel_warmup<<<16, 128, 0, stream>>>(nbCycles);
checkCuda(cudaGetLastError());
}