mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bug] Fix gemma4 MTP IMA issue when TP>1, CUDA error: an illegal memory access was encountered (#43909)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
@@ -501,6 +501,7 @@ class Gemma4MTP(nn.Module):
|
||||
config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
text_config = _get_text_config(config)
|
||||
self.config = config
|
||||
self._stable_full_lm_head_weight: torch.Tensor | None = None
|
||||
|
||||
self.model = Gemma4MultiTokenPredictor(
|
||||
vllm_config=vllm_config,
|
||||
@@ -567,6 +568,8 @@ class Gemma4MTP(nn.Module):
|
||||
)
|
||||
|
||||
def _get_full_lm_head_weight(self) -> torch.Tensor:
|
||||
if self._stable_full_lm_head_weight is not None:
|
||||
return self._stable_full_lm_head_weight
|
||||
lm_head_weight = self.lm_head.weight
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size > 1:
|
||||
@@ -574,7 +577,11 @@ class Gemma4MTP(nn.Module):
|
||||
lm_head_weight,
|
||||
dim=0,
|
||||
)
|
||||
return lm_head_weight[: self.masked_embedding.vocab_size]
|
||||
lm_head_weight = lm_head_weight[: self.masked_embedding.vocab_size]
|
||||
if tp_size > 1:
|
||||
lm_head_weight = lm_head_weight.contiguous()
|
||||
self._stable_full_lm_head_weight = lm_head_weight
|
||||
return lm_head_weight
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@@ -599,5 +606,6 @@ class Gemma4MTP(nn.Module):
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
self._stable_full_lm_head_weight = None
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Reference in New Issue
Block a user