TensorRT-LLMs/tensorrt_llm/commands/utils.py
Chang Liu 26901e4aa0
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2026-02-14 06:11:11 +08:00

133 lines
4.4 KiB
Python

# Adapted from https://github.com/sgl-project/sglang/blob/030496eb06472f76fcb11de53d93f10cefb4604f/python/sglang/cli/utils.py#L27
import json
import logging
import os
from tensorrt_llm.llmapi.utils import download_hf_partial
logger = logging.getLogger(__name__)
def _maybe_download_model(
model_name_or_path: str, local_dir: str | None = None, download: bool = True
) -> str:
"""Resolve a model path. If it's a local directory, return it.
If it's a Hugging Face Hub ID, download only the config file
(`model_index.json` or `config.json`) and return its directory.
Args:
model_name_or_path: Local path or Hugging Face Hub model ID
local_dir: Local directory to save the downloaded file (if any)
download: Whether to download from Hugging Face Hub when needed
Returns:
Local directory path that contains the downloaded config file, or the original local directory.
"""
if os.path.exists(model_name_or_path):
logger.info("Model already exists locally")
return model_name_or_path
if not download:
return model_name_or_path
try:
logger.info(
"Downloading model_index.json from HF Hub for %s...",
model_name_or_path,
)
file_path = download_hf_partial(
model=model_name_or_path,
allow_patterns=["model_index.json", "config.json"],
)
logger.info("Downloaded to %s", file_path)
return str(file_path)
except Exception as e:
raise ValueError(
(
"Could not find model locally at %s and failed to download "
"model_index.json/config.json from HF Hub: %s"
)
% (model_name_or_path, e)
) from e
# Copied and adapted from hf_diffusers_utils.py
def is_diffusers_model_path(model_path: str) -> bool:
"""Verify if the model directory contains a valid diffusers configuration.
Args:
model_path: Path to the model directory
Returns:
The loaded model configuration as a dictionary if the model is a diffusers model
None if the model is not a diffusers model
"""
# Prefer model_index.json which indicates a diffusers pipeline
config_path = os.path.join(model_path, "model_index.json")
if not os.path.exists(config_path):
return False
# Load the config
with open(config_path) as f:
config = json.load(f)
# Verify diffusers version exists
if "_diffusers_version" not in config:
return False
return True
def get_is_diffusion_model(model_path: str):
model_path = _maybe_download_model(model_path)
is_diffusion_model = is_diffusers_model_path(model_path)
if is_diffusion_model:
logger.info("Diffusion model detected")
return is_diffusion_model
def get_model_path(extra_argv):
# Find the model_path argument
model_path = None
for i, arg in enumerate(extra_argv):
if arg == "--model-path":
if i + 1 < len(extra_argv):
model_path = extra_argv[i + 1]
break
elif arg.startswith("--model-path="):
model_path = arg.split("=", 1)[1]
break
if model_path is None:
# Fallback for --help or other cases where model-path is not provided
if any(h in extra_argv for h in ["-h", "--help"]):
raise Exception(
"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n\n"
"This command can launch either a standard language model server or a diffusion model server.\n"
"The server type is determined by the model path.\n"
"For specific arguments, please provide a model_path."
)
else:
raise Exception(
"Error: --model-path is required. Please provide the path to the model."
)
return model_path
VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE = {
"FLUX.2": "flux2",
"LTX-2": "ltx2",
"Wan2": "wan2",
}
def get_visual_gen_model_type(model_path: str):
for partial_model_name, model_type in VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE.items():
if partial_model_name.lower() in model_path.lower():
return model_type
raise ValueError(
f"Unknown VISUAL_GEN model type for model path: {model_path},"
f"available models: {VISUAL_GEN_PARTIAL_MODEL_NAME_TO_MODEL_TYPE.keys()}"
)