[Arch] Freeze model_config (#4814)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com>
Co-authored-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com>
This commit is contained in:
hlu1 2025-06-03 11:51:35 -07:00 committed by GitHub
parent 2384655c3a
commit b4ed4b22f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 52 additions and 18 deletions

View File

@ -66,22 +66,49 @@ class MoeLoadBalancerConfig:
class ModelConfig(Generic[TConfig]):
pretrained_config: Optional[TConfig] = None
mapping: Mapping = field(default_factory=Mapping)
# quantization configs
quant_config: QuantConfig = field(default_factory=QuantConfig)
# TODO(qijun): support per linear layer quantization
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
# Delay weights creation to DecoderModelForCausalLM.__post_init__
# to support mixed quantization.
skip_create_weights_in_init: bool = False
spec_config: Optional["SpecConfig"] = None
lora_config: Optional["LoraConfig"] = None
is_generation: bool = True
max_num_tokens: int = 8192
moe_max_num_tokens: Optional[int] = None
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None
attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM
# If true, enable min-latency mode. Currently only used for Llama4.
enable_min_latency: bool = False
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
_frozen: bool = field(default=False, init=False, repr=False)
def __setattr__(self, key, value):
"""
Prevent modification of frozen instance attributes.
However, we allow modification of 'extra_attrs' attributes for torch.compile
and 'pretrained_config' attributes for mutimodal models. All the other
attributes are frozen.
This can be bypassed by manually setting '_frozen' to False. The design is
to discourage modifying the attributes unintentionally.
"""
if self._frozen:
if key not in ('_frozen', 'extra_attrs', 'pretrained_config'):
raise AttributeError(
f"Cannot modify ModelConfig.'{key}' - instance is frozen")
super().__setattr__(key, value)
def __post_init__(self):
if self.pretrained_config and hasattr(self.pretrained_config,
"architectures"):
@ -221,10 +248,12 @@ class ModelConfig(Generic[TConfig]):
128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
quant_config.group_size = block_size[0]
return cls(pretrained_config=pretrained_config,
quant_config=quant_config,
quant_config_dict=layer_quant_config,
**kwargs)
model_config = cls(pretrained_config=pretrained_config,
quant_config=quant_config,
quant_config_dict=layer_quant_config,
**kwargs)
model_config._frozen = True
return model_config
def get_bindings_model_config(self) -> "ModelConfigCpp":
"""

View File

@ -25,7 +25,9 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
f"Unknown architecture for AutoModelForCausalLM: {config.pretrained_config.architectures[0]}"
)
if issubclass(cls, DecoderModelForCausalLM):
config._frozen = False
config.skip_create_weights_in_init = True
config._frozen = True
extra_attrs = {}
with model_extra_attrs(extra_attrs):
model = cls(config)

View File

@ -717,7 +717,7 @@ class Llama4Model(DecoderModel):
# If enable_min_latency is True, we will use min-latency mode.
DecoderLayerClass = Llama4DecoderLayer
if model_config.pytorch_backend_config.enable_min_latency:
if model_config.enable_min_latency:
from .modeling_llama_min_latency import Llama4MinLatencyDecoderLayer
DecoderLayerClass = Llama4MinLatencyDecoderLayer
@ -879,12 +879,14 @@ class LlamaForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]):
trust_remote_code=True,
attn_backend=model_config.attn_backend,
moe_backend=model_config.moe_backend,
mapping=model_config.mapping)
draft_config.spec_config = model_config.spec_config
draft_config.max_num_tokens = model_config.max_num_tokens
draft_config.moe_max_num_tokens = model_config.moe_max_num_tokens
mapping=model_config.mapping,
spec_config=model_config.spec_config,
max_num_tokens=model_config.max_num_tokens,
moe_max_num_tokens=model_config.moe_max_num_tokens)
draft_config.quant_config.kv_cache_quant_algo = \
model_config.quant_config.kv_cache_quant_algo
self.draft_model = Eagle3LlamaForCausalLM(
draft_config, model_config.pretrained_config.num_hidden_layers)
self.spec_worker = get_spec_worker(model_config.spec_config,

View File

@ -900,15 +900,16 @@ class PyTorchModelEngine(ModelEngine):
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
lora_config: Optional[LoraConfig] = None,
**kwargs):
config = ModelConfig.from_pretrained(checkpoint_dir,
trust_remote_code=True,
**kwargs)
config.pytorch_backend_config = self.pytorch_backend_config
config.spec_config = self.spec_config
config.max_num_tokens = max_num_tokens
config.moe_max_num_tokens = moe_max_num_tokens
config.moe_load_balancer = moe_load_balancer
config.lora_config = lora_config
config = ModelConfig.from_pretrained(
checkpoint_dir,
trust_remote_code=True,
enable_min_latency=self.pytorch_backend_config.enable_min_latency,
spec_config=self.spec_config,
max_num_tokens=max_num_tokens,
moe_max_num_tokens=moe_max_num_tokens,
moe_load_balancer=moe_load_balancer,
lora_config=lora_config,
**kwargs)
validate_and_set_kv_cache_quant(
config, self.pytorch_backend_config.kv_cache_dtype)