TensorRT-LLMs/tensorrt_llm/prompt_adapter_manager.py
2ez4bz dc52b67492
linting(python): Enable ruff on more files (wave 1/N) (#5140)
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
2025-06-14 19:19:34 +08:00

50 lines
1.5 KiB
Python

from typing import TYPE_CHECKING, Dict, List, Optional
import torch
from ._utils import str_dtype_to_torch
from .models.convert_utils import get_model_path, load_state_dict
if TYPE_CHECKING:
from .runtime import ModelConfig
class PromptAdapterManager:
def __init__(self):
self._uid_counter = 0
self._uid_to_weights: Dict[str, torch.Tensor] = {}
def load_from_ckpt(
self, model_dirs: List[str], model_config: "ModelConfig", uids: Optional[List[str]] = None
):
if uids is None:
uids = [self._generate_uid() for _ in range(len(model_dirs))]
assert len(uids) == len(model_dirs)
new_uids, new_model_dirs = [], []
for uid, model_dir in zip(uids, model_dirs):
if uid in self._uid_to_weights:
continue
new_uids.append(uid)
new_model_dirs.append(model_dir)
if len(new_uids) == 0:
return
for uid, model_dir in zip(new_uids, new_model_dirs):
state_dict = load_state_dict(get_model_path(model_dir, "adapter_model"))
self._uid_to_weights[uid] = state_dict["prompt_embeddings"].to(
str_dtype_to_torch(model_config.dtype)
)
@property
def uid_to_weights(self):
return self._uid_to_weights
def _generate_uid(self):
while str(self._uid_counter) in self._uid_to_weights:
self._uid_counter += 1
uid = str(self._uid_counter)
self._uid_counter += 1
return uid