From e1105064b282bb807ba9c309741b40a3b64e2261 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Sat, 30 May 2026 10:34:33 -0400 Subject: [PATCH] [Bug] Fix gemma4 MTP IMA issue when TP>1, `CUDA error: an illegal memory access was encountered` (#43909) Signed-off-by: yewentao256 Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- vllm/model_executor/models/gemma4_mtp.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma4_mtp.py b/vllm/model_executor/models/gemma4_mtp.py index c294ffc6f9a..122855400d9 100644 --- a/vllm/model_executor/models/gemma4_mtp.py +++ b/vllm/model_executor/models/gemma4_mtp.py @@ -501,6 +501,7 @@ class Gemma4MTP(nn.Module): config = vllm_config.speculative_config.draft_model_config.hf_config text_config = _get_text_config(config) self.config = config + self._stable_full_lm_head_weight: torch.Tensor | None = None self.model = Gemma4MultiTokenPredictor( vllm_config=vllm_config, @@ -567,6 +568,8 @@ class Gemma4MTP(nn.Module): ) def _get_full_lm_head_weight(self) -> torch.Tensor: + if self._stable_full_lm_head_weight is not None: + return self._stable_full_lm_head_weight lm_head_weight = self.lm_head.weight tp_size = get_tensor_model_parallel_world_size() if tp_size > 1: @@ -574,7 +577,11 @@ class Gemma4MTP(nn.Module): lm_head_weight, dim=0, ) - return lm_head_weight[: self.masked_embedding.vocab_size] + lm_head_weight = lm_head_weight[: self.masked_embedding.vocab_size] + if tp_size > 1: + lm_head_weight = lm_head_weight.contiguous() + self._stable_full_lm_head_weight = lm_head_weight + return lm_head_weight def compute_logits( self, @@ -599,5 +606,6 @@ class Gemma4MTP(nn.Module): ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + self._stable_full_lm_head_weight = None loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)