diff --git a/cpp/include/tensorrt_llm/common/cudaUtils.h b/cpp/include/tensorrt_llm/common/cudaUtils.h index 3a11df85b1..cd58a7abb5 100644 --- a/cpp/include/tensorrt_llm/common/cudaUtils.h +++ b/cpp/include/tensorrt_llm/common/cudaUtils.h @@ -151,26 +151,6 @@ void checkEx( #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) #define check_cuda_error_2(val, file, line) check((val), #val, file, line) -inline std::optional isCudaLaunchBlocking() -{ - thread_local bool firstCall = true; - thread_local std::optional result = std::nullopt; - if (!firstCall) - { - char const* env = std::getenv("CUDA_LAUNCH_BLOCKING"); - if (env != nullptr && std::string(env) == "1") - { - result = true; - } - else - { - result = false; - } - firstCall = false; - } - return result; -} - inline bool isCapturing(cudaStream_t stream) { cudaStreamCaptureStatus status; @@ -180,21 +160,23 @@ inline bool isCapturing(cudaStream_t stream) inline bool doCheckError(cudaStream_t stream) { - auto const cudaLaunchBlocking = isCudaLaunchBlocking(); - if (cudaLaunchBlocking.has_value() && cudaLaunchBlocking.value()) + // If we're capturing a CUDA graph we don't check. Otherwise, we + // default to only checking in debug builds. But we always listen to + // the env variable. + static bool const doCheckIfNotCapturing = []() { - return !isCapturing(stream); - } - + char const* env = std::getenv("CUDA_LAUNCH_BLOCKING"); + if (env != nullptr) + { + return std::string(env) == "1"; + } #ifndef NDEBUG - // Debug builds will sync when we're not capturing unless explicitly - // disabled. - bool const checkError = cudaLaunchBlocking.value_or(!isCapturing(stream)); + return true; #else - bool const checkError = cudaLaunchBlocking.value_or(false); + return false; #endif - - return checkError; + }(); + return doCheckIfNotCapturing && !isCapturing(stream); } inline void syncAndCheck(cudaStream_t stream, char const* const file, int const line)