mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com> Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Co-authored-by: Yao Yao <lowsfer@users.noreply.github.com> Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com>
115 lines
3.4 KiB
C++
115 lines
3.4 KiB
C++
/*
|
|
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/common/opUtils.h"
|
|
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h"
|
|
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
|
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
|
|
#include "tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunner.h"
|
|
|
|
using namespace tensorrt_llm::common;
|
|
using tensorrt_llm::common::op::UniqPtrWNullCopy;
|
|
|
|
namespace tensorrt_llm::kernels
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct XqaFixedParams
|
|
{
|
|
// Whether the attention is MLA.
|
|
bool isMLA;
|
|
// The QKV input data type.
|
|
kernels::Data_type inputDataType;
|
|
// The XQA KV cache data type.
|
|
kernels::Data_type kvDataType;
|
|
// The XQA output data type.
|
|
kernels::Data_type outputDataType;
|
|
// The XQA BMM dtype.
|
|
kernels::Data_type mathDataType;
|
|
// The number of Q heads.
|
|
int numQHeads;
|
|
// The number of Kv Heads.
|
|
int numKvHeads;
|
|
// The number of tokens per kv cache block.
|
|
int numTokensPerBlock;
|
|
// The head size.
|
|
int headSize;
|
|
// The scaling applied to bmm1_scale.
|
|
float qScaling;
|
|
// Whether to enable multi block mode.
|
|
bool multiBlockMode;
|
|
// The KV cache layout.
|
|
bool isPagedKv;
|
|
// Is speculative decoding enabled.
|
|
bool isSpecDecoding;
|
|
// Do we apply alibi ?
|
|
bool hasAlibi;
|
|
};
|
|
|
|
class XqaDispatcher
|
|
{
|
|
public:
|
|
// Constructor.
|
|
XqaDispatcher(XqaFixedParams fixedParams);
|
|
|
|
// Deconstructor.
|
|
~XqaDispatcher() = default;
|
|
|
|
// Prepare for DecoderXQARunner.
|
|
void prepare(XQAParams const& params);
|
|
|
|
// Check whether XQA is supported.
|
|
bool isSupported();
|
|
|
|
// Run the XQA kernel.
|
|
void run(XQAParams const& params, KVLinearBuffer const& kv_cache_buffer);
|
|
|
|
void run(XQAParams const& params, KVBlockArray const& kv_cache_buffer);
|
|
|
|
int getWorkspaceAlignment();
|
|
|
|
size_t getWorkspaceSize(int max_num_tokens);
|
|
|
|
bool shouldUse(XQAParams const& params);
|
|
|
|
private:
|
|
// The fixed XQA parameters.
|
|
XqaFixedParams mFixedParams;
|
|
// The data type of tensor Q, which determines the Q input data type of fmha kernels.
|
|
Data_type mQDataType;
|
|
// Whether to enable trtllm-gen kernels.
|
|
bool mUseTllmGen;
|
|
// The multi-processor count.
|
|
int mMultiProcessorCount;
|
|
// Runner for decoder XQA kernels (for SM <= 90)
|
|
UniqPtrWNullCopy<DecoderXQARunner> mDecoderXqaRunner;
|
|
// Runner for trtllm-gen XQA kernels (for SM == 100)
|
|
UniqPtrWNullCopy<TllmGenFmhaRunner> mTllmGenFMHARunner;
|
|
|
|
protected:
|
|
template <typename T, typename KVCacheBuffer>
|
|
void runImpl(XQAParams params, KVCacheBuffer const& kv_cache_buffer);
|
|
};
|
|
|
|
constexpr uint32_t xqaMlaCgaXBufSize = 8704 * 2;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace tensorrt_llm::kernels
|