From 1294ecb12f1e5fb1eb418dd4a60b01c2360aca06 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Wed, 30 Apr 2025 23:51:09 -0700 Subject: [PATCH] Add attention workspace memory check (#3970) Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com> --- cpp/tensorrt_llm/thop/attentionOp.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 181f792d46..5b63bc5f64 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -522,6 +522,17 @@ torch::Tensor attention(torch::Tensor q, torch::optional k, torch = beam_width == 1 ? attention_window_size : cache_indirection.value().size(2); int64_t const workspace_size = runner->getWorkspaceSize(*op, num_tokens, max_attention_window_size, num_gen_tokens); TLLM_LOG_TRACE("Expected workspace size is %ld bytes", workspace_size); + + if (workspace_size >= (16l << 30)) + { + auto const [free_mem, total_mem] = tensorrt_llm::common::getDeviceMemoryInfo(false); + if (workspace_size >= static_cast(free_mem)) + { + throw std::runtime_error("attention workspace size " + std::to_string(workspace_size) + + " bytes, exceeds available CUDA memory " + std::to_string(free_mem) + " bytes"); + } + } + torch::Tensor workspace; if (workspace_.has_value()) {