diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 0036535974..0829f6bac4 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -1003,16 +1003,28 @@ class Llama4VisionEncoder(nn.Module): self.dtype = self.pretrained_config.text_config.torch_dtype - def load_weights(self): + def load_weights(self, weights: Dict): module_dict = nn.ModuleDict({ "vision_model": Llama4VisionModel(self.pretrained_config.vision_config), "multi_modal_projector": Llama4MultiModalProjector(self.pretrained_config), }) - load_sharded_checkpoint(module_dict, - self.pretrained_config._name_or_path, - strict=False) + + # If the named params are present in the weights, load them directly. + param_names = [name for name, _ in module_dict.named_parameters()] + if all(name in weights for name in param_names): + vision_encoder_weights = { + name: weights[name] + for name in param_names + } + module_dict.load_state_dict(vision_encoder_weights) + + # Otherwise, load the weights from the checkpoint. + else: + load_sharded_checkpoint(module_dict, + self.pretrained_config._name_or_path, + strict=False) self.vision_model = module_dict["vision_model"].to(self.device) self.mm_projector = module_dict["multi_modal_projector"].to(self.device) @@ -1295,7 +1307,7 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model, def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper): if not DISAGG: - self.mm_encoder.load_weights() + self.mm_encoder.load_weights(weights) # Temporarily detach mm_encoder so the TRT-LLM loader doesn't try to load it had_mm_encoder = hasattr(self, "mm_encoder") diff --git a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py index 7b3e74a1bb..9f96f146b8 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @@ -266,10 +266,12 @@ class TestLlama4MinLatency(unittest.TestCase): attention_backend = "TRTLLM" metadata_cls = get_attention_backend(attention_backend).Metadata - if transformers.__version__ >= "4.55.0": + if transformers.__version__ >= "4.55.0" \ + and transformers.__version__ < "4.56.1": self.skipTest( - "The transformers 4.55.0 has accuracy issues while 4.33.1 works fine. " - "https://nvbugspro.nvidia.com/bug/5441729") + "The transformers between 4.55.0 and 4.56.1 have accuracy " + "issues for Llama4. See: " + "https://github.com/huggingface/transformers/pull/40609") torch.random.manual_seed(0) config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG)