TensorRT-LLMs/examples/auto_deploy/simple_config.py
Lucas Liebenwein 5cdd6bb10f
[AutoDeploy] Increased Model Coverage Mass Migration Week 1 (#4468)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Co-authored-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Co-authored-by: sugunav14 <178320438+sugunav14@users.noreply.github.com>
Co-authored-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-05-27 16:43:15 +08:00

140 lines
7.1 KiB
Python

"""A simple config for Llama-2 building and generating scripts.
Modify directly if you want to change settings.
"""
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Union
@dataclass
class SimpleConfig:
"""Experiment Configuration."""
### MODEL ARG #############################################################################
# Path or repo_id for a HF model directory
# The model directory should contain model weights and tokenizer configs. Model weights can be
# provided as either of the following:
# 1. Sharded checkpoint (multiple files) in the safetensors format
# 2. Single, unsharded checkpoint in the safetensors format
# 3. Single, unsharded checkpoint in the pytorch format (.pt/.pth) file ending.
# If no `model` argument is provided, the checkpoint directory is used to infer the model
# architecture.
model: Optional[str] = None
model_factory: Literal["AutoModelForCausalLM", "AutoModelForImageTextToText"] = (
"AutoModelForCausalLM"
)
skip_loading_weights: bool = False # only load the architecture, not the weights
customize_tokenizer: bool = False # True: tokenizer from the model factory, False: from LLM api
### MODEL EXTRA KWARGS #########################################################################
# Extra kwargs for the model config class to customize the model config. Those arguments will
# take precedence over the default values or config values in the model config file in the HF
# directory. Arguments are resolved in the following order:
# 1. Default values in the model config class
# 2. Values in the model config file in the HF directory
# 3. Values in the model_kwargs
# Note that that if the kwarg does not exist in the model config class, it will be ignored.
# An example model config class can be found [here](https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/models/llama/configuration_llama.py#L26).
model_kwargs: Dict = field(default_factory=dict)
# TODO: temp fix for dashboard to modify the number of hidden layers
num_hidden_layers: int = -1
### TOKENIZER EXTRA KWARGS #####################################################################
# Extra kwargs for the tokenizer class to customize the tokenizer. Same as model_kwargs.
# For example, the default HF Llama tokenizer can be initialized with the arguments specified
# [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama_fast.py#L127).
# NOTE: This is only used if customize_tokenizer is True
tokenizer_kwargs: Dict = field(default_factory=dict)
### CONFIGURE BACKEND, RUNTIME, AND WORLD SIZE ##################################
world_size: int = 1 # choose from number of GPUs for TP (0--> no TP, no spawned processes)
runtime: Literal["demollm", "trtllm"] = "trtllm"
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
"torch-compile"
)
attn_backend: Literal["TritonWithFlattenedInputs", "FlashInfer"] = "FlashInfer"
mla_backend: Literal["MultiHeadLatentAttention"] = "MultiHeadLatentAttention"
max_seq_len: int = 512 # max sequence length for inference/cache
max_batch_size: int = 8 # max dimension for statically allocated kv cache
page_size: int = 64 # page size for attention
### SOME SIMPLE PROMPTING CONFIG ###############################################################
batch_size: int = 2 # example input shape
device: str = "cuda"
prompt: Union[str, List[str]] = field(
default_factory=lambda: [
"How big is the universe? ",
"In simple words and in a single sentence, explain the concept of gravity: ",
"How to fix slicing in golf? ",
"Where is the capital of Iceland? ",
"How big is the universe? ",
"In simple words and in a single sentence, explain the concept of gravity: ",
"How to fix slicing in golf? ",
"Where is the capital of Iceland? ",
]
)
max_tokens: int = 100
top_k: int = 200
temperature: float = 1.0
visualize: bool = False
### BENCHMARKING CONFIG ########################################################################
free_mem_ratio: float = 0.0 # specifies the fraction of available memory to occupy for cache
benchmark: bool = False # If true, set ISO to 2048 random int and OSL to 128
benchmark_num: int = 10 # By default run 10 times and get average
benchmark_isl: int = 2048 # input seq length for benchmarking
benchmark_osl: int = 128 # output seq length for benchmarking
benchmark_bs: int = 1 # batch size for benchmarking
benchmark_results_path: Optional[str] = "./benchmark_results.json"
### POST INITIALIZATION ########################################################################
def __post_init__(self):
# check if model was supplied
assert self.model, "model must be supplied!"
# we don't want to loose the default values for model_kwargs unless explicitly set by the
# user. They are not preserved by the standard initialization process since they whole dict
# gets replaced by the user provided one. We don't want that though.
f_default = self.__dataclass_fields__["model_kwargs"].default_factory()
setattr(self, "model_kwargs", {**f_default, **getattr(self, "model_kwargs")})
# special handling for torch_dtype in model_kwargs since HF does not correctly update
# torch_dtype string to an actual torch.dtype object (only with default)
if "torch_dtype" in self.model_kwargs:
import torch
dtype = self.model_kwargs["torch_dtype"]
if isinstance(dtype, str):
dtype = getattr(torch, self.model_kwargs["torch_dtype"])
assert isinstance(dtype, torch.dtype), f"Invalid dtype: {dtype}"
self.model_kwargs["torch_dtype"] = dtype
self.max_batch_size = max(self.max_batch_size, self.batch_size)
# make sure benchmark isl/osl/bs fits into available resources
if self.benchmark:
self.max_batch_size = max(self.benchmark_bs, self.max_batch_size)
self.max_seq_len = max(self.max_seq_len, self.benchmark_isl + self.benchmark_osl)
# No paging allowed in TritonWithFlattenedInputs
if self.attn_backend in ["TritonWithFlattenedInputs"]:
self.page_size = self.max_seq_len
# use min instead of max to avoid OOM for large batch size
self.model_kwargs["max_position_embeddings"] = min(
self.max_seq_len,
self.model_kwargs.get("max_position_embeddings", self.max_seq_len),
)
if isinstance(self.prompt, str):
self.prompt = [self.prompt]
# replicate prompts to get to batch_size
prompts = self.prompt * (self.batch_size // len(self.prompt) + 1)
self.prompt = prompts[: self.batch_size]
if self.num_hidden_layers != -1:
self.model_kwargs["num_hidden_layers"] = self.num_hidden_layers