mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][fix] Fix out-of-bounds array access in kernel factory Get() methods (#11373)
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
2d5ebb3fe8
commit
2c4a4c7b94
@ -16,6 +16,7 @@
|
||||
|
||||
#include "fused_multihead_attention_v2.h"
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
@ -152,10 +153,9 @@ TFusedMHAKernelList const* TFusedMHAKernelFactory<TFusedMHAKernelList>::getXMMAK
|
||||
template <typename TFusedMHAKernelList>
|
||||
TFusedMHAKernelFactory<TFusedMHAKernelList>& TFusedMHAKernelFactory<TFusedMHAKernelList>::Get()
|
||||
{
|
||||
int device_id;
|
||||
cudaGetDevice(&device_id);
|
||||
static std::unique_ptr<TFusedMHAKernelFactory<TFusedMHAKernelList>> s_factory[32] = {nullptr};
|
||||
TLLM_CHECK(device_id <= 32);
|
||||
int const device_id = tensorrt_llm::common::getDevice();
|
||||
TLLM_CHECK_WITH_INFO(device_id < 32, "Invalid device_id %d (must be < 32)", device_id);
|
||||
if (s_factory[device_id] == nullptr)
|
||||
{
|
||||
s_factory[device_id] = std::make_unique<TFusedMHAKernelFactory<TFusedMHAKernelList>>(
|
||||
|
||||
@ -394,11 +394,11 @@ public:
|
||||
|
||||
static XQAKernelLoader& Get()
|
||||
{
|
||||
int device_id = tensorrt_llm::common::getDevice();
|
||||
static std::unique_ptr<XQAKernelLoader> s_factory[32] = {nullptr};
|
||||
int const device_id = tensorrt_llm::common::getDevice();
|
||||
TLLM_CHECK_WITH_INFO(device_id < 32, "Invalid device_id %d (must be < 32)", device_id);
|
||||
if (s_factory[device_id] == nullptr)
|
||||
{
|
||||
assert(device_id <= 32);
|
||||
s_factory[device_id] = std::make_unique<XQAKernelLoader>(XQAKernelLoader());
|
||||
}
|
||||
|
||||
|
||||
@ -925,12 +925,11 @@ public:
|
||||
|
||||
static TllmFmhaKernelFactory& Get()
|
||||
{
|
||||
int deviceId;
|
||||
cudaGetDevice(&deviceId);
|
||||
static std::unique_ptr<TllmFmhaKernelFactory> sFactory[32] = {nullptr};
|
||||
int const deviceId = tensorrt_llm::common::getDevice();
|
||||
TLLM_CHECK_WITH_INFO(deviceId < 32, "Invalid deviceId %d (must be < 32)", deviceId);
|
||||
if (sFactory[deviceId] == nullptr)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(deviceId < 32, "Invalid deviceId %d", deviceId);
|
||||
sFactory[deviceId] = std::make_unique<TllmFmhaKernelFactory>(TllmFmhaKernelFactory());
|
||||
}
|
||||
return *(sFactory[deviceId]);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user