mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
234 lines
7.9 KiB
Python
234 lines
7.9 KiB
Python
import json
|
|
import os
|
|
import subprocess
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import pytest
|
|
from utils.llm_data import llm_models_root
|
|
|
|
# Constants for test configuration
|
|
_DEFAULT_NUM_REQUESTS = 3
|
|
_DEFAULT_INPUT_MEAN = 100
|
|
_DEFAULT_INPUT_STDEV = 10
|
|
_DEFAULT_OUTPUT_MEAN = 100
|
|
_DEFAULT_OUTPUT_STDEV = 10
|
|
_TEST_TASK_IDS = [0, 1, 2]
|
|
_TOKENIZER_SUBPATH = "llama-models-v2/tinyllama-tarot-v1/"
|
|
_PREPARE_DATASET_SCRIPT_PATH = "benchmarks/cpp/prepare_dataset.py"
|
|
|
|
|
|
class TestPrepareDatasetLora:
|
|
"""
|
|
Test suite for prepare_dataset.py CLI tool LoRA metadata generation
|
|
functionality.
|
|
|
|
This test class validates that the prepare_dataset.py script correctly
|
|
generates LoRA request metadata when LoRA-specific parameters are provided.
|
|
It covers both fixed task ID and random task ID scenarios.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def temp_lora_dir(self) -> str:
|
|
"""
|
|
Create a temporary LoRA directory structure for testing.
|
|
|
|
Creates a temporary directory with subdirectories for each test task
|
|
ID, simulating the expected LoRA adapter directory structure.
|
|
|
|
Returns:
|
|
str: Path to the temporary LoRA directory
|
|
"""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
lora_dir = Path(temp_dir) / "loras"
|
|
# Create dummy LoRA adapter directories for each test task ID
|
|
for task_id in _TEST_TASK_IDS:
|
|
task_dir = lora_dir / str(task_id)
|
|
task_dir.mkdir(parents=True, exist_ok=True)
|
|
yield str(lora_dir)
|
|
|
|
def _build_base_command(self, output_path: Path) -> List[str]:
|
|
"""
|
|
Build the base command for running prepare_dataset.py.
|
|
|
|
Args:
|
|
output_path: Path to the output dataset file
|
|
|
|
Returns:
|
|
List[str]: Base command components
|
|
|
|
Raises:
|
|
pytest.skip: If LLM_MODELS_ROOT is not available
|
|
"""
|
|
cmd = ["trtllm-bench"]
|
|
|
|
# Add required tokenizer argument
|
|
model_cache = llm_models_root()
|
|
if model_cache is None:
|
|
pytest.skip("LLM_MODELS_ROOT not available")
|
|
|
|
tokenizer_dir = model_cache / _TOKENIZER_SUBPATH
|
|
cmd.extend(["--model", str(tokenizer_dir)])
|
|
|
|
# Always add --stdout flag since we parse stdout output
|
|
cmd.extend(["prepare-dataset", "--output", f"{output_path}"])
|
|
|
|
return cmd
|
|
|
|
def _add_lora_arguments(self, cmd: List[str], **kwargs) -> None:
|
|
"""
|
|
Add LoRA-specific arguments to the command.
|
|
|
|
Args:
|
|
cmd: Command list to modify in-place
|
|
**kwargs: Keyword arguments containing LoRA parameters
|
|
"""
|
|
if "lora_dir" in kwargs:
|
|
cmd.extend(["--lora-dir", kwargs["lora_dir"]])
|
|
if "task_id" in kwargs:
|
|
cmd.extend(["--task-id", str(kwargs["task_id"])])
|
|
if "rand_task_id" in kwargs:
|
|
min_id, max_id = kwargs["rand_task_id"]
|
|
cmd.extend(["--rand-task-id", str(min_id), str(max_id)])
|
|
|
|
def _add_synthetic_data_arguments(self, cmd: List[str]) -> None:
|
|
"""
|
|
Add synthetic data generation arguments to the command.
|
|
|
|
Args:
|
|
cmd: Command list to modify in-place
|
|
"""
|
|
cmd.extend([
|
|
"token-norm-dist", "--num-requests",
|
|
str(_DEFAULT_NUM_REQUESTS), "--input-mean",
|
|
str(_DEFAULT_INPUT_MEAN), "--input-stdev",
|
|
str(_DEFAULT_INPUT_STDEV), "--output-mean",
|
|
str(_DEFAULT_OUTPUT_MEAN), "--output-stdev",
|
|
str(_DEFAULT_OUTPUT_STDEV)
|
|
])
|
|
|
|
def _run_prepare_dataset(self, **kwargs) -> str:
|
|
"""
|
|
Execute prepare_dataset.py with specified parameters and capture
|
|
output.
|
|
|
|
Args:
|
|
llm_root: Path to the TensorRT LLM root directory
|
|
**kwargs: Keyword arguments for LoRA configuration
|
|
|
|
Returns:
|
|
str: Standard output from the executed command
|
|
|
|
Raises:
|
|
subprocess.CalledProcessError: If the command execution fails
|
|
"""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
output_path = Path(temp_dir) / "dataset.jsonl"
|
|
cmd = self._build_base_command(output_path)
|
|
self._add_lora_arguments(cmd, **kwargs)
|
|
self._add_synthetic_data_arguments(cmd)
|
|
|
|
# Execute command and capture output
|
|
subprocess.run(cmd, check=True, cwd=temp_dir)
|
|
|
|
data = ""
|
|
with open(output_path, "r") as f:
|
|
data = f.read()
|
|
|
|
return data
|
|
|
|
def _parse_json_output(self, output: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Parse JSON lines from prepare_dataset.py output.
|
|
|
|
Args:
|
|
output: Raw stdout output containing JSON lines
|
|
|
|
Returns:
|
|
List[Dict[str, Any]]: Parsed JSON data objects
|
|
"""
|
|
lines = output.strip().split('\n')
|
|
json_data = []
|
|
|
|
for line in lines:
|
|
if line.strip():
|
|
try:
|
|
data = json.loads(line)
|
|
json_data.append(data)
|
|
except json.JSONDecodeError:
|
|
# Skip non-JSON lines (such as debug prints)
|
|
continue
|
|
|
|
return json_data
|
|
|
|
def _validate_lora_request(self,
|
|
lora_request: Dict[str, Any],
|
|
expected_lora_dir: str,
|
|
task_id_range: Tuple[int, int] = None) -> None:
|
|
"""Validate LoRA request structure and content."""
|
|
# Check required fields
|
|
required_fields = ["lora_name", "lora_int_id", "lora_path"]
|
|
for field in required_fields:
|
|
assert field in lora_request, f"Missing '{field}' in LoRA request"
|
|
|
|
task_id = lora_request["lora_int_id"]
|
|
|
|
# Validate task ID range if specified
|
|
if task_id_range:
|
|
min_id, max_id = task_id_range
|
|
assert min_id <= task_id <= max_id, (
|
|
f"Task ID {task_id} out of range [{min_id}, {max_id}]")
|
|
|
|
# Validate structure
|
|
expected_name = f"lora_{task_id}"
|
|
expected_path = os.path.join(expected_lora_dir, str(task_id))
|
|
|
|
assert lora_request["lora_name"] == expected_name, (
|
|
f"Expected LoRA name '{expected_name}', "
|
|
f"got '{lora_request['lora_name']}'")
|
|
assert lora_request["lora_path"] == expected_path, (
|
|
f"Expected LoRA path '{expected_path}', "
|
|
f"got '{lora_request['lora_path']}'")
|
|
|
|
@pytest.mark.parametrize("test_params", [
|
|
pytest.param({
|
|
"task_id": 1,
|
|
"description": "fixed task ID"
|
|
},
|
|
id="fixed_task_id"),
|
|
pytest.param(
|
|
{
|
|
"rand_task_id": (0, 2),
|
|
"description": "random task ID range"
|
|
},
|
|
id="random_task_id")
|
|
])
|
|
def test_lora_metadata_generation(self, temp_lora_dir: str,
|
|
test_params: Dict) -> None:
|
|
"""Test LoRA metadata generation with various configurations."""
|
|
# Extract test parameters
|
|
task_id = test_params.get("task_id")
|
|
rand_task_id = test_params.get("rand_task_id")
|
|
description = test_params["description"]
|
|
|
|
# Run prepare_dataset
|
|
kwargs = {"lora_dir": temp_lora_dir}
|
|
if task_id is not None:
|
|
kwargs["task_id"] = task_id
|
|
if rand_task_id is not None:
|
|
kwargs["rand_task_id"] = rand_task_id
|
|
|
|
output = self._run_prepare_dataset(**kwargs)
|
|
json_data = self._parse_json_output(output)
|
|
|
|
assert len(json_data) > 0, f"No JSON data generated for {description}"
|
|
|
|
# Validate LoRA requests
|
|
for i, item in enumerate(json_data):
|
|
assert "lora_request" in item, (
|
|
f"Missing 'lora_request' in JSON entry {i} for {description}")
|
|
self._validate_lora_request(item["lora_request"],
|
|
temp_lora_dir,
|
|
task_id_range=rand_task_id)
|