TensorRT-LLMs/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h
Kaiyu Xie bca9a33b02
Update TensorRT-LLM (#2008)
* Update TensorRT-LLM

---------

Co-authored-by: Timur Abishev <abishev.timur@gmail.com>
Co-authored-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
Co-authored-by: Saeyoon Oh <saeyoon.oh@furiosa.ai>
Co-authored-by: hattizai <hattizai@gmail.com>
2024-07-23 23:05:09 +08:00

161 lines
4.5 KiB
C++

/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <NvInferRuntime.h>
#include <cuda_fp16.h>
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/cubinObjRegistry.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
using namespace tensorrt_llm::common;
namespace tensorrt_llm
{
namespace kernels
{
template <typename T, typename KVCacheBuffer>
struct XQADispatchHelper
{
static constexpr bool CanSupport = false;
};
template <>
struct XQADispatchHelper<__half, KVLinearBuffer>
{
static constexpr bool CanSupport = true;
};
template <>
struct XQADispatchHelper<__half, KVBlockArray>
{
static constexpr bool CanSupport = true;
};
#ifdef ENABLE_BF16
template <>
struct XQADispatchHelper<__nv_bfloat16, KVLinearBuffer>
{
static constexpr bool CanSupport = true;
};
template <>
struct XQADispatchHelper<__nv_bfloat16, KVBlockArray>
{
static constexpr bool CanSupport = true;
};
#endif
class DecoderXQARunner
{
public:
DecoderXQARunner(
const XQADataType data_type, int num_heads, int num_kv_heads, int head_size, bool multi_block_mode);
~DecoderXQARunner();
/**
* \param[in] xqaParams the xqaParams to be tested against.
* \param[in] forConfigurePlugin indicates whether this method is called in configurePlugin, or in
* enqueueGeneration.
*/
bool shouldUse(XQAParams const& xqaParams, bool forConfigurePlugin);
size_t getWorkspaceSize(int max_batch_beam_size, int max_num_tokens);
void prepare(XQAParams const& xqa_params)
{
this->prepareForRun(xqa_params);
}
template <typename KVCacheBuffer>
void dispatch(XQAParams const& xqa_params, KVCacheBuffer const& kv_cache_buffer, cudaStream_t const& stream)
{
sync_check_cuda_error();
this->run(xqa_params, kv_cache_buffer, stream);
}
class Resource;
static Resource* getResourceGlobal();
private:
void prepareForRun(XQAParams const& xqa_params);
template <typename KVCacheBuffer>
void run(XQAParams const& xqa_params, KVCacheBuffer const& kv_cache_buffer, cudaStream_t const& stream);
static constexpr int kMaxBeamWidth = 4;
XQADataType mDataType;
int mNumHeads;
int mNumKVHeads;
int mHeadSize;
bool mMultiBlockMode;
int mMultiProcessorCount;
std::unique_ptr<DecoderXQAImpl> mJITImpl, mPrecompiledImpl;
DecoderXQAImpl* getImplFromXQAParams(XQAParams const& params, bool for_configure_plugin);
friend DecoderXQAImplPrecompiled;
friend DecoderXQAImplJIT;
};
class DecoderXQARunner::Resource
{
public:
Resource();
Resource(Resource const& other);
Resource& operator=(Resource const& other);
Resource(Resource&& other) = default;
Resource& operator=(Resource&& other) = default;
// Construct from a serialized buffer.
Resource(void const* buffer, size_t buffer_size);
~Resource() = default;
void merge(Resource const& other)
{
getCubinObjRegistry()->merge(*other.getCubinObjRegistry());
}
jit::CubinObjRegistry* getCubinObjRegistry()
{
return mCubinObjRegistry.get();
}
jit::CubinObjRegistry const* getCubinObjRegistry() const
{
return mCubinObjRegistry.get();
}
size_t getSerializationSize() const noexcept;
void serialize(void* buffer, size_t buffer_size) const noexcept;
private:
std::unique_ptr<jit::CubinObjRegistry> mCubinObjRegistry;
};
} // namespace kernels
} // namespace tensorrt_llm