diff --git a/vllm/model_executor/models/gemma4_mtp.py b/vllm/model_executor/models/gemma4_mtp.py index c294ffc6f9a..122855400d9 100644 --- a/vllm/model_executor/models/gemma4_mtp.py +++ b/vllm/model_executor/models/gemma4_mtp.py @@ -501,6 +501,7 @@ class Gemma4MTP(nn.Module): config = vllm_config.speculative_config.draft_model_config.hf_config text_config = _get_text_config(config) self.config = config + self._stable_full_lm_head_weight: torch.Tensor | None = None self.model = Gemma4MultiTokenPredictor( vllm_config=vllm_config, @@ -567,6 +568,8 @@ class Gemma4MTP(nn.Module): ) def _get_full_lm_head_weight(self) -> torch.Tensor: + if self._stable_full_lm_head_weight is not None: + return self._stable_full_lm_head_weight lm_head_weight = self.lm_head.weight tp_size = get_tensor_model_parallel_world_size() if tp_size > 1: @@ -574,7 +577,11 @@ class Gemma4MTP(nn.Module): lm_head_weight, dim=0, ) - return lm_head_weight[: self.masked_embedding.vocab_size] + lm_head_weight = lm_head_weight[: self.masked_embedding.vocab_size] + if tp_size > 1: + lm_head_weight = lm_head_weight.contiguous() + self._stable_full_lm_head_weight = lm_head_weight + return lm_head_weight def compute_logits( self, @@ -599,5 +606,6 @@ class Gemma4MTP(nn.Module): ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + self._stable_full_lm_head_weight = None loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)