diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.cpp index bef395a4e3..4bd4407f0e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.cpp @@ -16,6 +16,7 @@ #include "fused_multihead_attention_v2.h" #include "tensorrt_llm/common/config.h" +#include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" #include #include @@ -152,10 +153,9 @@ TFusedMHAKernelList const* TFusedMHAKernelFactory::getXMMAK template TFusedMHAKernelFactory& TFusedMHAKernelFactory::Get() { - int device_id; - cudaGetDevice(&device_id); static std::unique_ptr> s_factory[32] = {nullptr}; - TLLM_CHECK(device_id <= 32); + int const device_id = tensorrt_llm::common::getDevice(); + TLLM_CHECK_WITH_INFO(device_id < 32, "Invalid device_id %d (must be < 32)", device_id); if (s_factory[device_id] == nullptr) { s_factory[device_id] = std::make_unique>( diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp index 822483560c..146aa4bdba 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp @@ -394,11 +394,11 @@ public: static XQAKernelLoader& Get() { - int device_id = tensorrt_llm::common::getDevice(); static std::unique_ptr s_factory[32] = {nullptr}; + int const device_id = tensorrt_llm::common::getDevice(); + TLLM_CHECK_WITH_INFO(device_id < 32, "Invalid device_id %d (must be < 32)", device_id); if (s_factory[device_id] == nullptr) { - assert(device_id <= 32); s_factory[device_id] = std::make_unique(XQAKernelLoader()); } diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h index b858fa047f..e3ad8f99ab 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h @@ -925,12 +925,11 @@ public: static TllmFmhaKernelFactory& Get() { - int deviceId; - cudaGetDevice(&deviceId); static std::unique_ptr sFactory[32] = {nullptr}; + int const deviceId = tensorrt_llm::common::getDevice(); + TLLM_CHECK_WITH_INFO(deviceId < 32, "Invalid deviceId %d (must be < 32)", deviceId); if (sFactory[deviceId] == nullptr) { - TLLM_CHECK_WITH_INFO(deviceId < 32, "Invalid deviceId %d", deviceId); sFactory[deviceId] = std::make_unique(TllmFmhaKernelFactory()); } return *(sFactory[deviceId]);