TensorRT-LLMs/cpp/kernels/xqa/tensorMap.h
Jinyang Yuan 20d0649f19
[feat] Support XQA-based MLA on SM120 (#4858)
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Co-authored-by: Yao Yao <lowsfer@users.noreply.github.com>
Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com>
2025-06-06 22:32:49 +08:00

13 lines
549 B
C

#pragma once
#include <cuda.h>
uint32_t getElemBytes(CUtensorMapDataType_enum dataType);
CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems,
uint32_t nbKHeads, uint32_t maxCacheLen, uint32_t beamWidth, uint32_t batchSize, uint32_t partElems,
uint32_t nbTokens);
CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems,
uint32_t nbKHeads, uint32_t tokensPerPage, uint32_t partElems, uint32_t nbTokensPerTile);