TensorRT-LLMs/tensorrt_llm/bench/dataclasses/configuration.py
Lucas Liebenwein 743fb0a159
[AutoDeploy] _AutoDeployLlmArgs as primary config object (#4891)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
2025-06-05 17:20:55 +08:00

239 lines
8.6 KiB
Python
Executable File

from __future__ import annotations
from dataclasses import dataclass
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import (BaseModel, Field, PositiveFloat, field_validator,
model_validator)
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import (BatchingType, CapacitySchedulerPolicy,
ContextChunkingPolicy, DynamicBatchConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig,
SchedulerConfig)
from tensorrt_llm.llmapi.llm_args import _AutoDeployLlmArgs
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
SPECULATIVE_MAP = {
SpeculativeDecodingMode.NONE: lambda *args: None,
SpeculativeDecodingMode.MEDUSA: trtllm.DecodingMode.Medusa,
}
class RuntimeConfig(BaseModel):
model: str
model_path: Optional[Path] = None
engine_dir: Optional[Path] = None
sw_version: str
settings_config: ExecutorSettingsConfig
world_config: ExecutorWorldConfig
decoding_config: Optional[DecodingConfig] = None
performance_options: PerformanceOptions
backend: Literal["pytorch", "_autodeploy", None] = None
extra_llm_api_options: Optional[str] = None
iteration_log: Optional[Path] = None
def get_llm_args(self) -> Dict:
model = self.engine_dir or self.model_path or self.model
llm_args = {
"scheduler_config":
self.settings_config.get_scheduler_config(),
"model":
model,
"skip_tokenizer_init":
True,
"pipeline_parallel_size":
self.world_config.pp_size,
"tensor_parallel_size":
self.world_config.tp_size,
"gpus_per_node":
self.world_config.gpus_per_node,
"moe_expert_parallel_size":
self.world_config.ep_size,
"moe_cluster_parallel_size":
self.world_config.cluster_size,
"trust_remote_code":
True,
"kv_cache_config":
self.settings_config.get_kvcache_config(),
"enable_chunked_prefill":
self.settings_config.chunking,
"extended_runtime_perf_knob_config":
self.performance_options.get_perf_config(),
"decoding_config":
self.decoding_config.get_decoding_config(),
"batching_type":
BatchingType.INFLIGHT,
"max_batch_size":
self.settings_config.max_batch_size,
"max_num_tokens":
self.settings_config.max_num_tokens,
}
backend_config_map = {
"pytorch": self.performance_options.get_pytorch_perf_config,
"_autodeploy": self.performance_options.get_autodeploy_perf_config
}
if self.backend in backend_config_map:
llm_args.update(backend_config_map[self.backend]())
return update_llm_args_with_extra_options(llm_args,
self.extra_llm_api_options)
@model_validator(mode="after")
def validate_full_config(self) -> RuntimeConfig:
# TODO: Check engine to make sure it can support Medusa.
return self
@dataclass
class PerformanceOptions:
cuda_graphs: bool = False
multi_block_mode: bool = True
cuda_graph_cache_size: int = 1000
pytorch_config: Dict[str, Any] = Field(default_factory=dict)
def get_perf_config(self) -> ExtendedRuntimePerfKnobConfig:
config = ExtendedRuntimePerfKnobConfig()
config.cuda_graph_mode = self.cuda_graphs
config.multi_block_mode = self.multi_block_mode
config.cuda_graph_cache_size = self.cuda_graph_cache_size
return config
def get_pytorch_perf_config(self) -> PyTorchConfig:
return self.pytorch_config
def get_autodeploy_perf_config(self) -> _AutoDeployLlmArgs:
ad_config = _AutoDeployLlmArgs(**self.pytorch_config)
ad_config.attn_backend = "FlashInfer"
ad_config.torch_compile_enabled = True
ad_config.skip_loading_weights = True
return ad_config
class DecodingConfig(BaseModel):
medusa_choices: Optional[List[List[int]]] = None
decoding_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
@field_validator("decoding_mode")
@classmethod
def decoding_mode_validator(
cls, value: Union[str, int,
SpeculativeDecodingMode]) -> SpeculativeDecodingMode:
return SpeculativeDecodingMode(value)
@model_validator(mode="after")
def validate_speculative_decoding(self) -> DecodingConfig:
if self.medusa_choices and self.decoding_mode != SpeculativeDecodingMode.MEDUSA:
raise RuntimeError(
"Attempting to use set Medusa choices with a non-Medusa engine."
" Verify that you are using a Medusa engine.")
return self
def get_decoding_config(self) -> trtllm.DecodingConfig:
"""Create a populated TRT-LLM DecodingConfig."""
kwargs = {"decoding_mode": SPECULATIVE_MAP[self.decoding_mode]()}
if self.medusa_choices is not None:
kwargs["medusa_choices"] = self.medusa_choices
return trtllm.DecodingConfig(**kwargs)
class ExecutorWorldConfig(BaseModel):
pp_size: int = 1
tp_size: int = 1
# None to make LLM-API deduce it with a rule.
gpus_per_node: Optional[int] = None
leader_mode: bool = False
ep_size: Optional[int] = None
cluster_size: Optional[int] = None
@model_validator(mode="after")
def validate_world_size(self) -> ExecutorWorldConfig:
if self.gpus_per_node is None:
return self
parallel_world = self.pp_size * self.tp_size
num_gpus = self.world_size * self.gpus_per_node
valid_world = bool(num_gpus >= parallel_world)
if not valid_world:
raise ValueError(
f"World configuration is invalid, TP * PP ({parallel_world})"
"does not equal the total number of available GPUs"
f"({num_gpus}).")
return self
@property
def world_size(self) -> int:
return self.pp_size * self.tp_size
def _get_tensorrt_llm_executor_worker_path(self) -> Path:
module_path = find_spec("tensorrt_llm").loader.get_filename()
exec_path = Path(module_path).parent / 'bin' / 'executorWorker'
return exec_path.absolute()
def get_parallel_config(self) -> trtllm.ParallelConfig:
if self.leader_mode:
comm_mode = trtllm.CommunicationMode.LEADER
orchestrator_config = None
else:
comm_mode = trtllm.CommunicationMode.ORCHESTRATOR
orchestrator_config = trtllm.OrchestratorConfig(
True, str(self._get_tensorrt_llm_executor_worker_path()))
return trtllm.ParallelConfig(
trtllm.CommunicationType.MPI,
comm_mode,
orchestrator_config=orchestrator_config,
)
class ExecutorSettingsConfig(BaseModel):
chunking: bool = True
scheduler_policy: CapacitySchedulerPolicy = Field(
default=CapacitySchedulerPolicy.MAX_UTILIZATION)
max_batch_size: int
max_num_tokens: int
kv_cache_percent: PositiveFloat = Field(default=.90, le=1.0)
kv_cache_reuse: bool = False
dynamic_max_batch_size: bool = True
dynamic_max_num_tokens: bool = False # Will enable after more validation.
def get_dynamic_config(self) -> DynamicBatchConfig:
window_size = 128 if self.dynamic_max_batch_size else 0
return DynamicBatchConfig(
enable_batch_size_tuning=self.dynamic_max_batch_size,
enable_max_num_tokens_tuning=self.dynamic_max_num_tokens,
dynamic_batch_moving_average_window=window_size,
)
def get_kvcache_config(self) -> KvCacheConfig:
return KvCacheConfig(
free_gpu_memory_fraction=self.kv_cache_percent,
enable_block_reuse=False,
)
def get_scheduler_config(self) -> SchedulerConfig:
if self.chunking:
return SchedulerConfig(
capacity_scheduler_policy=self.scheduler_policy,
context_chunking_policy=ContextChunkingPolicy.
FIRST_COME_FIRST_SERVED,
dynamic_batch_config=self.get_dynamic_config(),
)
else:
return SchedulerConfig(
capacity_scheduler_policy=self.scheduler_policy,
dynamic_batch_config=self.get_dynamic_config())