mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 16:55:08 +08:00
[None][feat] Use new index api, add block scale support, fix max_seq_len esitmation, add flash mla support (#11334)
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
parent
59b6bee7e6
commit
361ff36784
@ -186,10 +186,10 @@ CUresult copyDeviceToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes,
|
||||
|
||||
// dst_tensor[:, :num_seqs, 0] = src_tensor[:, copy_idx]
|
||||
// dst_tensor[:, :num_seqs, 1] = dst_tensor[:, :num_seqs, 0] + 1
|
||||
template <bool COPY_V_IDX = true>
|
||||
__global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict__ srcPtr,
|
||||
SizeType32* __restrict__ dstPtr, SizeType32 const srcMaxNumSequences, SizeType32 const dstMaxNumSequences,
|
||||
SizeType32 numBlocksPerSeq, SizeType32 const* __restrict__ copyIndex)
|
||||
SizeType32 numBlocksPerSeq, SizeType32 const* __restrict__ copyIndex, SizeType32 const* __restrict__ indexScales,
|
||||
SizeType32 const* __restrict__ kvOffset)
|
||||
{
|
||||
constexpr uint32_t kvFactor = 2;
|
||||
constexpr auto elemPerAccess = sizeof(PackedInt) / sizeof(SizeType32);
|
||||
@ -224,19 +224,12 @@ __global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(copyBlocknbBufs - 1) : "memory");
|
||||
if (srcIdx < srcIdxEnd)
|
||||
{
|
||||
dstK = src;
|
||||
if (COPY_V_IDX)
|
||||
{
|
||||
dstV = src;
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < elemPerAccess; j++)
|
||||
{
|
||||
auto const val = src.unpacked[j];
|
||||
dstV.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (val + 1);
|
||||
}
|
||||
for (uint32_t j = 0; j < elemPerAccess; j++)
|
||||
{
|
||||
auto const val = src.unpacked[j];
|
||||
dstK.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (indexScales[poolIdx] * val);
|
||||
dstV.unpacked[j] = (val == BAD_PAGE_INDEX) ? val : (indexScales[poolIdx] * val + kvOffset[poolIdx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -256,8 +249,8 @@ __global__ void copyBatchBlockOffsetsToDeviceKernel(SizeType32 const* __restrict
|
||||
}
|
||||
|
||||
// Host-side launcher
|
||||
void copyBatchBlockOffsetsToDevice(
|
||||
ITensor const& input, ITensor& output, ITensor const& copyIndex, bool copyVIdx, CUstream stream) noexcept
|
||||
void copyBatchBlockOffsetsToDevice(ITensor const& input, ITensor& output, ITensor const& copyIndex,
|
||||
ITensor const& indexScales, ITensor const& kvOffset, CUstream stream) noexcept
|
||||
{
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
@ -265,6 +258,8 @@ void copyBatchBlockOffsetsToDevice(
|
||||
auto* dstPtr = bufferCast<tk::KVCacheIndex::UnderlyingType>(
|
||||
output); // [numPools, maxNumSequences, kvFactor, numBlocksPerSeq]
|
||||
auto const* copyIndexPtr = bufferCast<SizeType32 const>(copyIndex);
|
||||
auto const* indexScalesPtr = bufferCast<SizeType32 const>(indexScales);
|
||||
auto const* kvOffsetPtr = bufferCast<SizeType32 const>(kvOffset);
|
||||
auto const& srcShape = input.getShape();
|
||||
auto const& dstShape = output.getShape();
|
||||
auto const& copyIndexShape = copyIndex.getShape();
|
||||
@ -290,16 +285,8 @@ void copyBatchBlockOffsetsToDevice(
|
||||
dim3 gridDim(numPools, numSeqs, 1);
|
||||
dim3 blockDim(copyBlockCtaSize);
|
||||
|
||||
if (copyVIdx)
|
||||
{
|
||||
copyBatchBlockOffsetsToDeviceKernel<true><<<gridDim, blockDim, 0, stream>>>(
|
||||
srcPtr, dstPtr, srcMaxNumSequences, dstMaxNumSequences, numBlocksPerSeq, copyIndexPtr);
|
||||
}
|
||||
else
|
||||
{
|
||||
copyBatchBlockOffsetsToDeviceKernel<false><<<gridDim, blockDim, 0, stream>>>(
|
||||
srcPtr, dstPtr, srcMaxNumSequences, dstMaxNumSequences, numBlocksPerSeq, copyIndexPtr);
|
||||
}
|
||||
copyBatchBlockOffsetsToDeviceKernel<<<gridDim, blockDim, 0, stream>>>(srcPtr, dstPtr, srcMaxNumSequences,
|
||||
dstMaxNumSequences, numBlocksPerSeq, copyIndexPtr, indexScalesPtr, kvOffsetPtr);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||
|
||||
@ -95,7 +95,7 @@ CUresult copyDeviceToHost(
|
||||
CUresult copyDeviceToDevice(
|
||||
std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, CUstream stream) noexcept;
|
||||
|
||||
void copyBatchBlockOffsetsToDevice(
|
||||
ITensor const& input, ITensor& output, ITensor const& copyIndex, bool copyVIdx, CUstream stream) noexcept;
|
||||
void copyBatchBlockOffsetsToDevice(ITensor const& input, ITensor& output, ITensor const& copyIndex,
|
||||
ITensor const& indexScales, ITensor const& kvOffset, CUstream stream) noexcept;
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||
|
||||
@ -131,19 +131,24 @@ void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module)
|
||||
|
||||
module.def(
|
||||
"copy_batch_block_offsets_to_device",
|
||||
[](at::Tensor input, at::Tensor output, at::Tensor copyIndex, bool copyVIdx, uintptr_t stream)
|
||||
[](at::Tensor input, at::Tensor output, at::Tensor copyIndex, at::Tensor indexScales, at::Tensor kvOffset,
|
||||
uintptr_t stream)
|
||||
{
|
||||
auto _input = from_torch(input);
|
||||
auto _output = from_torch(output);
|
||||
auto _copyIndex = from_torch(copyIndex);
|
||||
auto _indexScales = from_torch(indexScales);
|
||||
auto _kvOffset = from_torch(kvOffset);
|
||||
TLLM_CHECK_WITH_INFO(_input.has_value(), "Invalid input tensor.");
|
||||
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
|
||||
TLLM_CHECK_WITH_INFO(_copyIndex.has_value(), "Invalid copy index tensor.");
|
||||
copyBatchBlockOffsetsToDevice(*(_input.value()), *(_output.value()), *(_copyIndex.value()), copyVIdx,
|
||||
reinterpret_cast<CUstream>(stream));
|
||||
TLLM_CHECK_WITH_INFO(_indexScales.has_value(), "Invalid index scales tensor.");
|
||||
TLLM_CHECK_WITH_INFO(_kvOffset.has_value(), "Invalid kv offset tensor.");
|
||||
copyBatchBlockOffsetsToDevice(*(_input.value()), *(_output.value()), *(_copyIndex.value()),
|
||||
*(_indexScales.value()), *(_kvOffset.value()), reinterpret_cast<CUstream>(stream));
|
||||
},
|
||||
nb::arg("input"), nb::arg("output"), nb::arg("copy_index"), nb::arg("copy_v_idx"), nb::arg("stream"),
|
||||
nb::call_guard<nb::gil_scoped_release>(), "Copy batch block indices to device");
|
||||
nb::arg("input"), nb::arg("output"), nb::arg("copy_index"), nb::arg("index_scales"), nb::arg("kv_offset"),
|
||||
nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(), "Copy batch block indices to device");
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||
|
||||
@ -676,10 +676,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
reverse: bool = False):
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
curr_max_num_tokens = min(
|
||||
kv_cache_manager.get_num_available_tokens(
|
||||
max_num_draft_tokens=self.original_max_draft_len),
|
||||
self.max_num_tokens, self.batch_size * (self.max_seq_len - 1))
|
||||
token_num_upper_bound = min(self.max_num_tokens,
|
||||
self.batch_size * (self.max_seq_len - 1))
|
||||
curr_max_num_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
token_num_upper_bound=token_num_upper_bound,
|
||||
max_num_draft_tokens=self.original_max_draft_len)
|
||||
max_batch_size = min(
|
||||
self.batch_size,
|
||||
curr_max_num_tokens // (1 + self.runtime_draft_len))
|
||||
@ -728,10 +729,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
logger.info("Running autotuner warmup...")
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
curr_max_num_tokens = min(
|
||||
kv_cache_manager.get_num_available_tokens(
|
||||
max_num_draft_tokens=self.original_max_draft_len),
|
||||
self.max_num_tokens, self.batch_size * (self.max_seq_len - 1))
|
||||
token_num_upper_bound = min(self.max_num_tokens,
|
||||
self.batch_size * (self.max_seq_len - 1))
|
||||
curr_max_num_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
token_num_upper_bound=token_num_upper_bound,
|
||||
max_num_draft_tokens=self.original_max_draft_len)
|
||||
|
||||
cache_path = os.environ.get("TLLM_AUTOTUNER_CACHE_PATH", None)
|
||||
with self.no_cuda_graph(), autotune(cache_path=cache_path):
|
||||
@ -962,6 +964,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
token_num_upper_bound=num_tokens,
|
||||
max_num_draft_tokens=self.runtime_draft_len)
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks()
|
||||
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
|
||||
@ -1104,19 +1107,24 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if requests is None:
|
||||
return None
|
||||
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
batch_size=batch_size, max_num_draft_tokens=draft_len)
|
||||
|
||||
# Also consider draft KV cache capacity when it exists
|
||||
if draft_kv_cache_manager is not None:
|
||||
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
|
||||
batch_size=batch_size, max_num_draft_tokens=draft_len)
|
||||
available_tokens = min(available_tokens, draft_available_tokens)
|
||||
|
||||
# Add one dummy request with the maximum possible sequence length.
|
||||
max_seq_len = min(
|
||||
self.max_seq_len if max_seq_len is None else max_seq_len,
|
||||
kv_cache_manager.max_seq_len)
|
||||
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
token_num_upper_bound=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
max_num_draft_tokens=draft_len)
|
||||
|
||||
# Also consider draft KV cache capacity when it exists
|
||||
if draft_kv_cache_manager is not None:
|
||||
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
|
||||
batch_size=batch_size,
|
||||
token_num_upper_bound=max_seq_len,
|
||||
max_num_draft_tokens=draft_len)
|
||||
available_tokens = min(available_tokens, draft_available_tokens)
|
||||
|
||||
token_num = max(
|
||||
1,
|
||||
min(
|
||||
|
||||
@ -7,7 +7,6 @@ from collections import OrderedDict, defaultdict, deque
|
||||
from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence,
|
||||
Set, Tuple, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mpi4py import MPI
|
||||
|
||||
@ -23,7 +22,6 @@ from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
|
||||
from tensorrt_llm.math_utils import ceil_div
|
||||
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
|
||||
from tensorrt_llm.runtime.kv_cache_manager_v2 import (DEFAULT_BEAM_INDEX,
|
||||
AttentionLayerConfig,
|
||||
@ -36,7 +34,8 @@ from tensorrt_llm.runtime.kv_cache_manager_v2 import \
|
||||
KVCacheManagerConfig as KVCacheManagerConfigPy
|
||||
from tensorrt_llm.runtime.kv_cache_manager_v2 import (LayerId, TokenIdExt,
|
||||
_KVCache)
|
||||
from tensorrt_llm.runtime.kv_cache_manager_v2._common import GPU_LEVEL
|
||||
from tensorrt_llm.runtime.kv_cache_manager_v2._common import (BAD_PAGE_INDEX,
|
||||
GPU_LEVEL)
|
||||
from tensorrt_llm.runtime.kv_cache_manager_v2._config import DataRole
|
||||
from tensorrt_llm.runtime.kv_cache_manager_v2._utils import (exact_div,
|
||||
typed_range)
|
||||
@ -82,8 +81,8 @@ class ResourceManagerType(enum.Enum):
|
||||
class Role:
|
||||
KEY = DataRole("key")
|
||||
VALUE = DataRole("value")
|
||||
KEY_BLOCK_QUANT = DataRole("key_block_quant")
|
||||
VALUE_BLOCK_QUANT = DataRole("value_block_quant")
|
||||
KEY_BLOCK_SCALE = DataRole("key_block_scale")
|
||||
VALUE_BLOCK_SCALE = DataRole("value_block_scale")
|
||||
ALL = DataRole("all")
|
||||
|
||||
|
||||
@ -1009,10 +1008,13 @@ class KVCacheManager(BaseResourceManager):
|
||||
return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block
|
||||
|
||||
def get_num_available_tokens(self,
|
||||
token_num_upper_bound: int,
|
||||
max_num_draft_tokens: int = 0,
|
||||
**kwargs) -> int:
|
||||
return (self.get_num_free_blocks() * self.tokens_per_block -
|
||||
self.num_extra_kv_tokens - max_num_draft_tokens)
|
||||
return min(
|
||||
token_num_upper_bound,
|
||||
self.get_num_free_blocks() * self.tokens_per_block -
|
||||
self.num_extra_kv_tokens - max_num_draft_tokens)
|
||||
|
||||
def get_buffers(self,
|
||||
layer_idx: int,
|
||||
@ -1524,12 +1526,12 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
max_beam_width: int = 1,
|
||||
is_draft: bool = False,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.mapping = mapping
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.dtype != DataType.NVFP4, "NVFP4 is not supported for KVCacheManagerV2"
|
||||
assert kv_connector_manager is None, "kv_connector_manager is not supported for KVCacheManagerV2"
|
||||
assert max_beam_width == 1, "max_beam_width must be 1 for KVCacheManagerV2"
|
||||
|
||||
@ -1565,6 +1567,10 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
|
||||
assert self.event_buffer_max_size == 0, "event_buffer_max_size must be 0"
|
||||
|
||||
self._stream = execution_stream if execution_stream is not None else torch.cuda.Stream(
|
||||
)
|
||||
logger.info(f"[KVCacheManager] execution_stream: {self._stream}")
|
||||
|
||||
# Determine max_attention_window_vec
|
||||
if kv_cache_config.max_attention_window is not None:
|
||||
|
||||
@ -1626,10 +1632,9 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
|
||||
if kv_cache_config.max_tokens is not None:
|
||||
quota = int(
|
||||
ceil_div(
|
||||
kv_cache_config.max_tokens *
|
||||
self.get_cache_bytes_per_token(),
|
||||
kv_cache_config.max_util_for_resume))
|
||||
math.ceil(kv_cache_config.max_tokens *
|
||||
self.get_cache_bytes_per_token() /
|
||||
kv_cache_config.max_util_for_resume))
|
||||
if kv_cache_config.free_gpu_memory_fraction is not None:
|
||||
logger.warning(
|
||||
f"Both max_tokens and free_gpu_memory_fraction are set to {kv_cache_config.max_tokens} and {kv_cache_config.free_gpu_memory_fraction}, the smaller value will be used."
|
||||
@ -1658,6 +1663,11 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
buffer_type = [Role.KEY]
|
||||
if kv_cache_type != CacheTypeCpp.SELFKONLY:
|
||||
buffer_type.append(Role.VALUE)
|
||||
if kv_cache_config.dtype == "nvfp4":
|
||||
assert head_dim % 2 == 0, "head_dim must be divisible by 2 for nvfp4 kv cache"
|
||||
buffer_type.append(Role.KEY_BLOCK_SCALE)
|
||||
if kv_cache_type != CacheTypeCpp.SELFKONLY:
|
||||
buffer_type.append(Role.VALUE_BLOCK_SCALE)
|
||||
|
||||
config = KVCacheManagerConfigPy(
|
||||
tokens_per_block=tokens_per_block,
|
||||
@ -1701,20 +1711,54 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
|
||||
if kv_cache_config.dtype == "nvfp4":
|
||||
self.kv_cache_pool_pointers = torch.stack([
|
||||
self.kv_cache_pool_pointers,
|
||||
torch.tensor([[
|
||||
self.impl.get_mem_pool_base_address(
|
||||
self.impl.layer_grouping[pool_id][0],
|
||||
Role.KEY_BLOCK_SCALE), 0
|
||||
] for pool_id in range(self.num_pools)],
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
kv_cache_pool_mapping_list = []
|
||||
for layer_id in typed_range(LayerId(self.num_local_layers)):
|
||||
layer_group_id = self.impl.get_layer_group_id(layer_id)
|
||||
if kv_cache_config.dtype != "nvfp4":
|
||||
addr_offset = self.impl.get_mem_pool_base_address(
|
||||
layer_id, Role.KEY) - int(
|
||||
self.kv_cache_pool_pointers[layer_group_id][0])
|
||||
else:
|
||||
addr_offset = self.impl.get_mem_pool_base_address(
|
||||
layer_id, Role.KEY) - int(
|
||||
self.kv_cache_pool_pointers[layer_group_id][0][0])
|
||||
block_scale_addr_offset = self.impl.get_mem_pool_base_address(
|
||||
layer_id, Role.KEY_BLOCK_SCALE) - int(
|
||||
self.kv_cache_pool_pointers[layer_group_id][0][1])
|
||||
block_scale_offset = exact_div(
|
||||
block_scale_addr_offset,
|
||||
self.get_cache_bytes_per_token(
|
||||
layer_id, Role.KEY_BLOCK_SCALE) * self.kv_factor *
|
||||
self.tokens_per_block)
|
||||
offset = exact_div(
|
||||
self.impl.get_mem_pool_base_address(layer_id, Role.KEY) -
|
||||
int(self.kv_cache_pool_pointers[layer_group_id][0]),
|
||||
addr_offset,
|
||||
self.get_cache_bytes_per_token(layer_id, Role.KEY) *
|
||||
self.kv_factor * self.tokens_per_block)
|
||||
|
||||
if kv_cache_config.dtype == "nvfp4":
|
||||
assert block_scale_offset == offset, "Block scale offset and offset should be the same"
|
||||
|
||||
kv_cache_pool_mapping_list.append([layer_group_id, offset])
|
||||
|
||||
self.kv_cache_pool_mapping = torch.tensor(kv_cache_pool_mapping_list,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
|
||||
# Pad max_blocks_per_seq to next multiple of 4 for copy_block_offsets kernel
|
||||
self.max_blocks_per_seq = (max_seq_len + tokens_per_block -
|
||||
1) // tokens_per_block
|
||||
@ -1723,7 +1767,8 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
|
||||
self.kv_cache_map: dict[int, _KVCache] = {}
|
||||
|
||||
max_num_tokens = self.get_num_available_tokens()
|
||||
max_num_tokens = self.get_num_available_tokens(
|
||||
token_num_upper_bound=max_seq_len)
|
||||
|
||||
if max_seq_len > max_num_tokens:
|
||||
logger.warning(
|
||||
@ -1736,6 +1781,25 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
|
||||
# Plus 1 for cuda graph dummy request
|
||||
self.index_mapper = IndexMapper(max_batch_size + 1, max_beam_width)
|
||||
self.index_scales = torch.empty(self.num_pools,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True,
|
||||
device='cpu')
|
||||
self.kv_offset = torch.empty(self.num_pools,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True,
|
||||
device='cpu')
|
||||
for pool_id in range(self.num_pools):
|
||||
layer_id = self.impl.layer_grouping[pool_id][0]
|
||||
self.index_scales[pool_id] = self.impl.get_page_index_scale(
|
||||
layer_id, Role.KEY)
|
||||
if self.kv_cache_type != CacheTypeCpp.SELFKONLY:
|
||||
self.kv_offset[pool_id] = exact_div(
|
||||
self.impl.get_mem_pool_base_address(layer_id, Role.VALUE) -
|
||||
self.impl.get_mem_pool_base_address(layer_id, Role.KEY),
|
||||
self.impl.get_page_stride(layer_id, Role.KEY))
|
||||
else:
|
||||
self.kv_offset[pool_id] = 0
|
||||
|
||||
self.host_kv_cache_block_offsets = torch.empty(
|
||||
self.num_pools,
|
||||
@ -1770,6 +1834,12 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
assert kv_layout in ["NHD",
|
||||
"HND"], f"Unsupported kv_layout: {kv_layout}"
|
||||
|
||||
element_per_container = 1
|
||||
dtype = self.dtype
|
||||
if dtype == DataType.NVFP4:
|
||||
element_per_container = 2
|
||||
dtype = torch.int8
|
||||
|
||||
if kv_layout == "NHD":
|
||||
shape = [
|
||||
self.impl.get_page_index_upper_bound(layer_offset, Role.KEY) //
|
||||
@ -1777,7 +1847,7 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
self.kv_factor,
|
||||
self.tokens_per_block,
|
||||
self.num_kv_heads_per_layer[layer_offset],
|
||||
self.head_dim,
|
||||
self.head_dim // element_per_container,
|
||||
]
|
||||
else:
|
||||
shape = [
|
||||
@ -1786,27 +1856,28 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
self.kv_factor,
|
||||
self.num_kv_heads_per_layer[layer_offset],
|
||||
self.tokens_per_block,
|
||||
self.head_dim,
|
||||
self.head_dim // element_per_container,
|
||||
]
|
||||
|
||||
return convert_to_torch_tensor(
|
||||
TensorWrapper(
|
||||
addr_key,
|
||||
self.dtype,
|
||||
shape,
|
||||
))
|
||||
return convert_to_torch_tensor(TensorWrapper(
|
||||
addr_key,
|
||||
dtype,
|
||||
shape,
|
||||
))
|
||||
|
||||
def get_num_available_tokens(self,
|
||||
*,
|
||||
token_num_upper_bound: int,
|
||||
batch_size: int = 1,
|
||||
max_num_draft_tokens: int = 0) -> int:
|
||||
if max_num_draft_tokens > 0:
|
||||
raise ValueError(
|
||||
"max_num_draft_tokens is not supported for KVCacheManagerV2")
|
||||
return int(
|
||||
self.impl.clamp_max_seq_len_for_mem(batch_size) *
|
||||
self.kv_cache_manager_py_config.max_util_for_resume
|
||||
) - self.num_extra_kv_tokens - max_num_draft_tokens
|
||||
extra_tokens = self.num_extra_kv_tokens + max_num_draft_tokens
|
||||
# Token num upper bound is the maximum number of tokens that can be allocated in the kv cache manager.
|
||||
# We need to add extra tokens to the token num upper bound to account for the extra tokens.
|
||||
return self.impl.clamp_max_seq_len_for_mem(
|
||||
batch_size, token_num_upper_bound + extra_tokens) - extra_tokens
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
# NOTE This method is used to get the number of blocks in the primary pool not the FREE blocks.
|
||||
@ -1859,11 +1930,14 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
chunk_size,
|
||||
req.prompt_len - req.context_current_position)
|
||||
|
||||
success = kv_cache.resume(
|
||||
torch.cuda.current_stream().cuda_stream)
|
||||
success = kv_cache.resume(self._stream.cuda_stream)
|
||||
assert success
|
||||
|
||||
kv_cache.resize(req.prompt_len)
|
||||
success = kv_cache.resize(req.prompt_len)
|
||||
if not success:
|
||||
raise ValueError(
|
||||
f"Failed to resize capacity of KV cache for request {req.py_request_id} to {req.prompt_len} tokens for context update"
|
||||
)
|
||||
|
||||
if self.kv_connector_manager is not None:
|
||||
block_ids = self.get_cache_indices(req)
|
||||
@ -1872,7 +1946,11 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
|
||||
for req in generation_batch:
|
||||
kv_cache = self.kv_cache_map[req.py_request_id]
|
||||
kv_cache.resize(kv_cache.capacity + 1)
|
||||
success = kv_cache.resize(kv_cache.capacity + 1)
|
||||
if not success:
|
||||
raise ValueError(
|
||||
f"Failed to resize capacity of KV cache for request {req.py_request_id} to {kv_cache.capacity + 1} tokens for generation update"
|
||||
)
|
||||
|
||||
if self.kv_connector_manager is not None:
|
||||
self.kv_connector_manager.build_scheduler_output(
|
||||
@ -1891,6 +1969,19 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
|
||||
return KVCacheStatus(allocated_bytes=self.impl.get_quota(GPU_LEVEL))
|
||||
|
||||
def get_block_ids_per_seq(self, request_ids: List[int]) -> torch.Tensor:
|
||||
block_ids_per_seq = self.get_batch_cache_indices(request_ids)
|
||||
block_ids_per_seq_tensors = [
|
||||
torch.tensor([
|
||||
i // self.num_local_layers if i != BAD_PAGE_INDEX else i
|
||||
for i in sublist
|
||||
],
|
||||
dtype=torch.int) for sublist in block_ids_per_seq
|
||||
]
|
||||
padded_tensor = torch.nn.utils.rnn.pad_sequence(
|
||||
block_ids_per_seq_tensors, batch_first=True, padding_value=0)
|
||||
return padded_tensor
|
||||
|
||||
def add_dummy_requests(
|
||||
self,
|
||||
request_ids: List[int],
|
||||
@ -1938,15 +2029,18 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
kv_cache = self._create_kv_cache(req.py_request_id,
|
||||
req.lora_task_id, input_tokens)
|
||||
assert kv_cache.num_committed_tokens == 0
|
||||
success = kv_cache.resume(
|
||||
torch.cuda.current_stream().cuda_stream)
|
||||
success = kv_cache.resume(self._stream.cuda_stream)
|
||||
if not success:
|
||||
for r in requests:
|
||||
self.free_resources(r)
|
||||
self.free_resources(req)
|
||||
return None
|
||||
kv_cache.stop_committing()
|
||||
kv_cache.resize(token_num)
|
||||
success = kv_cache.resize(token_num)
|
||||
if not success:
|
||||
raise ValueError(
|
||||
f"Failed to resize capacity of KV cache for request {req.py_request_id} to {token_num} tokens for dummy request"
|
||||
)
|
||||
|
||||
if is_gen:
|
||||
req.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
@ -1999,10 +2093,17 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
else:
|
||||
div_factor = 1
|
||||
|
||||
return [
|
||||
(np.asarray(self.kv_cache_map[req_id].get_page_indices(pool_id)) //
|
||||
div_factor).tolist() for req_id in request_ids
|
||||
]
|
||||
res = []
|
||||
|
||||
for req_id in request_ids:
|
||||
idx_tensor = torch.as_tensor(
|
||||
self.kv_cache_map[req_id].get_base_page_indices(pool_id))
|
||||
res.append((torch.where(
|
||||
idx_tensor != BAD_PAGE_INDEX,
|
||||
idx_tensor * self.index_scales[pool_id] // div_factor,
|
||||
BAD_PAGE_INDEX)).tolist())
|
||||
|
||||
return res
|
||||
|
||||
def get_cache_bytes_per_token(
|
||||
self,
|
||||
@ -2020,10 +2121,10 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
if data_role == Role.ALL:
|
||||
kv_factor = self.kv_factor
|
||||
elif data_role in [
|
||||
Role.KEY, Role.VALUE, Role.KEY_BLOCK_QUANT,
|
||||
Role.VALUE_BLOCK_QUANT
|
||||
Role.KEY, Role.VALUE, Role.KEY_BLOCK_SCALE,
|
||||
Role.VALUE_BLOCK_SCALE
|
||||
]:
|
||||
if data_role in [Role.KEY_BLOCK_QUANT, Role.VALUE_BLOCK_QUANT]:
|
||||
if data_role in [Role.KEY_BLOCK_SCALE, Role.VALUE_BLOCK_SCALE]:
|
||||
assert self.dtype == DataType.NVFP4, "NVFP4 is the only supported dtype for block quant data roles"
|
||||
if data_role == Role.VALUE:
|
||||
assert self.kv_cache_type != CacheTypeCpp.SELFKONLY, "SELFKONLY is the only supported cache type for value data role"
|
||||
@ -2055,7 +2156,7 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
scaling_factor_dtype=DataType.FP8,
|
||||
)
|
||||
|
||||
if data_role in [Role.KEY_BLOCK_QUANT, Role.VALUE_BLOCK_QUANT]:
|
||||
if data_role in [Role.KEY_BLOCK_SCALE, Role.VALUE_BLOCK_SCALE]:
|
||||
return quant_size_per_token
|
||||
|
||||
return cache_size_bytes_per_token + quant_size_per_token
|
||||
@ -2180,13 +2281,21 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
context_current_position])
|
||||
kv_cache.stop_committing()
|
||||
else:
|
||||
kv_cache.resize(None, req.context_current_position)
|
||||
success = kv_cache.resize(None, req.context_current_position)
|
||||
if not success:
|
||||
raise ValueError(
|
||||
f"Failed to resize history length of KV cache for request {req.py_request_id} to {req.context_current_position} tokens at context update"
|
||||
)
|
||||
|
||||
for req in scheduled_batch.generation_requests:
|
||||
if req.py_request_id not in self.kv_cache_map:
|
||||
continue
|
||||
kv_cache = self.kv_cache_map[req.py_request_id]
|
||||
kv_cache.resize(None, req.max_beam_num_tokens - 1)
|
||||
success = kv_cache.resize(None, req.max_beam_num_tokens - 1)
|
||||
if not success:
|
||||
raise ValueError(
|
||||
f"Failed to resize history length of KV cache for request {req.py_request_id} to {req.max_beam_num_tokens - 1} tokens at generation update"
|
||||
)
|
||||
|
||||
def copy_batch_block_offsets(self, dst_tensor: torch.Tensor,
|
||||
request_ids: List[int], beam_width: int,
|
||||
@ -2197,10 +2306,10 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
beam_width)
|
||||
assert copy_idx.shape[0] == num_seqs
|
||||
|
||||
copy_batch_block_offsets_to_device(
|
||||
self.host_kv_cache_block_offsets, dst_tensor, copy_idx,
|
||||
self.kv_cache_type == CacheTypeCpp.SELFKONLY,
|
||||
torch.cuda.current_stream().cuda_stream)
|
||||
copy_batch_block_offsets_to_device(self.host_kv_cache_block_offsets,
|
||||
dst_tensor, copy_idx,
|
||||
self.index_scales, self.kv_offset,
|
||||
self._stream.cuda_stream)
|
||||
|
||||
def _create_kv_cache(self, request_id: int, lora_task_id: int | None,
|
||||
input_tokens: Sequence[TokenIdExt] | None):
|
||||
@ -2212,8 +2321,8 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
for pool_idx in range(self.num_pools):
|
||||
buffer: torch.Tensor = self.host_kv_cache_block_offsets[
|
||||
pool_idx, index * self.max_beam_width + i, 0]
|
||||
kv_cache.set_page_index_buf(i, pool_idx,
|
||||
memoryview(buffer.numpy()))
|
||||
kv_cache.set_base_page_index_buf(i, pool_idx,
|
||||
memoryview(buffer.numpy()))
|
||||
return kv_cache
|
||||
|
||||
|
||||
|
||||
@ -265,4 +265,4 @@ class KVCacheManager:
|
||||
def get_aggregated_pages(
|
||||
self, buffers: Iterable[BufferSlice]
|
||||
) -> Iterator[AggregatedPageDesc]: ...
|
||||
def clamp_max_seq_len_for_mem(self, batch_size: int) -> int: ...
|
||||
def clamp_max_seq_len_for_mem(self, batch_size: int, token_num_upper_bound: int) -> int: ...
|
||||
|
||||
@ -303,7 +303,7 @@ class KVCacheManager:
|
||||
)
|
||||
|
||||
# @TODO: need updating when dynamic resizing is supported.
|
||||
def clamp_max_seq_len_for_mem(self, batch_size: int) -> int:
|
||||
def clamp_max_seq_len_for_mem(self, batch_size: int, token_num_upper_bound: int) -> int:
|
||||
"Get the max possible sequence length limited by the GPU memory pools."
|
||||
assert batch_size > 0
|
||||
tokens_per_block = self.tokens_per_block
|
||||
@ -338,14 +338,13 @@ class KVCacheManager:
|
||||
|
||||
assert is_enough(1)
|
||||
lb = 1
|
||||
ub = lb
|
||||
while is_enough(ub):
|
||||
lb = ub
|
||||
ub *= 2
|
||||
ub = div_up(token_num_upper_bound, tokens_per_block)
|
||||
if is_enough(ub):
|
||||
return token_num_upper_bound
|
||||
while lb < ub - 1:
|
||||
mid = (lb + ub) // 2
|
||||
if is_enough(mid):
|
||||
lb = mid
|
||||
else:
|
||||
ub = mid
|
||||
return lb * tokens_per_block
|
||||
return min(lb * tokens_per_block, token_num_upper_bound)
|
||||
|
||||
@ -23,7 +23,11 @@ from ._utils import ItemHolderWithSharedPool, PooledFactoryBase, _unwrap, div_up
|
||||
|
||||
def _is_prop_supported(prop: drv.CUmemAllocationProp) -> bool:
|
||||
err, handle = drv.cuMemCreate(2 << 20, prop, 0)
|
||||
if err == drv.CUresult.CUDA_ERROR_NOT_PERMITTED or err == drv.CUresult.CUDA_ERROR_NOT_SUPPORTED:
|
||||
if (
|
||||
err == drv.CUresult.CUDA_ERROR_NOT_PERMITTED
|
||||
or err == drv.CUresult.CUDA_ERROR_NOT_SUPPORTED
|
||||
or err == drv.CUresult.CUDA_ERROR_INVALID_DEVICE
|
||||
):
|
||||
return False
|
||||
elif err == drv.CUresult.CUDA_SUCCESS:
|
||||
_unwrap(drv.cuMemRelease(handle))
|
||||
@ -52,6 +56,10 @@ class NativePhysMemAllocator:
|
||||
prop.requestedHandleTypes = drv.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
|
||||
if not _is_prop_supported(prop):
|
||||
prop.requestedHandleTypes = drv.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_NONE
|
||||
if not _is_prop_supported(prop):
|
||||
prop.allocFlags.gpuDirectRDMACapable = 0
|
||||
if not _is_prop_supported(prop):
|
||||
raise ValueError("Failed to create physical memory allocation property")
|
||||
self._prop = prop
|
||||
self._outstanding_handles = set()
|
||||
|
||||
|
||||
@ -335,7 +335,8 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
@skip_pre_blackwell
|
||||
@parametrize_with_ids("torch_compile", [False, True])
|
||||
@parametrize_with_ids("attn_backend", ["TRTLLM"])
|
||||
def test_nvfp4_kv(self, attn_backend, torch_compile):
|
||||
@parametrize_with_ids("v2_kv_cache", [True, False])
|
||||
def test_nvfp4_kv(self, attn_backend, torch_compile, v2_kv_cache):
|
||||
torch_compile_config = _get_default_torch_compile_config(torch_compile)
|
||||
pytorch_config = dict(
|
||||
torch_compile_config=torch_compile_config,
|
||||
@ -344,7 +345,8 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=torch_compile,
|
||||
)
|
||||
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="nvfp4")
|
||||
pytorch_config["kv_cache_config"] = KvCacheConfig(
|
||||
dtype="nvfp4", use_kv_cache_manager_v2=v2_kv_cache)
|
||||
with LLM(f"{llm_models_root()}/Llama-3_1-8B-Instruct_fp8_kv_nvfp4",
|
||||
**pytorch_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
|
||||
@ -17,8 +17,9 @@ l0_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_4]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4_streaming[stream_interval_64]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[attn_backend=TRTLLM-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[v2_kv_cache=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[v2_kv_cache=False-attn_backend=TRTLLM-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_nvfp4_kv[v2_kv_cache=True-attn_backend=TRTLLM-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True]
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.util import getSMVersion
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.interface import (
|
||||
@ -368,10 +367,6 @@ def test_attention_mla(scenario: Scenario, context_sequence_lengths: List[int],
|
||||
num_generation_steps: List[int], v2_kv_cache: bool):
|
||||
"""Test MLA computation for both context and generation phases"""
|
||||
|
||||
if v2_kv_cache and getSMVersion() != 100:
|
||||
pytest.skip(
|
||||
"v2_kv_cache is only supported for MLA on Blackwell architectures")
|
||||
|
||||
num_heads = scenario.num_heads
|
||||
num_kv_heads = scenario.num_kv_heads
|
||||
q_lora_rank = scenario.q_lora_rank
|
||||
|
||||
Loading…
Reference in New Issue
Block a user