mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-08 12:12:33 +08:00
[https://nvbugs/5441729][test] Fix test_modeling_llama_min_latency.py failures (#7478)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
parent
9fe63dd8db
commit
6fc6f70a68
@ -1003,16 +1003,28 @@ class Llama4VisionEncoder(nn.Module):
|
||||
|
||||
self.dtype = self.pretrained_config.text_config.torch_dtype
|
||||
|
||||
def load_weights(self):
|
||||
def load_weights(self, weights: Dict):
|
||||
module_dict = nn.ModuleDict({
|
||||
"vision_model":
|
||||
Llama4VisionModel(self.pretrained_config.vision_config),
|
||||
"multi_modal_projector":
|
||||
Llama4MultiModalProjector(self.pretrained_config),
|
||||
})
|
||||
load_sharded_checkpoint(module_dict,
|
||||
self.pretrained_config._name_or_path,
|
||||
strict=False)
|
||||
|
||||
# If the named params are present in the weights, load them directly.
|
||||
param_names = [name for name, _ in module_dict.named_parameters()]
|
||||
if all(name in weights for name in param_names):
|
||||
vision_encoder_weights = {
|
||||
name: weights[name]
|
||||
for name in param_names
|
||||
}
|
||||
module_dict.load_state_dict(vision_encoder_weights)
|
||||
|
||||
# Otherwise, load the weights from the checkpoint.
|
||||
else:
|
||||
load_sharded_checkpoint(module_dict,
|
||||
self.pretrained_config._name_or_path,
|
||||
strict=False)
|
||||
|
||||
self.vision_model = module_dict["vision_model"].to(self.device)
|
||||
self.mm_projector = module_dict["multi_modal_projector"].to(self.device)
|
||||
@ -1295,7 +1307,7 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
|
||||
|
||||
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
|
||||
if not DISAGG:
|
||||
self.mm_encoder.load_weights()
|
||||
self.mm_encoder.load_weights(weights)
|
||||
|
||||
# Temporarily detach mm_encoder so the TRT-LLM loader doesn't try to load it
|
||||
had_mm_encoder = hasattr(self, "mm_encoder")
|
||||
|
||||
@ -266,10 +266,12 @@ class TestLlama4MinLatency(unittest.TestCase):
|
||||
attention_backend = "TRTLLM"
|
||||
metadata_cls = get_attention_backend(attention_backend).Metadata
|
||||
|
||||
if transformers.__version__ >= "4.55.0":
|
||||
if transformers.__version__ >= "4.55.0" \
|
||||
and transformers.__version__ < "4.56.1":
|
||||
self.skipTest(
|
||||
"The transformers 4.55.0 has accuracy issues while 4.33.1 works fine. "
|
||||
"https://nvbugspro.nvidia.com/bug/5441729")
|
||||
"The transformers between 4.55.0 and 4.56.1 have accuracy "
|
||||
"issues for Llama4. See: "
|
||||
"https://github.com/huggingface/transformers/pull/40609")
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user