Add attention workspace memory check (#3970)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com>
This commit is contained in:
hlu1 2025-04-30 23:51:09 -07:00 committed by GitHub
parent 6ded5f984b
commit 1294ecb12f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -522,6 +522,17 @@ torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch
= beam_width == 1 ? attention_window_size : cache_indirection.value().size(2);
int64_t const workspace_size = runner->getWorkspaceSize(*op, num_tokens, max_attention_window_size, num_gen_tokens);
TLLM_LOG_TRACE("Expected workspace size is %ld bytes", workspace_size);
if (workspace_size >= (16l << 30))
{
auto const [free_mem, total_mem] = tensorrt_llm::common::getDeviceMemoryInfo(false);
if (workspace_size >= static_cast<int64_t const>(free_mem))
{
throw std::runtime_error("attention workspace size " + std::to_string(workspace_size)
+ " bytes, exceeds available CUDA memory " + std::to_string(free_mem) + " bytes");
}
}
torch::Tensor workspace;
if (workspace_.has_value())
{