TensorRT-LLMs/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp
Omer Ullman Argov 8731f5f14f
chore: Mass integration of release/0.20 (#4898)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
Signed-off-by: Hui Gao <huig@nvidia.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
Signed-off-by: Ruodi <200874449+ruodil@users.noreply.github.com>
Signed-off-by: ruodil <200874449+ruodil@users.noreply.github.com>
Signed-off-by: Stanley Sun <190317771+StanleySun639@users.noreply.github.com>
Signed-off-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
Signed-off-by: Anurag Mukkara <134339030+amukkara@users.noreply.github.com>
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Signed-off-by: moraxu <mguzek@nvidia.com>
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Co-authored-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
Co-authored-by: Yiqing Yan <yiqingy@nvidia.com>
Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
Co-authored-by: HuiGao-NV <huig@nvidia.com>
Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com>
Co-authored-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com>
Co-authored-by: ruodil <200874449+ruodil@users.noreply.github.com>
Co-authored-by: Stanley Sun <190317771+StanleySun639@users.noreply.github.com>
Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
Co-authored-by: Anurag Mukkara <134339030+amukkara@users.noreply.github.com>
Co-authored-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Co-authored-by: Faraz <58580514+farazkh80@users.noreply.github.com>
Co-authored-by: Michal Guzek <moraxu@users.noreply.github.com>
Co-authored-by: Larry <197874197+LarryXFly@users.noreply.github.com>
Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com>
2025-06-08 23:26:26 +08:00

170 lines
5.5 KiB
C++

/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "decoderXQARunner.h"
#include <assert.h>
#include <string.h>
#include <mutex>
#include <unordered_map>
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/workspace.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/cubin/xqa_kernel_cubin.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAConstants.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
namespace tensorrt_llm
{
namespace kernels
{
DecoderXQARunner::DecoderXQARunner(
const XQADataType data_type, int num_heads, int num_kv_heads, int head_size, bool multi_block_mode)
: mDataType(data_type)
, mNumHeads(num_heads)
, mNumKVHeads(num_kv_heads)
, mHeadSize(head_size)
, mMultiBlockMode(multi_block_mode)
{
mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
// TODO: needs both impls because medusa kernels haven't been migrated to JIT yet (which should be).
// mJITImpl/mPrecompiledImpl assignments must be the last lines of this constructor. DecoderXQAImpl::create() relies
// on *this being fully initialized.
mJITImpl = DecoderXQAImpl::create(this, DecoderXQAImpl::ImplType::kJIT);
mPrecompiledImpl = DecoderXQAImpl::create(this, DecoderXQAImpl::ImplType::kPrecompiled);
}
DecoderXQARunner::~DecoderXQARunner() = default;
namespace
{
template <typename T>
constexpr inline T divUp(T a, T b)
{
return (a + b - 1) / b;
}
template <typename T>
constexpr inline T roundUp(T a, T b)
{
return divUp(a, b) * b;
}
} // namespace
DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParams, bool for_configure_plugin)
{
int const smVersion = tensorrt_llm::common::getSMVersion();
if (xqaParams.multi_query_tokens)
{
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
// Ampere XQA supports spec dec with pre-compiled cubins (may also work with JIT but not implemented yet)
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
// now.
bool const supportedByHopperXqa
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && 64 % grpSize == 0);
bool const supportedBySm120Mla
= (smVersion == 120 && xqaParams.isMLA() && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
return (supportedByHopperXqa || supportedBySm120Mla) ? mJITImpl.get() : mPrecompiledImpl.get();
}
std::optional<bool> envEnableXQAJIT = tensorrt_llm::common::getEnvEnableXQAJIT();
if (envEnableXQAJIT.has_value())
{
return envEnableXQAJIT.value() ? mJITImpl.get() : mPrecompiledImpl.get();
}
else
{
return mJITImpl.get();
}
}
bool DecoderXQARunner::shouldUse(XQAParams const& xqa_params, bool for_configure_plugin)
{
return getImplFromXQAParams(xqa_params, for_configure_plugin)->shouldUse(xqa_params, for_configure_plugin);
}
void DecoderXQARunner::prepareForRun(XQAParams const& xqa_params)
{
return getImplFromXQAParams(xqa_params, true)->prepare(xqa_params);
}
template <typename KVCacheBuffer>
void DecoderXQARunner::run(
XQAParams const& xqa_params, KVCacheBuffer const& kv_cache_buffer, cudaStream_t const& stream)
{
return getImplFromXQAParams(xqa_params, false)->run(xqa_params, kv_cache_buffer, stream);
}
DecoderXQARunner::Resource* DecoderXQARunner::getResourceGlobal()
{
static DecoderXQARunner::Resource sResource;
return &sResource;
}
template void DecoderXQARunner::run(
XQAParams const& xqa_params, KVLinearBuffer const& kv_linear_buffer, cudaStream_t const& stream);
template void DecoderXQARunner::run(
XQAParams const& xqa_params, KVBlockArray const& kv_block_array, cudaStream_t const& stream);
//// DecoderXQARunner::Resource
DecoderXQARunner::Resource::Resource()
: mCubinObjRegistry(std::make_unique<jit::CubinObjRegistry>())
{
}
DecoderXQARunner::Resource::Resource(DecoderXQARunner::Resource const& other)
: mCubinObjRegistry(other.mCubinObjRegistry->clone())
{
}
DecoderXQARunner::Resource& DecoderXQARunner::Resource::operator=(DecoderXQARunner::Resource const& other)
{
if (this == &other)
{
return *this;
}
mCubinObjRegistry = other.mCubinObjRegistry->clone();
return *this;
}
DecoderXQARunner::Resource::Resource(void const* buffer, size_t buffer_size)
: mCubinObjRegistry(std::make_unique<jit::CubinObjRegistry>(buffer, buffer_size))
{
}
size_t DecoderXQARunner::Resource::getSerializationSize() const noexcept
{
return mCubinObjRegistry->getSerializationSize();
}
void DecoderXQARunner::Resource::serialize(void* buffer, size_t buffer_size) const noexcept
{
mCubinObjRegistry->serialize(buffer, buffer_size);
}
} // namespace kernels
} // namespace tensorrt_llm