[Bug] Fix gemma4 MTP IMA issue when TP>1, CUDA error: an illegal memory access was encountered (#43909)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Wentao Ye
2026-05-30 10:34:33 -04:00
committed by GitHub
parent 50c80d7923
commit e1105064b2
+9 -1
View File
@@ -501,6 +501,7 @@ class Gemma4MTP(nn.Module):
config = vllm_config.speculative_config.draft_model_config.hf_config
text_config = _get_text_config(config)
self.config = config
self._stable_full_lm_head_weight: torch.Tensor | None = None
self.model = Gemma4MultiTokenPredictor(
vllm_config=vllm_config,
@@ -567,6 +568,8 @@ class Gemma4MTP(nn.Module):
)
def _get_full_lm_head_weight(self) -> torch.Tensor:
if self._stable_full_lm_head_weight is not None:
return self._stable_full_lm_head_weight
lm_head_weight = self.lm_head.weight
tp_size = get_tensor_model_parallel_world_size()
if tp_size > 1:
@@ -574,7 +577,11 @@ class Gemma4MTP(nn.Module):
lm_head_weight,
dim=0,
)
return lm_head_weight[: self.masked_embedding.vocab_size]
lm_head_weight = lm_head_weight[: self.masked_embedding.vocab_size]
if tp_size > 1:
lm_head_weight = lm_head_weight.contiguous()
self._stable_full_lm_head_weight = lm_head_weight
return lm_head_weight
def compute_logits(
self,
@@ -599,5 +606,6 @@ class Gemma4MTP(nn.Module):
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
self._stable_full_lm_head_weight = None
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)