From 4bac6b337eee525ef5b6981b64ea3a77ee10fad7 Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Tue, 14 Oct 2025 20:51:45 +0800 Subject: [PATCH] [https://nvbugs/5537348][fix] Use device tensor index for MTP (#8062) Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- tensorrt_llm/_torch/speculative/mtp.py | 27 +++++++++++--------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 1262b8d502..a3f4f83ac0 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -1195,20 +1195,17 @@ class MTPEagleWorker(MTPWorker): position_ids = position_ids.squeeze(0) last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 - last_tokens_idx_host = torch.cumsum( - attn_metadata.seq_lens, dim=0, dtype=torch.long) - 1 - return position_ids, last_tokens_idx, last_tokens_idx_host + return position_ids, last_tokens_idx - position_ids, last_tokens_idx, last_tokens_idx_host = prepare_position_ids_and_last_tokens( + position_ids, last_tokens_idx = prepare_position_ids_and_last_tokens( position_ids, attn_metadata) - inputs = self.prepare_drafter_inputs( - input_ids=input_ids, - position_ids=position_ids, - last_tokens_idx_host=last_tokens_idx_host, - hidden_states=hidden_states, - accepted_tokens=accepted_tokens, - attn_metadata=attn_metadata, - spec_metadata=spec_metadata) + inputs = self.prepare_drafter_inputs(input_ids=input_ids, + position_ids=position_ids, + last_tokens_idx=last_tokens_idx, + hidden_states=hidden_states, + accepted_tokens=accepted_tokens, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata) # Predict draft tokens next_draft_tokens = [] @@ -1355,7 +1352,7 @@ class MTPEagleWorker(MTPWorker): self, input_ids: torch.IntTensor, position_ids: torch.IntTensor, - last_tokens_idx_host: torch.LongTensor, + last_tokens_idx: torch.LongTensor, hidden_states: torch.Tensor, accepted_tokens: torch.Tensor, attn_metadata: AttentionMetadata, @@ -1370,9 +1367,7 @@ class MTPEagleWorker(MTPWorker): device="cuda") input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) input_ids_ctx[ - last_tokens_idx_host[: - num_contexts]] = accepted_tokens[:num_contexts, - 0] + last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0] # generation input_ids_gen = accepted_tokens[num_contexts:, :].flatten()