TensorRT-LLMs/tests/test_common/llm_data.py
Anish Shanbhag 76798fc52b Mock snapshot_download to avoid download from HF
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2026-01-12 10:53:19 -08:00

116 lines
4.7 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for local LLM model paths and HuggingFace download mocking."""
import os
from functools import wraps
from pathlib import Path
from typing import Optional
from unittest.mock import patch
# Mapping from HuggingFace Hub ID to local subdirectory under LLM_MODELS_ROOT.
# NOTE: hf_id_to_llm_models_subdir below will fall back to checking if the model name exists
# in LLM_MODELS_ROOT if not present here, so it's not required to exhaustively list all
# models here.
HF_ID_TO_LLM_MODELS_SUBDIR = {
"meta-llama/Meta-Llama-3.1-8B-Instruct": "llama-3.1-model/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct": "llama-3.1-model/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.1-8B": "llama-3.1-model/Meta-Llama-3.1-8B",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0": "llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
"meta-llama/Llama-4-Scout-17B-16E-Instruct": "llama4-models/Llama-4-Scout-17B-16E-Instruct",
"mistralai/Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": "Mistral-Small-3.1-24B-Instruct-2503",
"Qwen/Qwen3-30B-A3B": "Qwen3/Qwen3-30B-A3B",
"Qwen/Qwen2.5-3B-Instruct": "Qwen2.5-3B-Instruct",
"microsoft/Phi-3-mini-4k-instruct": "Phi-3/Phi-3-mini-4k-instruct",
"deepseek-ai/DeepSeek-V3": "DeepSeek-V3",
"deepseek-ai/DeepSeek-R1": "DeepSeek-R1/DeepSeek-R1",
"ibm-ai-platform/Bamba-9B-v2": "Bamba-9B-v2",
"nvidia/NVIDIA-Nemotron-Nano-12B-v2": "NVIDIA-Nemotron-Nano-12B-v2",
"nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3": "NVIDIA-Nemotron-Nano-31B-A3-v3",
"nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": "Nemotron-Nano-3-30B-A3.5B-dev-1024",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B": "EAGLE3-LLaMA3.1-Instruct-8B",
}
def llm_models_root(check: bool = False) -> Optional[Path]:
root = Path("/home/scratch.trt_llm_data/llm-models/")
if "LLM_MODELS_ROOT" in os.environ:
root = Path(os.environ.get("LLM_MODELS_ROOT"))
if not root.exists():
root = Path("/scratch.trt_llm_data/llm-models/")
if check:
assert root.exists(), (
"You must set LLM_MODELS_ROOT env or be able to access /home/scratch.trt_llm_data to run this test"
)
return root if root.exists() else None
def llm_datasets_root() -> str:
return os.path.join(llm_models_root(check=True), "datasets")
def hf_id_to_local_model_dir(hf_hub_id: str) -> str | None:
"""Return the local model directory under LLM_MODELS_ROOT for a given HuggingFace Hub ID, or None if not found."""
root = llm_models_root()
if root is None:
return None
if hf_hub_id in HF_ID_TO_LLM_MODELS_SUBDIR:
return str(root / HF_ID_TO_LLM_MODELS_SUBDIR[hf_hub_id])
# Fall back to checking if the model name exists as a top-level directory in LLM_MODELS_ROOT
model_name = hf_hub_id.split("/")[-1]
if os.path.isdir(root / model_name):
return str(root / model_name)
return None
def hf_model_dir_or_hub_id(hf_hub_id: str) -> str:
"""Resolve a HuggingFace Hub ID to local path if available, otherwise return the Hub ID."""
return hf_id_to_local_model_dir(hf_hub_id) or hf_hub_id
def mock_snapshot_download(repo_id: str, **kwargs) -> str:
"""Mock huggingface_hub.snapshot_download that returns an existing local model directory.
NOTE: This function does not currently handle the revision / allow_patterns / ignore_patterns parameters.
"""
local_path = hf_id_to_local_model_dir(repo_id)
if local_path is None:
raise ValueError(f"Model '{repo_id}' not found in LLM_MODELS_ROOT")
return local_path
def with_mocked_hf_download(func):
"""Decorator to mock huggingface_hub.snapshot_download for tests.
When applied, any calls to snapshot_download will be redirected to use
local model paths from LLM_MODELS_ROOT instead of downloading from HuggingFace.
"""
@wraps(func)
def wrapper(*args, **kwargs):
with patch("huggingface_hub.snapshot_download", side_effect=mock_snapshot_download):
return func(*args, **kwargs)
return wrapper