mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
parent
d8e6e22060
commit
2146c23786
@ -14,7 +14,7 @@ from pydantic_settings import (
|
||||
SettingsConfigDict,
|
||||
)
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy import LLM, AutoDeployConfig, DemoLLM
|
||||
from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.utils._config import (
|
||||
DynamicYamlMixInForSettings,
|
||||
@ -142,7 +142,6 @@ class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
# The main AutoDeploy arguments - contains model, tokenizer, backend configs, etc.
|
||||
args: LlmArgs = Field(
|
||||
description="The main AutoDeploy arguments containing model, tokenizer, backend configs, etc. "
|
||||
"Contains all the fields from `AutoDeployConfig` and `BaseLlmArgs`. "
|
||||
"Please check `tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs` for more details."
|
||||
)
|
||||
|
||||
@ -213,7 +212,7 @@ class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
def sync_model_with_args(cls, model_value, info):
|
||||
if "args" not in info.data:
|
||||
return model_value
|
||||
args: AutoDeployConfig = info.data["args"]
|
||||
args: LlmArgs = info.data["args"]
|
||||
return args.model
|
||||
|
||||
@field_validator("prompt", mode="after")
|
||||
@ -221,7 +220,7 @@ class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
def sync_prompt_batch_size_with_args_max_batch_size(cls, prompt: PromptConfig, info):
|
||||
if "args" not in info.data:
|
||||
return prompt
|
||||
args: AutoDeployConfig = info.data["args"]
|
||||
args: LlmArgs = info.data["args"]
|
||||
if args.max_batch_size < prompt.batch_size:
|
||||
args.max_batch_size = prompt.batch_size
|
||||
return prompt
|
||||
@ -231,7 +230,7 @@ class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
def adjust_args_for_benchmark(cls, benchmark: BenchmarkConfig, info):
|
||||
if "args" not in info.data:
|
||||
return benchmark
|
||||
args: AutoDeployConfig = info.data["args"]
|
||||
args: LlmArgs = info.data["args"]
|
||||
if benchmark.enabled:
|
||||
# propagate benchmark settings to args
|
||||
args.max_batch_size = max(benchmark.bs, args.max_batch_size)
|
||||
@ -246,7 +245,7 @@ def build_llm_from_config(config: ExperimentConfig) -> LLM:
|
||||
"demollm": DemoLLM,
|
||||
"trtllm": LLM,
|
||||
}
|
||||
llm = llm_lookup[config.args.runtime](**config.args.to_llm_kwargs())
|
||||
llm = llm_lookup[config.args.runtime](**config.args.model_dump(exclude_unset=True))
|
||||
return llm
|
||||
|
||||
|
||||
|
||||
@ -3,17 +3,15 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field, PrivateAttr, ValidationInfo, field_validator, model_validator
|
||||
from pydantic import Field, ValidationInfo, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from ...llmapi.llm_args import (
|
||||
BaseLlmArgs,
|
||||
BuildConfig,
|
||||
EagleDecodingConfig,
|
||||
KvCacheConfig,
|
||||
SamplerType,
|
||||
TorchLlmArgs,
|
||||
_ParallelConfig,
|
||||
)
|
||||
from .models import ModelFactory, ModelFactoryRegistry
|
||||
@ -58,23 +56,101 @@ def _shortcut_description(description: str, shortcut: str) -> str:
|
||||
return f"{description} Alias for: {long_names_str}."
|
||||
|
||||
|
||||
class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"""An argument class stripped down to AutoDeploy-specific configurations.
|
||||
|
||||
This class be used as a drop-in replacement to simplify configuring the AutoDeploy backend and
|
||||
should be used in place of LlmArgs unless more advanced features are needed.
|
||||
|
||||
It is compatible with AutoDeploy's LLM API (``tensorrt_llm._torch.auto_deploy.llm.LLM``) and
|
||||
exposes the full set of parameters used in AutoDeploy's ``InferenceOptimizer``.
|
||||
"""
|
||||
class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
|
||||
"""LlmArgs config class for providing full expert configurability of the AutoDeploy backend."""
|
||||
|
||||
model_config = _get_config_dict()
|
||||
|
||||
### MODEL AND TOKENIZER FACTORY ################################################################
|
||||
model: PathLike = Field(
|
||||
description="The path to the model checkpoint or the model name from the Hugging Face Hub."
|
||||
build_config: Optional[BuildConfig] = Field(
|
||||
default_factory=BuildConfig,
|
||||
description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.",
|
||||
exclude_from_json=True,
|
||||
frozen=True,
|
||||
repr=False,
|
||||
)
|
||||
backend: Literal["_autodeploy"] = Field(
|
||||
default="_autodeploy",
|
||||
description="The backend to use for this LLM instance.",
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
gpus_per_node: int = Field(
|
||||
default=torch.cuda.device_count(),
|
||||
description="The number of GPUs per node.",
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
@field_validator("max_seq_len", mode="before")
|
||||
@classmethod
|
||||
def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
# NOTE: the bass class's default value is `None`, which is incompatible with the validators
|
||||
# defined in this child class. This is problematic when e.g. TRTLLM serve explicitly passes
|
||||
# the bass class's default in.
|
||||
if value is None:
|
||||
# Fallback to the AutoDeployConfig default when not provided.
|
||||
return cls.model_fields["max_seq_len"].get_default(call_default_factory=True)
|
||||
return value
|
||||
|
||||
@field_validator("build_config", mode="before")
|
||||
@classmethod
|
||||
def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
msg = "build_config is not in use by AutoDeploy's LlmArgs"
|
||||
return _check_for_default_value_only(cls, value, info, msg)
|
||||
|
||||
@field_validator(
|
||||
"tensor_parallel_size",
|
||||
"pipeline_parallel_size",
|
||||
"context_parallel_size",
|
||||
"moe_cluster_parallel_size",
|
||||
"moe_tensor_parallel_size",
|
||||
"moe_expert_parallel_size",
|
||||
"enable_attention_dp",
|
||||
"cp_config",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
msg = "AutoDeploy only supports parallelization via the `world_size` argument."
|
||||
return _check_for_default_value_only(cls, value, info, msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_hidden_state_capture(self):
|
||||
if self.speculative_config is None or not isinstance(
|
||||
self.speculative_config, EagleDecodingConfig
|
||||
):
|
||||
return self
|
||||
|
||||
self.transforms["detect_hidden_states_for_capture"]["capture_hidden_states"] = True
|
||||
self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = (
|
||||
self.speculative_config.eagle3_layers_to_capture
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parallel_config(self):
|
||||
"""Setup parallel config according to world_size.
|
||||
|
||||
NOTE: AutoDeploy does *not* use parallel_config directly. It simply uses world_size and
|
||||
rank to automatically shard the model. This is just to ensure that other objects in the
|
||||
runtime that may read parallel_config can do so.
|
||||
"""
|
||||
|
||||
# Set tp_size = self.world_size so that _ParallelConfig.world_size will return the
|
||||
# correct value (computed as tp_size * pp_size * cp_size). This does not necessarily
|
||||
# mean that TP will actually be used.
|
||||
self._parallel_config = _ParallelConfig(
|
||||
tp_size=self.world_size, gpus_per_node=self.gpus_per_node
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_init_tokenizer(self):
|
||||
"""Skip tokenizer initialization in config. We do this in the AutoDeploy LLM class."""
|
||||
return self
|
||||
|
||||
## !! Remnants (fields and validators) from the now removed `AutoDeployConfig`.
|
||||
|
||||
### MODEL AND TOKENIZER FACTORY ################################################################
|
||||
model_factory: str = Field(
|
||||
default="AutoModelForCausalLM",
|
||||
description="The model factory to use for loading the model.",
|
||||
@ -95,12 +171,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"If True, only the model architecture is loaded.",
|
||||
)
|
||||
|
||||
tokenizer: Optional[PathLike] = Field(
|
||||
description="The tokenizer",
|
||||
default=None,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
tokenizer_kwargs: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Extra kwargs for the tokenizer class to customize the tokenizer. Same as "
|
||||
@ -109,16 +179,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama_fast.py#L127.",
|
||||
)
|
||||
|
||||
skip_tokenizer_init: bool = Field(
|
||||
default=False, description="Whether to skip the tokenizer initialization."
|
||||
)
|
||||
|
||||
### RUNTIME FEATURES ###########################################################################
|
||||
disable_overlap_scheduler: bool = Field(
|
||||
default=False,
|
||||
description="Disable the overlap scheduler in trtllm runtime",
|
||||
)
|
||||
|
||||
world_size: int = Field(
|
||||
default=1,
|
||||
ge=0,
|
||||
@ -155,8 +216,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
enable_chunked_prefill: bool = Field(default=False, description="Enable chunked prefill.")
|
||||
|
||||
draft_checkpoint_loader: Optional[object] = Field(
|
||||
default=None,
|
||||
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
|
||||
@ -193,6 +252,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"The backend to use for compiling the model.", "compile_backend"
|
||||
),
|
||||
)
|
||||
# TODO(#9306): fold this into `CudaGraphConfig`.
|
||||
cuda_graph_batch_sizes: Optional[List[int]] = Field(
|
||||
default=None,
|
||||
description=_shortcut_description(
|
||||
@ -203,7 +263,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
)
|
||||
|
||||
### SEQUENCE INTERFACE CONFIG ##################################################################
|
||||
max_input_len: int = Field(default=1024, description="The maximum input length.")
|
||||
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
|
||||
max_seq_len: int = Field(default=512, ge=1, description="The maximum sequence length.")
|
||||
max_batch_size: int = Field(default=8, ge=1, description="The maximum batch size.")
|
||||
@ -214,16 +273,23 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"backends, this should equal max_seq_len. Temporary field until tokens_per_block gets "
|
||||
"properly passed through.",
|
||||
)
|
||||
enable_iter_perf_stats: bool = Field(
|
||||
default=False, description="Enable iteration performance statistics.", status="prototype"
|
||||
)
|
||||
|
||||
enable_iter_req_stats: bool = Field(
|
||||
default=False,
|
||||
description="If true, enables per request stats per iteration. Must also set "
|
||||
"enable_iter_perf_stats to true to get request stats.",
|
||||
status="prototype",
|
||||
)
|
||||
def model_dump(self, *args, **kwargs):
|
||||
"""Convert the arguments to a dictionary that can be used as kwargs for the LLM API."""
|
||||
kwargs = super().model_dump(*args, **kwargs)
|
||||
|
||||
# ensure we remove the mode and yaml_default fields since they otherwise may conflict each
|
||||
# other.
|
||||
if "mode" not in self.model_fields_set:
|
||||
kwargs.pop("mode", None)
|
||||
if "yaml_default" not in self.model_fields_set:
|
||||
kwargs.pop("yaml_default", None)
|
||||
|
||||
# We never want these.
|
||||
kwargs.pop("build_config", None)
|
||||
kwargs.pop("mpi_session", None)
|
||||
|
||||
return kwargs
|
||||
|
||||
### VALIDATION #################################################################################
|
||||
@model_validator(mode="after")
|
||||
@ -316,22 +382,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
max_seq_len=self.max_seq_len,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the arguments to a dictionary."""
|
||||
return self.model_dump()
|
||||
|
||||
def to_llm_kwargs(self) -> Dict[str, Any]:
|
||||
"""Convert the arguments to a dictionary that can be used as kwargs for the LLM API."""
|
||||
kwargs = self.to_dict()
|
||||
|
||||
# ensure we remove the mode and yaml_default fields since they otherwise may conflict each
|
||||
# other.
|
||||
if "mode" not in self.model_fields_set:
|
||||
kwargs.pop("mode")
|
||||
if "yaml_default" not in self.model_fields_set:
|
||||
kwargs.pop("yaml_default")
|
||||
return kwargs
|
||||
|
||||
def is_cuda_graph_enabled(self) -> bool:
|
||||
return self.compile_backend in ["torch-cudagraph", "torch-opt"]
|
||||
|
||||
@ -344,134 +394,3 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
"transformers": str(config_path / "transformers.yaml"),
|
||||
}
|
||||
return mapping.get(mode)
|
||||
|
||||
|
||||
class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
|
||||
"""LlmArgs config class for providing full expert configurability of the AutoDeploy backend.
|
||||
|
||||
Specifically, this class extends AutoDeployConfig with all the fields from BaseLlmArgs for
|
||||
providing configurability beyond what is provided by AutoDeployConfig.
|
||||
|
||||
Just like AutoDeployConfig, this class is compatible with AutoDeploy's LLM API
|
||||
(``tensorrt_llm._torch.auto_deploy.llm.LLM``) but provides greater configurability.
|
||||
|
||||
NOTE: this class should only be used directly for advanced use cases. For most use cases,
|
||||
AutoDeployConfig should be used instead.
|
||||
|
||||
NOTE: this class may expose redundant fields from BaseLlmArgs or fields that are ignored or
|
||||
have overlapping functionality with AutoDeployConfig. Please be careful when using this class.
|
||||
"""
|
||||
|
||||
model_config = _get_config_dict()
|
||||
|
||||
build_config: Optional[BuildConfig] = Field(
|
||||
default_factory=BuildConfig,
|
||||
description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.",
|
||||
exclude_from_json=True,
|
||||
frozen=True,
|
||||
repr=False,
|
||||
)
|
||||
backend: Literal["_autodeploy"] = Field(
|
||||
default="_autodeploy",
|
||||
description="The backend to use for this LLM instance.",
|
||||
frozen=True,
|
||||
)
|
||||
gpus_per_node: int = Field(
|
||||
default=torch.cuda.device_count(),
|
||||
description="The number of GPUs per node.",
|
||||
frozen=True,
|
||||
)
|
||||
garbage_collection_gen0_threshold: int = Field(default=20000, description="See TorchLlmArgs.")
|
||||
|
||||
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)
|
||||
|
||||
max_stats_len: int = Field(
|
||||
default=1000,
|
||||
description="The max number of performance statistic entries.",
|
||||
status="prototype",
|
||||
)
|
||||
|
||||
@property
|
||||
def quant_config(self) -> QuantConfig:
|
||||
if self._quant_config is None:
|
||||
self._quant_config = QuantConfig()
|
||||
return self._quant_config
|
||||
|
||||
@quant_config.setter
|
||||
def quant_config(self, value: QuantConfig):
|
||||
self._quant_config = value
|
||||
|
||||
### VALIDATION #################################################################################
|
||||
@field_validator("max_seq_len", mode="before")
|
||||
@classmethod
|
||||
def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
if value is None:
|
||||
# Fallback to the AutoDeployConfig default when not provided
|
||||
return AutoDeployConfig.model_fields["max_seq_len"].get_default(
|
||||
call_default_factory=True
|
||||
)
|
||||
return value
|
||||
|
||||
@field_validator("build_config", mode="before")
|
||||
@classmethod
|
||||
def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
msg = "build_config is not in use by AutoDeploy's LlmArgs"
|
||||
return _check_for_default_value_only(cls, value, info, msg)
|
||||
|
||||
@field_validator(
|
||||
"tensor_parallel_size",
|
||||
"pipeline_parallel_size",
|
||||
"context_parallel_size",
|
||||
"moe_cluster_parallel_size",
|
||||
"moe_tensor_parallel_size",
|
||||
"moe_expert_parallel_size",
|
||||
"enable_attention_dp",
|
||||
"cp_config",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
msg = "AutoDeploy only supports parallelization via the `world_size` argument."
|
||||
return _check_for_default_value_only(cls, value, info, msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_hidden_state_capture(self):
|
||||
if self.speculative_config is None or not isinstance(
|
||||
self.speculative_config, EagleDecodingConfig
|
||||
):
|
||||
return self
|
||||
|
||||
self.transforms["detect_hidden_states_for_capture"]["capture_hidden_states"] = True
|
||||
self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = (
|
||||
self.speculative_config.eagle3_layers_to_capture
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_parallel_config(self):
|
||||
"""Setup parallel config according to world_size.
|
||||
|
||||
NOTE: AutoDeploy does *not* use parallel_config directly. It simply uses world_size and
|
||||
rank to automatically shard the model. This is just to ensure that other objects in the
|
||||
runtime that may read parallel_config can do so.
|
||||
"""
|
||||
|
||||
# Set tp_size = self.world_size so that _ParallelConfig.world_size will return the
|
||||
# correct value (computed as tp_size * pp_size * cp_size). This does not necessarily
|
||||
# mean that TP will actually be used.
|
||||
self._parallel_config = _ParallelConfig(
|
||||
tp_size=self.world_size, gpus_per_node=self.gpus_per_node
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_init_tokenizer(self):
|
||||
"""Skip tokenizer initialization in config. We do this in the AutoDeploy LLM class."""
|
||||
return self
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert model to a dictionary such that cls(**self.to_dict()) == self."""
|
||||
self_dict = super().to_dict()
|
||||
self_dict.pop("build_config", None)
|
||||
self_dict.pop("mpi_session", None)
|
||||
return self_dict
|
||||
|
||||
@ -4,7 +4,7 @@ from _model_test_utils import get_small_model_config
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
|
||||
|
||||
# NOTE: find example inputs with the same tokenization length to avoid seq concat.
|
||||
@ -51,7 +51,7 @@ def test_bamba_patches(
|
||||
"dtype": "bfloat16",
|
||||
},
|
||||
}
|
||||
llm_args = AutoDeployConfig(**llm_args)
|
||||
llm_args = LlmArgs(**llm_args)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
|
||||
@ -11,7 +11,7 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.models.custom.modeling_nemotron_h import NemotronHForCausalLM
|
||||
from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
|
||||
|
||||
@ -164,7 +164,7 @@ def test_custom_model_implementation_can_be_exported(
|
||||
"dtype": "bfloat16",
|
||||
},
|
||||
}
|
||||
llm_args = AutoDeployConfig(**llm_args)
|
||||
llm_args = LlmArgs(**llm_args)
|
||||
|
||||
factory = llm_args.create_factory()
|
||||
model = factory.build_model("meta")
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
from _model_test_utils import get_small_model_config
|
||||
from build_and_run_ad import ExperimentConfig, main
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig, LlmArgs, _ParallelConfig
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs, _ParallelConfig
|
||||
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine
|
||||
|
||||
|
||||
@ -12,15 +12,12 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
# Verify that llm_args was captured
|
||||
assert llm_args is not None, "llm_args should have been captured"
|
||||
|
||||
# Check that llm_args is an instance of LlmArgs and also an instance of AutoDeployConfig
|
||||
# Check that llm_args is an instance of LlmArgs.
|
||||
assert isinstance(llm_args, LlmArgs), f"Expected LlmArgs, got {type(llm_args)}"
|
||||
assert isinstance(llm_args, AutoDeployConfig), (
|
||||
f"Expected AutoDeployConfig, got {type(llm_args)}"
|
||||
)
|
||||
|
||||
# check that llm_args and experiment_config have the same args
|
||||
expected_ad_config: AutoDeployConfig = experiment_config.args
|
||||
expected_llm_args: LlmArgs = LlmArgs(**expected_ad_config.to_llm_kwargs())
|
||||
expected_ad_config: LlmArgs = experiment_config.args
|
||||
expected_llm_args: LlmArgs = LlmArgs(**expected_ad_config.model_dump())
|
||||
assert expected_llm_args == llm_args, f"Expected llm args {expected_llm_args}, got {llm_args}"
|
||||
|
||||
# check expected parallel config
|
||||
|
||||
Loading…
Reference in New Issue
Block a user