mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
1198 lines
56 KiB
Python
1198 lines
56 KiB
Python
import atexit
|
|
import json
|
|
import os
|
|
import shutil
|
|
import socket
|
|
import tempfile
|
|
import time
|
|
import weakref
|
|
from collections.abc import Mapping
|
|
from pathlib import Path
|
|
from typing import Any, List, Literal, Optional, Sequence, Union, cast
|
|
|
|
import transformers
|
|
from tqdm import tqdm
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from tensorrt_llm._utils import mpi_disabled
|
|
from tensorrt_llm.inputs.data import TextPrompt
|
|
from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams
|
|
from tensorrt_llm.inputs.registry import (BaseMultimodalInputProcessor,
|
|
DefaultInputProcessor)
|
|
from tensorrt_llm.llmapi import tracing
|
|
from tensorrt_llm.metrics.enums import MetricNames
|
|
|
|
from .._utils import nvtx_range_debug
|
|
from ..bindings import executor as tllm
|
|
from ..bindings import steady_clock_now
|
|
from ..builder import EngineConfig
|
|
from ..disaggregated_params import DisaggregatedParams
|
|
from ..executor import (DetokenizedGenerationResultBase, GenerationExecutor,
|
|
GenerationResult, IterationResult, LoRARequest,
|
|
PostprocWorkerConfig, PromptAdapterRequest)
|
|
from ..executor.postproc_worker import PostprocParams
|
|
from ..executor.utils import (create_mpi_comm_session,
|
|
get_spawn_proxy_process_env)
|
|
from ..inputs import (PromptInputs, create_input_processor,
|
|
create_input_processor_with_hash, get_cache_salt_id,
|
|
prompt_inputs)
|
|
from ..logger import logger
|
|
from ..sampling_params import SamplingParams
|
|
from ..scheduling_params import SchedulingParams
|
|
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
|
|
TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig,
|
|
PybindMirror, TorchLlmArgs, TrtLlmArgs)
|
|
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
|
|
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
|
|
from .mpi_session import MpiPoolSession, external_mpi_comm_available
|
|
from .tokenizer import TokenizerBase, _xgrammar_tokenizer_info
|
|
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
|
|
from .utils import (append_docstring, exception_handler, get_device_count,
|
|
logger_debug, set_api_status)
|
|
|
|
|
|
class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
|
|
"""The output data of a completion request to the LLM.
|
|
|
|
Attributes:
|
|
request_id (int): The unique ID of the request.
|
|
prompt (str, optional): The prompt string of the request.
|
|
prompt_token_ids (List[int]): The token ids of the prompt.
|
|
outputs (List[CompletionOutput]): The output sequences of the request.
|
|
context_logits (torch.Tensor, optional): The logits on the prompt token ids.
|
|
mm_embedding_handle (Dict[str, Any], optional): The multimodal embedding handle of the request.
|
|
finished (bool): Whether the whole request is finished.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
raise RuntimeError(
|
|
f"{self.__class__.__name__} is designed to be instantiated using {self.__class__.__name__}._from_generation_result by GenerationExecutor. "
|
|
f"Users are not expected to create {self.__class__.__name__} directly."
|
|
)
|
|
|
|
@classmethod
|
|
def _from_generation_result(
|
|
cls,
|
|
generation_result: GenerationResult,
|
|
prompt: Optional[str] = None,
|
|
tokenizer: Optional[TokenizerBase] = None) -> 'RequestOutput':
|
|
inst = cls.__new__(cls)
|
|
inst.__dict__.update(generation_result.__dict__)
|
|
inst.tokenizer = tokenizer
|
|
inst._streaming = generation_result._streaming
|
|
inst._prompt = prompt
|
|
return inst
|
|
|
|
@property
|
|
def prompt(self) -> Optional[str]:
|
|
return self._prompt
|
|
|
|
def _repr_fields(self):
|
|
return [
|
|
"request_id",
|
|
"prompt",
|
|
"prompt_token_ids",
|
|
"outputs",
|
|
"finished",
|
|
"mm_embedding_handle",
|
|
]
|
|
|
|
|
|
TRT_LLM_DOCSTRING = TRT_LLMARGS_EXPLICIT_DOCSTRING + """
|
|
|
|
Attributes:
|
|
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
|
|
workspace (pathlib.Path): The directory to store intermediate files.
|
|
llm_id (str): The unique ID of the LLM instance.
|
|
"""
|
|
|
|
TORCH_LLM_DOCSTRING = TORCH_LLMARGS_EXPLICIT_DOCSTRING + """
|
|
|
|
Attributes:
|
|
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
|
|
llm_id (str): The unique ID of the LLM instance.
|
|
"""
|
|
|
|
|
|
class BaseLLM:
|
|
"""
|
|
The base class for all LLM classes.
|
|
"""
|
|
|
|
def __init__(self,
|
|
model: Union[str, Path],
|
|
tokenizer: Optional[Union[str, Path, TokenizerBase,
|
|
PreTrainedTokenizerBase]] = None,
|
|
tokenizer_mode: Literal['auto', 'slow'] = 'auto',
|
|
skip_tokenizer_init: bool = False,
|
|
trust_remote_code: bool = False,
|
|
tensor_parallel_size: int = 1,
|
|
dtype: str = "auto",
|
|
revision: Optional[str] = None,
|
|
tokenizer_revision: Optional[str] = None,
|
|
**kwargs: Any) -> None:
|
|
|
|
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
|
|
self._orchestrator_type = kwargs.get("orchestrator_type", None)
|
|
self._llm_id = None
|
|
|
|
log_level = logger.level
|
|
logger.set_level("info") # force display the backend
|
|
|
|
try:
|
|
env_overrides = kwargs.get("env_overrides", None)
|
|
self._process_env_overrides(env_overrides)
|
|
|
|
backend = kwargs.get('backend', None)
|
|
if backend == "pytorch":
|
|
logger.info("Using LLM with PyTorch backend")
|
|
llm_args_cls = TorchLlmArgs
|
|
if self._orchestrator_type == "ray" or mpi_disabled():
|
|
self._orchestrator_type = "ray"
|
|
os.environ["TLLM_DISABLE_MPI"] = "1"
|
|
# Propagate to args construction
|
|
kwargs["orchestrator_type"] = "ray"
|
|
|
|
elif backend == '_autodeploy':
|
|
logger.info("Using LLM with AutoDeploy backend")
|
|
from .._torch.auto_deploy.llm_args import \
|
|
LlmArgs as AutoDeployLlmArgs
|
|
llm_args_cls = AutoDeployLlmArgs
|
|
else:
|
|
logger.info("Using LLM with TensorRT backend")
|
|
llm_args_cls = TrtLlmArgs
|
|
|
|
# check the kwargs and raise ValueError directly
|
|
valid_keys = set(
|
|
list(llm_args_cls.model_fields.keys()) +
|
|
['_mpi_session', 'backend'])
|
|
for key in kwargs:
|
|
if key not in valid_keys:
|
|
raise ValueError(
|
|
f"{self.__class__.__name__} got invalid argument: {key}"
|
|
)
|
|
|
|
self.args = llm_args_cls.from_kwargs(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
tokenizer_mode=tokenizer_mode,
|
|
skip_tokenizer_init=skip_tokenizer_init,
|
|
trust_remote_code=trust_remote_code,
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
dtype=dtype,
|
|
revision=revision,
|
|
tokenizer_revision=tokenizer_revision,
|
|
**kwargs)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to parse the arguments for the LLM constructor: {e}")
|
|
raise e
|
|
|
|
finally:
|
|
logger.set_level(log_level) # restore the log level
|
|
|
|
logger_debug(f"LLM.args.mpi_session: {self.args.mpi_session}\n",
|
|
"yellow")
|
|
self.mpi_session = self.args.mpi_session
|
|
|
|
if self.args.parallel_config.is_multi_gpu:
|
|
if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count(
|
|
) < self.args.parallel_config.world_size_per_node:
|
|
raise RuntimeError(
|
|
f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required."
|
|
)
|
|
|
|
logger.info(
|
|
f'start MpiSession with {self.args.parallel_config.world_size} workers'
|
|
)
|
|
if not self.mpi_session:
|
|
mpi_process_pre_spawned: bool = get_spawn_proxy_process_env()
|
|
if not mpi_process_pre_spawned:
|
|
logger_debug(f"LLM create MpiPoolSession\n", "yellow")
|
|
self.mpi_session = MpiPoolSession(
|
|
n_workers=self.args.parallel_config.world_size)
|
|
else:
|
|
logger_debug(f"LLM create MpiCommSession\n", "yellow")
|
|
self.mpi_session = create_mpi_comm_session(
|
|
self.args.parallel_config.world_size)
|
|
|
|
try:
|
|
# Due to the Executor can only accept a engine path, we need to save the engine to a directory
|
|
self._engine_dir: Optional[Path] = None
|
|
self._executor: Optional[GenerationExecutor] = None
|
|
if self._on_trt_backend:
|
|
self._workspace = tempfile.TemporaryDirectory(
|
|
suffix="-llm-workspace", dir=self.args.workspace)
|
|
else:
|
|
self._workspace = None
|
|
|
|
self._hf_model_dir: Optional[Path] = None
|
|
self._hf_model_config = None
|
|
self._generation_config = None
|
|
|
|
self.runtime_context: Optional[_ModelRuntimeContext] = None
|
|
self.llm_build_stats = LlmBuildStats()
|
|
self._build_model()
|
|
|
|
except Exception:
|
|
if self.mpi_session is not None:
|
|
self.mpi_session.shutdown()
|
|
raise
|
|
|
|
try:
|
|
if self.args.otlp_traces_endpoint:
|
|
tracing.init_tracer("trt.llm", self.args.otlp_traces_endpoint)
|
|
logger.info(
|
|
f"Initialized OTLP tracer successfully, endpoint: {self.args.otlp_traces_endpoint}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize OTLP tracer: {e}")
|
|
|
|
exception_handler.register(self, 'shutdown')
|
|
atexit.register(LLM._shutdown_wrapper, weakref.ref(self))
|
|
|
|
@property
|
|
@set_api_status("beta")
|
|
def llm_id(self) -> str:
|
|
if self._llm_id is None:
|
|
hostname = socket.gethostname()
|
|
pid = os.getpid()
|
|
timestamp = int(time.time() * 1000)
|
|
self._llm_id = f"{hostname}-{pid}-{timestamp}"
|
|
|
|
return self._llm_id
|
|
|
|
def generate(
|
|
self,
|
|
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
|
sampling_params: Optional[Union[SamplingParams,
|
|
List[SamplingParams]]] = None,
|
|
use_tqdm: bool = True,
|
|
lora_request: Optional[Union[LoRARequest,
|
|
Sequence[LoRARequest]]] = None,
|
|
prompt_adapter_request: Optional[Union[
|
|
PromptAdapterRequest, Sequence[PromptAdapterRequest]]] = None,
|
|
kv_cache_retention_config: Optional[Union[
|
|
KvCacheRetentionConfig, Sequence[KvCacheRetentionConfig]]] = None,
|
|
disaggregated_params: Optional[Union[
|
|
DisaggregatedParams, Sequence[DisaggregatedParams]]] = None,
|
|
scheduling_params: Optional[Union[SchedulingParams,
|
|
List[SchedulingParams]]] = None,
|
|
cache_salt: Optional[Union[str, Sequence[str]]] = None,
|
|
) -> Union[RequestOutput, List[RequestOutput]]:
|
|
"""Generate output for the given prompts in the synchronous mode.
|
|
Synchronous generation accepts either single prompt or batched prompts.
|
|
|
|
Args:
|
|
inputs (tensorrt_llm.inputs.data.PromptInputs, Sequence[tensorrt_llm.inputs.data.PromptInputs]): The prompt text or token ids.
|
|
It can be single prompt or batched prompts.
|
|
sampling_params (tensorrt_llm.sampling_params.SamplingParams, List[tensorrt_llm.sampling_params.SamplingParams], optional): The sampling params for the generation. Defaults to None.
|
|
A default one will be used if not provided.
|
|
use_tqdm (bool): Whether to use tqdm to display the progress bar. Defaults to True.
|
|
lora_request (tensorrt_llm.executor.request.LoRARequest, Sequence[tensorrt_llm.executor.request.LoRARequest], optional):
|
|
LoRA request to use for generation, if any. Defaults to None.
|
|
prompt_adapter_request (tensorrt_llm.executor.request.PromptAdapterRequest, Sequence[tensorrt_llm.executor.request.PromptAdapterRequest], optional):
|
|
Prompt Adapter request to use for generation, if any. Defaults to None.
|
|
kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, Sequence[tensorrt_llm.bindings.executor.KvCacheRetentionConfig], optional):
|
|
Configuration for the request's retention in the KV Cache. Defaults to None.
|
|
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, Sequence[tensorrt_llm.disaggregated_params.DisaggregatedParams], optional):
|
|
Disaggregated parameters. Defaults to None.
|
|
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, List[tensorrt_llm.scheduling_params.SchedulingParams], optional):
|
|
Scheduling parameters. Defaults to None.
|
|
cache_salt (str, Sequence[str], optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None.
|
|
Returns:
|
|
Union[tensorrt_llm.llmapi.RequestOutput, List[tensorrt_llm.llmapi.RequestOutput]]: The output data of the completion request to the LLM.
|
|
"""
|
|
unbatched = not isinstance(inputs, list)
|
|
if not unbatched:
|
|
if isinstance(inputs[0], int):
|
|
unbatched = True
|
|
|
|
if unbatched:
|
|
inputs = [inputs]
|
|
|
|
inputs = [prompt_inputs(i) for i in inputs]
|
|
|
|
def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any:
|
|
if isinstance(maybe_batched, list):
|
|
return maybe_batched[pos]
|
|
else:
|
|
return maybe_batched
|
|
|
|
futures = []
|
|
for i, request_inputs in enumerate(inputs):
|
|
future = self.generate_async(
|
|
request_inputs,
|
|
sampling_params=_item_at(sampling_params, i),
|
|
lora_request=_item_at(lora_request, i),
|
|
prompt_adapter_request=_item_at(prompt_adapter_request, i),
|
|
kv_cache_retention_config=_item_at(kv_cache_retention_config,
|
|
i),
|
|
disaggregated_params=_item_at(disaggregated_params, i),
|
|
scheduling_params=_item_at(scheduling_params, i),
|
|
cache_salt=_item_at(cache_salt, i),
|
|
streaming=False,
|
|
)
|
|
futures.append(future)
|
|
|
|
for future in tqdm(futures,
|
|
desc="Processed requests",
|
|
dynamic_ncols=True,
|
|
disable=not use_tqdm):
|
|
future.result()
|
|
|
|
if unbatched:
|
|
futures = futures[0]
|
|
|
|
return futures
|
|
|
|
@nvtx_range_debug("LLM.generate_async", color="green", category="LLM")
|
|
def generate_async(
|
|
self,
|
|
inputs: PromptInputs,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
streaming: bool = False,
|
|
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
|
|
disaggregated_params: Optional[DisaggregatedParams] = None,
|
|
trace_headers: Optional[Mapping[str, str]] = None,
|
|
_postproc_params: Optional[PostprocParams] = None,
|
|
scheduling_params: Optional[SchedulingParams] = None,
|
|
cache_salt: Optional[str] = None,
|
|
) -> RequestOutput:
|
|
"""Generate output for the given prompt in the asynchronous mode.
|
|
Asynchronous generation accepts single prompt only.
|
|
|
|
Args:
|
|
inputs (tensorrt_llm.inputs.data.PromptInputs): The prompt text or token ids; it must be single prompt.
|
|
sampling_params (tensorrt_llm.sampling_params.SamplingParams, optional): The sampling params for the generation. Defaults to None.
|
|
A default one will be used if not provided.
|
|
lora_request (tensorrt_llm.executor.request.LoRARequest, optional): LoRA request to use for generation, if any. Defaults to None.
|
|
prompt_adapter_request (tensorrt_llm.executor.request.PromptAdapterRequest, optional): Prompt Adapter request to use for generation, if any. Defaults to None.
|
|
streaming (bool): Whether to use the streaming mode for the generation. Defaults to False.
|
|
kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, optional): Configuration for the request's retention in the KV Cache. Defaults to None.
|
|
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Disaggregated parameters. Defaults to None.
|
|
trace_headers (Mapping[str, str], optional): Trace headers. Defaults to None.
|
|
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, optional): Scheduling parameters. Defaults to None.
|
|
cache_salt (str, optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None.
|
|
Returns:
|
|
tensorrt_llm.llmapi.RequestOutput: The output data of the completion request to the LLM.
|
|
"""
|
|
|
|
# Check if the worker is shutting down
|
|
if self._executor is None or self._executor.is_shutdown():
|
|
raise RuntimeError("LLM is shutting down")
|
|
|
|
arrival_time = steady_clock_now(
|
|
) if self.args.return_perf_metrics else None
|
|
|
|
sampling_params = self._prepare_sampling_params(sampling_params)
|
|
cache_salt_id = get_cache_salt_id(
|
|
cache_salt) if cache_salt is not None else None
|
|
# With pytorch backend, py_executor has logic to handle max_tokens of 1,
|
|
# so set to 1 to avoid allocating unnecessary KV cache blocks for single request
|
|
# TODO: Also support for trt backend
|
|
is_ctx_only = disaggregated_params is not None and disaggregated_params.request_type == "context_only"
|
|
is_gen_only = disaggregated_params is not None and disaggregated_params.request_type == "generation_only"
|
|
is_mm_disagg = disaggregated_params is not None and disaggregated_params.multimodal_embedding_handles is not None
|
|
|
|
if is_ctx_only and not self._on_trt_backend:
|
|
sampling_params.max_tokens = 1
|
|
|
|
inputs = prompt_inputs(inputs)
|
|
|
|
if not inputs.get("prompt") and inputs.get("prompt_token_ids") and (
|
|
inputs.get("multi_modal_data")
|
|
or inputs.get("multi_modal_embeddings")) and not isinstance(
|
|
self.input_processor, DefaultInputProcessor):
|
|
# VLMs need to process/tokenize the prompt in their own way
|
|
prompt = self.tokenizer.decode(inputs['prompt_token_ids'])
|
|
inputs = TextPrompt(
|
|
prompt=prompt,
|
|
multi_modal_data=inputs.get("multi_modal_data"),
|
|
mm_processor_kwargs=inputs.get("mm_processor_kwargs"))
|
|
if sampling_params.add_special_tokens:
|
|
logger.debug(
|
|
"Setting add_special_tokens to False because prompt_token_ids were provided to generate. VLMs will re-encode the prompt."
|
|
)
|
|
sampling_params.add_special_tokens = False
|
|
|
|
query_token_ids = None
|
|
multimodal_params = None
|
|
|
|
if is_mm_disagg:
|
|
if not getattr(self.input_processor, "support_mm_disagg", False):
|
|
raise ValueError(
|
|
"Multimodal disaggregated inference is not supported for this model"
|
|
)
|
|
mm_handles = disaggregated_params.multimodal_embedding_handles
|
|
prompt_token_ids, mm_token_length, mm_token_positions = self.input_processor.get_prompt_token_ids(
|
|
inputs, mm_handles)
|
|
prompt = inputs.get("prompt", None)
|
|
query_token_ids = inputs.get("query_token_ids", None)
|
|
if is_gen_only:
|
|
raise ValueError(
|
|
"Generation-only mode should not need multimodal parameters"
|
|
)
|
|
else:
|
|
mm_hashes = disaggregated_params.multimodal_hashes
|
|
multimodal_input = MultimodalInput.from_components(
|
|
mm_hashes, mm_token_positions, mm_token_length)
|
|
multimodal_data = {"multimodal_embedding": mm_handles}
|
|
if disaggregated_params.mrope_position_ids_handle is not None:
|
|
# NOTE: `PyTorchModelEngine` assumes both are present when using mrope.
|
|
assert disaggregated_params.mrope_position_deltas_handle is not None
|
|
mrope_config = {}
|
|
mrope_config[
|
|
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
|
|
mrope_config[
|
|
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
|
|
multimodal_data["mrope_config"] = mrope_config
|
|
multimodal_params = MultimodalParams(
|
|
multimodal_input=multimodal_input,
|
|
multimodal_data=multimodal_data,
|
|
)
|
|
|
|
elif "prompt_token_ids" in inputs:
|
|
prompt_token_ids = inputs['prompt_token_ids']
|
|
prompt = None
|
|
query_token_ids = inputs.get("query_token_ids", None)
|
|
multimodal_data = {}
|
|
# NOTE: when running in `generation_only` for disagg, this is the code path we expect to hit.
|
|
if disaggregated_params is not None and disaggregated_params.mrope_position_ids_handle is not None:
|
|
# It looks like `PyTorchModelEngine` assumes both are present when using mrope?
|
|
if disaggregated_params.mrope_position_deltas_handle is None:
|
|
raise RuntimeError(
|
|
"`mrope_position_ids_handle` and `mrope_position_deltas_handle` must both "
|
|
"be provided, or both `None`.")
|
|
mrope_config = {}
|
|
mrope_config[
|
|
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
|
|
mrope_config[
|
|
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
|
|
multimodal_data["mrope_config"] = mrope_config
|
|
if multimodal_data:
|
|
multimodal_params = MultimodalParams(
|
|
multimodal_data=multimodal_data)
|
|
elif "prompt" in inputs:
|
|
if 'multi_modal_data' in inputs:
|
|
# TODO: The current design uses a wrapper for existing input processor (input_processor_with_hash)
|
|
# to handle/add multimodal hashes, positions, and lengths. Now we only support image modality.
|
|
# In the future, we should refactor this to:
|
|
# 1. Extend support for more modalities and models
|
|
# 2. Decouple input processor into distinct phases (preprocessor (all preprocessing logics), vision model (fuse in model fwd), etc.
|
|
input_processor_with_hash = create_input_processor_with_hash(
|
|
self.input_processor)
|
|
with nvtx_range_debug("input_processor_with_hash"):
|
|
prompt_token_ids, extra_processed_inputs = input_processor_with_hash(
|
|
inputs, sampling_params)
|
|
elif 'multi_modal_embeddings' in inputs:
|
|
mm_embedding_info = inputs['multi_modal_embeddings']
|
|
prompt_token_ids, extra_processed_inputs = cast(
|
|
BaseMultimodalInputProcessor,
|
|
self.input_processor).attach_multimodal_embeddings(
|
|
inputs, mm_embedding_info, sampling_params)
|
|
else:
|
|
with nvtx_range_debug("input_processor"):
|
|
prompt_token_ids, extra_processed_inputs = self.input_processor(
|
|
inputs, sampling_params)
|
|
prompt = inputs['prompt']
|
|
if extra_processed_inputs is not None:
|
|
query_token_ids = extra_processed_inputs.get('query_token_ids')
|
|
# Create unified MultimodalParams
|
|
multimodal_params = MultimodalParams(
|
|
multimodal_input=extra_processed_inputs.get(
|
|
'multimodal_input'),
|
|
multimodal_data=extra_processed_inputs.get(
|
|
'multimodal_data'))
|
|
# Only pass it if it has content
|
|
if not multimodal_params.has_content():
|
|
multimodal_params = None
|
|
else:
|
|
# Convert to shared tensor handle to reduce IPC overhead
|
|
multimodal_params.to_handle("multimodal_data")
|
|
else:
|
|
raise TypeError(
|
|
f"The inputs must be type str or list of int, but got {type(inputs)}"
|
|
)
|
|
|
|
self._check_arguments(
|
|
len(prompt_token_ids),
|
|
len(query_token_ids) if query_token_ids is not None else 0,
|
|
sampling_params,
|
|
is_gen_only=is_gen_only)
|
|
if _postproc_params:
|
|
_postproc_params.postproc_args.num_prompt_tokens = len(
|
|
prompt_token_ids)
|
|
result = self._executor.generate_async(
|
|
prompt_token_ids,
|
|
query_token_ids=query_token_ids,
|
|
sampling_params=sampling_params,
|
|
lora_request=lora_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
streaming=streaming,
|
|
kv_cache_retention_config=kv_cache_retention_config,
|
|
disaggregated_params=disaggregated_params,
|
|
trace_headers=trace_headers,
|
|
postproc_params=_postproc_params,
|
|
multimodal_params=multimodal_params,
|
|
scheduling_params=scheduling_params,
|
|
cache_salt_id=cache_salt_id,
|
|
arrival_time=arrival_time,
|
|
)
|
|
|
|
if sampling_params.return_perf_metrics:
|
|
result.metrics_dict.update(
|
|
{MetricNames.ARRIVAL_TIMESTAMP: time.time()})
|
|
|
|
return RequestOutput._from_generation_result(result, prompt,
|
|
self.tokenizer)
|
|
|
|
@set_api_status("beta")
|
|
def get_stats(self, timeout: Optional[float] = 2) -> List[dict]:
|
|
'''Get iteration statistics from the runtime.
|
|
To collect statistics, call this function after prompts have been submitted with LLM().generate().
|
|
|
|
Args:
|
|
timeout (float, optional): Max wait time in seconds when retrieving stats from queue. Defaults to 2.
|
|
|
|
Returns:
|
|
List[dict]: A list of runtime stats as dict.
|
|
e.g., ['{"cpuMemUsage": ..., "iter": 0, ...}', '{"cpuMemUsage": ..., "iter": 1, ...}']
|
|
'''
|
|
return self._executor.get_stats(timeout=timeout)
|
|
|
|
@set_api_status("beta")
|
|
def get_stats_async(self, timeout: Optional[float] = 2) -> IterationResult:
|
|
'''Get iteration statistics from the runtime.
|
|
To collect statistics, you can call this function in an async coroutine or the /metrics endpoint (if you're using trtllm-serve)
|
|
after prompts have been submitted.
|
|
|
|
Args:
|
|
timeout (float, optional): Max wait time in seconds when retrieving stats from queue. Defaults to 2.
|
|
|
|
Returns:
|
|
tensorrt_llm.executor.result.IterationResult: An async iterable object containing runtime stats.
|
|
'''
|
|
return self._executor.aget_stats(timeout=timeout)
|
|
|
|
@set_api_status("beta")
|
|
def get_kv_cache_events(self, timeout: Optional[float] = 2) -> List[dict]:
|
|
'''Get iteration KV events from the runtime.
|
|
|
|
KV events are used to track changes and operations within the KV Cache. Types of events:
|
|
- KVCacheCreatedData: Indicates the creation of cache blocks.
|
|
- KVCacheStoredData: Represents a sequence of stored blocks.
|
|
- KVCacheRemovedData: Contains the hashes of blocks that are being removed from the cache.
|
|
- KVCacheUpdatedData: Captures updates to existing cache blocks.
|
|
|
|
To enable KV events:
|
|
- set `event_buffer_max_size` to a positive integer in the `KvCacheConfig`.
|
|
- set `enable_block_reuse` to True in the `KvCacheConfig`.
|
|
|
|
Args:
|
|
timeout (float, optional): Max wait time in seconds when retrieving events from queue. Defaults to 2.
|
|
|
|
Returns:
|
|
List[dict]: A list of runtime events as dict.
|
|
'''
|
|
return self._executor.get_kv_events(timeout=timeout)
|
|
|
|
@set_api_status("beta")
|
|
def get_kv_cache_events_async(self,
|
|
timeout: Optional[float] = 2
|
|
) -> IterationResult:
|
|
'''Get iteration KV events from the runtime.
|
|
|
|
KV events are used to track changes and operations within the KV Cache. Types of events:
|
|
- KVCacheCreatedData: Indicates the creation of cache blocks.
|
|
- KVCacheStoredData: Represents a sequence of stored blocks.
|
|
- KVCacheRemovedData: Contains the hashes of blocks that are being removed from the cache.
|
|
- KVCacheUpdatedData: Captures updates to existing cache blocks.
|
|
|
|
To enable KV events:
|
|
- set `event_buffer_max_size` to a positive integer in the `KvCacheConfig`.
|
|
- set `enable_block_reuse` to True in the `KvCacheConfig`.
|
|
|
|
Args:
|
|
timeout (float, optional): Max wait time in seconds when retrieving events from queue. . Defaults to 2.
|
|
|
|
Returns:
|
|
tensorrt_llm.executor.result.IterationResult: An async iterable object containing runtime events.
|
|
'''
|
|
return self._executor.aget_kv_events(timeout=timeout)
|
|
|
|
def _process_env_overrides(self,
|
|
env_overrides: Optional[dict[str, str]]) -> None:
|
|
if env_overrides is None:
|
|
return
|
|
logger.info("Processing LLM API environment variable overrides")
|
|
# TODO: If an env var is cached at import-time in code, overriding os.environ will
|
|
# unfortunately not update wherever the var is used.
|
|
# This is a known issue and only way to fix it is at every such usage to access it
|
|
# from os.environ on-demand.
|
|
for key, value in env_overrides.items():
|
|
str_value = str(value)
|
|
if key in os.environ:
|
|
old_value = os.environ[key]
|
|
os.environ[key] = str_value
|
|
logger.info(f"Overriding {key}: '{old_value}' -> '{str_value}'")
|
|
else:
|
|
os.environ[key] = str_value
|
|
logger.info(f"Setting {key}='{str_value}'")
|
|
|
|
def _prepare_sampling_params(
|
|
self,
|
|
sampling_params: Optional[SamplingParams] = None) -> SamplingParams:
|
|
if sampling_params is None:
|
|
sampling_params = SamplingParams()
|
|
if isinstance(sampling_params, SamplingParams):
|
|
if sampling_params.end_id is None:
|
|
if self.tokenizer is None:
|
|
raise ValueError(
|
|
"tokenizer is required to reset end_id if it is None, or you can explicitly specify the end_id for sampling_params"
|
|
)
|
|
sampling_params._setup(self.tokenizer, self._hf_model_config,
|
|
self._generation_config)
|
|
else:
|
|
raise TypeError(
|
|
f"The sampling_params must be type SamplingParams or None, but got {type(sampling_params)}"
|
|
)
|
|
|
|
# auto enabled context and/or generation logits flags, as they are required by logprob computation for TRT backend.
|
|
if self.args.backend not in ["pytorch", "_autodeploy"]:
|
|
if sampling_params.prompt_logprobs and not sampling_params.return_context_logits:
|
|
sampling_params.return_context_logits = True
|
|
sampling_params._context_logits_auto_enabled = True
|
|
if sampling_params.logprobs and not sampling_params.return_generation_logits:
|
|
sampling_params.return_generation_logits = True
|
|
sampling_params._generation_logits_auto_enabled = True
|
|
|
|
if sampling_params._stream_interval is None:
|
|
sampling_params._stream_interval = getattr(self.args,
|
|
"stream_interval", 1)
|
|
sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics
|
|
return sampling_params
|
|
|
|
def _check_arguments(self, prompt_len: int, query_len: int,
|
|
sampling_params: SamplingParams,
|
|
is_gen_only: bool) -> None:
|
|
|
|
if self.args.backend in ["pytorch", "_autodeploy"]:
|
|
# Check prompt length and query length against max_num_tokens to filter illegal requests.
|
|
# Skip check for gen-only requests
|
|
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:
|
|
max_num_tokens = self.args.max_num_tokens
|
|
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
|
|
raise ValueError(
|
|
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "
|
|
f"max_num_tokens ({max_num_tokens})")
|
|
return
|
|
|
|
build_config = self.args.build_config
|
|
|
|
built_enging_cfg_file = Path(self.args.model) / 'config.json'
|
|
with open(built_enging_cfg_file) as f:
|
|
built_enging_cfg = json.load(f)
|
|
max_seq_len = built_enging_cfg['build_config'][
|
|
'max_seq_len'] if 'build_config' in built_enging_cfg else build_config.max_seq_len
|
|
# TODO: Remove this check and left the request verification to cpp runtime
|
|
|
|
if (not self.args.enable_chunked_prefill) and (
|
|
prompt_len / self.args.parallel_config.cp_size + query_len +
|
|
(sampling_params.max_tokens or 0) > max_seq_len):
|
|
raise ValueError(
|
|
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}) and query length ({query_len}) max_tokens ({sampling_params.max_tokens}) should not exceed "
|
|
f"max_seq_len ({max_seq_len})")
|
|
|
|
if sampling_params.use_beam_search and sampling_params.best_of > build_config.max_beam_width:
|
|
if sampling_params.n == sampling_params.best_of:
|
|
raise ValueError(
|
|
f"sampling_params.n ({sampling_params.n}) cannot exceed max_beam_width ({build_config.max_beam_width}) when use_beam_search is True"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"sampling_params.best_of ({sampling_params.best_of}) cannot exceed max_beam_width ({build_config.max_beam_width}) when use_beam_search is True"
|
|
)
|
|
|
|
max_batch_size = self.args.max_batch_size
|
|
if max_batch_size is None:
|
|
max_batch_size = build_config.max_batch_size
|
|
if not sampling_params.use_beam_search and sampling_params.best_of > max_batch_size:
|
|
if sampling_params.n == sampling_params.best_of:
|
|
raise ValueError(
|
|
f"sampling_params.n ({sampling_params.n}) cannot exceed max_batch_size ({max_batch_size}) when use_beam_search is False"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"sampling_params.best_of ({sampling_params.best_of}) cannot exceed max_batch_size ({max_batch_size}) when use_beam_search is False"
|
|
)
|
|
|
|
if sampling_params.prompt_logprobs and not build_config.gather_context_logits:
|
|
raise ValueError(
|
|
f"`sampling_params's prompt_logprobs={sampling_params.prompt_logprobs}` requires `gather_context_logits=True` "
|
|
f"in the `BuildConfig` when constructing the LLM. "
|
|
f"Example: LLM(..., build_config=BuildConfig(gather_context_logits=True))."
|
|
)
|
|
|
|
if sampling_params.logprobs and not self.args.gather_generation_logits:
|
|
raise ValueError(
|
|
f"`sampling_params.logprobs={sampling_params.logprobs}` requires `gather_generation_logits=True` "
|
|
f"to be passed explicitly to the `LLM()` constructor.")
|
|
|
|
def _build_model(self):
|
|
model_loader = CachedModelLoader(self.args,
|
|
mpi_session=self.mpi_session,
|
|
workspace=self._workspace,
|
|
llm_build_stats=weakref.proxy(
|
|
self.llm_build_stats))
|
|
self._engine_dir, self._hf_model_dir = model_loader()
|
|
|
|
@property
|
|
def _on_trt_backend(self) -> bool:
|
|
return isinstance(self.args, TrtLlmArgs)
|
|
|
|
def _try_load_tokenizer(self) -> Optional[TokenizerBase]:
|
|
if self.args.skip_tokenizer_init:
|
|
return None
|
|
|
|
if self.args.tokenizer is not None:
|
|
assert isinstance(self.args.tokenizer, TokenizerBase)
|
|
return self.args.tokenizer
|
|
|
|
if self.runtime_context is not None:
|
|
return self.runtime_context.tokenizer
|
|
|
|
# TODO smor- need to refine what is the desired behavior if lora is enabled
|
|
# in terms of the tokenizer initialization process
|
|
if hasattr(self.args, "backend") and self.args.backend in [
|
|
"pytorch", "_autodeploy"
|
|
] and self.args.lora_config is not None:
|
|
num_lora_dirs = len(self.args.lora_config.lora_dir)
|
|
if num_lora_dirs == 1:
|
|
tokenizer_path = self.args.lora_config.lora_dir[0]
|
|
try:
|
|
tokenizer = ModelLoader.load_hf_tokenizer(
|
|
tokenizer_path,
|
|
trust_remote_code=self.args.trust_remote_code,
|
|
use_fast=self.args.tokenizer_mode != 'slow')
|
|
if tokenizer is None:
|
|
tokenizer_path = self.args.model
|
|
else:
|
|
return tokenizer
|
|
except Exception:
|
|
tokenizer_path = self.args.model
|
|
else:
|
|
tokenizer_path = self.args.model
|
|
else:
|
|
tokenizer_path = self.args.model
|
|
return ModelLoader.load_hf_tokenizer(
|
|
tokenizer_path,
|
|
trust_remote_code=self.args.trust_remote_code,
|
|
use_fast=self.args.tokenizer_mode != 'slow')
|
|
|
|
@property
|
|
def tokenizer(self) -> Optional[TokenizerBase]:
|
|
if hasattr(self, "input_processor"):
|
|
if hasattr(self.input_processor, "tokenizer"):
|
|
return self.input_processor.tokenizer
|
|
return self._tokenizer
|
|
|
|
@tokenizer.setter
|
|
def tokenizer(self, tokenizer: TokenizerBase):
|
|
self._tokenizer = tokenizer
|
|
|
|
def _try_load_generation_config(
|
|
self) -> Optional[transformers.GenerationConfig]:
|
|
return ModelLoader.load_hf_generation_config(self.args.model)
|
|
|
|
def _try_load_hf_model_config(
|
|
self) -> Optional[transformers.PretrainedConfig]:
|
|
return ModelLoader.load_hf_model_config(self.args.model)
|
|
|
|
@set_api_status("beta")
|
|
def shutdown(self) -> None:
|
|
if hasattr(self, "_executor") and self._executor is not None:
|
|
self._executor.shutdown()
|
|
self._executor = None
|
|
|
|
if hasattr(self, 'mpi_session') and self.mpi_session is not None:
|
|
self.mpi_session.shutdown()
|
|
self.mpi_session = None
|
|
|
|
def _check_health(self) -> bool:
|
|
"""Check if the LLM is healthy.
|
|
|
|
Returns:
|
|
bool: True if the executor is running and not shutdown, False otherwise.
|
|
"""
|
|
if hasattr(self, "_executor") and self._executor is not None:
|
|
return not self._executor.is_shutdown()
|
|
|
|
return False
|
|
|
|
@staticmethod
|
|
def _shutdown_wrapper(self_ref):
|
|
# Retrieve the instance if it still exists
|
|
instance = self_ref()
|
|
if instance is not None:
|
|
instance.shutdown()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(
|
|
self, exc_type, exc_value, traceback
|
|
) -> Literal[
|
|
False]: # https://github.com/microsoft/pyright/issues/7009#issuecomment-1894135045
|
|
del exc_value, traceback
|
|
self.shutdown()
|
|
return False # propagate exceptions
|
|
|
|
def __getstate__(self):
|
|
raise RuntimeError("LLM object can not be pickled.")
|
|
|
|
def __del__(self):
|
|
self.shutdown()
|
|
|
|
|
|
@append_docstring(TRT_LLM_DOCSTRING)
|
|
class _TrtLLM(BaseLLM):
|
|
"""LLM class is the main class for running a LLM model using TensorRT LLM backend.
|
|
|
|
Parameters:
|
|
"""
|
|
|
|
def __init__(self,
|
|
model: Union[str, Path],
|
|
tokenizer: Optional[Union[str, Path, TokenizerBase,
|
|
PreTrainedTokenizerBase]] = None,
|
|
tokenizer_mode: Literal['auto', 'slow'] = 'auto',
|
|
skip_tokenizer_init: bool = False,
|
|
trust_remote_code: bool = False,
|
|
tensor_parallel_size: int = 1,
|
|
dtype: str = "auto",
|
|
revision: Optional[str] = None,
|
|
tokenizer_revision: Optional[str] = None,
|
|
**kwargs: Any) -> None:
|
|
# TODO: deprecate backend in LLM kwargs
|
|
|
|
super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,
|
|
trust_remote_code, tensor_parallel_size, dtype,
|
|
revision, tokenizer_revision, **kwargs)
|
|
|
|
@property
|
|
def workspace(self) -> Path:
|
|
return Path(self._workspace.name) if self._on_trt_backend else None
|
|
|
|
def save(self, engine_dir: str) -> None:
|
|
"""Save the built engine to the given path.
|
|
|
|
Args:
|
|
engine_dir (str): The path to save the engine.
|
|
"""
|
|
logger.info(f"Save model to {engine_dir}")
|
|
if self._engine_dir is None:
|
|
raise RuntimeError("The engine is not built yet.")
|
|
|
|
if self._engine_dir.absolute() == os.path.abspath(engine_dir):
|
|
return
|
|
|
|
if not self.mpi_session or not self.mpi_session.is_comm_session():
|
|
shutil.copytree(self._engine_dir, engine_dir, dirs_exist_ok=True)
|
|
else:
|
|
# NFS is fragile, so we copy files one by one
|
|
target_engine_dir = Path(engine_dir)
|
|
target_engine_dir.mkdir(parents=True, exist_ok=True)
|
|
# copy files one by one
|
|
for file in self._engine_dir.iterdir():
|
|
logger_debug(
|
|
f"Copying {file} to {target_engine_dir / file.name}\n")
|
|
shutil.copy(file, target_engine_dir / file.name)
|
|
|
|
def _build_model(self):
|
|
super()._build_model()
|
|
# update the model_dir to a local dir for the runtime, such as tokenizer loading.
|
|
if self._engine_dir is not None:
|
|
self.args.model = self._engine_dir
|
|
|
|
# Tokenizer and config loading should be after calling model_loader(), since model_loader() may download the model from HF hub.
|
|
# It should also be before bindings ExecutorConfig, which may depend on tokenizer info.
|
|
self._tokenizer = self._try_load_tokenizer()
|
|
self._hf_model_config = self._try_load_hf_model_config()
|
|
self._generation_config = self._try_load_generation_config()
|
|
|
|
# Multimodal special handling:
|
|
# 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor
|
|
# 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__
|
|
self.input_processor = create_input_processor(self._hf_model_dir,
|
|
self.tokenizer)
|
|
self._tokenizer = self.input_processor.tokenizer
|
|
|
|
max_batch_size = self.args.max_batch_size
|
|
max_num_tokens = self.args.max_num_tokens
|
|
max_seq_len = self.args.max_seq_len
|
|
|
|
build_config = self.args.build_config
|
|
|
|
max_batch_size = max_batch_size or build_config.max_batch_size
|
|
max_num_tokens = max_num_tokens or build_config.max_num_tokens
|
|
max_seq_len = max_seq_len or build_config.max_seq_len
|
|
|
|
self._executor_config = tllm.ExecutorConfig(
|
|
max_beam_width=self.args.max_beam_width,
|
|
scheduler_config=PybindMirror.maybe_to_pybind(
|
|
self.args.scheduler_config),
|
|
batching_type=PybindMirror.maybe_to_pybind(self.args.batching_type)
|
|
or tllm.BatchingType.INFLIGHT,
|
|
max_batch_size=max_batch_size,
|
|
max_num_tokens=max_num_tokens,
|
|
gather_generation_logits=self.args.gather_generation_logits,
|
|
fail_fast_on_attention_window_too_large=getattr(
|
|
self.args, 'fail_fast_on_attention_window_too_large', False))
|
|
|
|
# also set executor_config.max_seq_len in TRT workflow, to deduce default max_tokens
|
|
if max_seq_len is not None:
|
|
self._executor_config.max_seq_len = max_seq_len
|
|
else:
|
|
engine_config = EngineConfig.from_json_file(self._engine_dir /
|
|
"config.json")
|
|
self._executor_config.max_seq_len = engine_config.build_config.max_seq_len
|
|
|
|
if self.args.kv_cache_config is not None:
|
|
self._executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
|
|
self.args.kv_cache_config)
|
|
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
|
|
# Disable KV cache reuse for deterministic mode
|
|
self._executor_config.kv_cache_config.enable_block_reuse = False
|
|
self._executor_config.kv_cache_config.enable_partial_reuse = False
|
|
if self.args.peft_cache_config is not None:
|
|
self._executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
|
|
self.args.peft_cache_config)
|
|
|
|
lora_config = None
|
|
if self.args.build_config.plugin_config.lora_plugin:
|
|
engine_config = EngineConfig.from_json_file(self._engine_dir /
|
|
"config.json")
|
|
lora_config = engine_config.build_config.lora_config
|
|
if self.args.lora_config is not None:
|
|
logger.info(
|
|
"Overriding lora_config from engine with lora_config from LLM args"
|
|
)
|
|
lora_config = self.args.lora_config
|
|
|
|
max_lora_rank = lora_config.max_lora_rank
|
|
num_lora_modules = engine_config.pretrained_config.num_hidden_layers * \
|
|
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
|
|
|
|
peft_cache_config_model = PeftCacheConfig.from_pybind(
|
|
self._executor_config.peft_cache_config
|
|
) if self._executor_config.peft_cache_config is not None else PeftCacheConfig(
|
|
)
|
|
if lora_config.max_loras is not None:
|
|
peft_cache_config_model.num_device_module_layer = \
|
|
max_lora_rank * num_lora_modules * lora_config.max_loras
|
|
if lora_config.max_cpu_loras is not None:
|
|
peft_cache_config_model.num_host_module_layer = \
|
|
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
|
|
self._executor_config.peft_cache_config = peft_cache_config_model._to_pybind(
|
|
)
|
|
|
|
if self.args.decoding_config is not None:
|
|
self._executor_config.decoding_config = self.args.decoding_config
|
|
if self.args.guided_decoding_backend == 'xgrammar':
|
|
self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig(
|
|
backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend.
|
|
XGRAMMAR,
|
|
**_xgrammar_tokenizer_info(self.tokenizer))
|
|
elif self.args.guided_decoding_backend is not None:
|
|
raise ValueError(
|
|
f"Unsupported guided decoding backend {self.args.guided_decoding_backend}"
|
|
)
|
|
|
|
self._executor_config.normalize_log_probs = self.args.normalize_log_probs
|
|
self._executor_config.enable_chunked_context = self.args.enable_chunked_prefill
|
|
self._executor_config.max_beam_width = self.args.max_beam_width or self.args.build_config.max_beam_width
|
|
if self.args.extended_runtime_perf_knob_config is not None:
|
|
self._executor_config.extended_runtime_perf_knob_config = PybindMirror.maybe_to_pybind(
|
|
self.args.extended_runtime_perf_knob_config)
|
|
if self.args.cache_transceiver_config is not None:
|
|
self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
|
|
self.args.cache_transceiver_config)
|
|
self._executor_config.llm_parallel_config = self.args.parallel_config
|
|
return_logits = (self.args.gather_generation_logits
|
|
or (self.args.build_config
|
|
and self.args.build_config.gather_context_logits))
|
|
|
|
self._executor = self._executor_cls.create(
|
|
self._engine_dir,
|
|
executor_config=self._executor_config,
|
|
batched_logits_processor=self.args.batched_logits_processor,
|
|
model_world_size=self.args.parallel_config.world_size,
|
|
mpi_session=self.mpi_session,
|
|
reuse_mpi_comm=external_mpi_comm_available(
|
|
self.args.parallel_config.world_size),
|
|
return_logits=return_logits,
|
|
postproc_worker_config=PostprocWorkerConfig(
|
|
num_postprocess_workers=self.args.num_postprocess_workers,
|
|
postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir,
|
|
),
|
|
is_llm_executor=True)
|
|
|
|
|
|
@append_docstring(TORCH_LLM_DOCSTRING)
|
|
class _TorchLLM(BaseLLM):
|
|
"""LLM class is the main class for running a LLM model using PyTorch backend.
|
|
|
|
Parameters:
|
|
"""
|
|
|
|
def __init__(self,
|
|
model: Union[str, Path],
|
|
tokenizer: Optional[Union[str, Path, TokenizerBase,
|
|
PreTrainedTokenizerBase]] = None,
|
|
tokenizer_mode: Literal['auto', 'slow'] = 'auto',
|
|
skip_tokenizer_init: bool = False,
|
|
trust_remote_code: bool = False,
|
|
tensor_parallel_size: int = 1,
|
|
dtype: str = "auto",
|
|
revision: Optional[str] = None,
|
|
tokenizer_revision: Optional[str] = None,
|
|
**kwargs: Any) -> None:
|
|
|
|
# TODO: deprecate backend in LLM kwargs
|
|
backend = kwargs.pop("backend", "pytorch")
|
|
|
|
# Validate that users don't pass TrtLlmArgs-specific arguments
|
|
self._validate_args_for_torch_backend(kwargs)
|
|
|
|
super().__init__(model,
|
|
tokenizer,
|
|
tokenizer_mode,
|
|
skip_tokenizer_init,
|
|
trust_remote_code,
|
|
tensor_parallel_size,
|
|
dtype,
|
|
revision,
|
|
tokenizer_revision,
|
|
backend=backend,
|
|
**kwargs)
|
|
|
|
@set_api_status("prototype")
|
|
def _collective_rpc(self,
|
|
method: str,
|
|
args: tuple[Any, ...] = (),
|
|
kwargs: Optional[dict] = None,
|
|
non_block: bool = False,
|
|
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
|
"""
|
|
Execute an RPC call on all GPU workers. Currently, this is only supported for RayExecutor.
|
|
|
|
Args:
|
|
method (str): The name of the worker method to execute.
|
|
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
|
|
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
|
|
non_block (bool): Whether to block until all workers have completed the RPC call. Defaults to False.
|
|
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply. Defaults to None.
|
|
|
|
Returns:
|
|
list[Any]: A list of results from each worker.
|
|
"""
|
|
if hasattr(self._executor, 'collective_rpc'):
|
|
return self._executor.collective_rpc(method, args, kwargs,
|
|
non_block, unique_reply_rank)
|
|
else:
|
|
raise ValueError(
|
|
f"Executor type {type(self._executor)} does not support collective RPC."
|
|
)
|
|
|
|
def _build_model(self):
|
|
super()._build_model()
|
|
assert self._engine_dir is None
|
|
|
|
# Tokenizer and config loading should be after calling model_loader(), since model_loader() may download the model from HF hub.
|
|
# It should also be before bindings ExecutorConfig, which may depend on tokenizer info.
|
|
self._tokenizer = self._try_load_tokenizer()
|
|
self._hf_model_config = self._try_load_hf_model_config()
|
|
self._generation_config = self._try_load_generation_config()
|
|
|
|
# Multimodal special handling:
|
|
# 1. Default load_tokenizer may fail because MM has different tokenizer configuration. Hence we initialize it inside input processor
|
|
# 2. May need to modify model weights for MM (e.g., resize vocab embedding). We must do such operation via input processor's __init__
|
|
checkpoint_format = getattr(self.args, "checkpoint_format", None)
|
|
self.input_processor = create_input_processor(self._hf_model_dir,
|
|
self.tokenizer,
|
|
checkpoint_format)
|
|
self._tokenizer = self.input_processor.tokenizer
|
|
|
|
# TODO: revisit gather_context_logits
|
|
return_logits = self.args.gather_generation_logits
|
|
self._executor = self._executor_cls.create(
|
|
self._engine_dir,
|
|
executor_config=None,
|
|
batched_logits_processor=self.args.batched_logits_processor,
|
|
model_world_size=self.args.parallel_config.world_size,
|
|
mpi_session=self.mpi_session,
|
|
reuse_mpi_comm=external_mpi_comm_available(
|
|
self.args.parallel_config.world_size),
|
|
return_logits=return_logits,
|
|
postproc_worker_config=PostprocWorkerConfig(
|
|
num_postprocess_workers=self.args.num_postprocess_workers,
|
|
postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir,
|
|
),
|
|
is_llm_executor=True,
|
|
hf_model_dir=self._hf_model_dir,
|
|
tokenizer=self.tokenizer,
|
|
llm_args=self.args)
|
|
|
|
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
|
|
"""Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend.
|
|
"""
|
|
trtllm_fields = set(TrtLlmArgs.model_fields.keys())
|
|
torchllm_fields = set(TorchLlmArgs.model_fields.keys())
|
|
|
|
trtllm_specific_fields = trtllm_fields - torchllm_fields
|
|
|
|
# Check if any TrtLlmArgs-specific arguments are passed
|
|
trtllm_specific_args = []
|
|
for key in kwargs:
|
|
if key in trtllm_specific_fields:
|
|
trtllm_specific_args.append(key)
|
|
|
|
if trtllm_specific_args:
|
|
raise ValueError(
|
|
f"The following arguments are specific to TensorRT backend and cannot be used with PyTorch backend: {trtllm_specific_args}.\n"
|
|
f"Please use 'from tensorrt_llm._tensorrt_engine import LLM' instead to use the TensorRT backend."
|
|
)
|
|
|
|
|
|
class LLM(_TorchLLM):
|
|
|
|
def __init__(self,
|
|
model: Union[str, Path],
|
|
tokenizer: Optional[Union[str, Path, TokenizerBase,
|
|
PreTrainedTokenizerBase]] = None,
|
|
tokenizer_mode: Literal['auto', 'slow'] = 'auto',
|
|
skip_tokenizer_init: bool = False,
|
|
trust_remote_code: bool = False,
|
|
tensor_parallel_size: int = 1,
|
|
dtype: str = "auto",
|
|
revision: Optional[str] = None,
|
|
tokenizer_revision: Optional[str] = None,
|
|
**kwargs: Any) -> None:
|
|
super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,
|
|
trust_remote_code, tensor_parallel_size, dtype,
|
|
revision, tokenizer_revision, **kwargs)
|
|
|
|
|
|
# sphinx will ignore the LLM's docstring if it is not explicitly set
|
|
LLM.__doc__ = \
|
|
f"""LLM class is the main class for running a LLM model.
|
|
|
|
For more details about the arguments, please refer to :class:`TorchLlmArgs`.
|
|
|
|
Parameters:
|
|
""" + TORCH_LLM_DOCSTRING
|