mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5537348][fix] Use device tensor index for MTP (#8062)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
parent
7b5ba7ca66
commit
4bac6b337e
@ -1195,20 +1195,17 @@ class MTPEagleWorker(MTPWorker):
|
||||
position_ids = position_ids.squeeze(0)
|
||||
last_tokens_idx = torch.cumsum(
|
||||
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
||||
last_tokens_idx_host = torch.cumsum(
|
||||
attn_metadata.seq_lens, dim=0, dtype=torch.long) - 1
|
||||
return position_ids, last_tokens_idx, last_tokens_idx_host
|
||||
return position_ids, last_tokens_idx
|
||||
|
||||
position_ids, last_tokens_idx, last_tokens_idx_host = prepare_position_ids_and_last_tokens(
|
||||
position_ids, last_tokens_idx = prepare_position_ids_and_last_tokens(
|
||||
position_ids, attn_metadata)
|
||||
inputs = self.prepare_drafter_inputs(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
last_tokens_idx_host=last_tokens_idx_host,
|
||||
hidden_states=hidden_states,
|
||||
accepted_tokens=accepted_tokens,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
inputs = self.prepare_drafter_inputs(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
last_tokens_idx=last_tokens_idx,
|
||||
hidden_states=hidden_states,
|
||||
accepted_tokens=accepted_tokens,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
|
||||
# Predict draft tokens
|
||||
next_draft_tokens = []
|
||||
@ -1355,7 +1352,7 @@ class MTPEagleWorker(MTPWorker):
|
||||
self,
|
||||
input_ids: torch.IntTensor,
|
||||
position_ids: torch.IntTensor,
|
||||
last_tokens_idx_host: torch.LongTensor,
|
||||
last_tokens_idx: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
accepted_tokens: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
@ -1370,9 +1367,7 @@ class MTPEagleWorker(MTPWorker):
|
||||
device="cuda")
|
||||
input_ids_ctx[:-1].copy_(input_prompt_ids[1:])
|
||||
input_ids_ctx[
|
||||
last_tokens_idx_host[:
|
||||
num_contexts]] = accepted_tokens[:num_contexts,
|
||||
0]
|
||||
last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0]
|
||||
|
||||
# generation
|
||||
input_ids_gen = accepted_tokens[num_contexts:, :].flatten()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user