mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][fix] Respect CUDA_LAUNCH_BLOCKING by fixing doCheckError (#11261)
Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com>
This commit is contained in:
parent
c37531c3f7
commit
100bfdc516
@ -151,26 +151,6 @@ void checkEx(
|
||||
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
|
||||
#define check_cuda_error_2(val, file, line) check((val), #val, file, line)
|
||||
|
||||
inline std::optional<bool> isCudaLaunchBlocking()
|
||||
{
|
||||
thread_local bool firstCall = true;
|
||||
thread_local std::optional<bool> result = std::nullopt;
|
||||
if (!firstCall)
|
||||
{
|
||||
char const* env = std::getenv("CUDA_LAUNCH_BLOCKING");
|
||||
if (env != nullptr && std::string(env) == "1")
|
||||
{
|
||||
result = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
result = false;
|
||||
}
|
||||
firstCall = false;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline bool isCapturing(cudaStream_t stream)
|
||||
{
|
||||
cudaStreamCaptureStatus status;
|
||||
@ -180,21 +160,23 @@ inline bool isCapturing(cudaStream_t stream)
|
||||
|
||||
inline bool doCheckError(cudaStream_t stream)
|
||||
{
|
||||
auto const cudaLaunchBlocking = isCudaLaunchBlocking();
|
||||
if (cudaLaunchBlocking.has_value() && cudaLaunchBlocking.value())
|
||||
// If we're capturing a CUDA graph we don't check. Otherwise, we
|
||||
// default to only checking in debug builds. But we always listen to
|
||||
// the env variable.
|
||||
static bool const doCheckIfNotCapturing = []()
|
||||
{
|
||||
return !isCapturing(stream);
|
||||
}
|
||||
|
||||
char const* env = std::getenv("CUDA_LAUNCH_BLOCKING");
|
||||
if (env != nullptr)
|
||||
{
|
||||
return std::string(env) == "1";
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
// Debug builds will sync when we're not capturing unless explicitly
|
||||
// disabled.
|
||||
bool const checkError = cudaLaunchBlocking.value_or(!isCapturing(stream));
|
||||
return true;
|
||||
#else
|
||||
bool const checkError = cudaLaunchBlocking.value_or(false);
|
||||
return false;
|
||||
#endif
|
||||
|
||||
return checkError;
|
||||
}();
|
||||
return doCheckIfNotCapturing && !isCapturing(stream);
|
||||
}
|
||||
|
||||
inline void syncAndCheck(cudaStream_t stream, char const* const file, int const line)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user