[None][fix] Fix out-of-bounds array access in kernel factory Get() methods (#11373)

Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Harris Nover 2026-02-11 17:21:01 -07:00 committed by GitHub
parent 2d5ebb3fe8
commit 2c4a4c7b94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 7 additions and 8 deletions

View File

@ -16,6 +16,7 @@
#include "fused_multihead_attention_v2.h"
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <algorithm>
#include <cmath>
@ -152,10 +153,9 @@ TFusedMHAKernelList const* TFusedMHAKernelFactory<TFusedMHAKernelList>::getXMMAK
template <typename TFusedMHAKernelList>
TFusedMHAKernelFactory<TFusedMHAKernelList>& TFusedMHAKernelFactory<TFusedMHAKernelList>::Get()
{
int device_id;
cudaGetDevice(&device_id);
static std::unique_ptr<TFusedMHAKernelFactory<TFusedMHAKernelList>> s_factory[32] = {nullptr};
TLLM_CHECK(device_id <= 32);
int const device_id = tensorrt_llm::common::getDevice();
TLLM_CHECK_WITH_INFO(device_id < 32, "Invalid device_id %d (must be < 32)", device_id);
if (s_factory[device_id] == nullptr)
{
s_factory[device_id] = std::make_unique<TFusedMHAKernelFactory<TFusedMHAKernelList>>(

View File

@ -394,11 +394,11 @@ public:
static XQAKernelLoader& Get()
{
int device_id = tensorrt_llm::common::getDevice();
static std::unique_ptr<XQAKernelLoader> s_factory[32] = {nullptr};
int const device_id = tensorrt_llm::common::getDevice();
TLLM_CHECK_WITH_INFO(device_id < 32, "Invalid device_id %d (must be < 32)", device_id);
if (s_factory[device_id] == nullptr)
{
assert(device_id <= 32);
s_factory[device_id] = std::make_unique<XQAKernelLoader>(XQAKernelLoader());
}

View File

@ -925,12 +925,11 @@ public:
static TllmFmhaKernelFactory& Get()
{
int deviceId;
cudaGetDevice(&deviceId);
static std::unique_ptr<TllmFmhaKernelFactory> sFactory[32] = {nullptr};
int const deviceId = tensorrt_llm::common::getDevice();
TLLM_CHECK_WITH_INFO(deviceId < 32, "Invalid deviceId %d (must be < 32)", deviceId);
if (sFactory[deviceId] == nullptr)
{
TLLM_CHECK_WITH_INFO(deviceId < 32, "Invalid deviceId %d", deviceId);
sFactory[deviceId] = std::make_unique<TllmFmhaKernelFactory>(TllmFmhaKernelFactory());
}
return *(sFactory[deviceId]);