TensorRT-LLMs/tensorrt_llm/executor/request.py
shaharmor98 5fff8f0935
Add running E2E LoRA flow (#3648)
* add passing E2E LoRA flow

Signed-off-by: Shahar Mor <smor@nvidia.com>

* add experimental feature

Signed-off-by: Shahar Mor <smor@nvidia.com>

* fix llma_args definition

Signed-off-by: Shahar Mor <smor@nvidia.com>

* decreased manually size of max loras to address OOM

Signed-off-by: Shahar Mor <smor@nvidia.com>

---------

Signed-off-by: Shahar Mor <smor@nvidia.com>
2025-04-23 11:19:41 +08:00

123 lines
3.6 KiB
Python

import os
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import torch
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.llm_utils import KvCacheRetentionConfig
from ..sampling_params import SamplingParams
from .postproc_worker import PostprocParams
__all__ = [
"LoRARequest",
"PromptAdapterRequest",
"GenerationRequest",
]
@dataclass(slots=True)
class LoRARequest:
""" Request for a LoRA adapter. """
lora_name: str
lora_int_id: int
lora_path: str = ""
def __post_init__(self):
if self.lora_path is not None and not os.path.exists(self.lora_path):
raise ValueError(f"lora_path ({self.lora_path}) does not exist.")
@property
def adapter_id(self):
return self.lora_int_id
@property
def name(self):
return self.lora_name
@property
def path(self):
return self.lora_path
@dataclass(slots=True)
class PromptAdapterRequest:
"""
Request for a Prompt adapter.
"""
prompt_adapter_name: str
prompt_adapter_id: int
prompt_adapter_local_path: str = ""
def __post_init__(self):
if not os.path.exists(self.prompt_adapter_local_path):
raise RuntimeError(
f"prompt_adapter_local_path ({self.prompt_adapter_local_path}) does not exist."
)
@property
def adapter_id(self):
return self.prompt_adapter_id
@property
def name(self):
return self.prompt_adapter_name
@property
def local_path(self):
return self.prompt_adapter_local_path
class GenerationRequest:
def __init__(
self,
prompt_token_ids: Union[torch.Tensor, np.ndarray,
Union[List[int], List[List[int]]]],
sampling_params: SamplingParams,
query_token_ids: Optional[Union[torch.Tensor, np.ndarray, list]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
streaming: bool = False,
prompt_tuning_config: Optional[list] = None,
mrope_config: Optional[dict] = None,
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
postproc_params: Optional[PostprocParams] = None,
):
if isinstance(prompt_token_ids, list):
self.prompt_token_ids = prompt_token_ids
self.query_token_ids = query_token_ids
elif isinstance(prompt_token_ids, (torch.Tensor, np.ndarray)):
self.prompt_token_ids = prompt_token_ids.tolist()
if query_token_ids:
self.query_token_ids = query_token_ids.tolist()
else:
raise TypeError(
f"prompt_token_ids ({prompt_token_ids}) should be an instance of torch.Tensor, np.ndarray or list"
)
self.sampling_params = sampling_params
self.postproc_params = postproc_params
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.streaming = streaming
self.prompt_tuning_config = prompt_tuning_config
self.mrope_config = mrope_config
self.kv_cache_retention_config = kv_cache_retention_config
self.id: Optional[int] = None
self.disaggregated_params = disaggregated_params
def set_id(self, id):
assert self.id is None, f"Request ID is already set: {self.id}"
self.id = id
return self
class CancellingRequest:
''' The request to cancel a generation. '''
def __init__(self, id: int):
self.id = id