diff --git a/tensorrt_llm/models/gemma/model.py b/tensorrt_llm/models/gemma/model.py index f39ed97e51..a80b9e771d 100644 --- a/tensorrt_llm/models/gemma/model.py +++ b/tensorrt_llm/models/gemma/model.py @@ -157,10 +157,10 @@ class GemmaDecoderLayer(Module): if default_net().plugin_config.reduce_fusion else AllReduceFusionOp.NONE, residual=residual, - norm_weight=self.post_layernorm.weight.value, - norm_pre_residual_weight=self.pre_feedforward_layernorm.weight. - value if self.config.inter_layernorms else None, - eps=self.post_layernorm.eps)) + norm_weight=self.pre_feedforward_layernorm.weight.value, + norm_pre_residual_weight=self.post_layernorm.weight.value + if self.config.inter_layernorms else None, + eps=self.pre_feedforward_layernorm.eps)) if use_cache: attention_output, presents = attention_output