mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Eddie-Wang1120 <81598289+Eddie-Wang1120@users.noreply.github.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
297 lines
12 KiB
Python
297 lines
12 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import unittest
|
|
|
|
from parameterized import parameterized
|
|
|
|
# isort: off
|
|
import torch
|
|
# isort: on
|
|
|
|
import copy
|
|
from argparse import Namespace
|
|
from typing import List
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm.runtime.medusa_utils import (_medusa_setup, choices_2_paths,
|
|
expand_choices_if_needed,
|
|
get_packed_mask,
|
|
path_sorting_key)
|
|
|
|
#############################################
|
|
## Taken from Medusa and slightly modified ##
|
|
## Used for testing only ##
|
|
#############################################
|
|
|
|
|
|
def pad_path(path, length, pad_value=-2):
|
|
"""
|
|
Pad the given path list with a specific value up to a specified length.
|
|
|
|
Parameters:
|
|
- path (list): The original list that needs padding.
|
|
- length (int): The desired length of the padded list.
|
|
- pad_value (optional, default=-2): The value to use for padding.
|
|
|
|
Returns:
|
|
- list: A new list based on the original path but padded to the desired length.
|
|
|
|
Example:
|
|
>>> pad_path([1,2,3], 5)
|
|
[1, 2, 3, -2, -2]
|
|
|
|
Note:
|
|
If the given path is already longer than the specified length,
|
|
then no padding occurs, and the original path is returned.
|
|
"""
|
|
|
|
# Calculate the number of padding values needed by subtracting the length
|
|
# of the path from the desired length.
|
|
# Append the padding values to the original path and return the new list.
|
|
return path + [pad_value] * (length - len(path))
|
|
|
|
|
|
def generate_medusa_info(medusa_choices,
|
|
num_medusa_heads=None,
|
|
device="cuda"): # MODIFIED
|
|
"""
|
|
[MODIFIED]
|
|
NOTE:
|
|
tree_ids is different than the vanilla implementation.
|
|
Instead of TOPK to be fixed for all the heads, we can vary it based on the given choices.
|
|
Example: [[0], [1], [0,0]] would lead to topk for the 2 Medusa heads to be [2, 1].
|
|
And tree_ids would be [0, 1, 2, 3], instead of vanilla [0, 1, 2, 11]
|
|
|
|
Generate buffers for the Medusa structure based on the provided choices.
|
|
|
|
Parameters:
|
|
- medusa_choices (list): A nested list representing tree in the Medusa structure.
|
|
- device (str): Device to which the tensors should be moved. Default is "cuda".
|
|
|
|
Returns:
|
|
- dict: A dictionary containing buffers related to the Medusa structure.
|
|
"""
|
|
|
|
# Sort the medusa_choices based on their lengths and then their values
|
|
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
|
|
medusa_len = len(sorted_medusa_choices) + 1
|
|
|
|
# Initialize depth_counts to keep track of how many choices have a particular depth
|
|
# MODIFIED
|
|
if num_medusa_heads is None:
|
|
num_medusa_heads = max([len(c) for c in medusa_choices])
|
|
medusa_topks = [0] * num_medusa_heads
|
|
#
|
|
depth_counts = []
|
|
prev_depth = 0
|
|
for path in sorted_medusa_choices:
|
|
depth = len(path)
|
|
if depth != prev_depth:
|
|
depth_counts.append(0)
|
|
depth_counts[depth - 1] += 1
|
|
medusa_topks[depth - 1] = max(medusa_topks[depth - 1],
|
|
path[-1] + 1) # MODIFIED
|
|
prev_depth = depth
|
|
|
|
# Create the attention mask for Medusa
|
|
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
|
|
medusa_attn_mask[:, 0] = 1
|
|
start = 0
|
|
for i in range(len(depth_counts)):
|
|
for j in range(depth_counts[i]):
|
|
cur_medusa_choice = sorted_medusa_choices[start + j]
|
|
# retrieve ancestor position
|
|
if len(cur_medusa_choice) == 1:
|
|
continue
|
|
ancestor_idx = []
|
|
for c in range(len(cur_medusa_choice) - 1):
|
|
ancestor_idx.append(
|
|
sorted_medusa_choices.index(cur_medusa_choice[:c + 1]) + 1)
|
|
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
|
|
start += depth_counts[i]
|
|
|
|
# Generate tree indices for the Medusa structure
|
|
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
|
|
medusa_tree_indices[0] = 0
|
|
start = 0
|
|
for i in range(len(depth_counts)):
|
|
for j in range(depth_counts[i]):
|
|
cur_medusa_choice = sorted_medusa_choices[start + j]
|
|
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + sum(
|
|
medusa_topks[:i]) + 1 # MODIFIED
|
|
start += depth_counts[i]
|
|
|
|
# Generate position IDs for the Medusa structure
|
|
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
|
|
start = 0
|
|
for i in range(len(depth_counts)):
|
|
medusa_position_ids[start + 1:start + depth_counts[i] + 1] = i + 1
|
|
start += depth_counts[i]
|
|
|
|
# Generate retrieval indices for Medusa structure verification
|
|
retrieve_indices_nest = []
|
|
retrieve_paths = []
|
|
for i in range(len(sorted_medusa_choices)):
|
|
cur_medusa_choice = sorted_medusa_choices[-i - 1]
|
|
retrieve_indice = []
|
|
if cur_medusa_choice in retrieve_paths:
|
|
continue
|
|
else:
|
|
for c in range(len(cur_medusa_choice)):
|
|
retrieve_indice.append(
|
|
sorted_medusa_choices.index(cur_medusa_choice[:c + 1]))
|
|
retrieve_paths.append(cur_medusa_choice[:c + 1])
|
|
retrieve_indices_nest.append(retrieve_indice)
|
|
max_length = max([len(x) for x in retrieve_indices_nest])
|
|
retrieve_indices = [
|
|
pad_path(path, max_length) for path in retrieve_indices_nest
|
|
]
|
|
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
|
|
retrieve_indices = retrieve_indices + 1
|
|
retrieve_indices = torch.cat([
|
|
torch.zeros(
|
|
(retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices
|
|
],
|
|
dim=1)
|
|
|
|
# Aggregate the generated buffers into a dictionary
|
|
medusa_buffers = {
|
|
"medusa_attn_mask":
|
|
medusa_attn_mask, #.unsqueeze(0).unsqueeze(0), # MODIFIED
|
|
"tree_indices": medusa_tree_indices,
|
|
"medusa_position_ids": medusa_position_ids,
|
|
"retrieve_indices": retrieve_indices,
|
|
# MODIFIED
|
|
"medusa_topks": medusa_topks,
|
|
}
|
|
|
|
# Move the tensors in the dictionary to the specified device
|
|
medusa_buffers = {
|
|
k: v.clone().to(device)
|
|
if isinstance(v, torch.Tensor) else torch.tensor(v, device=device)
|
|
for k, v in medusa_buffers.items()
|
|
}
|
|
return medusa_buffers
|
|
|
|
|
|
#############################################
|
|
## End of imported code ##
|
|
#############################################
|
|
|
|
|
|
def _vanilla_medusa_setup(medusa_choices: List[List[int]],
|
|
num_medusa_heads=None):
|
|
"""
|
|
Just a wrapper around the vanilla (modified) method, in case we write our own in the future.
|
|
One flexibility added: medusa choices can also be given as only the full paths e.g.
|
|
Path-only style:
|
|
[[0,0,0,0],[0,1,0],[1,0],[1,1]] is equivalent to
|
|
Vanilla-style:
|
|
[[0], [0,0], [0,0,0], [0,0,0,0], [0,1], [0,1,0], [1], [1,0], [1,1]]
|
|
"""
|
|
medusa_choices = copy.deepcopy(medusa_choices)
|
|
medusa_choices = expand_choices_if_needed(medusa_choices)
|
|
medusa_info = generate_medusa_info(medusa_choices, num_medusa_heads)
|
|
medusa_topks = medusa_info['medusa_topks']
|
|
medusa_mask = medusa_info['medusa_attn_mask']
|
|
medusa_paths = medusa_info['retrieve_indices']
|
|
medusa_tree_ids = medusa_info['tree_indices']
|
|
medusa_position_offsets = medusa_info['medusa_position_ids']
|
|
medusa_packed_mask = get_packed_mask(len(medusa_choices), medusa_mask[1:,
|
|
1:])
|
|
return Namespace(
|
|
medusa_mask=medusa_mask,
|
|
medusa_packed_mask=medusa_packed_mask,
|
|
medusa_topks=medusa_topks,
|
|
medusa_paths=medusa_paths,
|
|
medusa_tree_ids=medusa_tree_ids,
|
|
medusa_position_offsets=medusa_position_offsets,
|
|
)
|
|
|
|
|
|
class TestMedusaUtils(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
tensorrt_llm.logger.set_level('error')
|
|
|
|
@parameterized.expand([
|
|
([[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 1], [0, 1, 0], [1], [1, 0],
|
|
[1, 1]], ),
|
|
([[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 1], [0, 1, 0], [1], [1, 0],
|
|
[1, 1], [2], [2, 0], [0, 2], [0, 2, 0]], ),
|
|
([[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3],
|
|
[4], [0, 4], [2, 0]], ),
|
|
(
|
|
[[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 1], [0, 1, 0], [1],
|
|
[1, 0], [1, 1]],
|
|
[[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]],
|
|
),
|
|
([[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1]], ),
|
|
# test the original choices
|
|
(
|
|
[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3],
|
|
[0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6],
|
|
[0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9],
|
|
[8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1],
|
|
[0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6],
|
|
[0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0],
|
|
[0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2],
|
|
[8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0],
|
|
[0, 0, 0, 1], [1, 6], [0, 7, 0]], )
|
|
])
|
|
def test_all_buffers(self, choices, paths=None):
|
|
num_medusa_heads = max([len(c) for c in choices])
|
|
info = _vanilla_medusa_setup(choices)
|
|
vanilla_paths = sorted(info.medusa_paths.tolist(), key=path_sorting_key)
|
|
cinfo = _medusa_setup(choices)
|
|
test_paths = sorted(cinfo.medusa_paths.tolist(), key=path_sorting_key)
|
|
assert info.medusa_tree_ids.equal(
|
|
cinfo.medusa_tree_ids
|
|
), f"{info.medusa_tree_ids} \n\n!=\n\n {cinfo.medusa_tree_ids}"
|
|
assert vanilla_paths == test_paths, f"\n{vanilla_paths} \n\n!=\n\n {test_paths}"
|
|
assert info.medusa_position_offsets.equal(
|
|
cinfo.medusa_position_offsets
|
|
), f"{info.medusa_position_offsets} != {cinfo.medusa_position_offsets}"
|
|
assert info.medusa_mask.equal(
|
|
cinfo.medusa_mask
|
|
), f"{info.medusa_mask} \n\n!=\n\n {cinfo.medusa_mask}"
|
|
assert info.medusa_packed_mask.cuda().equal(
|
|
cinfo.medusa_packed_mask
|
|
), f"{info.medusa_packed_mask} != {cinfo.medusa_packed_mask}"
|
|
if paths is not None:
|
|
deduced_choices = expand_choices_if_needed(paths)
|
|
assert sorted(deduced_choices) == sorted(
|
|
choices), f"{deduced_choices} != {choices}"
|
|
deduced_paths, _, _, _ = choices_2_paths(num_medusa_heads, choices)
|
|
assert sorted(deduced_paths, key=path_sorting_key) == sorted(
|
|
paths, key=path_sorting_key), f"{deduced_paths} != {paths}"
|
|
pinfo = _medusa_setup(paths)
|
|
assert info.medusa_paths.sort(dim=0).values.equal(
|
|
pinfo.medusa_paths.sort(dim=0).values
|
|
), f"{info.medusa_paths.sort(dim=0).values} != {pinfo.medusa_paths.sort(dim=0).values}"
|
|
assert info.medusa_tree_ids.equal(
|
|
pinfo.medusa_tree_ids
|
|
), f"{info.medusa_tree_ids} != {pinfo.medusa_tree_ids}"
|
|
assert info.medusa_position_offsets.equal(
|
|
pinfo.medusa_position_offsets
|
|
), f"{info.medusa_position_offsets} != {pinfo.medusa_position_offsets}"
|
|
assert info.medusa_mask.equal(
|
|
pinfo.medusa_mask
|
|
), f"{info.medusa_mask} \n\n!=\n\n {pinfo.medusa_mask}"
|
|
assert info.medusa_packed_mask.cuda().equal(
|
|
pinfo.medusa_packed_mask
|
|
), f"{info.medusa_packed_mask} != {pinfo.medusa_packed_mask}"
|
|
return
|