From 1592dfab6daaf17e983aa63043eddca58891a5c2 Mon Sep 17 00:00:00 2001 From: HuiGao-NV Date: Wed, 21 Jan 2026 14:17:29 +0800 Subject: [PATCH] [https://nvbugs/5740377][fix] Lock resource to fix potential access to released data (#10827) Signed-off-by: Hui Gao --- cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h | 5 ++++- cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp | 6 ++++++ tests/integration/test_lists/waives.txt | 3 --- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 94717307b6..9a48b3391e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -368,6 +368,9 @@ private: std::optional mExpirationTime; // Hash for the event manager size_t mHash; + + // Mutex for the next blocks + mutable std::mutex mNextBlocksMutex; }; class GenerationRequest @@ -1021,7 +1024,7 @@ private: std::shared_ptr mKvCacheConnectorManager; // Mutex for the cached blocks root - std::mutex mCachedBlocksRootMutex; + mutable std::mutex mCachedBlocksRootMutex; // Record which sequence is using the block std::map mBlockToSequence; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 1ce9a08a91..4138e4c605 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -416,6 +416,7 @@ void KVCacheBlock::setPrevBlockInSeq(BlockPtr prevBlock) void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block) { + std::lock_guard lock(mNextBlocksMutex); if (mNextBlocks.find(blockKey) == mNextBlocks.end()) { mNextBlocks[blockKey] = std::move(block); @@ -425,6 +426,8 @@ void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block) std::tuple KVCacheBlock::findMatchingBlock( BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const { + std::lock_guard lock(mNextBlocksMutex); + if (blockKey.uniqueTokens.size() == 0 || mNextBlocks.size() == 0) { return {false, 0, nullptr}; @@ -474,11 +477,13 @@ void KVCacheBlock::freeLeafBlock() void KVCacheBlock::removeNextBlock(BlockKey const& blockKey) { + std::lock_guard lock(mNextBlocksMutex); mNextBlocks.erase(blockKey); } void KVCacheBlock::freeDescendantsRecursively() { + std::lock_guard lock(mNextBlocksMutex); bool hasChildren = !mNextBlocks.empty(); if (hasChildren) { @@ -1176,6 +1181,7 @@ std::optional WindowBlockManager::findNewContextBlock( auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest); BlockKey ret; ret.loraTaskId = llmRequest.getLoraTaskId(); + std::lock_guard lock(mCachedBlocksRootMutex); auto searchRoot = mCachedBlocksRoot; for (auto const& blockKey : blockKeys) { diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index d1a28d25db..b98d108a53 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -239,13 +239,10 @@ full:sm89/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] SKIP (https://nvbugs/5721672) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True] SKIP (https://nvbugs/5741304) unittest/executor/test_rpc.py::TestRpcCorrectness::test_incremental_task_async SKIP (https://nvbugs/5741476) -accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5740377) -accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] SKIP (https://nvbugs/5740377) test_e2e.py::test_trtllm_bench_llmapi_launch[pytorch_backend-llama-v3-llama3-8b] SKIP (https://nvbugs/5744432) test_e2e.py::test_trtllm_serve_multimodal_example SKIP (https://nvbugs/5747920) test_e2e.py::test_trtllm_serve_example SKIP (https://nvbugs/5747938) triton_server/test_triton.py::test_opt[opt] SKIP (https://nvbugs/5739981) -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=TRTLLM] SKIP (https://nvbugs/5740377) cpp/test_multi_gpu.py::TestDisagg::test_symmetric_executor[gpt-2proc-mpi_kvcache-90] SKIP (https://nvbugs/5755941) examples/test_granite.py::test_llm_granite[granite-3.0-1b-a400m-instruct-bfloat16] SKIP (https://nvbugs/5608979) examples/test_granite.py::test_llm_granite[granite-3.0-2b-instruct-bfloat16] SKIP (https://nvbugs/5608979)