From 0306c0f12c4d8e3f19523e15d41e03333d0b6078 Mon Sep 17 00:00:00 2001 From: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:29:02 +0800 Subject: [PATCH] [TRTLLM-9766][feat] Integration of the KVCacheManager V2 to TRTLLM Runtime (#10659) Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com> --- .gitignore | 2 + .../batch_manager/kvCacheManagerV2Utils.cpp | 76 ++ .../batch_manager/kvCacheManagerV2Utils.cu | 123 +++ .../batch_manager/kvCacheManagerV2Utils.h | 50 ++ .../batch_manager/kvCacheManagerV2Utils.cpp | 41 + .../nanobind/executor/bindings.cpp | 18 +- examples/llm-api/quickstart_advanced.py | 7 + .../_torch/attention_backend/interface.py | 4 +- tensorrt_llm/_torch/pyexecutor/_util.py | 34 +- .../_torch/pyexecutor/cuda_graph_runner.py | 10 +- .../_torch/pyexecutor/model_engine.py | 52 +- .../_torch/pyexecutor/resource_manager.py | 767 +++++++++++++++++- tensorrt_llm/_torch/pyexecutor/scheduler.py | 62 ++ tensorrt_llm/_utils.py | 13 +- tensorrt_llm/llmapi/llm_args.py | 20 + .../runtime/kv_cache_manager_v2/__init__.pyi | 2 +- .../_core/_kv_cache_manager.py | 15 +- tests/integration/defs/.test_durations | 51 +- .../defs/accuracy/test_llm_api_pytorch.py | 34 +- .../test_lists/qa/llm_function_core.txt | 40 +- .../qa/llm_function_core_sanity.txt | 40 +- .../test_lists/qa/llm_function_rtx6k.txt | 10 +- .../test_lists/test-db/l0_b200.yml | 10 +- .../test_lists/test-db/l0_dgx_b200.yml | 16 +- .../test_lists/test-db/l0_dgx_b300.yml | 11 +- .../test_lists/test-db/l0_dgx_h100.yml | 16 +- .../test-db/l0_gb200_multi_gpus.yml | 18 +- .../test_lists/test-db/l0_gb300.yml | 3 +- .../test_lists/test-db/l0_rtx_pro_6000.yml | 2 +- tests/integration/test_lists/waives.txt | 18 +- .../_torch/attention/test_attention.py | 51 +- .../_torch/attention/test_attention_mla.py | 38 +- 32 files changed, 1467 insertions(+), 187 deletions(-) diff --git a/.gitignore b/.gitignore index d7c360cce6..1f1d89079d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ __pycache__/ *.cache *.nsys-rep *.npy +*.so +*.whl .VSCodeCounter cpp/build* cpp/Release diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp index fb369b0f0f..079de1a188 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp @@ -17,6 +17,7 @@ #include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h" #include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/memoryUtils.h" #include #include #include @@ -25,6 +26,9 @@ #include #include +namespace tc = tensorrt_llm::common; +using namespace tensorrt_llm::runtime; + namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 { @@ -160,4 +164,76 @@ CUresult copyHostToHost(std::vector> tasks, ssize_t return cuLaunchHostFunc(stream, hostFnHostToHostCopy, data.release()); } +SizeType32 IndexMapper::addNewSequence(LlmRequest::RequestIdType requestId) +{ + TLLM_CHECK(indexMap_.find(requestId) == indexMap_.end()); + auto iter = freeIndices_.begin(); + TLLM_CHECK_WITH_INFO(iter != freeIndices_.end(), "No free index found"); + auto index = *iter; + freeIndices_.erase(iter); + indexMap_[requestId] = index; + return index; +} + +SizeType32 IndexMapper::getIndex(LlmRequest::RequestIdType requestId) +{ + auto iter = indexMap_.find(requestId); + TLLM_CHECK_WITH_INFO(iter != indexMap_.end(), "Request ID not found in IndexMapper"); + return iter->second; +} + +void IndexMapper::removeSequence(LlmRequest::RequestIdType requestId) +{ + auto iter = indexMap_.find(requestId); + TLLM_CHECK(iter != indexMap_.end()); + auto index = iter->second; + freeIndices_.insert(index); + indexMap_.erase(iter); +} + +at::Tensor IndexMapper::getCopyIndex( + std::vector const& requestIds, SizeType32 numContext, SizeType32 beamWidth) +{ + int numSeqs = numContext + beamWidth * (requestIds.size() - numContext); + SizeType32 batchSize = static_cast(requestIds.size()); + SizeType32 idx = 0; + for (SizeType32 i = 0; i < batchSize; i++) + { + if (i < numContext) + { + copyIndex_[idx++] = this->getIndex(requestIds[i]) * maxBeamWidth_; + } + else + { + for (SizeType32 j = 0; j < beamWidth; j++) + { + copyIndex_[idx++] = this->getIndex(requestIds[i]) * maxBeamWidth_ + j; + } + } + } + + TLLM_CHECK_WITH_INFO(idx == numSeqs, "Index mapper failed to generate copy index"); + + return copyIndex_.slice(0, 0, numSeqs); +} + +IndexMapper::IndexMapper(SizeType32 maxBatchSize, SizeType32 maxBeamWidth) + : maxBeamWidth_(maxBeamWidth) +{ + indexMap_.reserve(maxBatchSize); + for (SizeType32 i = 0; i < maxBatchSize; i++) + { + freeIndices_.insert(i); + } + // Allocate copyIndex_ memory as pinned (page-locked) host memory + copyIndex_ + = at::empty({maxBatchSize * maxBeamWidth}, at::TensorOptions().dtype(at::ScalarType::Int).pinned_memory(true)); +} + +IndexMapper::~IndexMapper() +{ + indexMap_.clear(); + freeIndices_.clear(); +} + } // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu index 4a134e5d08..2c7fb9c9ff 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu @@ -18,15 +18,20 @@ #include "kvCacheManagerV2Utils.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/common/memoryUtils.h" #include #include #include #include +#include namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 { using Grain = uint4; constexpr uint32_t ctaSize = 128; +constexpr uint32_t copyBlockCtaSize = 128; +constexpr uint32_t copyBlocknbBufs = 2; constexpr uint32_t nbBufs = 4; constexpr uint32_t grainBytes = sizeof(Grain); @@ -179,4 +184,122 @@ CUresult copyDeviceToDevice(std::vector const& tasks, ssize_t numBytes, return launchBatchedCopy(false, tasks, numBytes, stream); } +// dst_tensor[:, :num_seqs, 0] = src_tensor[:, copy_idx] +// dst_tensor[:, :num_seqs, 1] = dst_tensor[:, :num_seqs, 0] + 1 +template +__global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict__ srcPtr, + SizeType32* __restrict__ dstPtr, SizeType32 const srcMaxNumSequences, SizeType32 const dstMaxNumSequences, + SizeType32 numBlocksPerSeq, SizeType32 const* __restrict__ copyIndex) +{ + constexpr uint32_t kvFactor = 2; + constexpr auto elemPerAccess = sizeof(PackedInt) / sizeof(SizeType32); + + __shared__ PackedInt data[copyBlocknbBufs][copyBlockCtaSize]; + + auto const iterPerSeq = divUp(numBlocksPerSeq * sizeof(SizeType32), sizeof(PackedInt) * copyBlockCtaSize); + auto const tid = threadIdx.x; + auto const poolIdx = blockIdx.x; + auto const seqIdx = blockIdx.y; + auto const seqDimStride = kvFactor * numBlocksPerSeq; + uint32_t const srcIdxBeg = tid * elemPerAccess + (poolIdx * srcMaxNumSequences + copyIndex[seqIdx]) * seqDimStride; + uint32_t const dstIdxKBeg = tid * elemPerAccess + (poolIdx * dstMaxNumSequences + seqIdx) * seqDimStride; + uint32_t const dstIdxVBeg = dstIdxKBeg + numBlocksPerSeq; + + uint32_t const srcIdxEnd = (poolIdx * srcMaxNumSequences + copyIndex[seqIdx]) * seqDimStride + numBlocksPerSeq; + + for (uint32_t i = 0; i < iterPerSeq + copyBlocknbBufs; i++) + { + uint32_t const idxBuf = i % copyBlocknbBufs; + if (i >= copyBlocknbBufs) + { + uint32_t const stIter = i - copyBlocknbBufs; + assert(idxBuf == (stIter % copyBlocknbBufs)); + auto const offset = copyBlockCtaSize * stIter * elemPerAccess; + SizeType32 const srcIdx = srcIdxBeg + offset; + SizeType32 const dstIdxK = dstIdxKBeg + offset; + SizeType32 const dstIdxV = dstIdxVBeg + offset; + PackedInt const& src = data[idxBuf][tid]; + PackedInt& dstK = *reinterpret_cast(dstPtr + dstIdxK); + PackedInt& dstV = *reinterpret_cast(dstPtr + dstIdxV); + asm volatile("cp.async.wait_group %0;\n" ::"n"(copyBlocknbBufs - 1) : "memory"); + if (srcIdx < srcIdxEnd) + { + dstK = src; + if (COPY_V_IDX) + { + dstV = src; + } + else + { +#pragma unroll + for (uint32_t j = 0; j < elemPerAccess; j++) + { + auto const val = src.unpacked[j]; + dstV.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (val + 1); + } + } + } + } + uint32_t const ldIter = i; + PackedInt* const dst = &data[idxBuf][tid]; + uint32_t const srcIdx = srcIdxBeg + copyBlockCtaSize * ldIter * elemPerAccess; + PackedInt const* const src = reinterpret_cast(srcPtr + srcIdx); + if (srcIdx < srcIdxEnd) + { + uint32_t const size = sizeof(PackedInt); + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)), + "l"(src), "n"(size), "r"(size) + : "memory"); + } + asm volatile("cp.async.commit_group;\n" : : : "memory"); + } +} + +// Host-side launcher +void copyBatchBlockOffsetsToDevice( + ITensor const& input, ITensor& output, ITensor const& copyIndex, bool copyVIdx, CUstream stream) noexcept +{ + using namespace tensorrt_llm::runtime; + + auto const* srcPtr = bufferCast(input); + auto* dstPtr = bufferCast( + output); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq] + auto const* copyIndexPtr = bufferCast(copyIndex); + auto const& srcShape = input.getShape(); + auto const& dstShape = output.getShape(); + auto const& copyIndexShape = copyIndex.getShape(); + + TLLM_CHECK(srcShape.nbDims == 4); // [numPools, srcMaxNumSequences, kvFactor, numBlocksPerSeq] + TLLM_CHECK(dstShape.nbDims == 4); // [numPools, dstMaxNumSequences, kvFactor, numBlocksPerSeq] + + SizeType32 numPools = srcShape.d[0]; + SizeType32 srcMaxNumSequences = srcShape.d[1]; + SizeType32 dstMaxNumSequences = dstShape.d[1]; + SizeType32 numBlocksPerSeq = srcShape.d[3]; + SizeType32 numSeqs = copyIndexShape.d[0]; + + if (numSeqs == 0) + { + return; + } + + TLLM_CHECK_WITH_INFO((numBlocksPerSeq * sizeof(SizeType32)) % sizeof(PackedInt) == 0, + "Not implemented case: numBlocksPerSeq * sizeof(SizeType32) = %zu must be a multiple of %zu.", + static_cast(numBlocksPerSeq * sizeof(SizeType32)), static_cast(sizeof(PackedInt))); + + dim3 gridDim(numPools, numSeqs, 1); + dim3 blockDim(copyBlockCtaSize); + + if (copyVIdx) + { + copyBatchBlockOffsetsToDeviceKernel<<>>( + srcPtr, dstPtr, srcMaxNumSequences, dstMaxNumSequences, numBlocksPerSeq, copyIndexPtr); + } + else + { + copyBatchBlockOffsetsToDeviceKernel<<>>( + srcPtr, dstPtr, srcMaxNumSequences, dstMaxNumSequences, numBlocksPerSeq, copyIndexPtr); + } +} + } // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h index 2acb81e522..7ff742073f 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h @@ -17,10 +17,21 @@ #pragma once +#include "tensorrt_llm/batch_manager/llmRequest.h" +#include "tensorrt_llm/kernels/kvCacheIndex.h" +#include "tensorrt_llm/runtime/iBuffer.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include #include #include +#include +#include #include +namespace tk = tensorrt_llm::kernels; +using SizeType32 = tensorrt_llm::runtime::SizeType32; +using ITensor = tensorrt_llm::runtime::ITensor; + namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 { struct DiskAddress @@ -31,6 +42,9 @@ struct DiskAddress using MemAddress = std::uintptr_t; +// Please make sure to align with the definition in tensorrt_llm/runtime/kv_cache_manager_v2/_common.py +constexpr tk::KVCacheIndex::UnderlyingType BAD_PAGE_INDEX = -1; + template struct Task { @@ -38,6 +52,38 @@ struct Task SrcAddr src; }; +using PackedInt = union +{ + int4 packed; + tk::KVCacheIndex::UnderlyingType unpacked[4]; +}; + +class IndexMapper +{ +public: + IndexMapper(SizeType32 maxBatchSize, SizeType32 maxBeamWidth); + + ~IndexMapper(); + + IndexMapper(IndexMapper const&) = delete; + IndexMapper& operator=(IndexMapper const&) = delete; + + SizeType32 addNewSequence(LlmRequest::RequestIdType requestId); + + SizeType32 getIndex(LlmRequest::RequestIdType requestId); + + void removeSequence(LlmRequest::RequestIdType requestId); + + at::Tensor getCopyIndex( + std::vector const& requestIds, SizeType32 numContext, SizeType32 beamWidth); + +private: + std::unordered_map indexMap_; + std::set freeIndices_; + SizeType32 maxBeamWidth_; + at::Tensor copyIndex_; +}; + CUresult copyDiskToDisk(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept; CUresult copyDiskToHost(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept; CUresult copyHostToDisk(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept; @@ -48,4 +94,8 @@ CUresult copyDeviceToHost( std::vector> const& tasks, ssize_t numBytes, CUstream stream) noexcept; CUresult copyDeviceToDevice( std::vector> const& tasks, ssize_t numBytes, CUstream stream) noexcept; + +void copyBatchBlockOffsetsToDevice( + ITensor const& input, ITensor& output, ITensor const& copyIndex, bool copyVIdx, CUstream stream) noexcept; + } // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.cpp index 0985d0299e..d1376476c2 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.cpp @@ -17,14 +17,32 @@ #include "kvCacheManagerV2Utils.h" #include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h" +#include "tensorrt_llm/nanobind/common/customCasters.h" +#include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/torchView.h" +#include #include +#include #include +#include +namespace tr = tensorrt_llm::runtime; namespace nb = nanobind; +using SizeType32 = tensorrt_llm::runtime::SizeType32; + namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 { +std::optional from_torch(std::optional torchPtr) +{ + if (torchPtr) + { + return tr::TorchView::of(torchPtr.value()); + } + return std::nullopt; +} + void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module) { // Bind DiskAddress struct @@ -54,6 +72,13 @@ void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module) .def_rw("dst", &Task::dst) .def_rw("src", &Task::src); + nb::class_(module, "IndexMapper") + .def(nb::init(), nb::arg("max_batch_size"), nb::arg("max_beam_width")) + .def("add_new_sequence", &IndexMapper::addNewSequence) + .def("get_index", &IndexMapper::getIndex) + .def("remove_sequence", &IndexMapper::removeSequence) + .def("get_copy_index", &IndexMapper::getCopyIndex); + // Bind copy functions module.def( "copy_disk_to_disk", @@ -103,6 +128,22 @@ void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module) { return copyDeviceToDevice(tasks, numBytes, reinterpret_cast(stream)); }, nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard(), "Copy data from device to device using CUDA kernels"); + + module.def( + "copy_batch_block_offsets_to_device", + [](at::Tensor input, at::Tensor output, at::Tensor copyIndex, bool copyVIdx, uintptr_t stream) + { + auto _input = from_torch(input); + auto _output = from_torch(output); + auto _copyIndex = from_torch(copyIndex); + TLLM_CHECK_WITH_INFO(_input.has_value(), "Invalid input tensor."); + TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor."); + TLLM_CHECK_WITH_INFO(_copyIndex.has_value(), "Invalid copy index tensor."); + copyBatchBlockOffsetsToDevice(*(_input.value()), *(_output.value()), *(_copyIndex.value()), copyVIdx, + reinterpret_cast(stream)); + }, + nb::arg("input"), nb::arg("output"), nb::arg("copy_index"), nb::arg("copy_v_idx"), nb::arg("stream"), + nb::call_guard(), "Copy batch block indices to device"); } } // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp index 388af63cac..f8e69fa1ad 100644 --- a/cpp/tensorrt_llm/nanobind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/bindings.cpp @@ -63,15 +63,15 @@ void initBindings(nb::module_& m) new (&self) tle::DecodingMode(nb::cast(state[0])); }; nb::class_(m, "DecodingMode") - .def("Auto", &tle::DecodingMode::Auto) - .def("TopK", &tle::DecodingMode::TopK) - .def("TopP", &tle::DecodingMode::TopP) - .def("TopKTopP", &tle::DecodingMode::TopKTopP) - .def("BeamSearch", &tle::DecodingMode::BeamSearch) - .def("Medusa", &tle::DecodingMode::Medusa) - .def("Lookahead", &tle::DecodingMode::Lookahead) - .def("ExplicitDraftTokens", &tle::DecodingMode::ExplicitDraftTokens) - .def("Eagle", &tle::DecodingMode::Eagle) + .def_static("Auto", &tle::DecodingMode::Auto) + .def_static("TopK", &tle::DecodingMode::TopK) + .def_static("TopP", &tle::DecodingMode::TopP) + .def_static("TopKTopP", &tle::DecodingMode::TopKTopP) + .def_static("BeamSearch", &tle::DecodingMode::BeamSearch) + .def_static("Medusa", &tle::DecodingMode::Medusa) + .def_static("Lookahead", &tle::DecodingMode::Lookahead) + .def_static("ExplicitDraftTokens", &tle::DecodingMode::ExplicitDraftTokens) + .def_static("Eagle", &tle::DecodingMode::Eagle) .def("isAuto", &tle::DecodingMode::isAuto) .def("isTopK", &tle::DecodingMode::isTopK) .def("isTopP", &tle::DecodingMode::isTopP) diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index f31854782a..c1351c1754 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -109,6 +109,12 @@ def add_llm_args(parser): parser.add_argument('--log_kv_cache_events', default=False, action='store_true') + parser.add_argument( + '--use_kv_cache_manager_v2', + default=False, + action='store_true', + help='Use KVCacheManagerV2 for KV cache management (PyTorch backend).', + ) # Runtime parser.add_argument('--disable_overlap_scheduler', @@ -214,6 +220,7 @@ def setup_llm(args, **kwargs): free_gpu_memory_fraction=args.kv_cache_fraction, dtype=args.kv_cache_dtype, tokens_per_block=args.tokens_per_block, + use_kv_cache_manager_v2=args.use_kv_cache_manager_v2, mamba_ssm_cache_dtype=args.mamba_ssm_cache_dtype, event_buffer_max_size=1024 if args.log_kv_cache_events else 0) diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 71f7c8be67..cfa33def9d 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -21,7 +21,7 @@ from tensorrt_llm.models.modeling_utils import QuantConfig from ..memory_buffer_utils import Buffers from ..metadata import KVCacheParams -from ..pyexecutor.resource_manager import KVCacheManager +from ..pyexecutor.resource_manager import KVCacheManager, KVCacheManagerV2 from ..utils import get_model_extra_attrs try: @@ -63,7 +63,7 @@ class AttentionMetadata: # The max number of sequences in a single batch. max_num_sequences: Optional[int] = None # The KV cache manager. - kv_cache_manager: KVCacheManager + kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2] mapping: Optional[Mapping] = None enable_flash_mla: bool = False diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 898fd1575b..4fea1e0b4e 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -34,12 +34,14 @@ from .llm_request import ExecutorResponse from .mamba_cache_manager import MambaHybridCacheManager from .model_engine import PyTorchModelEngine from .py_executor import PyExecutor -from .resource_manager import (KVCacheManager, PeftCacheManager, - ResourceManager, ResourceManagerType) +from .resource_manager import (KVCacheManager, KVCacheManagerV2, + PeftCacheManager, ResourceManager, + ResourceManagerType) from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler, TRTLLMSampler) from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler, - SimpleScheduler, SimpleUnifiedScheduler) + KVCacheV2DummyScheduler, SimpleScheduler, + SimpleUnifiedScheduler) from .seq_slot_manager import SeqSlotManager GB = 1 << 30 @@ -99,6 +101,8 @@ class KvCacheCreator: self._kv_cache_manager_cls = get_kv_cache_manager_cls( model_engine.model.model_config) self._execution_stream = execution_stream + if self._kv_cache_manager_cls == KVCacheManager and kv_cache_config.use_kv_cache_manager_v2: + self._kv_cache_manager_cls = KVCacheManagerV2 def _get_kv_size_per_token(self): model_config = self._model_engine.model.model_config @@ -583,6 +587,7 @@ def _create_kv_cache_manager( mapping=mapping, dtype=kv_cache_dtype, spec_config=spec_config, + vocab_size=config.vocab_size, max_beam_width=max_beam_width, is_draft=model_engine.is_draft_model, kv_connector_manager=kv_connector_manager @@ -704,6 +709,7 @@ def _create_kv_cache_manager( mapping=mapping, dtype=kv_cache_dtype, spec_config=spec_config, + vocab_size=config.vocab_size, max_num_tokens=max_num_tokens, model_config=binding_model_config, max_beam_width=max_beam_width, @@ -855,7 +861,8 @@ def create_py_executor_instance( scheduler_capacity += 1 use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1" - if use_python_scheduler: + if use_python_scheduler and not isinstance(kv_cache_manager, + KVCacheManagerV2): scheduler = SimpleUnifiedScheduler( max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, @@ -868,12 +875,19 @@ def create_py_executor_instance( two_step_lookahead=mapping.has_pp(), scheduler_capacity=scheduler_capacity) else: - capacity_scheduler = BindCapacityScheduler( - scheduler_capacity, - kv_cache_manager.impl if kv_cache_manager is not None else None, - peft_cache_manager.impl if peft_cache_manager is not None else None, - scheduler_config.capacity_scheduler_policy, - two_step_lookahead=mapping.has_pp()) + if isinstance(kv_cache_manager, KVCacheManagerV2): + capacity_scheduler = KVCacheV2DummyScheduler( + scheduler_capacity, + kv_cache_manager if kv_cache_manager is not None else None) + else: + capacity_scheduler = BindCapacityScheduler( + scheduler_capacity, + kv_cache_manager.impl if kv_cache_manager is not None else None, + peft_cache_manager.impl + if peft_cache_manager is not None else None, + scheduler_config.capacity_scheduler_policy, + two_step_lookahead=mapping.has_pp()) + mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens, ctx_chunk_config) scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 30fdaff391..9b0b51fa31 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -434,17 +434,17 @@ class CUDAGraphRunner: # This is not strictly required, but we should probably # respect the requirement just in case that changes in the future. if self.padding_dummy_request is None: - available_blocks = kv_cache_manager.get_num_free_blocks() - # No padding if not enough KV cache space - if available_blocks < 1: - return 0 self.padding_dummy_request = kv_cache_manager.add_dummy_requests( [CUDA_GRAPH_DUMMY_REQUEST_ID], is_gen=True, max_num_draft_tokens=runtime_draft_len, use_mrope=self.config.use_mrope, - max_beam_width=self.config.max_beam_width)[0] + max_beam_width=self.config.max_beam_width) + if self.padding_dummy_request is None: + return 0 + else: + self.padding_dummy_request = self.padding_dummy_request[0] self.padding_dummy_request.is_cuda_graph_dummy = True spec_res_mgr = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 603ccc1e67..110b05b4f6 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -8,7 +8,7 @@ import os import weakref from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch._dynamo.config @@ -64,8 +64,8 @@ from .layerwise_nvtx_marker import LayerwiseNvtxMarker from .llm_request import LlmRequest, get_draft_token_length from .model_loader import ModelLoader, _construct_checkpoint_loader from .resource_manager import (BaseResourceManager, KVCacheManager, - PeftCacheManager, ResourceManager, - ResourceManagerType) + KVCacheManagerV2, PeftCacheManager, + ResourceManager, ResourceManagerType) from .sampler import SampleStateTensors from .scheduler import ScheduledRequests @@ -668,8 +668,8 @@ class PyTorchModelEngine(ModelEngine): self.kv_cache_manager_key) curr_max_num_tokens = min( kv_cache_manager.get_num_available_tokens( - self.original_max_draft_len), self.max_num_tokens, - self.batch_size * (self.max_seq_len - 1)) + max_num_draft_tokens=self.original_max_draft_len), + self.max_num_tokens, self.batch_size * (self.max_seq_len - 1)) max_batch_size = min( self.batch_size, curr_max_num_tokens // (1 + self.runtime_draft_len)) @@ -720,8 +720,8 @@ class PyTorchModelEngine(ModelEngine): self.kv_cache_manager_key) curr_max_num_tokens = min( kv_cache_manager.get_num_available_tokens( - self.original_max_draft_len), self.max_num_tokens, - self.batch_size * (self.max_seq_len - 1)) + max_num_draft_tokens=self.original_max_draft_len), + self.max_num_tokens, self.batch_size * (self.max_seq_len - 1)) cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None) with self.no_cuda_graph(), autotune(cache_path=cache_path): @@ -945,7 +945,7 @@ class PyTorchModelEngine(ModelEngine): ResourceManagerType.SPEC_RESOURCE_MANAGER) available_tokens = kv_cache_manager.get_num_available_tokens( - self.runtime_draft_len) + max_num_draft_tokens=self.runtime_draft_len) available_blocks = kv_cache_manager.get_num_free_blocks() if num_tokens > self.max_num_tokens or num_tokens > available_tokens: return None @@ -998,7 +998,8 @@ class PyTorchModelEngine(ModelEngine): num_left_over_tokens / kv_cache_manager.tokens_per_block) + num_gen_requests - if blocks_to_use > available_blocks: + if blocks_to_use > available_blocks and isinstance( + kv_cache_manager, KVCacheManager): return None if num_ctx_tokens > 0: @@ -1014,6 +1015,9 @@ class PyTorchModelEngine(ModelEngine): use_mrope=self.use_mrope, num_extra_decoding_steps=num_extra_decoding_steps) + if ctx_requests is None: + return None + if spec_resource_manager is not None: spec_resource_manager.add_dummy_requests( request_ids=list(range(num_ctx_requests))) @@ -1029,6 +1033,12 @@ class PyTorchModelEngine(ModelEngine): use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, num_extra_decoding_steps=num_extra_decoding_steps) + + if gen_requests is None: + for r in ctx_requests: + kv_cache_manager.free_resources(r) + return None + if spec_resource_manager is not None: spec_resource_manager.add_dummy_requests(request_ids=list( range(num_ctx_requests, num_ctx_requests + @@ -1069,7 +1079,11 @@ class PyTorchModelEngine(ModelEngine): max_beam_width=self.max_beam_width, num_extra_decoding_steps=num_extra_decoding_steps) - available_tokens = kv_cache_manager.get_num_available_tokens(draft_len) + if requests is None: + return None + + available_tokens = kv_cache_manager.get_num_available_tokens( + batch_size=batch_size, max_num_draft_tokens=draft_len) # Add one dummy request with the maximum possible sequence length. max_seq_len = min( @@ -1098,7 +1112,14 @@ class PyTorchModelEngine(ModelEngine): max_num_draft_tokens=draft_len, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, - num_extra_decoding_steps=num_extra_decoding_steps)[0] + num_extra_decoding_steps=num_extra_decoding_steps) + + if max_seq_len_request is None: + for r in requests: + kv_cache_manager.free_resources(r) + return None + else: + max_seq_len_request = max_seq_len_request[0] # Insert the longest request first to simulate padding for the CUDA graph. requests.insert(0, max_seq_len_request) @@ -1122,7 +1143,8 @@ class PyTorchModelEngine(ModelEngine): req.py_is_first_draft = True req.py_draft_tokens = [] - def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager): + def _set_up_attn_metadata(self, kv_cache_manager: Union[KVCacheManager, + KVCacheManagerV2]): enable_context_mla_with_cached_kv = is_mla( self.model.model_config.pretrained_config) and ( self.attn_runtime_features.cache_reuse @@ -1529,7 +1551,7 @@ class PyTorchModelEngine(ModelEngine): def _apply_incremental_update( self, scheduled_requests: ScheduledRequests, - kv_cache_manager: KVCacheManager, + kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2], attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, new_tensors_device: Optional[SampleStateTensors] = None, @@ -1961,7 +1983,7 @@ class PyTorchModelEngine(ModelEngine): def _prepare_tp_inputs( self, scheduled_requests: ScheduledRequests, - kv_cache_manager: KVCacheManager, + kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2], attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, new_tensors_device: Optional[SampleStateTensors] = None, @@ -3306,7 +3328,7 @@ class PyTorchModelEngine(ModelEngine): def _prepare_inputs( self, scheduled_requests: ScheduledRequests, - kv_cache_manager: KVCacheManager, + kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2], attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, new_tensors_device: Optional[SampleStateTensors] = None, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index a55e7d3402..c915ea3f4a 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -3,20 +3,40 @@ import enum import math from abc import ABC, abstractmethod from collections import OrderedDict, defaultdict, deque -from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, - Union) +from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, + Set, Tuple, Union) +import numpy as np import torch import tensorrt_llm import tensorrt_llm.bindings from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp +from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor, + get_size_in_bytes) +from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( + IndexMapper, copy_batch_block_offsets_to_device) from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig, PybindMirror) from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig +from tensorrt_llm.math_utils import ceil_div from tensorrt_llm.runtime import ModelConfig as ModelConfigPython +from tensorrt_llm.runtime.kv_cache_manager_v2 import (AttentionLayerConfig, + BufferConfig, + GpuCacheTierConfig, + HostCacheTierConfig) +from tensorrt_llm.runtime.kv_cache_manager_v2 import \ + KVCacheManager as KVCacheManagerPy +from tensorrt_llm.runtime.kv_cache_manager_v2 import \ + KVCacheManagerConfig as KVCacheManagerConfigPy +from tensorrt_llm.runtime.kv_cache_manager_v2 import (LayerId, TokenIdExt, + _KVCache) +from tensorrt_llm.runtime.kv_cache_manager_v2._common import GPU_LEVEL +from tensorrt_llm.runtime.kv_cache_manager_v2._config import DataRole +from tensorrt_llm.runtime.kv_cache_manager_v2._utils import (exact_div, + typed_range) from tensorrt_llm.sampling_params import SamplingParams from ..._utils import (binding_to_str_dtype, get_size_in_bytes, mpi_rank, @@ -56,6 +76,14 @@ class ResourceManagerType(enum.Enum): SPEC_RESOURCE_MANAGER = "SPEC_RESOURCE_MANAGER" +class Role: + KEY = DataRole("key") + VALUE = DataRole("value") + KEY_BLOCK_QUANT = DataRole("key_block_quant") + VALUE_BLOCK_QUANT = DataRole("value_block_quant") + ALL = DataRole("all") + + def compute_page_count(token_count: int, tokens_per_page: int) -> int: return (token_count + tokens_per_page) // tokens_per_page @@ -523,6 +551,11 @@ class KVCacheManager(BaseResourceManager): # occur. num_extra_decoding_steps: int = 0, ): + available_blocks = self.get_num_free_blocks() + # No padding if not enough KV cache space + if available_blocks < 1: + return None + beam_width = max_beam_width requests = [] for i, req_id in enumerate(request_ids): @@ -870,12 +903,16 @@ class KVCacheManager(BaseResourceManager): def get_batch_cache_indices( self, request_ids: List[int], - window_size: Optional[int] = None, + layer_idx: Optional[int] = None, ) -> List[List[int]]: - if window_size is None: + if layer_idx is None: if len(self.max_attention_window_vec) > 1: - raise ValueError("window_size must be provided for VSWA") + raise ValueError("layer_idx must be provided for VSWA") window_size = self.max_attention_window_vec[0] + else: + layer_offset = self.layer_offsets[layer_idx] + window_size = self.max_attention_window_vec[layer_offset % len( + self.max_attention_window_vec)] result = self.impl.get_batch_cache_block_ids(request_ids, window_size) for i in range(len(result)): @@ -896,7 +933,9 @@ class KVCacheManager(BaseResourceManager): def get_num_kv_blocks(self, num_tokens: int) -> int: return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block - def get_num_available_tokens(self, max_num_draft_tokens: int = 0) -> int: + def get_num_available_tokens(self, + max_num_draft_tokens: int = 0, + **kwargs) -> int: return (self.get_num_free_blocks() * self.tokens_per_block - self.num_extra_kv_tokens - max_num_draft_tokens) @@ -1326,6 +1365,722 @@ class KVCacheManager(BaseResourceManager): self.impl.reset_reuse_state() +class KVCacheManagerV2(BaseResourceManager): + + def __init__( + self, + kv_cache_config: KvCacheConfig, + kv_cache_type: CacheTypeCpp, + *, + num_layers: int, + num_kv_heads: Union[int, List[Optional[int]]], + head_dim: int, + tokens_per_block: int, + # Note that max_seq_len is not necessarily equal to kv_cache_config.num_tokens. + # It's derived from the model's BuildConfig for consistency with the C++ backend. + max_seq_len: int, + max_batch_size: int, + mapping: Mapping, + dtype: DataType = DataType.HALF, + spec_config=None, + layer_mask: Optional[List[bool]] = None, + vocab_size: int = None, + max_num_tokens: int = 8192, + model_config: Optional[ModelConfigCpp] = None, + max_beam_width: int = 1, + is_draft: bool = False, + kv_connector_manager: Optional[KvCacheConnectorManager] = None, + **kwargs, + ) -> None: + self.mapping = mapping + self.dtype = dtype + + assert self.dtype != DataType.NVFP4, "NVFP4 is not supported for KVCacheManagerV2" + assert kv_connector_manager is None, "kv_connector_manager is not supported for KVCacheManagerV2" + assert max_beam_width == 1, "max_beam_width must be 1 for KVCacheManagerV2" + + self.kv_cache_type = kv_cache_type + self.pp_layers, self.num_layers = get_pp_layers( + num_layers, + mapping, + spec_config=spec_config, + layer_mask=layer_mask, + ) + self.is_draft = is_draft + self.num_local_layers = len(self.pp_layers) + self.layer_offsets = { + idx: offset + for offset, idx in enumerate(self.pp_layers) + } + self.max_beam_width = max_beam_width + + tp_size = mapping.tp_size + if mapping.enable_attention_dp: + tp_size = 1 + + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.tokens_per_block = tokens_per_block + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + self.kv_factor = 1 if kv_cache_type == CacheTypeCpp.SELFKONLY else 2 + from ..speculative import get_num_extra_kv_tokens + self.num_extra_kv_tokens = get_num_extra_kv_tokens(spec_config) + + self.event_buffer_max_size = kv_cache_config.event_buffer_max_size + + assert self.event_buffer_max_size == 0, "event_buffer_max_size must be 0" + + # Determine max_attention_window_vec + if kv_cache_config.max_attention_window is not None: + + self.max_attention_window_vec = kv_cache_config.max_attention_window.copy( + ) # Make a copy to avoid modifying original + # Clamp all window sizes to max_seq_len before calculating the + # number of KV cache blocks. This prevents the KV cache pool from + # being skewed by the largest window values. + self.max_attention_window_vec = [ + min(max_seq_len, w) for w in self.max_attention_window_vec + ] + + self.max_attention_window_vec = [ + None if w == max_seq_len else w + for w in self.max_attention_window_vec + ] + + else: + self.max_attention_window_vec = [None] + + if isinstance(num_kv_heads, int): + self.num_kv_heads_per_layer = [ + (num_kv_heads + tp_size - 1) // tp_size + for _ in range(self.num_local_layers) + ] + self.total_num_kv_heads_per_layer = [ + (num_kv_heads + tp_size - 1) // tp_size + for _ in range(self.num_layers) + ] + else: + assert len(num_kv_heads) == self.num_layers + + def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], + kv_head: Optional[int]): + if kv_head is not None: + num_kv_heads_per_layer.append( + (kv_head + tp_size - 1) // tp_size) + else: + num_kv_heads_per_layer.append(0) + + self.num_kv_heads_per_layer = [] + if self.num_local_layers > 0: + for i in self.pp_layers: + kv_head = num_kv_heads[i] + append_to_kv_heads_per_layer(self.num_kv_heads_per_layer, + kv_head) + + self.total_num_kv_heads_per_layer = [] + for i in range(self.num_layers): + kv_head = num_kv_heads[i] + append_to_kv_heads_per_layer(self.total_num_kv_heads_per_layer, + kv_head) + + self.is_vswa = len(set(self.max_attention_window_vec)) > 1 + + self.kv_connector_manager = kv_connector_manager + + quota = float('inf') + + if kv_cache_config.max_tokens is not None: + quota = int( + ceil_div( + kv_cache_config.max_tokens * + self.get_cache_bytes_per_token(), + kv_cache_config.max_util_for_resume)) + if kv_cache_config.free_gpu_memory_fraction is not None: + logger.warning( + f"Both max_tokens and free_gpu_memory_fraction are set to {kv_cache_config.max_tokens} and {kv_cache_config.free_gpu_memory_fraction}, the smaller value will be used." + ) + if kv_cache_config.max_gpu_total_bytes is not None and kv_cache_config.max_gpu_total_bytes > 0: + if quota > int(kv_cache_config.max_gpu_total_bytes): + logger.warning( + f"max_gpu_total_bytes {kv_cache_config.max_gpu_total_bytes / (1 << 30)}GiB is smaller than the calculated quota {quota / (1 << 30)}GiB, clamping quota to {kv_cache_config.max_gpu_total_bytes / (1 << 30)}GiB" + ) + quota = min(quota, int(kv_cache_config.max_gpu_total_bytes)) + + assert quota != float( + 'inf' + ), "Quota not set. Check kv_cache_config.max_tokens or kv_cache_config.max_gpu_total_bytes" + logger.info( + f"KV cache manager v2 device quota set to {quota / (1 << 30)}GiB") + + cache_tiers = [GpuCacheTierConfig(quota=quota)] + if kv_cache_config.host_cache_size is not None and kv_cache_config.host_cache_size > 0: + cache_tiers.append( + HostCacheTierConfig(quota=kv_cache_config.host_cache_size)) + logger.info( + f"KV cache manager v2 host cache quota set to {kv_cache_config.host_cache_size / (1 << 30)}GiB" + ) + + buffer_type = [Role.KEY] + if kv_cache_type != CacheTypeCpp.SELFKONLY: + buffer_type.append(Role.VALUE) + + config = KVCacheManagerConfigPy( + tokens_per_block=tokens_per_block, + vocab_size=vocab_size, + cache_tiers=cache_tiers, + max_util_for_resume=kv_cache_config.max_util_for_resume, + layers=[ + AttentionLayerConfig( + layer_id=layer_id, + buffers=[ + BufferConfig( + role=role, + size=self.get_cache_bytes_per_token( + local_layer_idx=layer_id, data_role=role) * + tokens_per_block, + ) for role in buffer_type + ], + sliding_window_size=self.max_attention_window_vec[ + layer_id % len(self.max_attention_window_vec)], + num_sink_tokens=None, + ) for layer_id in typed_range(LayerId(self.num_local_layers)) + ], + ) + + self.kv_cache_manager_py_config = config + + self.impl = KVCacheManagerPy(config) + + self.num_pools = len(self.impl.layer_grouping) + + self.layer_to_pool_mapping_dict: dict[int, int] = { + layer_id: self.impl.get_layer_group_id(layer_id) + for layer_id in typed_range(LayerId(self.num_local_layers)) + } + + self.kv_cache_pool_pointers = torch.tensor([[ + self.impl.get_mem_pool_base_address( + self.impl.layer_grouping[pool_id][0], Role.KEY), 0 + ] for pool_id in range(self.num_pools)], + dtype=torch.int64, + device="cpu", + pin_memory=True) + + kv_cache_pool_mapping_list = [] + for layer_id in typed_range(LayerId(self.num_local_layers)): + layer_group_id = self.impl.get_layer_group_id(layer_id) + offset = exact_div( + self.impl.get_mem_pool_base_address(layer_id, Role.KEY) - + int(self.kv_cache_pool_pointers[layer_group_id][0]), + self.get_cache_bytes_per_token(layer_id, Role.KEY) * + self.kv_factor * self.tokens_per_block) + kv_cache_pool_mapping_list.append([layer_group_id, offset]) + + self.kv_cache_pool_mapping = torch.tensor(kv_cache_pool_mapping_list, + dtype=torch.int32, + device="cpu", + pin_memory=True) + # Pad max_blocks_per_seq to next multiple of 4 for copy_block_offsets kernel + self.max_blocks_per_seq = (max_seq_len + tokens_per_block - + 1) // tokens_per_block + if self.max_blocks_per_seq % 4 != 0: + self.max_blocks_per_seq = ((self.max_blocks_per_seq + 3) // 4) * 4 + + self.kv_cache_map: dict[int, _KVCache] = {} + + max_num_tokens = self.get_num_available_tokens() + + if max_seq_len > max_num_tokens: + logger.warning( + f"max_seq_len {max_seq_len} is greater than max_num_tokens {max_num_tokens} that can be allocated in kv cache manager, setting max_seq_len to {max_num_tokens}" + ) + self.max_seq_len = max_num_tokens + + self.enable_block_reuse = kv_cache_config.enable_block_reuse + + # Plus 1 for cuda graph dummy request + self.index_mapper = IndexMapper(max_batch_size + 1, max_beam_width) + + self.host_kv_cache_block_offsets = torch.empty( + self.num_pools, + (max_batch_size + 1) * max_beam_width, + 2, # key and value + self.max_blocks_per_seq, + dtype=torch.int32, + pin_memory=True, + device='cpu') + + @property + def blocks_in_primary_pool(self) -> int: + """ + Get the number of blocks in the primary pool. + """ + return self.impl.get_page_index_upper_bound(0, Role.KEY) + + def get_buffers(self, + layer_idx: int, + kv_layout: str = "NHD") -> Optional[torch.Tensor]: + layer_offset = self.layer_offsets[layer_idx] + addr_key = self.impl.get_mem_pool_base_address(layer_offset, Role.KEY) + if self.kv_cache_type != CacheTypeCpp.SELFKONLY: + addr_value = self.impl.get_mem_pool_base_address( + layer_offset, Role.VALUE) + page_size_key = self.impl.get_page_stride(layer_offset, Role.KEY) + page_size_value = self.impl.get_page_stride(layer_offset, + Role.VALUE) + + assert addr_key + page_size_value == addr_value and page_size_key == page_size_value + + assert kv_layout in ["NHD", + "HND"], f"Unsupported kv_layout: {kv_layout}" + + if kv_layout == "NHD": + shape = [ + self.impl.get_page_index_upper_bound(layer_offset, Role.KEY) // + self.kv_factor, + self.kv_factor, + self.tokens_per_block, + self.num_kv_heads_per_layer[layer_offset], + self.head_dim, + ] + else: + shape = [ + self.impl.get_page_index_upper_bound(layer_offset, Role.KEY) // + self.kv_factor, + self.kv_factor, + self.num_kv_heads_per_layer[layer_offset], + self.tokens_per_block, + self.head_dim, + ] + + return convert_to_torch_tensor( + TensorWrapper( + addr_key, + self.dtype, + shape, + )) + + def get_num_available_tokens(self, + *, + batch_size: int = 1, + max_num_draft_tokens: int = 0) -> int: + if max_num_draft_tokens > 0: + raise ValueError( + "max_num_draft_tokens is not supported for KVCacheManagerV2") + return int( + self.impl.clamp_max_seq_len_for_mem(batch_size) * + self.kv_cache_manager_py_config.max_util_for_resume + ) - self.num_extra_kv_tokens - max_num_draft_tokens + + def get_num_free_blocks(self) -> int: + # NOTE This method is used to get the number of blocks in the primary pool not the FREE blocks. + # However, since we only use this function when the kv cache manager is empty, so it is safe to do so. + assert len( + self.kv_cache_map + ) == 0, "get_num_free_blocks is only used when the kv cache manager is empty" + max_num_pages = max([ + self.impl.get_page_index_upper_bound(layer_id, Role.KEY) + for layer_id in typed_range(LayerId(self.num_local_layers)) + ]) + return max_num_pages // self.kv_factor + + @nvtx_range("prepare_resources_kv_cache_manager_v2") + def prepare_resources(self, scheduled_batch: ScheduledRequests): + with request_context(self.is_draft, scheduled_batch): + context_batch = scheduled_batch.context_requests + generation_batch = scheduled_batch.generation_requests + # allocate KV Cache + for req in context_batch: + beam_width = req.sampling_config.beam_width + if 'cp_type' in self.mapping.cp_config and CpType.STAR == self.mapping.cp_config[ + 'cp_type']: + raise RuntimeError( + "Star attention is not supported for kv cache manager v2" + ) + else: + if req.is_first_context_chunk and self._kv_connector_should_add_sequence( + req): + # Last token cannot be recovered, so we don't include it in the input tokens to look up for the block that can be reused. + kv_cache = self._create_kv_cache( + req.py_request_id, req.lora_task_id, + req.get_tokens(0)[:-1] + if self.enable_block_reuse else None) + assert beam_width == 1, "Currently, KVCacheManagerV2 only supports beam width 1" + if not self.enable_block_reuse: + assert kv_cache.num_committed_tokens == 0 + kv_cache.stop_committing() + else: + req.context_current_position = kv_cache.num_committed_tokens + chunk_size = req.context_chunk_size + if req.context_current_position + req.context_chunk_size < req.prompt_len: + floored_end_position = ( + req.context_current_position + + req.context_chunk_size + ) // self.tokens_per_block * self.tokens_per_block + chunk_size = floored_end_position - req.context_current_position + + req.context_chunk_size = min( + chunk_size, + req.prompt_len - req.context_current_position) + + success = kv_cache.resume( + torch.cuda.current_stream().cuda_stream) + assert success + + kv_cache.resize(req.prompt_len) + + if self.kv_connector_manager is not None: + block_ids = self.get_cache_indices(req) + self.kv_connector_manager.update_state_after_alloc( + req, block_ids) + + for req in generation_batch: + kv_cache = self.kv_cache_map[req.py_request_id] + kv_cache.resize(kv_cache.capacity + 1) + + if self.kv_connector_manager is not None: + self.kv_connector_manager.build_scheduler_output( + scheduled_batch, self) + + def _kv_connector_should_add_sequence(self, request: LlmRequest) -> bool: + return self.kv_connector_manager is None or self.kv_connector_manager.should_add_sequence( + request) + + def get_kv_cache_stats(self): + + class KVCacheStatus: + + def __init__(self, allocated_bytes: int): + self.allocated_bytes = allocated_bytes + + return KVCacheStatus(allocated_bytes=self.impl.get_quota(GPU_LEVEL)) + + def add_dummy_requests( + self, + request_ids: List[int], + # Note that token_nums should be past_kv_len + input_len (without + # spec decoding). The draft tokens will be added in this function, + # so we don't need to take care of it in the caller. When preparing + # token_nums, we should not take the draft tokens into account, so + # don't use the kv_cache_manager.max_seq_len, which includes both + # extra tokens and draft tokens. + token_nums: Optional[List[int]] = None, + is_gen: bool = False, + prepare_resource: bool = True, + max_num_draft_tokens: int = 0, + use_mrope: bool = False, + max_beam_width: int = 1, + num_extra_decoding_steps: + int = 0, # TODO: support num_extra_decoding_steps + ): + + beam_width = max_beam_width + requests = [] + for i, req_id in enumerate(request_ids): + # exact choice of n can be ignored for dummy requests + sampling_params = SamplingParams(n=beam_width, + best_of=beam_width, + use_beam_search=beam_width > 1) + # Here 1+max_num_draft_tokens is used to extend the prompt length to + # a non-zero number to skip illegal memory access issue in MLA kernel + # during warmup. + token_num = token_nums[ + i] if token_nums is not None else 1 + max_num_draft_tokens + # TODO: support cross attention + encoder_input_tokens = None + # Using 1 instead of 0 prevents NaN during warmup in e.g. Deepseek + input_tokens = [1 for _ in range(token_num)] + req = LlmRequest(request_id=req_id, + max_new_tokens=1, + input_tokens=input_tokens, + sampling_config=SamplingConfig( + sampling_params._get_sampling_config()), + is_streaming=False, + encoder_input_tokens=encoder_input_tokens) + req.is_dummy_request = True + req.paged_kv_block_ids = [] + if prepare_resource: + kv_cache = self._create_kv_cache(req.py_request_id, + req.lora_task_id, input_tokens) + assert kv_cache.num_committed_tokens == 0 + success = kv_cache.resume( + torch.cuda.current_stream().cuda_stream) + if not success: + for r in requests: + self.free_resources(r) + self.free_resources(req) + return None + kv_cache.stop_committing() + kv_cache.resize(token_num) + + if is_gen: + req.state = LlmRequestState.GENERATION_IN_PROGRESS + req.prompt_len = token_num - 1 + req.py_prompt_len = req.prompt_len + + # TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrop_config to the request here. + if use_mrope: + dummy_mrope_position_ids = torch.arange( + 0, token_num, dtype=torch.int32).expand(3, 1, -1).clone() + req.py_multimodal_data = { + "mrope_config": { + "mrope_position_ids": dummy_mrope_position_ids + } + } + if is_gen: + dummy_mrope_position_deltas = torch.zeros( + 1, dtype=torch.int32).unsqueeze(0) + req.py_multimodal_data["mrope_config"][ + "mrope_position_deltas"] = dummy_mrope_position_deltas + requests.append(req) + + return requests + + def free_resources(self, request: LlmRequest, pin_on_release: bool = False): + kv_cache = self.kv_cache_map.pop(request.py_request_id) + kv_cache.close() + self.index_mapper.remove_sequence(request.py_request_id) + + def get_batch_cache_indices(self, + request_ids: List[int], + layer_id: int = 0) -> List[List[int]]: + + return self._get_batch_cache_indices_by_pool_id( + request_ids, + pool_id=self.layer_to_pool_mapping_dict[layer_id], + is_kv_aggregate=True) + + def _get_batch_cache_indices_by_pool_id( + self, + request_ids: List[int], + *, + pool_id: int = 0, + is_kv_aggregate: bool = True) -> List[List[int]]: + + if is_kv_aggregate: + # Div by kv_factor to index kv cache with size [num_blocks, kv_factor, tokens_per_block, num_kv_heads, head_dim] + div_factor = self.kv_factor + else: + div_factor = 1 + + return [ + (np.asarray(self.kv_cache_map[req_id].get_page_indices(pool_id)) // + div_factor).tolist() for req_id in request_ids + ] + + def get_cache_bytes_per_token( + self, + local_layer_idx: Optional[int] = None, + data_role: Role = Role.ALL): # None means all layers/data_roles + if self.dtype not in ( + DataType.FP8, + DataType.HALF, + DataType.BF16, + DataType.FLOAT, + DataType.NVFP4, + ): + raise ValueError(f"Cannot support {self.dtype} KV cache.") + + if data_role == Role.ALL: + kv_factor = self.kv_factor + elif data_role in [ + Role.KEY, Role.VALUE, Role.KEY_BLOCK_QUANT, + Role.VALUE_BLOCK_QUANT + ]: + if data_role in [Role.KEY_BLOCK_QUANT, Role.VALUE_BLOCK_QUANT]: + assert self.dtype == DataType.NVFP4, "NVFP4 is the only supported dtype for block quant data roles" + if data_role == Role.VALUE: + assert self.kv_cache_type != CacheTypeCpp.SELFKONLY, "SELFKONLY is the only supported cache type for value data role" + kv_factor = 1 + else: + raise ValueError(f"Invalid data role: {data_role}") + + if local_layer_idx is None: + cache_size_per_token = (kv_factor * + sum(self.num_kv_heads_per_layer) * + self.head_dim) + else: + cache_size_per_token = ( + kv_factor * self.num_kv_heads_per_layer[local_layer_idx] * + self.head_dim) + + cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token, + self.dtype) + + if data_role in [Role.KEY, Role.VALUE]: + return cache_size_bytes_per_token + + quant_size_per_token = 0 + + if self.dtype == DataType.NVFP4: + quant_size_per_token = self.calculate_scaling_factor_size_bytes( + cache_size_per_token, + quant_vector_size=16, + scaling_factor_dtype=DataType.FP8, + ) + + if data_role in [Role.KEY_BLOCK_QUANT, Role.VALUE_BLOCK_QUANT]: + return quant_size_per_token + + return cache_size_bytes_per_token + quant_size_per_token + + @staticmethod + def calculate_scaling_factor_size_bytes( + cache_size: int, quant_vector_size: int, + scaling_factor_dtype: DataType) -> int: + assert cache_size % quant_vector_size == 0, "NVFP4 cache size must be divisible by quant vector size" + return get_size_in_bytes(cache_size // quant_vector_size, + scaling_factor_dtype) + + def check_invalid_values_in_kv_cache(self, + fill_with_zero: bool = False) -> bool: + some_checks_unavailable = False + has_invalid_values = torch.tensor([False], + dtype=torch.bool, + device=torch.cuda.current_device()) + pool_handled = set() + + # Handle each layer from start to end to traverse the whole KV cache. + for layer_id in typed_range(LayerId(self.num_local_layers)): + pool_id = self.layer_to_pool_mapping_dict[layer_id] + if pool_id in pool_handled: + continue + buffer = self.get_buffers(layer_id) + # process in chunks of 256 pages to avoid OoM + for i in range(0, buffer.shape[0], 256): + buffer_slice = buffer[i:i + 256] + try: + has_invalid_values.logical_or_( + torch.isnan(buffer_slice).any()) + has_invalid_values.logical_or_( + torch.isinf(buffer_slice).any()) + except NotImplementedError: + some_checks_unavailable = True + if fill_with_zero: + buffer.zero_() + pool_handled.add(pool_id) + torch.cuda.synchronize() + + if some_checks_unavailable: + logger.warning( + "`torch.isnan` or `torch.isinf` is not implemented for current kv cache dtype, related checks are skipped" + ) + return bool(has_invalid_values) + + def shutdown(self): + for kv_cache in self.kv_cache_map.values(): + kv_cache.close() + self.kv_cache_map.clear() + self.impl.clear_reusable_blocks() + + def get_max_resource_count(self) -> int: + # TODO: implement this + return 1 + + def get_needed_resource_to_completion(self, request: LlmRequest) -> int: + # TODO: implement this + # context_token_count = request.orig_prompt_len + # num_context_blocks = context_token_count // self.tokens_per_block + # remaining_tokens = context_token_count + request.max_new_tokens - num_context_blocks * self.tokens_per_block + # need_blocks = num_context_blocks + math.ceil( + # remaining_tokens / self.tokens_per_block) + # return need_blocks + return 0 + + # TODO: refactor get_cache_size_per_token and get_cache_bytes_per_token to use the same logic + @staticmethod + def get_cache_size_per_token(model_config: ModelConfigPython, + mapping: Mapping, **kwargs): + # get kv cache dtype bytes + mem_per_token = 2 + quant_config = model_config.quant_config + if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache( + ): + mem_per_token = 1 + + # get num key value heads + config = model_config.pretrained_config + num_key_value_heads = getattr(config, 'num_key_value_heads', + config.num_attention_heads) + if isinstance(num_key_value_heads, Iterable): + num_key_value_heads = sum(num_key_value_heads) / len( + num_key_value_heads) + + # get head dim + mla = hasattr(config, "kv_lora_rank") + if mla: + head_dim = config.kv_lora_rank + config.qk_rope_head_dim + kv_factor = 1 + else: + tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size + head_dim = getattr(config, "head_dim", None) + if not isinstance(head_dim, int): + head_dim = config.hidden_size // config.num_attention_heads + head_dim = head_dim * num_key_value_heads // tp_size + kv_factor = 2 + + # provide at least 1 layer to prevent division by zero cache size + num_attention_layers = max( + len(mapping.pp_layers(model_config.get_num_attention_layers())), 1) + mem_per_token *= num_attention_layers * head_dim + + # K and V + mem_per_token *= kv_factor + return mem_per_token + + def update_resources(self, + scheduled_batch: ScheduledRequests, + attn_metadata: "AttentionMetadata" = None, + kv_cache_dtype_byte_size: float = None): + for req in scheduled_batch.context_requests: + if req.py_request_id not in self.kv_cache_map: + continue + kv_cache = self.kv_cache_map[req.py_request_id] + if self.enable_block_reuse and not req.is_dummy_request: + if req.context_current_position > kv_cache.num_committed_tokens: + kv_cache.commit( + req.get_tokens(0)[kv_cache.num_committed_tokens:req. + context_current_position]) + kv_cache.stop_committing() + else: + kv_cache.resize(None, req.context_current_position) + + for req in scheduled_batch.generation_requests: + if req.py_request_id not in self.kv_cache_map: + continue + kv_cache = self.kv_cache_map[req.py_request_id] + kv_cache.resize(None, req.max_beam_num_tokens - 1) + + def copy_batch_block_offsets(self, dst_tensor: torch.Tensor, + request_ids: List[int], beam_width: int, + num_contexts: int, num_seqs: int): + assert beam_width == 1, "beam_width must be 1 for KVCacheManagerV2" + + copy_idx = self.index_mapper.get_copy_index(request_ids, num_contexts, + beam_width) + assert copy_idx.shape[0] == num_seqs + + copy_batch_block_offsets_to_device( + self.host_kv_cache_block_offsets, dst_tensor, copy_idx, + self.kv_cache_type == CacheTypeCpp.SELFKONLY, + torch.cuda.current_stream().cuda_stream) + + def _create_kv_cache(self, request_id: int, lora_task_id: int | None, + input_tokens: Sequence[TokenIdExt] | None): + assert request_id not in self.kv_cache_map, f"KV cache for request {request_id} already exists" + kv_cache = self.impl.create_kv_cache(lora_task_id, input_tokens) + self.kv_cache_map[request_id] = kv_cache + index = self.index_mapper.add_new_sequence(request_id) + for i in range(self.max_beam_width): + for pool_idx in range(self.num_pools): + buffer: torch.Tensor = self.host_kv_cache_block_offsets[ + pool_idx, index * self.max_beam_width + i, 0] + kv_cache.set_page_index_buf(i, pool_idx, + memoryview(buffer.numpy())) + return kv_cache + + class SlotManager: def __init__(self, max_num_requests: int): diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 03c4bf0aff..6631057251 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -168,6 +168,68 @@ class BindCapacityScheduler(CapacityScheduler): self.peft_cache_manager) +class KVCacheV2DummyScheduler(CapacityScheduler): + # only schedule requests has no_schedule_until_state <= state < no_schedule_after_state + no_schedule_until_state = LlmRequestState.CONTEXT_INIT + no_schedule_after_state = LlmRequestState.GENERATION_COMPLETE + + def __init__(self, max_num_requests: int, kv_cache_manager): + super(KVCacheV2DummyScheduler, self).__init__() + self.max_num_requests = max_num_requests + self.kv_cache_manager = kv_cache_manager + + def schedule_request( + self, active_requests: RequestList + ) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]: + scheduled_requests = [] + scheduled_disagg_gen_init_requests = [] + pending_requests = [] + reserved_blocks = 0 + max_blocks = self.kv_cache_manager.get_max_resource_count() + for request in active_requests: + req_state = request.state + # if request cannot be scheduled yet or request should no longer be scheduled, skip + if not req_state == LlmRequestState.DISAGG_GENERATION_INIT and ( + req_state.value < self.no_schedule_until_state.value + or req_state.value >= self.no_schedule_after_state.value): + continue + + if len(scheduled_requests + ) >= self.max_num_requests or reserved_blocks >= max_blocks: + break + elif req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE: + scheduled_requests.append(request) + reserved_blocks += self.kv_cache_manager.get_needed_resource_to_completion( + request) + elif req_state == LlmRequestState.DISAGG_GENERATION_INIT: + scheduled_disagg_gen_init_requests.append(request) + reserved_blocks += self.kv_cache_manager.get_needed_resource_to_completion( + request) + else: + pending_requests.append(request) + + avaiable_blocks = max_blocks - reserved_blocks + for request in pending_requests: + req_state = request.state + if len(scheduled_requests) >= self.max_num_requests: + break + elif req_state == LlmRequestState.CONTEXT_INIT: + needed_blocks = self.kv_cache_manager.get_needed_resource_to_completion( + request) + if needed_blocks <= avaiable_blocks: + scheduled_requests.append(request) + avaiable_blocks -= needed_blocks + elif needed_blocks > avaiable_blocks: + # If one requests fails to be scheduled, break + break + + assert len(scheduled_requests) + len( + scheduled_disagg_gen_init_requests) > 0, ( + "no pending request can get enough resource to complete, " + "please increase KV cache pool size.") + return scheduled_requests, scheduled_disagg_gen_init_requests, [] + + class MicroBatchScheduler(ABC): @abstractmethod diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index bfb4d32b42..8180487fa3 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -210,6 +210,12 @@ def binding_to_str_dtype(binding_dtype) -> str: return ret +def binding_to_torch_dtype(binding_dtype) -> torch.dtype: + ret = _binding_to_str_dtype.get(binding_dtype) + assert ret is not None, f'Unsupported binding dtype: {binding_dtype}' + return str_dtype_to_torch(ret) + + def binding_dtype_size(dtype: DataType): return _binding_dtype_size[dtype] @@ -989,7 +995,7 @@ class TensorWrapper: def __init__( self, data_ptr: int, - dtype: Union[torch.dtype, str, np.dtype, trt.DataType], + dtype: Union[torch.dtype, str, np.dtype, trt.DataType, DataType], shape: Sequence[int], strides: Optional[Sequence[int]] = None, ): @@ -1011,7 +1017,8 @@ class TensorWrapper: return getattr(self, "_shape", None) @dtype.setter - def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType]): + def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType, + DataType]): if isinstance(dtype, torch.dtype): self._dtype = dtype elif isinstance(dtype, str): @@ -1020,6 +1027,8 @@ class TensorWrapper: self._dtype = np_dtype_to_torch(dtype) elif isinstance(dtype, trt.DataType): self._dtype = trt_dtype_to_torch(dtype) + elif isinstance(dtype, DataType): + self._dtype = binding_to_torch_dtype(dtype) else: raise TypeError(f"Unsupported dtype: {dtype}") diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 7286ced1f0..90010b6fa3 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1760,6 +1760,18 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): tokens_per_block: int = Field(default=32, description="The number of tokens per block.") + use_kv_cache_manager_v2: bool = Field( + default=False, + status="prototype", + description="Whether to use the KV cache manager v2 (experimental).") + + max_util_for_resume: float = Field( + default=0.95, + status="prototype", + description= + "The maximum utilization of the KV cache for resume. Default is 95%. Only used when using KV cache manager v2 (experimental)." + ) + def _to_pybind(self): return _KvCacheConfig( enable_block_reuse=self.enable_block_reuse, @@ -1820,6 +1832,14 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): ) return v + @field_validator('max_util_for_resume') + @classmethod + def validate_max_util_for_resume(cls, v: float): + if not 0 <= v <= 1: + raise ValueError( + "kv_cache_config.max_util_for_resume must be between 0 and 1") + return v + @PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig) class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror): diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi b/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi index f2eaa079bf..549f0617a1 100644 --- a/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi @@ -258,4 +258,4 @@ class KVCacheManager: def get_aggregated_pages( self, buffers: Iterable[BufferSlice] ) -> Iterator[AggregatedPageDesc]: ... - def clamp_max_seq_len_for_mem(self, batch_size: int, model_max_seq_len: int) -> int: ... + def clamp_max_seq_len_for_mem(self, batch_size: int) -> int: ... diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py index 46ab44a340..925aca4be8 100644 --- a/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py @@ -295,7 +295,7 @@ class KVCacheManager: ) # @TODO: need updating when dynamic resizing is supported. - def clamp_max_seq_len_for_mem(self, batch_size: int, model_max_seq_len: int) -> int: + def clamp_max_seq_len_for_mem(self, batch_size: int) -> int: "Get the max possible sequence length limited by the GPU memory pools." assert batch_size > 0 tokens_per_block = self.tokens_per_block @@ -330,13 +330,14 @@ class KVCacheManager: assert is_enough(1) lb = 1 - ub = div_up(model_max_seq_len, tokens_per_block) - if is_enough(ub): - return model_max_seq_len - while lb < ub: + ub = lb + while is_enough(ub): + lb = ub + ub *= 2 + while lb < ub - 1: mid = (lb + ub) // 2 if is_enough(mid): lb = mid else: - ub = mid - 1 - return min(lb * tokens_per_block, model_max_seq_len) + ub = mid + return lb * tokens_per_block diff --git a/tests/integration/defs/.test_durations b/tests/integration/defs/.test_durations index 6da286cf0c..75a7dad17a 100644 --- a/tests/integration/defs/.test_durations +++ b/tests/integration/defs/.test_durations @@ -256,11 +256,16 @@ "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True-moe_backend=WIDEEP]": 360.0002855450729839504, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_batch_waiting[batch_wait_timeout_iters=10-batch_wait_max_tokens_ratio=0.75-mtp_nextn=0-fp8kv=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 360.0003064870252273977, "accuracy/test_llm_api_pytorch.py::TestEXAONE4::test_auto_dtype": 3600.0004039629711769521, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto]": 360.00032637204276397824, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-fp8]": 360.0003586999955587089, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton-auto]": 360.6586053780047223, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-auto]": 360.0003633099840953946, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-fp8]": 360.00036422599805518985, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-auto]": 360.00032637204276397824, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-fp8]": 360.0003586999955587089, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-triton-auto]": 360.6586053780047223, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-trtllm-auto]": 360.0003633099840953946, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-trtllm-fp8]": 360.00036422599805518985, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-cutlass-auto]": 360.00032637204276397824, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-cutlass-fp8]": 360.0003586999955587089, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-triton-auto]": 360.6586053780047223, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-trtllm-auto]": 360.0003633099840953946, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-trtllm-fp8]": 360.00036422599805518985, "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-cutlass-auto]": 360.0003378289984539151, "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-triton-auto]": 360.9436147869564593, "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-trtllm-auto]": 360.0003398499684408307, @@ -273,18 +278,30 @@ "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-triton-auto]": 360.8670774899655953, "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-trtllm-auto]": 360.00040231598541140556, "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-trtllm-fp8]": 360.0003254589391872287, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]": 745.8583740849863, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto]": 745.9345730679342523, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto]": 745.0004936959594488144, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8]": 745.00031642295653000474, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto]": 658.1757711600512, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-triton-auto]": 745.9436021829606034, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-auto]": 745.0004371170070953667, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-fp8]": 745.0004142870311625302, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto]": 676.3980704760179, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton-auto]": 745.0292645250447094, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto]": 745.0003769229515455663, - "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-fp8]": 677.000331886054482311, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutlass-auto]": 745.8583740849863, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-triton-auto]": 745.9345730679342523, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-auto]": 745.0004936959594488144, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-fp8]": 745.00031642295653000474, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-cutlass-auto]": 658.1757711600512, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-triton-auto]": 745.9436021829606034, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-auto]": 745.0004371170070953667, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-fp8]": 745.0004142870311625302, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-auto]": 676.3980704760179, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-triton-auto]": 745.0292645250447094, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-trtllm-auto]": 745.0003769229515455663, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-trtllm-fp8]": 677.000331886054482311, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto]": 745.8583740849863, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-triton-auto]": 745.9345730679342523, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-auto]": 745.0004936959594488144, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-fp8]": 745.00031642295653000474, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-cutlass-auto]": 658.1757711600512, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-triton-auto]": 745.9436021829606034, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-auto]": 745.0004371170070953667, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-fp8]": 745.0004142870311625302, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-tp4-cutlass-auto]": 676.3980704760179, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-tp4-triton-auto]": 745.0292645250447094, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-tp4-trtllm-auto]": 745.0003769229515455663, + "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-tp4-trtllm-fp8]": 677.000331886054482311, "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto]": 643.3513998010312, "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto]": 764.9216735750087537, "accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto]": 764.0002969659981317818, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 12b1e1b282..969e26e4e1 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -4472,8 +4472,10 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): @pytest.mark.parametrize("cuda_graph,overlap_scheduler", [ (True, True), ]) + @pytest.mark.parametrize("v2_kv_cache", [True, False], + ids=["v2_kv_cache", "v1_kv_cache"]) def test_w4_1gpu(self, kv_cache_dtype, moe_backend, cuda_graph, - overlap_scheduler, mocker): + overlap_scheduler, mocker, v2_kv_cache): mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192) mocker.patch.dict(GSM8K.EVALUATE_KWARGS, {"scores_filter": "exact_match,flexible-extract"}) @@ -4482,14 +4484,16 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None) - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5, - dtype=kv_cache_dtype) + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, + dtype=kv_cache_dtype, + use_kv_cache_manager_v2=v2_kv_cache) llm = LLM(self.MODEL_PATH, tensor_parallel_size=1, pipeline_parallel_size=1, moe_expert_parallel_size=1, kv_cache_config=kv_cache_config, + max_batch_size=720, **pytorch_config, moe_config=MoeConfig(backend=moe_backend)) @@ -4526,26 +4530,11 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): (4, 1, 4, True, True, True), ], ids=["tp4", "ep4", "dp4"]) - @pytest.mark.parametrize("enable_configurable_moe", [0, 1], - ids=lambda x: "" - if x == 0 else "enable_configurable_moe") + @pytest.mark.parametrize("v2_kv_cache", [True, False], + ids=["v2_kv_cache", "v1_kv_cache"]) def test_w4_4gpus(self, kv_cache_dtype, moe_backend, tp_size, pp_size, ep_size, attention_dp, cuda_graph, overlap_scheduler, - enable_configurable_moe, mocker): - # Handle ENABLE_CONFIGURABLE_MOE environment variable - if enable_configurable_moe == 1 and moe_backend not in [ - "TRTLLM", "CUTLASS" - ]: - pytest.skip( - f"ENABLE_CONFIGURABLE_MOE=1 is only supported with TRTLLM and CUTLASS backend, " - f"current backend is {moe_backend}") - - # Patch MpiPoolSession to propagate env vars to MPI worker processes - env_value = "1" if enable_configurable_moe == 1 and moe_backend in [ - "TRTLLM", "CUTLASS" - ] else "0" - patch_mpi_pool_session_for_env(mocker, - {"ENABLE_CONFIGURABLE_MOE": env_value}) + mocker, v2_kv_cache): MAX_OUTPUT_LEN = 128179 MAX_INPUT_LEN = 32768 @@ -4563,7 +4552,8 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): moe_config=MoeConfig(backend=moe_backend)) kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, - dtype=kv_cache_dtype) + dtype=kv_cache_dtype, + use_kv_cache_manager_v2=v2_kv_cache) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH, diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index cae1767865..b6a57ad1c0 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -159,11 +159,13 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cu accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-fp8] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-trtllm-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-trtllm-fp8] @@ -176,18 +178,22 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-trtllm-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-triton-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-triton-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-fp8] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index 5b5c801d70..2d7994eba3 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -49,11 +49,13 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-overlap_scheduler] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-overlap_scheduler] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-fp8] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-triton-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-trtllm-auto] @@ -66,18 +68,22 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-triton-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-trtllm-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-trtllm-fp8] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-triton-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-fp8] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-triton-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-fp8] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-fp8] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[triton-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[trtllm-auto] diff --git a/tests/integration/test_lists/qa/llm_function_rtx6k.txt b/tests/integration/test_lists/qa/llm_function_rtx6k.txt index c9b42399eb..6f1c1601ee 100644 --- a/tests/integration/test_lists/qa/llm_function_rtx6k.txt +++ b/tests/integration/test_lists/qa/llm_function_rtx6k.txt @@ -144,13 +144,15 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutl accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_latency] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[latency] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[tp2-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[ep2-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_2gpus[dp2-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-cutlass-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto] accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_chunked_prefill[cutlass-auto] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm] diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 9c95a4f221..4823745d5a 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -38,10 +38,12 @@ l0_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_chunked_prefill[quant_dtype=nvfp4-kv_cache_reuse=True-fp8kv=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_dummy_load_format - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-trtllm-fp8] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-trtllm-fp8] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-trtllm-fp8] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_dummy_load_format - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # Cover nvbugs 5461712 and 5505402 - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 1b40aea3a8..14b62db485 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -181,12 +181,16 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_prefill[tp4ep4-cuda_graph=True] - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0] - accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-fp8] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-fp8] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-fp8] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-triton-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-overlap_scheduler] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-one_model-no_overlap_scheduler] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b300.yml b/tests/integration/test_lists/test-db/l0_dgx_b300.yml index 81031354d3..8c3885183e 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b300.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b300.yml @@ -54,9 +54,11 @@ l0_dgx_b300: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] - accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-fp8] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-fp8] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-tp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-triton-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] @@ -86,6 +88,7 @@ l0_dgx_b300: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=False] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-auto] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 6d02460179..d9eb2e5b36 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -210,12 +210,16 @@ l0_dgx_h100: auto_trigger: gpt_oss orchestrator: mpi tests: - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-triton-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-triton-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto] - condition: ranges: diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index f1e4bdd2c8..cce8f9f648 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -53,13 +53,19 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=True-enable_gemm_allreduce_fusion=False] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-tp4-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[fp8] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-triton-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-fp8] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-triton-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-fp8] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-overlap_scheduler] diff --git a/tests/integration/test_lists/test-db/l0_gb300.yml b/tests/integration/test_lists/test-db/l0_gb300.yml index 60e3826e41..a244192c39 100644 --- a/tests/integration/test_lists/test-db/l0_gb300.yml +++ b/tests/integration/test_lists/test-db/l0_gb300.yml @@ -17,6 +17,7 @@ l0_gb300: tests: # ------------- PyTorch tests --------------- - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency] - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v2_kv_cache-True-True-cutlass-auto] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] # Cover nvbugs 5461712 and 5505402 - unittest/_torch/thop/parallel TIMEOUT (90) diff --git a/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml b/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml index 5072fdb320..c8dfb98a40 100644 --- a/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml +++ b/tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml @@ -39,7 +39,7 @@ l0_rtx_pro_6000: - test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-fp8-multimodals/Phi-4-multimodal-instruct-FP8-image_audio] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=False] # 8mins - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[latency_moe_cutlass-torch_compile=True] # 8 mins - - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[True-True-cutlass-auto] + - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-cutlass-auto] - accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_fp4 - accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_fp8 - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index c26c45bbac..e8e4364507 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -181,8 +181,10 @@ full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ unittest/llmapi/test_memory_profiling.py::test_profile_kvcache SKIP (https://nvbugs/5580781) triton_server/test_triton.py::test_llava[llava] SKIP (https://nvbugs/5547414) full:RTX/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5569696) -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] SKIP (https://nvbugs/5596343) -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto] SKIP (https://nvbugs/5596343) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-cutlass-auto] SKIP (https://nvbugs/5596343) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutlass-auto] SKIP (https://nvbugs/5596343) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-cutlass-auto] SKIP (https://nvbugs/5596343) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-tp4-cutlass-auto] SKIP (https://nvbugs/5596343) examples/test_phi.py::test_llm_phi_lora_1gpu[Phi-3-mini-4k-instruct-ru-lora-Phi-3-mini-4k-instruct-lora_fp16-base_fp16] SKIP (https://nvbugs/5612313) triton_server/test_triton.py::test_cpp_unit_tests[cpp-unit-tests] SKIP (https://nvbugs/5619359) triton_server/test_triton_rcca.py::test_rcca_bug_4934893[Temperature:0.5-TOP_P:0.95-TOP_K:10-False-1---False-True-False-0-2048-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--max_utilization---1-1-1-False-ensemble] SKIP (https://nvbugs/5619369) @@ -312,7 +314,8 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=True] SKIP (https://nvbugs/5821053) accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] SKIP (https://nvbugs/5821415) accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4[torch_compile=True] SKIP (https://nvbugs/5821415) -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-auto] SKIP (https://nvbugs/5651865) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-auto] SKIP (https://nvbugs/5651865) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-auto] SKIP (https://nvbugs/5651865) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5822983) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-sampler_async_worker=False] SKIP (https://nvbugs/5701445) accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0] SKIP (https://nvbugs/5748600) @@ -325,7 +328,8 @@ perf/test_perf_sanity.py::test_e2e[disagg_upload-deepseek-r1-fp4_1k1k_ctx1_gen1_ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=True] SKIP (https://nvbugs/5826604) disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5834212) accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_multi_gpus[throughput_trtllm] SKIP (https://nvbugs/5837275) -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-fp8] SKIP (https://nvbugs/5640697) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-trtllm-fp8] SKIP (https://nvbugs/5640697) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-ep4-trtllm-fp8] SKIP (https://nvbugs/5640697) accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_multi_gpus[throughput] SKIP (https://nvbugs/5837275) test_e2e.py::test_ptp_quickstart_advanced_deepseek_r1_w4afp8_8gpus[DeepSeek-R1-W4AFP8-DeepSeek-R1/DeepSeek-R1-W4AFP8] SKIP (https://nvbugs/5836830) accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-False] SKIP (https://nvbugs/5823587) @@ -342,7 +346,8 @@ perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_grace_blackwel examples/test_llama.py::test_llama_3_x_with_bf16_lora_torch[llama-3.2-1b-instruct] SKIP (https://nvbugs/5838178) accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16 SKIP (https://nvbugs/5838184) cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mooncake_kvcache-90] SKIP (https://nvbugs/5838199) -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto] SKIP (https://nvbugs/5838211) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-cutlass-auto] SKIP (https://nvbugs/5838211) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-cutlass-auto] SKIP (https://nvbugs/5838211) test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct] SKIP (https://nvbugs/5843112) accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] SKIP (https://nvbugs/5839028) full:A10/unittest/kv_cache_manager_v2_tests/ SKIP (https://nvbugs/5841954) @@ -366,7 +371,8 @@ full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5846154) perf/test_perf_sanity.py::test_e2e[aggr_upload-deepseek_r1_fp4_v2_grace_blackwell-r1_fp4_v2_dep4_mtp1_1k8k] SKIP (https://nvbugs/5846166) accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] SKIP (https://nvbugs/5847284) -accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8] SKIP (https://nvbugs/5850183) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-fp8] SKIP (https://nvbugs/5850183) +accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v2_kv_cache-dp4-trtllm-fp8] SKIP (https://nvbugs/5850183) examples/test_ad_speculative_decoding.py::test_autodeploy_eagle3_acceptance_rate SKIP (https://nvbugs/5855540) unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py::test_ad_speculative_decoding_smoke[False] SKIP (https://nvbugs/5859869) unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_autotune_fp8_fp4[RoutingDSlite-384-1024-1] SKIP (https://nvbugs/5859881) diff --git a/tests/unittest/_torch/attention/test_attention.py b/tests/unittest/_torch/attention/test_attention.py index 9754f547e1..d0ea3d0a1c 100644 --- a/tests/unittest/_torch/attention/test_attention.py +++ b/tests/unittest/_torch/attention/test_attention.py @@ -14,8 +14,9 @@ from tensorrt_llm._torch.attention_backend import (AttentionBackend, from tensorrt_llm._torch.attention_backend.interface import \ PredefinedAttentionMask from tensorrt_llm._torch.metadata import KVCacheParams -from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager -from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager, + KVCacheManagerV2) +from tensorrt_llm.llmapi.llm_args import KvCacheConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo @@ -131,8 +132,13 @@ paged_backends = { } -def kv_cache_manager_from(Attention: type[AttentionBackend], s: Scenario, - kv_cache: torch.Tensor) -> KVCacheManager: +def kv_cache_manager_from( + Attention: type[AttentionBackend], + s: Scenario, + kv_cache: torch.Tensor, + request_ids: list[int], + token_nums: list[int], + use_kv_cache_manager_v2: bool = False) -> KVCacheManager: paged = paged_backends[Attention] num_blocks = s.max_num_pages if paged else s.batch_size @@ -158,7 +164,12 @@ def kv_cache_manager_from(Attention: type[AttentionBackend], s: Scenario, cache_type = tensorrt_llm.bindings.internal.batch_manager.CacheType.CROSS if s.cross else tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF - result = KVCacheManager( + if use_kv_cache_manager_v2: + kv_cache_manager_cls = KVCacheManagerV2 + else: + kv_cache_manager_cls = KVCacheManager + + kv_cache_manager = kv_cache_manager_cls( kv_cache_config, cache_type, num_layers=num_layers, @@ -171,9 +182,19 @@ def kv_cache_manager_from(Attention: type[AttentionBackend], s: Scenario, dtype=kv_cache_dtype, ) + kv_cache_manager.add_dummy_requests(request_ids, token_nums) + for i in range(s.num_layers): - result.get_buffers(i).view_as(kv_cache[i]).copy_(kv_cache[i]) - return result + buffer = kv_cache_manager.get_buffers(i) + block_ids = [ + block_id + for req_block_ids in kv_cache_manager.get_batch_cache_indices( + request_ids, i) for block_id in req_block_ids + if block_id is not -1 + ] + for idx, block_id in enumerate(block_ids): + buffer[block_id].view_as(kv_cache[i][idx]).copy_(kv_cache[i][idx]) + return kv_cache_manager def produce_outputs( @@ -181,6 +202,7 @@ def produce_outputs( q_at_layer: torch.Tensor, kv: Optional[torch.Tensor], s: Scenario, + use_kv_cache_manager_v2: bool = False, *, kv_cache: torch.Tensor, num_cached_tokens: Callable[[int], int] | int, @@ -197,12 +219,13 @@ def produce_outputs( kv_cache_params = KVCacheParams( use_cache=True, num_cached_tokens_per_seq=num_cached_tokens_per_seq) - kv_cache_manager = kv_cache_manager_from(Attention, s, kv_cache) request_ids = list(range(s.batch_size)) seq_lens_append = seq_lens_kv if seq_lens_kv is not None else seq_lens token_nums = (torch.tensor(num_cached_tokens_per_seq) + seq_lens_append).tolist() - kv_cache_manager.add_dummy_requests(request_ids, token_nums) + kv_cache_manager = kv_cache_manager_from(Attention, s, kv_cache, + request_ids, token_nums, + use_kv_cache_manager_v2) metadata = Attention.Metadata( num_contexts=num_contexts if num_contexts is not None else s.batch_size, @@ -414,7 +437,9 @@ def test_flashinfer_prefill(): Scenario(num_layers=1, qo_len=32, kv_len=64, causal=False) ], ids=["typical", "non-causal", "cross", "cross-diff-kv-len"]) -def test_attention_backend(s: Scenario): +@pytest.mark.parametrize("use_kv_cache_manager_v2", [True, False], + ids=["v2_kv_cache", "v1_kv_cache"]) +def test_attention_backend(s: Scenario, use_kv_cache_manager_v2: bool): dtype = s.dtype num_layers = s.num_layers num_heads = s.num_heads @@ -457,6 +482,7 @@ def test_attention_backend(s: Scenario): q_at_layer, kv, s, + use_kv_cache_manager_v2=use_kv_cache_manager_v2, kv_cache=kv_cache, num_cached_tokens=past_kv_len, seq_lens=torch.full((batch_size, ), qo_len).int(), @@ -559,7 +585,9 @@ def generate_causal_mask(seq_lens, qo_lens, batch_size, dtype): kvcache_dtype=torch.float8_e4m3fn), ], ids=["fp16", "fp16-cross", "fp8", "fp8-cross"]) -def test_attention_backend_ifb(s: PagedScenario): +@pytest.mark.parametrize("use_kv_cache_manager_v2", [True, False], + ids=["v2_kv_cache", "v1_kv_cache"]) +def test_attention_backend_ifb(s: PagedScenario, use_kv_cache_manager_v2: bool): dtype = s.dtype is_fp8 = s.kvcache_dtype == torch.float8_e4m3fn if is_fp8 and getSMVersion() < 89: @@ -625,6 +653,7 @@ def test_attention_backend_ifb(s: PagedScenario): q_at_layer, kv, s, + use_kv_cache_manager_v2=use_kv_cache_manager_v2, kv_cache=kv_cache, num_cached_tokens=lambda i: num_cached_tokens_prefill if i < num_contexts else num_cached_tokens_decode, diff --git a/tests/unittest/_torch/attention/test_attention_mla.py b/tests/unittest/_torch/attention/test_attention_mla.py index 43b9aadd73..9bb7ec8ff9 100644 --- a/tests/unittest/_torch/attention/test_attention_mla.py +++ b/tests/unittest/_torch/attention/test_attention_mla.py @@ -5,6 +5,7 @@ from typing import List import pytest import torch +from utils.util import getSMVersion import tensorrt_llm from tensorrt_llm._torch.attention_backend.interface import ( @@ -14,10 +15,11 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, LlmRequestState, SamplingConfig) -from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager, + KVCacheManagerV2) from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str -from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.functional import PositionEmbeddingType, RopeEmbeddingUtils +from tensorrt_llm.llmapi.llm_args import KvCacheConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.quantization.mode import QuantAlgo @@ -359,10 +361,17 @@ accuracy_dict = { @pytest.mark.parametrize("num_generation_steps", num_generation_steps, ids=lambda x: f"num_generation_steps: {x}") +@pytest.mark.parametrize("v2_kv_cache", [True, False], + ids=["v2_kv_cache", "v1_kv_cache"]) def test_attention_mla(scenario: Scenario, context_sequence_lengths: List[int], generation_seq_len_q: int, - num_generation_steps: List[int]): + num_generation_steps: List[int], v2_kv_cache: bool): """Test MLA computation for both context and generation phases""" + + if v2_kv_cache and getSMVersion() != 100: + pytest.skip( + "v2_kv_cache is only supported for MLA on Blackwell architectures") + num_heads = scenario.num_heads num_kv_heads = scenario.num_kv_heads q_lora_rank = scenario.q_lora_rank @@ -403,7 +412,8 @@ def test_attention_mla(scenario: Scenario, context_sequence_lengths: List[int], qk_rope_head_dim, v_head_dim, rope_config, kv_cache_tokens_per_block, device, dtype, kv_cache_dtype, context_sequence_lengths, - generation_seq_len_q, num_generation_steps) + generation_seq_len_q, num_generation_steps, + v2_kv_cache) def _run_test_for_backend(backend_name, num_heads, num_kv_heads, num_layers, @@ -411,7 +421,8 @@ def _run_test_for_backend(backend_name, num_heads, num_kv_heads, num_layers, qk_rope_head_dim, v_head_dim, rope_config, kv_cache_tokens_per_block, device, dtype, kv_cache_dtype, context_sequence_lengths, - generation_seq_len_q, num_generation_steps): + generation_seq_len_q, num_generation_steps, + v2_kv_cache): AttentionCls = get_attention_backend(backend_name) qk_head_dim = qk_nope_head_dim + qk_rope_head_dim @@ -597,7 +608,8 @@ def _run_test_for_backend(backend_name, num_heads, num_kv_heads, num_layers, (num_generation_steps + 1) * generation_seq_len_q + kv_cache_tokens_per_block - 1 ) // kv_cache_tokens_per_block * kv_cache_tokens_per_block * max_num_contexts - kv_cache_manager = KVCacheManager( + kv_cache_cls = KVCacheManagerV2 if v2_kv_cache else KVCacheManager + kv_cache_manager = kv_cache_cls( KvCacheConfig( max_tokens=max_tokens, enable_block_reuse=False, @@ -625,8 +637,14 @@ def _run_test_for_backend(backend_name, num_heads, num_kv_heads, num_layers, ) req.paged_kv_block_ids = [] beam_width = 1 - kv_cache_manager.impl.add_sequence(req_id, ctx_len, beam_width, req) request_list.append(req) + if v2_kv_cache: + kv_cache = kv_cache_manager._create_kv_cache(req_id, None, None) + success = kv_cache.resume(torch.cuda.current_stream().cuda_stream) + assert success, f"Failed to resume KV cache for request {req_id}" + kv_cache.capacity = ctx_len + else: + kv_cache_manager.impl.add_sequence(req_id, ctx_len, beam_width, req) attn_metadata = AttentionCls.Metadata( seq_lens=torch.tensor(context_sequence_lengths, dtype=torch.int), request_ids=list(range(len(context_sequence_lengths))), @@ -649,7 +667,11 @@ def _run_test_for_backend(backend_name, num_heads, num_kv_heads, num_layers, if step > 0: for req_id in range(len(context_sequence_lengths)): for _ in range(generation_seq_len_q): - kv_cache_manager.impl.add_token(req_id) + if v2_kv_cache: + kv_cache = kv_cache_manager.kv_cache_map[req_id] + kv_cache.capacity += 1 + else: + kv_cache_manager.impl.add_token(req_id) attn_metadata = AttentionCls.Metadata( seq_lens=torch.tensor([generation_seq_len_q] * len(context_sequence_lengths),