diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index c4c6659294..939c7741fb 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -891,6 +891,11 @@ public: return mIsSWA; } + [[nodiscard]] bool isEnablePartialReuse() const + { + return mEnablePartialReuse; + } + [[nodiscard]] std::shared_ptr findBlocksInReuseTreeByBlockKey(BlockKey const& blockKey); //! \brief Unpin blocks by block ids directly @@ -1078,6 +1083,11 @@ public: return mIndexerKCacheIndexHeadDim; } + [[nodiscard]] bool isEnablePartialReuse() const + { + return mWindowBlockManagers.begin()->second.isEnablePartialReuse(); + } + BlockManager(BlockManager const&) = delete; BlockManager& operator=(BlockManager const&) = delete; @@ -1565,6 +1575,8 @@ public: [[nodiscard]] virtual bool isEnableBlockReuse() const = 0; + [[nodiscard]] virtual bool isEnablePartialReuse() const = 0; + [[nodiscard]] virtual bool isEnableIndexerKCache() const = 0; [[nodiscard]] virtual SizeType32 getIndexerKCacheIndexHeadDim() const = 0; [[nodiscard]] virtual SizeType32 getIndexerKCacheQuantBlockSize() const = 0; @@ -1912,6 +1924,11 @@ public: return mEnableBlockReuse; } + [[nodiscard]] bool isEnablePartialReuse() const override + { + return mBlockManager.isEnablePartialReuse(); + } + [[nodiscard]] bool isEnableIndexerKCache() const override { return mBlockManager.isEnableIndexerKCache(); diff --git a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h index 4a67520b96..aedbbec13e 100644 --- a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h +++ b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h @@ -51,7 +51,8 @@ public: CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig, std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableBlockReuse = false, - bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128) + bool enablePartialReuse = false, bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0, + SizeType32 indexerKCacheQuantBlockSize = 128) : mModelConfig(std::move(modelConfig)) , mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(), worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), @@ -60,6 +61,7 @@ public: , mAttentionConfig(attentionType, kvFactor) { mEnableBlockReuse = enableBlockReuse; + mEnablePartialReuse = enablePartialReuse; mHasIndexerKCache = hasIndexerKCache; mIndexerDimPerHead = indexerDimPerHead; mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize; @@ -69,8 +71,8 @@ public: SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism, std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false, - int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool hasIndexerKCache = false, - SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128) + int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool enablePartialReuse = false, + bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128) : mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock} , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize, attentionLayerNumPerPP} @@ -78,6 +80,7 @@ public: , mAttentionConfig(attentionType, kvFactor) { mEnableBlockReuse = enableBlockReuse; + mEnablePartialReuse = enablePartialReuse; mHasIndexerKCache = hasIndexerKCache; mIndexerDimPerHead = indexerDimPerHead; mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize; @@ -87,8 +90,8 @@ public: SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism, std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false, - int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool hasIndexerKCache = false, - SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128) + int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool enablePartialReuse = false, + bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128) : mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock} , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize, attentionLayerNumPerPP} @@ -96,6 +99,7 @@ public: , mAttentionConfig(attentionType, kvFactor) { mEnableBlockReuse = enableBlockReuse; + mEnablePartialReuse = enablePartialReuse; mHasIndexerKCache = hasIndexerKCache; mIndexerDimPerHead = indexerDimPerHead; mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize; @@ -186,6 +190,11 @@ public: return mEnableBlockReuse; } + [[nodiscard]] bool getEnablePartialReuse() const + { + return mEnablePartialReuse; + } + [[nodiscard]] bool getHasIndexerKCache() const { return mHasIndexerKCache; @@ -221,6 +230,7 @@ public: sstring << "dpRank:" << mParallelConfig.mDPrank << "\n"; sstring << "dpSize:" << mParallelConfig.mDPsize << "\n"; sstring << "enableBlockReuse:" << mEnableBlockReuse << "\n"; + sstring << "enablePartialReuse:" << mEnablePartialReuse << "\n"; sstring << "hasIndexerKCache:" << mHasIndexerKCache << "\n"; sstring << "indexerDimPerHead:" << mIndexerDimPerHead << "\n"; sstring << "indexerKCacheQuantBlockSize:" << mIndexerKCacheQuantBlockSize << "\n"; @@ -234,6 +244,7 @@ private: nvinfer1::DataType mDataType; AttentionConfig mAttentionConfig; bool mEnableBlockReuse{false}; + bool mEnablePartialReuse{false}; bool mHasIndexerKCache{false}; SizeType32 mIndexerDimPerHead{0}; SizeType32 mIndexerKCacheQuantBlockSize{128}; diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 96eec0fd04..77e0668ab4 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -50,7 +50,8 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest // Note: When recv side has CP, the requested seqLen is lesser than seqLen on the sender side as seqLen is // distributed among CP ranks. So, we transfer all blocks from send side. - if (poolNum > 1 || !cacheManager->isEnableBlockReuse() || lastBlockKey.uniqueTokens.size() == 0 || recvSideHasCP) + if (poolNum > 1 || !cacheManager->isEnableBlockReuse() || !cacheManager->isEnablePartialReuse() + || lastBlockKey.uniqueTokens.size() == 0 || recvSideHasCP) { // disable reuse path, and vwsa don't support reuse. bool needSendAllForWindow = common::getEnvKVCacheTransferAllBlocksForWindow(); @@ -87,13 +88,13 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest return BlockRange::fromReuseTree(*cacheManager, lastBlockKey, indexFromEnd); } -BlockRange getBlockRangeForReceiving( - BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse, bool recvSideHasCP) +BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, + bool srcEnableBlockReuse, bool srcEnablePartialReuse, bool recvSideHasCP) { // Note: When recv side has CP, we request all blocks from send side right now. auto poolNum = cacheManager->getBlockManager().getNumPools( /*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false); - if (poolNum == 1 && srcEnableBlockReuse && !recvSideHasCP) + if (poolNum == 1 && srcEnableBlockReuse && srcEnablePartialReuse && !recvSideHasCP) { // Build from all block ids, then slice off the reused blocks so we only transfer newly allocated ones. auto windowSize = cacheManager->getBlockManager().getWindowSizesMetadata().begin()->first; @@ -555,7 +556,8 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess auto const& destConfig = session.getOtherState().getCacheState().value(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); auto& bufferManager = session.getBufferManager(); - auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse()); + auto blockRange = getBlockRangeForReceiving( + mCacheManager, llmRequest, destConfig.getEnableBlockReuse(), destConfig.getEnablePartialReuse()); auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig); diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index 3d04db26ed..4ddad832d7 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -53,7 +53,7 @@ using CacheTransBufferManager = kv_cache_manager::CacheTransBufferManager; using BlockRange = kv_cache_manager::BlockRange; BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, - bool srcEnableBlockReuse, bool recvSideHasCP = false); + bool srcEnableBlockReuse, bool srcEnablePartialReuse, bool recvSideHasCP = false); // Used to support the cache transmission with different layouts and different protocols. class BaseCacheFormatter diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 63a7ab2a38..e12cefca14 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -143,8 +143,8 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa } mCacheState = std::make_unique(cacheStateModelCfg, worldConfig, attentionLayerNumPerPP, dataType, attentionType, kvFactor, cacheManager->isEnableBlockReuse(), - cacheManager->isEnableIndexerKCache(), cacheManager->getIndexerKCacheIndexHeadDim(), - cacheManager->getIndexerKCacheQuantBlockSize()); + cacheManager->isEnablePartialReuse(), cacheManager->isEnableIndexerKCache(), + cacheManager->getIndexerKCacheIndexHeadDim(), cacheManager->getIndexerKCacheQuantBlockSize()); if (mCacheState->getParallelConfig().mEnableAttentionDP) { diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 06e28eab97..472ee63eb0 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -825,8 +825,8 @@ public: { auto* cacheManager = mFormatter->getCacheManager(); auto beam = 0; - auto requestedBlockRange - = getBlockRangeForReceiving(cacheManager, llmRequest, destCacheState.getEnableBlockReuse()); + auto requestedBlockRange = getBlockRangeForReceiving( + cacheManager, llmRequest, destCacheState.getEnableBlockReuse(), destCacheState.getEnablePartialReuse()); auto const& uniqueTokens = llmRequest.getUniqueTokens(beam); auto lastBlockKey diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index b2a60a3eda..95858cbd60 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -357,8 +357,8 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s auto& bufferManager = session.getBufferManager(); auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig); bool const recvSideHasCP = selfConfig.getParallelConfig().mContextParallelism > 1; - auto blockRange - = getBlockRangeForReceiving(mCacheManager, llmRequest, destConfig.getEnableBlockReuse(), recvSideHasCP); + auto blockRange = getBlockRangeForReceiving( + mCacheManager, llmRequest, destConfig.getEnableBlockReuse(), destConfig.getEnablePartialReuse(), recvSideHasCP); auto const numPools = mCacheManager->getBlockManager().getNumPools( /*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false); auto const& windowSizes = blockRange.getWindowSizes(); diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 79d015d585..10f238fa75 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -544,12 +544,13 @@ kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is) auto attentionType = su::deserialize(is); auto kvFactor = su::deserialize(is); auto enableBlockReuse = su::deserialize(is); + auto enablePartialReuse = su::deserialize(is); auto hasIndexerKCache = su::deserialize(is); auto indexerDimPerHead = su::deserialize(is); auto indexerKCacheQuantBlockSize = su::deserialize(is); return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, contextParallelism, attentionLayerNumPerPP, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, - DPsize, enableBlockReuse, hasIndexerKCache, indexerDimPerHead, indexerKCacheQuantBlockSize}; + DPsize, enableBlockReuse, enablePartialReuse, hasIndexerKCache, indexerDimPerHead, indexerKCacheQuantBlockSize}; } void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os) @@ -568,6 +569,7 @@ void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& o su::serialize(state.mAttentionConfig.mAttentionType, os); su::serialize(state.mAttentionConfig.mKvFactor, os); su::serialize(state.mEnableBlockReuse, os); + su::serialize(state.mEnablePartialReuse, os); su::serialize(state.getHasIndexerKCache(), os); su::serialize(state.getIndexerDimPerHead(), os); su::serialize(state.getIndexerKCacheQuantBlockSize(), os); @@ -590,6 +592,7 @@ size_t Serialization::serializedSize(kv_cache::CacheState const& state) totalSize += su::serializedSize(state.mAttentionConfig.mAttentionType); totalSize += su::serializedSize(state.mAttentionConfig.mKvFactor); totalSize += su::serializedSize(state.mEnableBlockReuse); + totalSize += su::serializedSize(state.mEnablePartialReuse); totalSize += su::serializedSize(state.getHasIndexerKCache()); totalSize += su::serializedSize(state.getIndexerDimPerHead()); totalSize += su::serializedSize(state.getIndexerKCacheQuantBlockSize()); diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 8c9018bc2b..b2b7370094 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -167,6 +167,11 @@ public: NB_OVERRIDE_PURE(isEnableBlockReuse); } + bool isEnablePartialReuse() const override + { + NB_OVERRIDE_PURE(isEnablePartialReuse); + } + void rewindKVCache(tb::LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) override { NB_OVERRIDE_PURE(rewindKVCache, requestId, rewindLengths); diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index a2aefd90e6..23683f36c7 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -1296,13 +1296,14 @@ TEST(SerializeUtilsTest, CacheStateIndexerKCache) int dpRank = 0; int dpSize = 1; bool enableBlockReuse = true; + bool enablePartialReuse = true; bool hasIndexerKCache = true; texec::SizeType32 indexerDimPerHead = 96; texec::SizeType32 indexerKCacheQuantBlockSize = 128; CacheState state{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tp, pp, cp, attentionLayerNumPerPP, dataType, - attentionType, kvFactor, enableAttentionDP, dpRank, dpSize, enableBlockReuse, hasIndexerKCache, - indexerDimPerHead, indexerKCacheQuantBlockSize}; + attentionType, kvFactor, enableAttentionDP, dpRank, dpSize, enableBlockReuse, enablePartialReuse, + hasIndexerKCache, indexerDimPerHead, indexerKCacheQuantBlockSize}; std::ostringstream oss; texec::Serialization::serialize(state, oss); @@ -1320,6 +1321,7 @@ TEST(SerializeUtilsTest, CacheStateIndexerKCache) EXPECT_EQ(state.getAttentionConfig().mAttentionType, state2.getAttentionConfig().mAttentionType); EXPECT_EQ(state.getAttentionConfig().mKvFactor, state2.getAttentionConfig().mKvFactor); EXPECT_EQ(state.getEnableBlockReuse(), state2.getEnableBlockReuse()); + EXPECT_EQ(state.getEnablePartialReuse(), state2.getEnablePartialReuse()); EXPECT_EQ(state.getHasIndexerKCache(), state2.getHasIndexerKCache()); EXPECT_EQ(state.getIndexerDimPerHead(), state2.getIndexerDimPerHead()); EXPECT_EQ(state.getIndexerKCacheQuantBlockSize(), state2.getIndexerKCacheQuantBlockSize()); diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e8be80483f..c8267d5745 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -351,6 +351,9 @@ class PyExecutor: ResourceManagerType.KV_CACHE_MANAGER) self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0 self.enable_kv_cache_reuse = self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse + self.enable_partial_reuse_for_disagg = ( + self.enable_kv_cache_reuse + and self.kv_cache_manager.enable_partial_reuse) self.max_input_len = max_input_len # _executor_loop private data @@ -359,7 +362,7 @@ class PyExecutor: self.expected_num_active_requests = 0 self.async_transfer_manager = AsyncTransferManager( self.resource_manager, - should_store_blocks=self.enable_kv_cache_reuse + should_store_blocks=self.enable_partial_reuse_for_disagg and not self.kv_cache_manager.is_vswa) self.previous_batch: Optional[BatchState] = None self.has_previous_draft_tokens = False @@ -3003,7 +3006,7 @@ class PyExecutor: logger.debug( f"Request {request.py_request_id} has no avg_decoded_tokens_per_iter" ) - if self.enable_kv_cache_reuse and not self.kv_cache_manager.is_vswa: + if self.enable_partial_reuse_for_disagg and not self.kv_cache_manager.is_vswa: requests_to_terminate.append(request) else: if not request.is_disagg_context_transmission_state: diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index f3cb048140..2b973af2e6 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -471,6 +471,7 @@ class KVCacheManager(BaseResourceManager): self.num_pools = self.impl.num_pools self.max_blocks_per_seq = self.impl.max_blocks_per_seq self.enable_block_reuse = kv_cache_config.enable_block_reuse + self.enable_partial_reuse = kv_cache_config.enable_partial_reuse self.host_kv_cache_block_offsets = torch.empty(self.num_pools, max_batch_size * max_beam_width, @@ -1711,6 +1712,7 @@ class KVCacheManagerV2(BaseResourceManager): self.max_seq_len = max_num_tokens self.enable_block_reuse = kv_cache_config.enable_block_reuse + self.enable_partial_reuse = kv_cache_config.enable_partial_reuse # Plus 1 for cuda graph dummy request self.index_mapper = IndexMapper(max_batch_size + 1, max_beam_width) diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 4050f8a26e..d7be3c6272 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -1273,20 +1273,27 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): @skip_pre_hopper @pytest.mark.skip_less_device(2) @pytest.mark.parametrize("overlap_scheduler", [False, True]) - def test_auto_dtype(self, overlap_scheduler): + @pytest.mark.parametrize("enable_partial_reuse", [True, False]) + def test_auto_dtype(self, overlap_scheduler, enable_partial_reuse): + kv_cache_config = { + "enable_block_reuse": True, + "enable_partial_reuse": enable_partial_reuse, + } ctx_server_config = { "disable_overlap_scheduler": True, "cuda_graph_config": None, "cache_transceiver_config": { "backend": "DEFAULT" - } + }, + "kv_cache_config": kv_cache_config, } gen_server_config = { "disable_overlap_scheduler": overlap_scheduler, "cuda_graph_config": None, "cache_transceiver_config": { "backend": "DEFAULT" - } + }, + "kv_cache_config": kv_cache_config, } disaggregated_server_config = { "hostname": "localhost", diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 815b0276c9..088ba536c7 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -332,8 +332,9 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_inst accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True] -accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] -accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False-True] +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True-True] +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False-False] accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend accuracy/test_disaggregated_serving.py::TestKimiK2::test_nvfp4 diff --git a/tests/integration/test_lists/qa/llm_function_rtx6k.txt b/tests/integration/test_lists/qa/llm_function_rtx6k.txt index 06e1ee2941..a7c9501e8b 100644 --- a/tests/integration/test_lists/qa/llm_function_rtx6k.txt +++ b/tests/integration/test_lists/qa/llm_function_rtx6k.txt @@ -207,8 +207,8 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_inst accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True] -accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] -accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False-False] +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False-True] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-overlap_scheduler] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[cutlass-one_model-no_overlap_scheduler] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index b791457fbb..c308bad8d7 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -25,8 +25,9 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=True] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar-eagle3_one_model=False] - - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] - - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False-True] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True-True] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False-False] - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend