diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 83fb1c15e3..9b7d1865c6 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -66,22 +66,49 @@ class MoeLoadBalancerConfig: class ModelConfig(Generic[TConfig]): pretrained_config: Optional[TConfig] = None mapping: Mapping = field(default_factory=Mapping) + + # quantization configs quant_config: QuantConfig = field(default_factory=QuantConfig) # TODO(qijun): support per linear layer quantization quant_config_dict: Optional[Dict[str, QuantConfig]] = None # Delay weights creation to DecoderModelForCausalLM.__post_init__ # to support mixed quantization. skip_create_weights_in_init: bool = False + + spec_config: Optional["SpecConfig"] = None + lora_config: Optional["LoraConfig"] = None + is_generation: bool = True max_num_tokens: int = 8192 + moe_max_num_tokens: Optional[int] = None moe_load_balancer: Optional[MoeLoadBalancerConfig] = None attn_backend: str = 'TRTLLM' moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM + # If true, enable min-latency mode. Currently only used for Llama4. + enable_min_latency: bool = False + extra_attrs: Dict = field(default_factory=dict, repr=False, init=False) + _frozen: bool = field(default=False, init=False, repr=False) + + def __setattr__(self, key, value): + """ + Prevent modification of frozen instance attributes. + However, we allow modification of 'extra_attrs' attributes for torch.compile + and 'pretrained_config' attributes for mutimodal models. All the other + attributes are frozen. + This can be bypassed by manually setting '_frozen' to False. The design is + to discourage modifying the attributes unintentionally. + """ + if self._frozen: + if key not in ('_frozen', 'extra_attrs', 'pretrained_config'): + raise AttributeError( + f"Cannot modify ModelConfig.'{key}' - instance is frozen") + super().__setattr__(key, value) + def __post_init__(self): if self.pretrained_config and hasattr(self.pretrained_config, "architectures"): @@ -221,10 +248,12 @@ class ModelConfig(Generic[TConfig]): 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" quant_config.group_size = block_size[0] - return cls(pretrained_config=pretrained_config, - quant_config=quant_config, - quant_config_dict=layer_quant_config, - **kwargs) + model_config = cls(pretrained_config=pretrained_config, + quant_config=quant_config, + quant_config_dict=layer_quant_config, + **kwargs) + model_config._frozen = True + return model_config def get_bindings_model_config(self) -> "ModelConfigCpp": """ diff --git a/tensorrt_llm/_torch/models/modeling_auto.py b/tensorrt_llm/_torch/models/modeling_auto.py index bfc92a2728..c26231eb5b 100644 --- a/tensorrt_llm/_torch/models/modeling_auto.py +++ b/tensorrt_llm/_torch/models/modeling_auto.py @@ -25,7 +25,9 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]): f"Unknown architecture for AutoModelForCausalLM: {config.pretrained_config.architectures[0]}" ) if issubclass(cls, DecoderModelForCausalLM): + config._frozen = False config.skip_create_weights_in_init = True + config._frozen = True extra_attrs = {} with model_extra_attrs(extra_attrs): model = cls(config) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 5cd59d5c85..5c616c20f5 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -717,7 +717,7 @@ class Llama4Model(DecoderModel): # If enable_min_latency is True, we will use min-latency mode. DecoderLayerClass = Llama4DecoderLayer - if model_config.pytorch_backend_config.enable_min_latency: + if model_config.enable_min_latency: from .modeling_llama_min_latency import Llama4MinLatencyDecoderLayer DecoderLayerClass = Llama4MinLatencyDecoderLayer @@ -879,12 +879,14 @@ class LlamaForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]): trust_remote_code=True, attn_backend=model_config.attn_backend, moe_backend=model_config.moe_backend, - mapping=model_config.mapping) - draft_config.spec_config = model_config.spec_config - draft_config.max_num_tokens = model_config.max_num_tokens - draft_config.moe_max_num_tokens = model_config.moe_max_num_tokens + mapping=model_config.mapping, + spec_config=model_config.spec_config, + max_num_tokens=model_config.max_num_tokens, + moe_max_num_tokens=model_config.moe_max_num_tokens) + draft_config.quant_config.kv_cache_quant_algo = \ model_config.quant_config.kv_cache_quant_algo + self.draft_model = Eagle3LlamaForCausalLM( draft_config, model_config.pretrained_config.num_hidden_layers) self.spec_worker = get_spec_worker(model_config.spec_config, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index cde42be09a..c92432d6a5 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -900,15 +900,16 @@ class PyTorchModelEngine(ModelEngine): moe_load_balancer: Optional[MoeLoadBalancerConfig] = None, lora_config: Optional[LoraConfig] = None, **kwargs): - config = ModelConfig.from_pretrained(checkpoint_dir, - trust_remote_code=True, - **kwargs) - config.pytorch_backend_config = self.pytorch_backend_config - config.spec_config = self.spec_config - config.max_num_tokens = max_num_tokens - config.moe_max_num_tokens = moe_max_num_tokens - config.moe_load_balancer = moe_load_balancer - config.lora_config = lora_config + config = ModelConfig.from_pretrained( + checkpoint_dir, + trust_remote_code=True, + enable_min_latency=self.pytorch_backend_config.enable_min_latency, + spec_config=self.spec_config, + max_num_tokens=max_num_tokens, + moe_max_num_tokens=moe_max_num_tokens, + moe_load_balancer=moe_load_balancer, + lora_config=lora_config, + **kwargs) validate_and_set_kv_cache_quant( config, self.pytorch_backend_config.kv_cache_dtype)