mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Co-authored-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterator, List
|
|
|
|
import yaml
|
|
from pydantic import BaseModel, Field, RootModel
|
|
|
|
DATABASE_LIST_PATH = Path(__file__).parent / "lookup.yaml"
|
|
|
|
|
|
class RecipeConstraints(BaseModel):
|
|
"""Recipe record for scenario list."""
|
|
|
|
model: str = Field(description="Model name")
|
|
gpu: str = Field(description="GPU name")
|
|
isl: int = Field(description="Input sequence length")
|
|
osl: int = Field(description="Output sequence length")
|
|
concurrency: int = Field(description="Concurrency")
|
|
config_path: str = Field(description="Configuration path")
|
|
num_gpus: int = Field(description="Number of GPUs")
|
|
|
|
def load_config(self) -> Dict[str, Any]:
|
|
"""Load and return the YAML config at config_path."""
|
|
with open(self.config_path) as f:
|
|
data = yaml.safe_load(f)
|
|
return data if data is not None else {}
|
|
|
|
|
|
class Recipe(BaseModel):
|
|
"""Recipe that describes a single scenario."""
|
|
|
|
constraints: RecipeConstraints = Field(description="Recipe constraints")
|
|
env_overrides: Dict[str, Any] = Field(description="Environment overrides", default_factory=dict)
|
|
config: Dict[str, Any] = Field(description="Configuration overrides", default_factory=dict)
|
|
|
|
|
|
class RecipeList(RootModel[List[RecipeConstraints]]):
|
|
@classmethod
|
|
def from_yaml(cls, yaml_path: Path) -> "RecipeList":
|
|
"""Load and validate recipe list from YAML file."""
|
|
with open(yaml_path) as f:
|
|
data = yaml.safe_load(f)
|
|
return cls(data)
|
|
|
|
def __iter__(self) -> Iterator[RecipeConstraints]:
|
|
return iter(self.root)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.root)
|