mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
312 lines
10 KiB
C++
312 lines
10 KiB
C++
/*
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
* Common utils to be shared between Precompiled and JIT implementation.
|
|
*/
|
|
#pragma once
|
|
#include "decoderXQAConstants.h"
|
|
#include "tensorrt_llm/common/envUtils.h"
|
|
#include "tensorrt_llm/common/workspace.h"
|
|
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
|
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
|
|
#include "xqaParams.h"
|
|
#include <cstddef>
|
|
#include <utility>
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
struct XQAKernelLoadHashKey
|
|
{
|
|
Data_type data_type;
|
|
unsigned int sm;
|
|
|
|
bool operator==(XQAKernelLoadHashKey const& other) const
|
|
{
|
|
return data_type == other.data_type && sm == other.sm;
|
|
}
|
|
};
|
|
|
|
struct XQAKernelLoadHasher
|
|
{
|
|
size_t operator()(XQAKernelLoadHashKey const& s) const
|
|
{
|
|
size_t key = s.data_type;
|
|
key <<= 16;
|
|
key ^= s.sm;
|
|
return key;
|
|
}
|
|
};
|
|
|
|
struct XQAKernelRuntimeHashKey
|
|
{
|
|
Data_type kv_data_type;
|
|
unsigned int head_size;
|
|
unsigned int beam_size;
|
|
unsigned int num_q_heads_per_kv;
|
|
unsigned int m_tilesize;
|
|
unsigned int tokens_per_page;
|
|
bool paged_kv_cache;
|
|
bool multi_query_tokens;
|
|
|
|
bool operator==(XQAKernelRuntimeHashKey const& other) const
|
|
{
|
|
return kv_data_type == other.kv_data_type && head_size == other.head_size
|
|
&& num_q_heads_per_kv == other.num_q_heads_per_kv && beam_size == other.beam_size
|
|
&& multi_query_tokens == other.multi_query_tokens && m_tilesize == other.m_tilesize
|
|
&& tokens_per_page == other.tokens_per_page && paged_kv_cache == other.paged_kv_cache;
|
|
}
|
|
};
|
|
|
|
struct XQAKernelRuntimeHasher
|
|
{
|
|
size_t operator()(XQAKernelRuntimeHashKey const& s) const
|
|
{
|
|
size_t key = s.kv_data_type;
|
|
key <<= 16;
|
|
key ^= s.head_size;
|
|
key <<= 8;
|
|
key ^= s.num_q_heads_per_kv;
|
|
key <<= 8;
|
|
key ^= s.beam_size;
|
|
key <<= 6;
|
|
key ^= s.m_tilesize;
|
|
key <<= 10;
|
|
key ^= s.tokens_per_page;
|
|
key <<= 1;
|
|
key ^= s.paged_kv_cache;
|
|
key <<= 1;
|
|
key ^= s.multi_query_tokens;
|
|
return key;
|
|
}
|
|
};
|
|
|
|
// XQA kernel can be uniquely identified by (LoadHashKey, RuntimeHashKey).
|
|
struct XQAKernelFullHashKey
|
|
{
|
|
XQAKernelLoadHashKey load_key;
|
|
XQAKernelRuntimeHashKey runtime_key;
|
|
|
|
XQAKernelFullHashKey() = default;
|
|
|
|
XQAKernelFullHashKey(XQAKernelLoadHashKey const& load_key, XQAKernelRuntimeHashKey const& runtime_key)
|
|
: load_key(load_key)
|
|
, runtime_key(runtime_key)
|
|
{
|
|
}
|
|
|
|
XQAKernelFullHashKey(void const* buffer, size_t buffer_size)
|
|
{
|
|
TLLM_CHECK(sizeof(*this) <= buffer_size);
|
|
memcpy(this, buffer, sizeof(*this));
|
|
}
|
|
|
|
bool operator==(XQAKernelFullHashKey const& other) const
|
|
{
|
|
return load_key == other.load_key && runtime_key == other.runtime_key;
|
|
}
|
|
|
|
size_t getSerializationSize() const
|
|
{
|
|
return sizeof(*this);
|
|
}
|
|
|
|
void serialize(void* buffer, size_t buffer_size) const
|
|
{
|
|
TLLM_CHECK(sizeof(*this) <= buffer_size);
|
|
memcpy(buffer, this, sizeof(*this));
|
|
}
|
|
};
|
|
|
|
struct XQAKernelFullHasher
|
|
{
|
|
size_t operator()(XQAKernelFullHashKey const& s) const
|
|
{
|
|
return XQAKernelLoadHasher()(s.load_key) ^ XQAKernelRuntimeHasher()(s.runtime_key);
|
|
}
|
|
};
|
|
|
|
// NOTE: we use int32_t sequence lengths as gpt attention plugins use int32_t for that.
|
|
// XQA kernels assume all length should use uint32_t.
|
|
// NOTE: Linear KV cache and paged KV cache uses the same structure.
|
|
|
|
template <typename KVCacheBuffer>
|
|
struct KVCache
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct KVCache<KVBlockArray>
|
|
{
|
|
// Start address of the paged kv block pool.
|
|
void* poolPtr = nullptr;
|
|
// Block indices in the memory pool.
|
|
int32_t const* blockIndices = nullptr;
|
|
int32_t const* sequence_lengths = nullptr;
|
|
// NOTE: max_num_blocks_per_sequence for paged kv cache.
|
|
uint32_t capacity = 0;
|
|
|
|
KVCache(KVBlockArray& kv_cache_buffer)
|
|
{
|
|
poolPtr = kv_cache_buffer.mPrimaryPoolPtr;
|
|
blockIndices = reinterpret_cast<KVCacheIndex::UnderlyingType const*>(kv_cache_buffer.data);
|
|
}
|
|
|
|
KVCache() = default;
|
|
};
|
|
|
|
template <>
|
|
struct KVCache<KVLinearBuffer>
|
|
{
|
|
// Buffer address.
|
|
void* data = nullptr;
|
|
int32_t const* sequence_lengths = nullptr;
|
|
// NOTE: max_sequence_length for linear kv cache.
|
|
uint32_t capacity = 0;
|
|
|
|
KVCache(KVLinearBuffer& kv_cache_buffer)
|
|
{
|
|
data = kv_cache_buffer.data;
|
|
}
|
|
|
|
KVCache() = default;
|
|
};
|
|
|
|
struct BeamSearchParams
|
|
{
|
|
int32_t const* indices; // cacheIndir with shape: [batchSize][beamWidth][capacity]
|
|
int32_t capacity;
|
|
int32_t const* ctxLenList; // shape: [batchSize][beamWidth]. Should be [batchSize] but we have to match trt-llm API.
|
|
};
|
|
|
|
// XQA kernels assume all integer values should use uint32_t.
|
|
template <typename KVCacheBuffer>
|
|
struct XQALaunchParam
|
|
{
|
|
uint32_t num_k_heads;
|
|
void* output;
|
|
void const* qkv;
|
|
KVCache<KVCacheBuffer> kvCacheParams;
|
|
std::optional<BeamSearchParams> beamSearchParams;
|
|
uint32_t batch_size;
|
|
float const* kv_scale_quant_orig = nullptr;
|
|
int* cu_seq_lens = nullptr;
|
|
float* rotary_inv_freq_buf = nullptr;
|
|
void* scratch = nullptr;
|
|
};
|
|
|
|
// Setup launch params.
|
|
template <typename KVCacheBuffer>
|
|
void buildXQALaunchParams(
|
|
XQALaunchParam<KVCacheBuffer>& launchParams, XQAParams const& params, KVCacheBuffer kv_cache_buffer)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(
|
|
params.data_type == DATA_TYPE_FP16 || params.data_type == DATA_TYPE_BF16, "Only fp16 or bf16 supported now.");
|
|
memset(&launchParams, 0, sizeof(XQALaunchParam<KVCacheBuffer>));
|
|
launchParams.num_k_heads = params.num_kv_heads;
|
|
launchParams.output = static_cast<uint8_t*>(params.output);
|
|
launchParams.qkv = static_cast<uint8_t const*>(params.qkv);
|
|
launchParams.batch_size = params.batch_size;
|
|
launchParams.kv_scale_quant_orig = params.kv_scale_quant_orig;
|
|
|
|
// Workspace.
|
|
size_t offset = 0;
|
|
int8_t* workspace = reinterpret_cast<int8_t*>(params.workspaces);
|
|
unsigned int batch_beam_size = params.batch_size * params.beam_width;
|
|
const size_t cu_seqlens_size = sizeof(int) * (batch_beam_size + 1);
|
|
const size_t rotary_inv_freq_size = sizeof(float) * batch_beam_size * params.rotary_embedding_dim / 2;
|
|
launchParams.cu_seq_lens = reinterpret_cast<int*>(workspace);
|
|
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, cu_seqlens_size);
|
|
launchParams.rotary_inv_freq_buf = reinterpret_cast<float*>(workspace);
|
|
auto const multi_block_workspace_alignment = tensorrt_llm::common::roundUp(
|
|
sizeof(half) * params.head_size * (params.num_q_heads / params.num_kv_heads) * params.beam_width, 128);
|
|
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(
|
|
workspace, rotary_inv_freq_size, multi_block_workspace_alignment);
|
|
launchParams.scratch = reinterpret_cast<void*>(workspace);
|
|
|
|
launchParams.kvCacheParams = KVCache<KVCacheBuffer>(kv_cache_buffer);
|
|
launchParams.kvCacheParams.sequence_lengths = params.sequence_lengths;
|
|
launchParams.kvCacheParams.capacity
|
|
= params.paged_kv_cache ? params.max_blocks_per_sequence : params.max_attention_window_size;
|
|
// TODO: beam searching has not been implemented yet.
|
|
if (params.beam_width > 1)
|
|
{
|
|
launchParams.beamSearchParams
|
|
= BeamSearchParams{params.cache_indir, params.max_attention_window_size, params.context_lengths};
|
|
}
|
|
else
|
|
{
|
|
launchParams.beamSearchParams = std::nullopt;
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
std::optional<T> getGlobalVar(std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> const& driver, CUmodule hmod,
|
|
char const* const name, bool required = false)
|
|
{
|
|
T* pVar = nullptr;
|
|
size_t size = 0;
|
|
auto const error = driver->cuModuleGetGlobal(reinterpret_cast<CUdeviceptr*>(&pVar), &size, hmod, name);
|
|
T ret;
|
|
switch (error)
|
|
{
|
|
case CUDA_SUCCESS:
|
|
TLLM_CHECK(size == sizeof(T));
|
|
tensorrt_llm::common::check_cuda_error(cudaMemcpy(&ret, pVar, size, cudaMemcpyDeviceToHost));
|
|
break;
|
|
case CUDA_ERROR_NOT_FOUND:
|
|
if (!required)
|
|
{
|
|
return std::nullopt;
|
|
}
|
|
[[fallthrough]];
|
|
default: cuErrCheck(("Failed to retrieve global variable from cubin.", error), driver);
|
|
}
|
|
return std::optional<T>{std::move(ret)};
|
|
}
|
|
|
|
inline int computeMultiBlockCount(XQAParams const& xqaParams, int batch_size, int multiprocessor_count)
|
|
{
|
|
if (tensorrt_llm::common::envXqaNbCtaPerKVHead().has_value())
|
|
{
|
|
return tensorrt_llm::common::envXqaNbCtaPerKVHead().value();
|
|
}
|
|
int multi_block_count = 1;
|
|
int num_kv_heads = xqaParams.num_kv_heads;
|
|
int history_length = xqaParams.timestep;
|
|
|
|
multi_block_count = history_length / kMinHistoryTokensPerBlock;
|
|
multi_block_count = std::max(multi_block_count, 1);
|
|
// adjust to kTargetWaveFactor, as already initialized using kMinHistoryTokensPerBlock, only need to decrease.
|
|
double wave_count = (double) batch_size * num_kv_heads * multi_block_count / (double) multiprocessor_count;
|
|
double adj_factor = wave_count / (double) kTargetWaveFactor;
|
|
if (adj_factor > 1.0)
|
|
{
|
|
multi_block_count = floor(multi_block_count / adj_factor);
|
|
}
|
|
multi_block_count = std::max(multi_block_count, 1);
|
|
|
|
// add limitation on upper bound.
|
|
multi_block_count = std::min(tensorrt_llm::common::xqaMaxNbCtaPerKVHeadFactor(), multi_block_count);
|
|
|
|
TLLM_CHECK_WITH_INFO(multi_block_count >= 1, "MultiBlock count should be larger than 1");
|
|
return multi_block_count;
|
|
}
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|