mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[Arch] Freeze model_config (#4814)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com> Co-authored-by: Hao Lu <14827759+hlu1@users.noreply.github.com@users.noreply.github.com>
This commit is contained in:
parent
2384655c3a
commit
b4ed4b22f3
@ -66,22 +66,49 @@ class MoeLoadBalancerConfig:
|
||||
class ModelConfig(Generic[TConfig]):
|
||||
pretrained_config: Optional[TConfig] = None
|
||||
mapping: Mapping = field(default_factory=Mapping)
|
||||
|
||||
# quantization configs
|
||||
quant_config: QuantConfig = field(default_factory=QuantConfig)
|
||||
# TODO(qijun): support per linear layer quantization
|
||||
quant_config_dict: Optional[Dict[str, QuantConfig]] = None
|
||||
# Delay weights creation to DecoderModelForCausalLM.__post_init__
|
||||
# to support mixed quantization.
|
||||
skip_create_weights_in_init: bool = False
|
||||
|
||||
spec_config: Optional["SpecConfig"] = None
|
||||
lora_config: Optional["LoraConfig"] = None
|
||||
|
||||
is_generation: bool = True
|
||||
max_num_tokens: int = 8192
|
||||
|
||||
moe_max_num_tokens: Optional[int] = None
|
||||
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None
|
||||
|
||||
attn_backend: str = 'TRTLLM'
|
||||
moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM
|
||||
|
||||
# If true, enable min-latency mode. Currently only used for Llama4.
|
||||
enable_min_latency: bool = False
|
||||
|
||||
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
|
||||
|
||||
_frozen: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
"""
|
||||
Prevent modification of frozen instance attributes.
|
||||
However, we allow modification of 'extra_attrs' attributes for torch.compile
|
||||
and 'pretrained_config' attributes for mutimodal models. All the other
|
||||
attributes are frozen.
|
||||
This can be bypassed by manually setting '_frozen' to False. The design is
|
||||
to discourage modifying the attributes unintentionally.
|
||||
"""
|
||||
if self._frozen:
|
||||
if key not in ('_frozen', 'extra_attrs', 'pretrained_config'):
|
||||
raise AttributeError(
|
||||
f"Cannot modify ModelConfig.'{key}' - instance is frozen")
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pretrained_config and hasattr(self.pretrained_config,
|
||||
"architectures"):
|
||||
@ -221,10 +248,12 @@ class ModelConfig(Generic[TConfig]):
|
||||
128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
|
||||
quant_config.group_size = block_size[0]
|
||||
|
||||
return cls(pretrained_config=pretrained_config,
|
||||
quant_config=quant_config,
|
||||
quant_config_dict=layer_quant_config,
|
||||
**kwargs)
|
||||
model_config = cls(pretrained_config=pretrained_config,
|
||||
quant_config=quant_config,
|
||||
quant_config_dict=layer_quant_config,
|
||||
**kwargs)
|
||||
model_config._frozen = True
|
||||
return model_config
|
||||
|
||||
def get_bindings_model_config(self) -> "ModelConfigCpp":
|
||||
"""
|
||||
|
||||
@ -25,7 +25,9 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
|
||||
f"Unknown architecture for AutoModelForCausalLM: {config.pretrained_config.architectures[0]}"
|
||||
)
|
||||
if issubclass(cls, DecoderModelForCausalLM):
|
||||
config._frozen = False
|
||||
config.skip_create_weights_in_init = True
|
||||
config._frozen = True
|
||||
extra_attrs = {}
|
||||
with model_extra_attrs(extra_attrs):
|
||||
model = cls(config)
|
||||
|
||||
@ -717,7 +717,7 @@ class Llama4Model(DecoderModel):
|
||||
|
||||
# If enable_min_latency is True, we will use min-latency mode.
|
||||
DecoderLayerClass = Llama4DecoderLayer
|
||||
if model_config.pytorch_backend_config.enable_min_latency:
|
||||
if model_config.enable_min_latency:
|
||||
from .modeling_llama_min_latency import Llama4MinLatencyDecoderLayer
|
||||
DecoderLayerClass = Llama4MinLatencyDecoderLayer
|
||||
|
||||
@ -879,12 +879,14 @@ class LlamaForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]):
|
||||
trust_remote_code=True,
|
||||
attn_backend=model_config.attn_backend,
|
||||
moe_backend=model_config.moe_backend,
|
||||
mapping=model_config.mapping)
|
||||
draft_config.spec_config = model_config.spec_config
|
||||
draft_config.max_num_tokens = model_config.max_num_tokens
|
||||
draft_config.moe_max_num_tokens = model_config.moe_max_num_tokens
|
||||
mapping=model_config.mapping,
|
||||
spec_config=model_config.spec_config,
|
||||
max_num_tokens=model_config.max_num_tokens,
|
||||
moe_max_num_tokens=model_config.moe_max_num_tokens)
|
||||
|
||||
draft_config.quant_config.kv_cache_quant_algo = \
|
||||
model_config.quant_config.kv_cache_quant_algo
|
||||
|
||||
self.draft_model = Eagle3LlamaForCausalLM(
|
||||
draft_config, model_config.pretrained_config.num_hidden_layers)
|
||||
self.spec_worker = get_spec_worker(model_config.spec_config,
|
||||
|
||||
@ -900,15 +900,16 @@ class PyTorchModelEngine(ModelEngine):
|
||||
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
**kwargs):
|
||||
config = ModelConfig.from_pretrained(checkpoint_dir,
|
||||
trust_remote_code=True,
|
||||
**kwargs)
|
||||
config.pytorch_backend_config = self.pytorch_backend_config
|
||||
config.spec_config = self.spec_config
|
||||
config.max_num_tokens = max_num_tokens
|
||||
config.moe_max_num_tokens = moe_max_num_tokens
|
||||
config.moe_load_balancer = moe_load_balancer
|
||||
config.lora_config = lora_config
|
||||
config = ModelConfig.from_pretrained(
|
||||
checkpoint_dir,
|
||||
trust_remote_code=True,
|
||||
enable_min_latency=self.pytorch_backend_config.enable_min_latency,
|
||||
spec_config=self.spec_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
moe_max_num_tokens=moe_max_num_tokens,
|
||||
moe_load_balancer=moe_load_balancer,
|
||||
lora_config=lora_config,
|
||||
**kwargs)
|
||||
|
||||
validate_and_set_kv_cache_quant(
|
||||
config, self.pytorch_backend_config.kv_cache_dtype)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user