[https://nvbugs/5537348][fix] Use device tensor index for MTP (#8062)

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
Jin Li 2025-10-14 20:51:45 +08:00 committed by GitHub
parent 7b5ba7ca66
commit 4bac6b337e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1195,20 +1195,17 @@ class MTPEagleWorker(MTPWorker):
position_ids = position_ids.squeeze(0)
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
last_tokens_idx_host = torch.cumsum(
attn_metadata.seq_lens, dim=0, dtype=torch.long) - 1
return position_ids, last_tokens_idx, last_tokens_idx_host
return position_ids, last_tokens_idx
position_ids, last_tokens_idx, last_tokens_idx_host = prepare_position_ids_and_last_tokens(
position_ids, last_tokens_idx = prepare_position_ids_and_last_tokens(
position_ids, attn_metadata)
inputs = self.prepare_drafter_inputs(
input_ids=input_ids,
position_ids=position_ids,
last_tokens_idx_host=last_tokens_idx_host,
hidden_states=hidden_states,
accepted_tokens=accepted_tokens,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
inputs = self.prepare_drafter_inputs(input_ids=input_ids,
position_ids=position_ids,
last_tokens_idx=last_tokens_idx,
hidden_states=hidden_states,
accepted_tokens=accepted_tokens,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
# Predict draft tokens
next_draft_tokens = []
@ -1355,7 +1352,7 @@ class MTPEagleWorker(MTPWorker):
self,
input_ids: torch.IntTensor,
position_ids: torch.IntTensor,
last_tokens_idx_host: torch.LongTensor,
last_tokens_idx: torch.LongTensor,
hidden_states: torch.Tensor,
accepted_tokens: torch.Tensor,
attn_metadata: AttentionMetadata,
@ -1370,9 +1367,7 @@ class MTPEagleWorker(MTPWorker):
device="cuda")
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
input_ids_ctx[
last_tokens_idx_host[:
num_contexts]] = accepted_tokens[:num_contexts,
0]
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
# generation
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()