mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Add attention workspace memory check (#3970)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com>
This commit is contained in:
parent
6ded5f984b
commit
1294ecb12f
@ -522,6 +522,17 @@ torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch
|
||||
= beam_width == 1 ? attention_window_size : cache_indirection.value().size(2);
|
||||
int64_t const workspace_size = runner->getWorkspaceSize(*op, num_tokens, max_attention_window_size, num_gen_tokens);
|
||||
TLLM_LOG_TRACE("Expected workspace size is %ld bytes", workspace_size);
|
||||
|
||||
if (workspace_size >= (16l << 30))
|
||||
{
|
||||
auto const [free_mem, total_mem] = tensorrt_llm::common::getDeviceMemoryInfo(false);
|
||||
if (workspace_size >= static_cast<int64_t const>(free_mem))
|
||||
{
|
||||
throw std::runtime_error("attention workspace size " + std::to_string(workspace_size)
|
||||
+ " bytes, exceeds available CUDA memory " + std::to_string(free_mem) + " bytes");
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor workspace;
|
||||
if (workspace_.has_value())
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user