[None][feat] Use new index api, add block scale support, fix max_seq_len esitmation, add flash mla support (#11334)

Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
Yi Zhang 2026-02-15 21:40:54 +08:00 committed by GitHub
parent 59b6bee7e6
commit 361ff36784
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 231 additions and 117 deletions

View File

@ -186,10 +186,10 @@ CUresult copyDeviceToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes,
// dst_tensor[:, :num_seqs, 0] = src_tensor[:, copy_idx]
// dst_tensor[:, :num_seqs, 1] = dst_tensor[:, :num_seqs, 0] + 1
template <bool COPY_V_IDX = true>
__global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict__ srcPtr,
SizeType32* __restrict__ dstPtr, SizeType32 const srcMaxNumSequences, SizeType32 const dstMaxNumSequences,
SizeType32 numBlocksPerSeq, SizeType32 const* __restrict__ copyIndex)
SizeType32 numBlocksPerSeq, SizeType32 const* __restrict__ copyIndex, SizeType32 const* __restrict__ indexScales,
SizeType32 const* __restrict__ kvOffset)
{
constexpr uint32_t kvFactor = 2;
constexpr auto elemPerAccess = sizeof(PackedInt) / sizeof(SizeType32);
@ -224,19 +224,12 @@ __global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict
asm volatile("cp.async.wait_group %0;\n" ::"n"(copyBlocknbBufs - 1) : "memory");
if (srcIdx < srcIdxEnd)
{
dstK = src;
if (COPY_V_IDX)
{
dstV = src;
}
else
{
#pragma unroll
for (uint32_t j = 0; j < elemPerAccess; j++)
{
auto const val = src.unpacked[j];
dstV.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (val + 1);
}
for (uint32_t j = 0; j < elemPerAccess; j++)
{
auto const val = src.unpacked[j];
dstK.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (indexScales[poolIdx] * val);
dstV.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (indexScales[poolIdx] * val + kvOffset[poolIdx]);
}
}
}
@ -256,8 +249,8 @@ __global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict
}
// Host-side launcher
void copyBatchBlockOffsetsToDevice(
ITensor const& input, ITensor& output, ITensor const& copyIndex, bool copyVIdx, CUstream stream) noexcept
void copyBatchBlockOffsetsToDevice(ITensor const& input, ITensor& output, ITensor const& copyIndex,
ITensor const& indexScales, ITensor const& kvOffset, CUstream stream) noexcept
{
using namespace tensorrt_llm::runtime;
@ -265,6 +258,8 @@ void copyBatchBlockOffsetsToDevice(
auto* dstPtr = bufferCast<tk::KVCacheIndex::UnderlyingType>(
output); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq]
auto const* copyIndexPtr = bufferCast<SizeType32 const>(copyIndex);
auto const* indexScalesPtr = bufferCast<SizeType32 const>(indexScales);
auto const* kvOffsetPtr = bufferCast<SizeType32 const>(kvOffset);
auto const& srcShape = input.getShape();
auto const& dstShape = output.getShape();
auto const& copyIndexShape = copyIndex.getShape();
@ -290,16 +285,8 @@ void copyBatchBlockOffsetsToDevice(
dim3 gridDim(numPools, numSeqs, 1);
dim3 blockDim(copyBlockCtaSize);
if (copyVIdx)
{
copyBatchBlockOffsetsToDeviceKernel<true><<<gridDim, blockDim, 0, stream>>>(
srcPtr, dstPtr, srcMaxNumSequences, dstMaxNumSequences, numBlocksPerSeq, copyIndexPtr);
}
else
{
copyBatchBlockOffsetsToDeviceKernel<false><<<gridDim, blockDim, 0, stream>>>(
srcPtr, dstPtr, srcMaxNumSequences, dstMaxNumSequences, numBlocksPerSeq, copyIndexPtr);
}
copyBatchBlockOffsetsToDeviceKernel<<<gridDim, blockDim, 0, stream>>>(srcPtr, dstPtr, srcMaxNumSequences,
dstMaxNumSequences, numBlocksPerSeq, copyIndexPtr, indexScalesPtr, kvOffsetPtr);
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -95,7 +95,7 @@ CUresult copyDeviceToHost(
CUresult copyDeviceToDevice(
std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, CUstream stream) noexcept;
void copyBatchBlockOffsetsToDevice(
ITensor const& input, ITensor& output, ITensor const& copyIndex, bool copyVIdx, CUstream stream) noexcept;
void copyBatchBlockOffsetsToDevice(ITensor const& input, ITensor& output, ITensor const& copyIndex,
ITensor const& indexScales, ITensor const& kvOffset, CUstream stream) noexcept;
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -131,19 +131,24 @@ void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module)
module.def(
"copy_batch_block_offsets_to_device",
[](at::Tensor input, at::Tensor output, at::Tensor copyIndex, bool copyVIdx, uintptr_t stream)
[](at::Tensor input, at::Tensor output, at::Tensor copyIndex, at::Tensor indexScales, at::Tensor kvOffset,
uintptr_t stream)
{
auto _input = from_torch(input);
auto _output = from_torch(output);
auto _copyIndex = from_torch(copyIndex);
auto _indexScales = from_torch(indexScales);
auto _kvOffset = from_torch(kvOffset);
TLLM_CHECK_WITH_INFO(_input.has_value(), "Invalid input tensor.");
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
TLLM_CHECK_WITH_INFO(_copyIndex.has_value(), "Invalid copy index tensor.");
copyBatchBlockOffsetsToDevice(*(_input.value()), *(_output.value()), *(_copyIndex.value()), copyVIdx,
reinterpret_cast<CUstream>(stream));
TLLM_CHECK_WITH_INFO(_indexScales.has_value(), "Invalid index scales tensor.");
TLLM_CHECK_WITH_INFO(_kvOffset.has_value(), "Invalid kv offset tensor.");
copyBatchBlockOffsetsToDevice(*(_input.value()), *(_output.value()), *(_copyIndex.value()),
*(_indexScales.value()), *(_kvOffset.value()), reinterpret_cast<CUstream>(stream));
},
nb::arg("input"), nb::arg("output"), nb::arg("copy_index"), nb::arg("copy_v_idx"), nb::arg("stream"),
nb::call_guard<nb::gil_scoped_release>(), "Copy batch block indices to device");
nb::arg("input"), nb::arg("output"), nb::arg("copy_index"), nb::arg("index_scales"), nb::arg("kv_offset"),
nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(), "Copy batch block indices to device");
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -676,10 +676,11 @@ class PyTorchModelEngine(ModelEngine):
reverse: bool = False):
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
curr_max_num_tokens = min(
kv_cache_manager.get_num_available_tokens(
max_num_draft_tokens=self.original_max_draft_len),
self.max_num_tokens, self.batch_size * (self.max_seq_len - 1))
token_num_upper_bound = min(self.max_num_tokens,
self.batch_size * (self.max_seq_len - 1))
curr_max_num_tokens = kv_cache_manager.get_num_available_tokens(
token_num_upper_bound=token_num_upper_bound,
max_num_draft_tokens=self.original_max_draft_len)
max_batch_size = min(
self.batch_size,
curr_max_num_tokens // (1 + self.runtime_draft_len))
@ -728,10 +729,11 @@ class PyTorchModelEngine(ModelEngine):
logger.info("Running autotuner warmup...")
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
curr_max_num_tokens = min(
kv_cache_manager.get_num_available_tokens(
max_num_draft_tokens=self.original_max_draft_len),
self.max_num_tokens, self.batch_size * (self.max_seq_len - 1))
token_num_upper_bound = min(self.max_num_tokens,
self.batch_size * (self.max_seq_len - 1))
curr_max_num_tokens = kv_cache_manager.get_num_available_tokens(
token_num_upper_bound=token_num_upper_bound,
max_num_draft_tokens=self.original_max_draft_len)
cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None)
with self.no_cuda_graph(), autotune(cache_path=cache_path):
@ -962,6 +964,7 @@ class PyTorchModelEngine(ModelEngine):
ResourceManagerType.SPEC_RESOURCE_MANAGER)
available_tokens = kv_cache_manager.get_num_available_tokens(
token_num_upper_bound=num_tokens,
max_num_draft_tokens=self.runtime_draft_len)
available_blocks = kv_cache_manager.get_num_free_blocks()
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
@ -1104,19 +1107,24 @@ class PyTorchModelEngine(ModelEngine):
if requests is None:
return None
available_tokens = kv_cache_manager.get_num_available_tokens(
batch_size=batch_size, max_num_draft_tokens=draft_len)
# Also consider draft KV cache capacity when it exists
if draft_kv_cache_manager is not None:
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
batch_size=batch_size, max_num_draft_tokens=draft_len)
available_tokens = min(available_tokens, draft_available_tokens)
# Add one dummy request with the maximum possible sequence length.
max_seq_len = min(
self.max_seq_len if max_seq_len is None else max_seq_len,
kv_cache_manager.max_seq_len)
available_tokens = kv_cache_manager.get_num_available_tokens(
token_num_upper_bound=max_seq_len,
batch_size=batch_size,
max_num_draft_tokens=draft_len)
# Also consider draft KV cache capacity when it exists
if draft_kv_cache_manager is not None:
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
batch_size=batch_size,
token_num_upper_bound=max_seq_len,
max_num_draft_tokens=draft_len)
available_tokens = min(available_tokens, draft_available_tokens)
token_num = max(
1,
min(

View File

@ -7,7 +7,6 @@ from collections import OrderedDict, defaultdict, deque
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
Set, Tuple, Union)
import numpy as np
import torch
from mpi4py import MPI
@ -23,7 +22,6 @@ from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
from tensorrt_llm.math_utils import ceil_div
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
from tensorrt_llm.runtime.kv_cache_manager_v2 import (DEFAULT_BEAM_INDEX,
AttentionLayerConfig,
@ -36,7 +34,8 @@ from tensorrt_llm.runtime.kv_cache_manager_v2 import \
KVCacheManagerConfig as KVCacheManagerConfigPy
from tensorrt_llm.runtime.kv_cache_manager_v2 import (LayerId, TokenIdExt,
_KVCache)
from tensorrt_llm.runtime.kv_cache_manager_v2._common import GPU_LEVEL
from tensorrt_llm.runtime.kv_cache_manager_v2._common import (BAD_PAGE_INDEX,
GPU_LEVEL)
from tensorrt_llm.runtime.kv_cache_manager_v2._config import DataRole
from tensorrt_llm.runtime.kv_cache_manager_v2._utils import (exact_div,
typed_range)
@ -82,8 +81,8 @@ class ResourceManagerType(enum.Enum):
class Role:
KEY = DataRole("key")
VALUE = DataRole("value")
KEY_BLOCK_QUANT = DataRole("key_block_quant")
VALUE_BLOCK_QUANT = DataRole("value_block_quant")
KEY_BLOCK_SCALE = DataRole("key_block_scale")
VALUE_BLOCK_SCALE = DataRole("value_block_scale")
ALL = DataRole("all")
@ -1009,10 +1008,13 @@ class KVCacheManager(BaseResourceManager):
return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block
def get_num_available_tokens(self,
token_num_upper_bound: int,
max_num_draft_tokens: int = 0,
**kwargs) -> int:
return (self.get_num_free_blocks() * self.tokens_per_block -
self.num_extra_kv_tokens - max_num_draft_tokens)
return min(
token_num_upper_bound,
self.get_num_free_blocks() * self.tokens_per_block -
self.num_extra_kv_tokens - max_num_draft_tokens)
def get_buffers(self,
layer_idx: int,
@ -1524,12 +1526,12 @@ class KVCacheManagerV2(BaseResourceManager):
max_beam_width: int = 1,
is_draft: bool = False,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
execution_stream: Optional[torch.cuda.Stream] = None,
**kwargs,
) -> None:
self.mapping = mapping
self.dtype = dtype
assert self.dtype != DataType.NVFP4, "NVFP4 is not supported for KVCacheManagerV2"
assert kv_connector_manager is None, "kv_connector_manager is not supported for KVCacheManagerV2"
assert max_beam_width == 1, "max_beam_width must be 1 for KVCacheManagerV2"
@ -1565,6 +1567,10 @@ class KVCacheManagerV2(BaseResourceManager):
assert self.event_buffer_max_size == 0, "event_buffer_max_size must be 0"
self._stream = execution_stream if execution_stream is not None else torch.cuda.Stream(
)
logger.info(f"[KVCacheManager] execution_stream: {self._stream}")
# Determine max_attention_window_vec
if kv_cache_config.max_attention_window is not None:
@ -1626,10 +1632,9 @@ class KVCacheManagerV2(BaseResourceManager):
if kv_cache_config.max_tokens is not None:
quota = int(
ceil_div(
kv_cache_config.max_tokens *
self.get_cache_bytes_per_token(),
kv_cache_config.max_util_for_resume))
math.ceil(kv_cache_config.max_tokens *
self.get_cache_bytes_per_token() /
kv_cache_config.max_util_for_resume))
if kv_cache_config.free_gpu_memory_fraction is not None:
logger.warning(
f"Both max_tokens and free_gpu_memory_fraction are set to {kv_cache_config.max_tokens} and {kv_cache_config.free_gpu_memory_fraction}, the smaller value will be used."
@ -1658,6 +1663,11 @@ class KVCacheManagerV2(BaseResourceManager):
buffer_type = [Role.KEY]
if kv_cache_type != CacheTypeCpp.SELFKONLY:
buffer_type.append(Role.VALUE)
if kv_cache_config.dtype == "nvfp4":
assert head_dim % 2 == 0, "head_dim must be divisible by 2 for nvfp4 kv cache"
buffer_type.append(Role.KEY_BLOCK_SCALE)
if kv_cache_type != CacheTypeCpp.SELFKONLY:
buffer_type.append(Role.VALUE_BLOCK_SCALE)
config = KVCacheManagerConfigPy(
tokens_per_block=tokens_per_block,
@ -1701,20 +1711,54 @@ class KVCacheManagerV2(BaseResourceManager):
device="cpu",
pin_memory=True)
if kv_cache_config.dtype == "nvfp4":
self.kv_cache_pool_pointers = torch.stack([
self.kv_cache_pool_pointers,
torch.tensor([[
self.impl.get_mem_pool_base_address(
self.impl.layer_grouping[pool_id][0],
Role.KEY_BLOCK_SCALE), 0
] for pool_id in range(self.num_pools)],
dtype=torch.int64,
device="cpu",
pin_memory=True)
],
dim=-1)
kv_cache_pool_mapping_list = []
for layer_id in typed_range(LayerId(self.num_local_layers)):
layer_group_id = self.impl.get_layer_group_id(layer_id)
if kv_cache_config.dtype != "nvfp4":
addr_offset = self.impl.get_mem_pool_base_address(
layer_id, Role.KEY) - int(
self.kv_cache_pool_pointers[layer_group_id][0])
else:
addr_offset = self.impl.get_mem_pool_base_address(
layer_id, Role.KEY) - int(
self.kv_cache_pool_pointers[layer_group_id][0][0])
block_scale_addr_offset = self.impl.get_mem_pool_base_address(
layer_id, Role.KEY_BLOCK_SCALE) - int(
self.kv_cache_pool_pointers[layer_group_id][0][1])
block_scale_offset = exact_div(
block_scale_addr_offset,
self.get_cache_bytes_per_token(
layer_id, Role.KEY_BLOCK_SCALE) * self.kv_factor *
self.tokens_per_block)
offset = exact_div(
self.impl.get_mem_pool_base_address(layer_id, Role.KEY) -
int(self.kv_cache_pool_pointers[layer_group_id][0]),
addr_offset,
self.get_cache_bytes_per_token(layer_id, Role.KEY) *
self.kv_factor * self.tokens_per_block)
if kv_cache_config.dtype == "nvfp4":
assert block_scale_offset == offset, "Block scale offset and offset should be the same"
kv_cache_pool_mapping_list.append([layer_group_id, offset])
self.kv_cache_pool_mapping = torch.tensor(kv_cache_pool_mapping_list,
dtype=torch.int32,
device="cpu",
pin_memory=True)
# Pad max_blocks_per_seq to next multiple of 4 for copy_block_offsets kernel
self.max_blocks_per_seq = (max_seq_len + tokens_per_block -
1) // tokens_per_block
@ -1723,7 +1767,8 @@ class KVCacheManagerV2(BaseResourceManager):
self.kv_cache_map: dict[int, _KVCache] = {}
max_num_tokens = self.get_num_available_tokens()
max_num_tokens = self.get_num_available_tokens(
token_num_upper_bound=max_seq_len)
if max_seq_len > max_num_tokens:
logger.warning(
@ -1736,6 +1781,25 @@ class KVCacheManagerV2(BaseResourceManager):
# Plus 1 for cuda graph dummy request
self.index_mapper = IndexMapper(max_batch_size + 1, max_beam_width)
self.index_scales = torch.empty(self.num_pools,
dtype=torch.int32,
pin_memory=True,
device='cpu')
self.kv_offset = torch.empty(self.num_pools,
dtype=torch.int32,
pin_memory=True,
device='cpu')
for pool_id in range(self.num_pools):
layer_id = self.impl.layer_grouping[pool_id][0]
self.index_scales[pool_id] = self.impl.get_page_index_scale(
layer_id, Role.KEY)
if self.kv_cache_type != CacheTypeCpp.SELFKONLY:
self.kv_offset[pool_id] = exact_div(
self.impl.get_mem_pool_base_address(layer_id, Role.VALUE) -
self.impl.get_mem_pool_base_address(layer_id, Role.KEY),
self.impl.get_page_stride(layer_id, Role.KEY))
else:
self.kv_offset[pool_id] = 0
self.host_kv_cache_block_offsets = torch.empty(
self.num_pools,
@ -1770,6 +1834,12 @@ class KVCacheManagerV2(BaseResourceManager):
assert kv_layout in ["NHD",
"HND"], f"Unsupported kv_layout: {kv_layout}"
element_per_container = 1
dtype = self.dtype
if dtype == DataType.NVFP4:
element_per_container = 2
dtype = torch.int8
if kv_layout == "NHD":
shape = [
self.impl.get_page_index_upper_bound(layer_offset, Role.KEY) //
@ -1777,7 +1847,7 @@ class KVCacheManagerV2(BaseResourceManager):
self.kv_factor,
self.tokens_per_block,
self.num_kv_heads_per_layer[layer_offset],
self.head_dim,
self.head_dim // element_per_container,
]
else:
shape = [
@ -1786,27 +1856,28 @@ class KVCacheManagerV2(BaseResourceManager):
self.kv_factor,
self.num_kv_heads_per_layer[layer_offset],
self.tokens_per_block,
self.head_dim,
self.head_dim // element_per_container,
]
return convert_to_torch_tensor(
TensorWrapper(
addr_key,
self.dtype,
shape,
))
return convert_to_torch_tensor(TensorWrapper(
addr_key,
dtype,
shape,
))
def get_num_available_tokens(self,
*,
token_num_upper_bound: int,
batch_size: int = 1,
max_num_draft_tokens: int = 0) -> int:
if max_num_draft_tokens > 0:
raise ValueError(
"max_num_draft_tokens is not supported for KVCacheManagerV2")
return int(
self.impl.clamp_max_seq_len_for_mem(batch_size) *
self.kv_cache_manager_py_config.max_util_for_resume
) - self.num_extra_kv_tokens - max_num_draft_tokens
extra_tokens = self.num_extra_kv_tokens + max_num_draft_tokens
# Token num upper bound is the maximum number of tokens that can be allocated in the kv cache manager.
# We need to add extra tokens to the token num upper bound to account for the extra tokens.
return self.impl.clamp_max_seq_len_for_mem(
batch_size, token_num_upper_bound + extra_tokens) - extra_tokens
def get_num_free_blocks(self) -> int:
# NOTE This method is used to get the number of blocks in the primary pool not the FREE blocks.
@ -1859,11 +1930,14 @@ class KVCacheManagerV2(BaseResourceManager):
chunk_size,
req.prompt_len - req.context_current_position)
success = kv_cache.resume(
torch.cuda.current_stream().cuda_stream)
success = kv_cache.resume(self._stream.cuda_stream)
assert success
kv_cache.resize(req.prompt_len)
success = kv_cache.resize(req.prompt_len)
if not success:
raise ValueError(
f"Failed to resize capacity of KV cache for request {req.py_request_id} to {req.prompt_len} tokens for context update"
)
if self.kv_connector_manager is not None:
block_ids = self.get_cache_indices(req)
@ -1872,7 +1946,11 @@ class KVCacheManagerV2(BaseResourceManager):
for req in generation_batch:
kv_cache = self.kv_cache_map[req.py_request_id]
kv_cache.resize(kv_cache.capacity + 1)
success = kv_cache.resize(kv_cache.capacity + 1)
if not success:
raise ValueError(
f"Failed to resize capacity of KV cache for request {req.py_request_id} to {kv_cache.capacity + 1} tokens for generation update"
)
if self.kv_connector_manager is not None:
self.kv_connector_manager.build_scheduler_output(
@ -1891,6 +1969,19 @@ class KVCacheManagerV2(BaseResourceManager):
return KVCacheStatus(allocated_bytes=self.impl.get_quota(GPU_LEVEL))
def get_block_ids_per_seq(self, request_ids: List[int]) -> torch.Tensor:
block_ids_per_seq = self.get_batch_cache_indices(request_ids)
block_ids_per_seq_tensors = [
torch.tensor([
i // self.num_local_layers if i != BAD_PAGE_INDEX else i
for i in sublist
],
dtype=torch.int) for sublist in block_ids_per_seq
]
padded_tensor = torch.nn.utils.rnn.pad_sequence(
block_ids_per_seq_tensors, batch_first=True, padding_value=0)
return padded_tensor
def add_dummy_requests(
self,
request_ids: List[int],
@ -1938,15 +2029,18 @@ class KVCacheManagerV2(BaseResourceManager):
kv_cache = self._create_kv_cache(req.py_request_id,
req.lora_task_id, input_tokens)
assert kv_cache.num_committed_tokens == 0
success = kv_cache.resume(
torch.cuda.current_stream().cuda_stream)
success = kv_cache.resume(self._stream.cuda_stream)
if not success:
for r in requests:
self.free_resources(r)
self.free_resources(req)
return None
kv_cache.stop_committing()
kv_cache.resize(token_num)
success = kv_cache.resize(token_num)
if not success:
raise ValueError(
f"Failed to resize capacity of KV cache for request {req.py_request_id} to {token_num} tokens for dummy request"
)
if is_gen:
req.state = LlmRequestState.GENERATION_IN_PROGRESS
@ -1999,10 +2093,17 @@ class KVCacheManagerV2(BaseResourceManager):
else:
div_factor = 1
return [
(np.asarray(self.kv_cache_map[req_id].get_page_indices(pool_id)) //
div_factor).tolist() for req_id in request_ids
]
res = []
for req_id in request_ids:
idx_tensor = torch.as_tensor(
self.kv_cache_map[req_id].get_base_page_indices(pool_id))
res.append((torch.where(
idx_tensor != BAD_PAGE_INDEX,
idx_tensor * self.index_scales[pool_id] // div_factor,
BAD_PAGE_INDEX)).tolist())
return res
def get_cache_bytes_per_token(
self,
@ -2020,10 +2121,10 @@ class KVCacheManagerV2(BaseResourceManager):
if data_role == Role.ALL:
kv_factor = self.kv_factor
elif data_role in [
Role.KEY, Role.VALUE, Role.KEY_BLOCK_QUANT,
Role.VALUE_BLOCK_QUANT
Role.KEY, Role.VALUE, Role.KEY_BLOCK_SCALE,
Role.VALUE_BLOCK_SCALE
]:
if data_role in [Role.KEY_BLOCK_QUANT, Role.VALUE_BLOCK_QUANT]:
if data_role in [Role.KEY_BLOCK_SCALE, Role.VALUE_BLOCK_SCALE]:
assert self.dtype == DataType.NVFP4, "NVFP4 is the only supported dtype for block quant data roles"
if data_role == Role.VALUE:
assert self.kv_cache_type != CacheTypeCpp.SELFKONLY, "SELFKONLY is the only supported cache type for value data role"
@ -2055,7 +2156,7 @@ class KVCacheManagerV2(BaseResourceManager):
scaling_factor_dtype=DataType.FP8,
)
if data_role in [Role.KEY_BLOCK_QUANT, Role.VALUE_BLOCK_QUANT]:
if data_role in [Role.KEY_BLOCK_SCALE, Role.VALUE_BLOCK_SCALE]:
return quant_size_per_token
return cache_size_bytes_per_token + quant_size_per_token
@ -2180,13 +2281,21 @@ class KVCacheManagerV2(BaseResourceManager):
context_current_position])
kv_cache.stop_committing()
else:
kv_cache.resize(None, req.context_current_position)
success = kv_cache.resize(None, req.context_current_position)
if not success:
raise ValueError(
f"Failed to resize history length of KV cache for request {req.py_request_id} to {req.context_current_position} tokens at context update"
)
for req in scheduled_batch.generation_requests:
if req.py_request_id not in self.kv_cache_map:
continue
kv_cache = self.kv_cache_map[req.py_request_id]
kv_cache.resize(None, req.max_beam_num_tokens - 1)
success = kv_cache.resize(None, req.max_beam_num_tokens - 1)
if not success:
raise ValueError(
f"Failed to resize history length of KV cache for request {req.py_request_id} to {req.max_beam_num_tokens - 1} tokens at generation update"
)
def copy_batch_block_offsets(self, dst_tensor: torch.Tensor,
request_ids: List[int], beam_width: int,
@ -2197,10 +2306,10 @@ class KVCacheManagerV2(BaseResourceManager):
beam_width)
assert copy_idx.shape[0] == num_seqs
copy_batch_block_offsets_to_device(
self.host_kv_cache_block_offsets, dst_tensor, copy_idx,
self.kv_cache_type == CacheTypeCpp.SELFKONLY,
torch.cuda.current_stream().cuda_stream)
copy_batch_block_offsets_to_device(self.host_kv_cache_block_offsets,
dst_tensor, copy_idx,
self.index_scales, self.kv_offset,
self._stream.cuda_stream)
def _create_kv_cache(self, request_id: int, lora_task_id: int | None,
input_tokens: Sequence[TokenIdExt] | None):
@ -2212,8 +2321,8 @@ class KVCacheManagerV2(BaseResourceManager):
for pool_idx in range(self.num_pools):
buffer: torch.Tensor = self.host_kv_cache_block_offsets[
pool_idx, index * self.max_beam_width + i, 0]
kv_cache.set_page_index_buf(i, pool_idx,
memoryview(buffer.numpy()))
kv_cache.set_base_page_index_buf(i, pool_idx,
memoryview(buffer.numpy()))
return kv_cache

View File

@ -265,4 +265,4 @@ class KVCacheManager:
def get_aggregated_pages(
self, buffers: Iterable[BufferSlice]
) -> Iterator[AggregatedPageDesc]: ...
def clamp_max_seq_len_for_mem(self, batch_size: int) -> int: ...
def clamp_max_seq_len_for_mem(self, batch_size: int, token_num_upper_bound: int) -> int: ...

View File

@ -303,7 +303,7 @@ class KVCacheManager:
)
# @TODO: need updating when dynamic resizing is supported.
def clamp_max_seq_len_for_mem(self, batch_size: int) -> int:
def clamp_max_seq_len_for_mem(self, batch_size: int, token_num_upper_bound: int) -> int:
"Get the max possible sequence length limited by the GPU memory pools."
assert batch_size > 0
tokens_per_block = self.tokens_per_block
@ -338,14 +338,13 @@ class KVCacheManager:
assert is_enough(1)
lb = 1
ub = lb
while is_enough(ub):
lb = ub
ub *= 2
ub = div_up(token_num_upper_bound, tokens_per_block)
if is_enough(ub):
return token_num_upper_bound
while lb < ub - 1:
mid = (lb + ub) // 2
if is_enough(mid):
lb = mid
else:
ub = mid
return lb * tokens_per_block
return min(lb * tokens_per_block, token_num_upper_bound)

View File

@ -23,7 +23,11 @@ from ._utils import ItemHolderWithSharedPool, PooledFactoryBase, _unwrap, div_up
def _is_prop_supported(prop: drv.CUmemAllocationProp) -> bool:
err, handle = drv.cuMemCreate(2 << 20, prop, 0)
if err == drv.CUresult.CUDA_ERROR_NOT_PERMITTED or err == drv.CUresult.CUDA_ERROR_NOT_SUPPORTED:
if (
err == drv.CUresult.CUDA_ERROR_NOT_PERMITTED
or err == drv.CUresult.CUDA_ERROR_NOT_SUPPORTED
or err == drv.CUresult.CUDA_ERROR_INVALID_DEVICE
):
return False
elif err == drv.CUresult.CUDA_SUCCESS:
_unwrap(drv.cuMemRelease(handle))
@ -52,6 +56,10 @@ class NativePhysMemAllocator:
prop.requestedHandleTypes = drv.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
if not _is_prop_supported(prop):
prop.requestedHandleTypes = drv.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE
if not _is_prop_supported(prop):
prop.allocFlags.gpuDirectRDMACapable = 0
if not _is_prop_supported(prop):
raise ValueError("Failed to create physical memory allocation property")
self._prop = prop
self._outstanding_handles = set()

View File

@ -335,7 +335,8 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@skip_pre_blackwell
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("attn_backend", ["TRTLLM"])
def test_nvfp4_kv(self, attn_backend, torch_compile):
@parametrize_with_ids("v2_kv_cache", [True, False])
def test_nvfp4_kv(self, attn_backend, torch_compile, v2_kv_cache):
torch_compile_config = _get_default_torch_compile_config(torch_compile)
pytorch_config = dict(
torch_compile_config=torch_compile_config,
@ -344,7 +345,8 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
attn_backend=attn_backend,
disable_overlap_scheduler=torch_compile,
)
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="nvfp4")
pytorch_config["kv_cache_config"] = KvCacheConfig(
dtype="nvfp4", use_kv_cache_manager_v2=v2_kv_cache)
with LLM(f"{llm_models_root()}/Llama-3_1-8B-Instruct_fp8_kv_nvfp4",
**pytorch_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8

View File

@ -17,8 +17,9 @@ l0_b200:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_4]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_64]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[v2_kv_cache=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[v2_kv_cache=False-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[v2_kv_cache=True-attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True]

View File

@ -5,7 +5,6 @@ from typing import List
import pytest
import torch
from utils.util import getSMVersion
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.interface import (
@ -368,10 +367,6 @@ def test_attention_mla(scenario: Scenario, context_sequence_lengths: List[int],
num_generation_steps: List[int], v2_kv_cache: bool):
"""Test MLA computation for both context and generation phases"""
if v2_kv_cache and getSMVersion() != 100:
pytest.skip(
"v2_kv_cache is only supported for MLA on Blackwell architectures")
num_heads = scenario.num_heads
num_kv_heads = scenario.num_kv_heads
q_lora_rank = scenario.q_lora_rank