TensorRT-LLMs/tensorrt_llm/_torch/metadata.py
Kaiyu Xie 3aa6b11d13
Update TensorRT-LLM (#2936)
* Update TensorRT-LLM

---------

Co-authored-by: changcui <cuichang147@gmail.com>
2025-03-18 21:25:19 +08:00

41 lines
1.2 KiB
Python

from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
import torch
@dataclass
class KVCacheParams:
"""
Parameters for the key-value cache.
"""
# Whether to use the cache or not.
use_cache: bool
# The number of the cached tokens of each sequence
num_cached_tokens_per_seq: List[int]
# Block IDs of the each sequence
# The shape is depending on the cache type:
# - LINEAR: (1)
# - PAGED: (num_pages)
# - PER_TOKEN: (num_tokens)
# The dtype is int64.
block_ids_per_seq: Optional[List[list]] = None
# The maximum attention window size for each layer.
host_max_attention_window_sizes: Optional[torch.Tensor] = None
# The number of sink tokens for each layer.
host_sink_token_length: Optional[torch.Tensor] = None
# The number of extra kv for draft tokens
num_extra_kv_tokens: Optional[List[int]] = 0
class CacheType(Enum):
# Linear KV cache stores all the cached tokens of a sequence in a single page.
LINEAR = 0
# Paged KV cache stores the cached tokens of a sequence in multiple pages.
PAGED = 1
# Per-token KV cache stores each token's cached value separately.
PER_TOKEN = 2