[https://nvbugs/5441729][test] Fix test_modeling_llama_min_latency.py failures (#7478)

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
Po-Han Huang (NVIDIA) 2025-10-13 15:35:02 +08:00 committed by GitHub
parent 9fe63dd8db
commit 6fc6f70a68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 8 deletions

View File

@ -1003,16 +1003,28 @@ class Llama4VisionEncoder(nn.Module):
self.dtype = self.pretrained_config.text_config.torch_dtype
def load_weights(self):
def load_weights(self, weights: Dict):
module_dict = nn.ModuleDict({
"vision_model":
Llama4VisionModel(self.pretrained_config.vision_config),
"multi_modal_projector":
Llama4MultiModalProjector(self.pretrained_config),
})
load_sharded_checkpoint(module_dict,
self.pretrained_config._name_or_path,
strict=False)
# If the named params are present in the weights, load them directly.
param_names = [name for name, _ in module_dict.named_parameters()]
if all(name in weights for name in param_names):
vision_encoder_weights = {
name: weights[name]
for name in param_names
}
module_dict.load_state_dict(vision_encoder_weights)
# Otherwise, load the weights from the checkpoint.
else:
load_sharded_checkpoint(module_dict,
self.pretrained_config._name_or_path,
strict=False)
self.vision_model = module_dict["vision_model"].to(self.device)
self.mm_projector = module_dict["multi_modal_projector"].to(self.device)
@ -1295,7 +1307,7 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
if not DISAGG:
self.mm_encoder.load_weights()
self.mm_encoder.load_weights(weights)
# Temporarily detach mm_encoder so the TRT-LLM loader doesn't try to load it
had_mm_encoder = hasattr(self, "mm_encoder")

View File

@ -266,10 +266,12 @@ class TestLlama4MinLatency(unittest.TestCase):
attention_backend = "TRTLLM"
metadata_cls = get_attention_backend(attention_backend).Metadata
if transformers.__version__ >= "4.55.0":
if transformers.__version__ >= "4.55.0" \
and transformers.__version__ < "4.56.1":
self.skipTest(
"The transformers 4.55.0 has accuracy issues while 4.33.1 works fine. "
"https://nvbugspro.nvidia.com/bug/5441729")
"The transformers between 4.55.0 and 4.56.1 have accuracy "
"issues for Llama4. See: "
"https://github.com/huggingface/transformers/pull/40609")
torch.random.manual_seed(0)
config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG)