TensorRT-LLMs/tensorrt_llm/_torch/metadata.py
Dan Blanaru 16d2467ea8 Update TensorRT-LLM (#2755)
* Update TensorRT-LLM

---------

Co-authored-by: Denis Kayshev <topenkoff@gmail.com>
Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>

Update
2025-02-11 03:01:00 +00:00

39 lines
1.1 KiB
Python

from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
import torch
@dataclass
class KVCacheParams:
"""
Parameters for the key- value cache.
"""
# Whether to use the cache or not.
use_cache: bool
# The number of the cached tokens of each sequence
num_cached_tokens_per_seq: List[int]
# Block IDs of the each sequence
# The shape is depending on the cache type:
# - LINEAR: (1)
# - PAGED: (num_pages)
# - PER_TOKEN: (num_tokens)
# The dtype is int64.
block_ids_per_seq: Optional[List[list]] = None
# The maximum attention window size for each layer.
host_max_attention_window_sizes: Optional[torch.Tensor] = None
# The number of sink tokens for each layer.
host_sink_token_length: Optional[torch.Tensor] = None
class CacheType(Enum):
# Linear KV cache stores all the cached tokens of a sequence in a single page.
LINEAR = 0
# Paged KV cache stores the cached tokens of a sequence in multiple pages.
PAGED = 1
# Per-token KV cache stores each token's cached value separately.
PER_TOKEN = 2