diff --git a/examples/auto_deploy/build_and_run_ad.py b/examples/auto_deploy/build_and_run_ad.py index f35c4af9bb..9a6088df38 100644 --- a/examples/auto_deploy/build_and_run_ad.py +++ b/examples/auto_deploy/build_and_run_ad.py @@ -14,7 +14,7 @@ from pydantic_settings import ( SettingsConfigDict, ) -from tensorrt_llm._torch.auto_deploy import LLM, AutoDeployConfig, DemoLLM +from tensorrt_llm._torch.auto_deploy import LLM, DemoLLM from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs from tensorrt_llm._torch.auto_deploy.utils._config import ( DynamicYamlMixInForSettings, @@ -142,7 +142,6 @@ class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings): # The main AutoDeploy arguments - contains model, tokenizer, backend configs, etc. args: LlmArgs = Field( description="The main AutoDeploy arguments containing model, tokenizer, backend configs, etc. " - "Contains all the fields from `AutoDeployConfig` and `BaseLlmArgs`. " "Please check `tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs` for more details." ) @@ -213,7 +212,7 @@ class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings): def sync_model_with_args(cls, model_value, info): if "args" not in info.data: return model_value - args: AutoDeployConfig = info.data["args"] + args: LlmArgs = info.data["args"] return args.model @field_validator("prompt", mode="after") @@ -221,7 +220,7 @@ class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings): def sync_prompt_batch_size_with_args_max_batch_size(cls, prompt: PromptConfig, info): if "args" not in info.data: return prompt - args: AutoDeployConfig = info.data["args"] + args: LlmArgs = info.data["args"] if args.max_batch_size < prompt.batch_size: args.max_batch_size = prompt.batch_size return prompt @@ -231,7 +230,7 @@ class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings): def adjust_args_for_benchmark(cls, benchmark: BenchmarkConfig, info): if "args" not in info.data: return benchmark - args: AutoDeployConfig = info.data["args"] + args: LlmArgs = info.data["args"] if benchmark.enabled: # propagate benchmark settings to args args.max_batch_size = max(benchmark.bs, args.max_batch_size) @@ -246,7 +245,7 @@ def build_llm_from_config(config: ExperimentConfig) -> LLM: "demollm": DemoLLM, "trtllm": LLM, } - llm = llm_lookup[config.args.runtime](**config.args.to_llm_kwargs()) + llm = llm_lookup[config.args.runtime](**config.args.model_dump(exclude_unset=True)) return llm diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index f4f05355d8..f892fa7bf7 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -3,17 +3,15 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Type, Union import torch -from pydantic import Field, PrivateAttr, ValidationInfo, field_validator, model_validator +from pydantic import Field, ValidationInfo, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from tensorrt_llm.models.modeling_utils import QuantConfig - from ...llmapi.llm_args import ( - BaseLlmArgs, BuildConfig, EagleDecodingConfig, KvCacheConfig, SamplerType, + TorchLlmArgs, _ParallelConfig, ) from .models import ModelFactory, ModelFactoryRegistry @@ -58,23 +56,101 @@ def _shortcut_description(description: str, shortcut: str) -> str: return f"{description} Alias for: {long_names_str}." -class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): - """An argument class stripped down to AutoDeploy-specific configurations. - - This class be used as a drop-in replacement to simplify configuring the AutoDeploy backend and - should be used in place of LlmArgs unless more advanced features are needed. - - It is compatible with AutoDeploy's LLM API (``tensorrt_llm._torch.auto_deploy.llm.LLM``) and - exposes the full set of parameters used in AutoDeploy's ``InferenceOptimizer``. - """ +class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings): + """LlmArgs config class for providing full expert configurability of the AutoDeploy backend.""" model_config = _get_config_dict() - ### MODEL AND TOKENIZER FACTORY ################################################################ - model: PathLike = Field( - description="The path to the model checkpoint or the model name from the Hugging Face Hub." + build_config: Optional[BuildConfig] = Field( + default_factory=BuildConfig, + description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.", + exclude_from_json=True, + frozen=True, + repr=False, + ) + backend: Literal["_autodeploy"] = Field( + default="_autodeploy", + description="The backend to use for this LLM instance.", + frozen=True, ) + gpus_per_node: int = Field( + default=torch.cuda.device_count(), + description="The number of GPUs per node.", + frozen=True, + ) + + @field_validator("max_seq_len", mode="before") + @classmethod + def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any: + # NOTE: the bass class's default value is `None`, which is incompatible with the validators + # defined in this child class. This is problematic when e.g. TRTLLM serve explicitly passes + # the bass class's default in. + if value is None: + # Fallback to the AutoDeployConfig default when not provided. + return cls.model_fields["max_seq_len"].get_default(call_default_factory=True) + return value + + @field_validator("build_config", mode="before") + @classmethod + def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any: + msg = "build_config is not in use by AutoDeploy's LlmArgs" + return _check_for_default_value_only(cls, value, info, msg) + + @field_validator( + "tensor_parallel_size", + "pipeline_parallel_size", + "context_parallel_size", + "moe_cluster_parallel_size", + "moe_tensor_parallel_size", + "moe_expert_parallel_size", + "enable_attention_dp", + "cp_config", + mode="before", + ) + @classmethod + def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> Any: + msg = "AutoDeploy only supports parallelization via the `world_size` argument." + return _check_for_default_value_only(cls, value, info, msg) + + @model_validator(mode="after") + def setup_hidden_state_capture(self): + if self.speculative_config is None or not isinstance( + self.speculative_config, EagleDecodingConfig + ): + return self + + self.transforms["detect_hidden_states_for_capture"]["capture_hidden_states"] = True + self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = ( + self.speculative_config.eagle3_layers_to_capture + ) + return self + + @model_validator(mode="after") + def validate_parallel_config(self): + """Setup parallel config according to world_size. + + NOTE: AutoDeploy does *not* use parallel_config directly. It simply uses world_size and + rank to automatically shard the model. This is just to ensure that other objects in the + runtime that may read parallel_config can do so. + """ + + # Set tp_size = self.world_size so that _ParallelConfig.world_size will return the + # correct value (computed as tp_size * pp_size * cp_size). This does not necessarily + # mean that TP will actually be used. + self._parallel_config = _ParallelConfig( + tp_size=self.world_size, gpus_per_node=self.gpus_per_node + ) + return self + + @model_validator(mode="after") + def validate_and_init_tokenizer(self): + """Skip tokenizer initialization in config. We do this in the AutoDeploy LLM class.""" + return self + + ## !! Remnants (fields and validators) from the now removed `AutoDeployConfig`. + + ### MODEL AND TOKENIZER FACTORY ################################################################ model_factory: str = Field( default="AutoModelForCausalLM", description="The model factory to use for loading the model.", @@ -95,12 +171,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "If True, only the model architecture is loaded.", ) - tokenizer: Optional[PathLike] = Field( - description="The tokenizer", - default=None, - repr=False, - ) - tokenizer_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Extra kwargs for the tokenizer class to customize the tokenizer. Same as " @@ -109,16 +179,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama_fast.py#L127.", ) - skip_tokenizer_init: bool = Field( - default=False, description="Whether to skip the tokenizer initialization." - ) - ### RUNTIME FEATURES ########################################################################### - disable_overlap_scheduler: bool = Field( - default=False, - description="Disable the overlap scheduler in trtllm runtime", - ) - world_size: int = Field( default=1, ge=0, @@ -155,8 +216,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): frozen=True, ) - enable_chunked_prefill: bool = Field(default=False, description="Enable chunked prefill.") - draft_checkpoint_loader: Optional[object] = Field( default=None, description="The checkpoint loader to use for the draft model when using speculative decoding with two models.", @@ -193,6 +252,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "The backend to use for compiling the model.", "compile_backend" ), ) + # TODO(#9306): fold this into `CudaGraphConfig`. cuda_graph_batch_sizes: Optional[List[int]] = Field( default=None, description=_shortcut_description( @@ -203,7 +263,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): ) ### SEQUENCE INTERFACE CONFIG ################################################################## - max_input_len: int = Field(default=1024, description="The maximum input length.") max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.") max_seq_len: int = Field(default=512, ge=1, description="The maximum sequence length.") max_batch_size: int = Field(default=8, ge=1, description="The maximum batch size.") @@ -214,16 +273,23 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "backends, this should equal max_seq_len. Temporary field until tokens_per_block gets " "properly passed through.", ) - enable_iter_perf_stats: bool = Field( - default=False, description="Enable iteration performance statistics.", status="prototype" - ) - enable_iter_req_stats: bool = Field( - default=False, - description="If true, enables per request stats per iteration. Must also set " - "enable_iter_perf_stats to true to get request stats.", - status="prototype", - ) + def model_dump(self, *args, **kwargs): + """Convert the arguments to a dictionary that can be used as kwargs for the LLM API.""" + kwargs = super().model_dump(*args, **kwargs) + + # ensure we remove the mode and yaml_default fields since they otherwise may conflict each + # other. + if "mode" not in self.model_fields_set: + kwargs.pop("mode", None) + if "yaml_default" not in self.model_fields_set: + kwargs.pop("yaml_default", None) + + # We never want these. + kwargs.pop("build_config", None) + kwargs.pop("mpi_session", None) + + return kwargs ### VALIDATION ################################################################################# @model_validator(mode="after") @@ -316,22 +382,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): max_seq_len=self.max_seq_len, ) - def to_dict(self) -> Dict[str, Any]: - """Convert the arguments to a dictionary.""" - return self.model_dump() - - def to_llm_kwargs(self) -> Dict[str, Any]: - """Convert the arguments to a dictionary that can be used as kwargs for the LLM API.""" - kwargs = self.to_dict() - - # ensure we remove the mode and yaml_default fields since they otherwise may conflict each - # other. - if "mode" not in self.model_fields_set: - kwargs.pop("mode") - if "yaml_default" not in self.model_fields_set: - kwargs.pop("yaml_default") - return kwargs - def is_cuda_graph_enabled(self) -> bool: return self.compile_backend in ["torch-cudagraph", "torch-opt"] @@ -344,134 +394,3 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "transformers": str(config_path / "transformers.yaml"), } return mapping.get(mode) - - -class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings): - """LlmArgs config class for providing full expert configurability of the AutoDeploy backend. - - Specifically, this class extends AutoDeployConfig with all the fields from BaseLlmArgs for - providing configurability beyond what is provided by AutoDeployConfig. - - Just like AutoDeployConfig, this class is compatible with AutoDeploy's LLM API - (``tensorrt_llm._torch.auto_deploy.llm.LLM``) but provides greater configurability. - - NOTE: this class should only be used directly for advanced use cases. For most use cases, - AutoDeployConfig should be used instead. - - NOTE: this class may expose redundant fields from BaseLlmArgs or fields that are ignored or - have overlapping functionality with AutoDeployConfig. Please be careful when using this class. - """ - - model_config = _get_config_dict() - - build_config: Optional[BuildConfig] = Field( - default_factory=BuildConfig, - description="!!! DO NOT USE !!! Internal only; needed for BaseLlmArgs compatibility.", - exclude_from_json=True, - frozen=True, - repr=False, - ) - backend: Literal["_autodeploy"] = Field( - default="_autodeploy", - description="The backend to use for this LLM instance.", - frozen=True, - ) - gpus_per_node: int = Field( - default=torch.cuda.device_count(), - description="The number of GPUs per node.", - frozen=True, - ) - garbage_collection_gen0_threshold: int = Field(default=20000, description="See TorchLlmArgs.") - - _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) - - max_stats_len: int = Field( - default=1000, - description="The max number of performance statistic entries.", - status="prototype", - ) - - @property - def quant_config(self) -> QuantConfig: - if self._quant_config is None: - self._quant_config = QuantConfig() - return self._quant_config - - @quant_config.setter - def quant_config(self, value: QuantConfig): - self._quant_config = value - - ### VALIDATION ################################################################################# - @field_validator("max_seq_len", mode="before") - @classmethod - def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any: - if value is None: - # Fallback to the AutoDeployConfig default when not provided - return AutoDeployConfig.model_fields["max_seq_len"].get_default( - call_default_factory=True - ) - return value - - @field_validator("build_config", mode="before") - @classmethod - def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any: - msg = "build_config is not in use by AutoDeploy's LlmArgs" - return _check_for_default_value_only(cls, value, info, msg) - - @field_validator( - "tensor_parallel_size", - "pipeline_parallel_size", - "context_parallel_size", - "moe_cluster_parallel_size", - "moe_tensor_parallel_size", - "moe_expert_parallel_size", - "enable_attention_dp", - "cp_config", - mode="before", - ) - @classmethod - def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> Any: - msg = "AutoDeploy only supports parallelization via the `world_size` argument." - return _check_for_default_value_only(cls, value, info, msg) - - @model_validator(mode="after") - def setup_hidden_state_capture(self): - if self.speculative_config is None or not isinstance( - self.speculative_config, EagleDecodingConfig - ): - return self - - self.transforms["detect_hidden_states_for_capture"]["capture_hidden_states"] = True - self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = ( - self.speculative_config.eagle3_layers_to_capture - ) - return self - - @model_validator(mode="after") - def validate_parallel_config(self): - """Setup parallel config according to world_size. - - NOTE: AutoDeploy does *not* use parallel_config directly. It simply uses world_size and - rank to automatically shard the model. This is just to ensure that other objects in the - runtime that may read parallel_config can do so. - """ - - # Set tp_size = self.world_size so that _ParallelConfig.world_size will return the - # correct value (computed as tp_size * pp_size * cp_size). This does not necessarily - # mean that TP will actually be used. - self._parallel_config = _ParallelConfig( - tp_size=self.world_size, gpus_per_node=self.gpus_per_node - ) - return self - - @model_validator(mode="after") - def validate_and_init_tokenizer(self): - """Skip tokenizer initialization in config. We do this in the AutoDeploy LLM class.""" - return self - - def to_dict(self) -> Dict: - """Convert model to a dictionary such that cls(**self.to_dict()) == self.""" - self_dict = super().to_dict() - self_dict.pop("build_config", None) - self_dict.pop("mpi_session", None) - return self_dict diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py index 430add5a28..8d4d567da6 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_hybrid_patches.py @@ -4,7 +4,7 @@ from _model_test_utils import get_small_model_config from torch.export import Dim from tensorrt_llm._torch.auto_deploy.export import apply_export_patches, torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig +from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device # NOTE: find example inputs with the same tokenization length to avoid seq concat. @@ -51,7 +51,7 @@ def test_bamba_patches( "dtype": "bfloat16", }, } - llm_args = AutoDeployConfig(**llm_args) + llm_args = LlmArgs(**llm_args) torch.manual_seed(0) if torch.cuda.is_available(): diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_modeling_nemotron_h.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_modeling_nemotron_h.py index d5d624e721..ae55b6ad4c 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_modeling_nemotron_h.py @@ -11,7 +11,7 @@ from transformers import AutoConfig, AutoModelForCausalLM from utils.llm_data import llm_models_root from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig +from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs from tensorrt_llm._torch.auto_deploy.models.custom.modeling_nemotron_h import NemotronHForCausalLM from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device @@ -164,7 +164,7 @@ def test_custom_model_implementation_can_be_exported( "dtype": "bfloat16", }, } - llm_args = AutoDeployConfig(**llm_args) + llm_args = LlmArgs(**llm_args) factory = llm_args.create_factory() model = factory.build_model("meta") diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 79457fbfca..e03d366911 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -4,7 +4,7 @@ import pytest from _model_test_utils import get_small_model_config from build_and_run_ad import ExperimentConfig, main -from tensorrt_llm._torch.auto_deploy.llm_args import AutoDeployConfig, LlmArgs, _ParallelConfig +from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs, _ParallelConfig from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine @@ -12,15 +12,12 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): # Verify that llm_args was captured assert llm_args is not None, "llm_args should have been captured" - # Check that llm_args is an instance of LlmArgs and also an instance of AutoDeployConfig + # Check that llm_args is an instance of LlmArgs. assert isinstance(llm_args, LlmArgs), f"Expected LlmArgs, got {type(llm_args)}" - assert isinstance(llm_args, AutoDeployConfig), ( - f"Expected AutoDeployConfig, got {type(llm_args)}" - ) # check that llm_args and experiment_config have the same args - expected_ad_config: AutoDeployConfig = experiment_config.args - expected_llm_args: LlmArgs = LlmArgs(**expected_ad_config.to_llm_kwargs()) + expected_ad_config: LlmArgs = experiment_config.args + expected_llm_args: LlmArgs = LlmArgs(**expected_ad_config.model_dump()) assert expected_llm_args == llm_args, f"Expected llm args {expected_llm_args}, got {llm_args}" # check expected parallel config