TensorRT-LLMs/tests/unittest/utils/test_medusa_utils.py
Kaiyu Xie 3aa6b11d13
Update TensorRT-LLM (#2936)
* Update TensorRT-LLM

---------

Co-authored-by: changcui <cuichang147@gmail.com>
2025-03-18 21:25:19 +08:00

298 lines
12 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from parameterized import parameterized
# isort: off
import torch
# isort: on
import copy
from argparse import Namespace
from typing import List
import tensorrt_llm
from tensorrt_llm.runtime.medusa_utils import (_medusa_setup, choices_2_paths,
expand_choices_if_needed,
get_packed_mask,
path_sorting_key)
#############################################
## Taken from Medusa and slightly modified ##
## Used for testing only ##
#############################################
def pad_path(path, length, pad_value=-2):
"""
Pad the given path list with a specific value up to a specified length.
Parameters:
- path (list): The original list that needs padding.
- length (int): The desired length of the padded list.
- pad_value (optional, default=-2): The value to use for padding.
Returns:
- list: A new list based on the original path but padded to the desired length.
Example:
>>> pad_path([1,2,3], 5)
[1, 2, 3, -2, -2]
Note:
If the given path is already longer than the specified length,
then no padding occurs, and the original path is returned.
"""
# Calculate the number of padding values needed by subtracting the length
# of the path from the desired length.
# Append the padding values to the original path and return the new list.
return path + [pad_value] * (length - len(path))
def generate_medusa_info(medusa_choices,
num_medusa_heads=None,
device="cuda"): # MODIFIED
"""
[MODIFIED]
NOTE:
tree_ids is different than the vanilla implementation.
Instead of TOPK to be fixed for all the heads, we can vary it based on the given choices.
Example: [[0], [1], [0,0]] would lead to topk for the 2 Medusa heads to be [2, 1].
And tree_ids would be [0, 1, 2, 3], instead of vanilla [0, 1, 2, 11]
Generate buffers for the Medusa structure based on the provided choices.
Parameters:
- medusa_choices (list): A nested list representing tree in the Medusa structure.
- device (str): Device to which the tensors should be moved. Default is "cuda".
Returns:
- dict: A dictionary containing buffers related to the Medusa structure.
"""
# Sort the medusa_choices based on their lengths and then their values
sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
medusa_len = len(sorted_medusa_choices) + 1
# Initialize depth_counts to keep track of how many choices have a particular depth
# MODIFIED
if num_medusa_heads is None:
num_medusa_heads = max([len(c) for c in medusa_choices])
medusa_topks = [0] * num_medusa_heads
#
depth_counts = []
prev_depth = 0
for path in sorted_medusa_choices:
depth = len(path)
if depth != prev_depth:
depth_counts.append(0)
depth_counts[depth - 1] += 1
medusa_topks[depth - 1] = max(medusa_topks[depth - 1],
path[-1] + 1) # MODIFIED
prev_depth = depth
# Create the attention mask for Medusa
medusa_attn_mask = torch.eye(medusa_len, medusa_len)
medusa_attn_mask[:, 0] = 1
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
# retrieve ancestor position
if len(cur_medusa_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_medusa_choice) - 1):
ancestor_idx.append(
sorted_medusa_choices.index(cur_medusa_choice[:c + 1]) + 1)
medusa_attn_mask[j + start + 1, ancestor_idx] = 1
start += depth_counts[i]
# Generate tree indices for the Medusa structure
medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
medusa_tree_indices[0] = 0
start = 0
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_medusa_choice = sorted_medusa_choices[start + j]
medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + sum(
medusa_topks[:i]) + 1 # MODIFIED
start += depth_counts[i]
# Generate position IDs for the Medusa structure
medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
start = 0
for i in range(len(depth_counts)):
medusa_position_ids[start + 1:start + depth_counts[i] + 1] = i + 1
start += depth_counts[i]
# Generate retrieval indices for Medusa structure verification
retrieve_indices_nest = []
retrieve_paths = []
for i in range(len(sorted_medusa_choices)):
cur_medusa_choice = sorted_medusa_choices[-i - 1]
retrieve_indice = []
if cur_medusa_choice in retrieve_paths:
continue
else:
for c in range(len(cur_medusa_choice)):
retrieve_indice.append(
sorted_medusa_choices.index(cur_medusa_choice[:c + 1]))
retrieve_paths.append(cur_medusa_choice[:c + 1])
retrieve_indices_nest.append(retrieve_indice)
max_length = max([len(x) for x in retrieve_indices_nest])
retrieve_indices = [
pad_path(path, max_length) for path in retrieve_indices_nest
]
retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
retrieve_indices = retrieve_indices + 1
retrieve_indices = torch.cat([
torch.zeros(
(retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices
],
dim=1)
# Aggregate the generated buffers into a dictionary
medusa_buffers = {
"medusa_attn_mask":
medusa_attn_mask, #.unsqueeze(0).unsqueeze(0), # MODIFIED
"tree_indices": medusa_tree_indices,
"medusa_position_ids": medusa_position_ids,
"retrieve_indices": retrieve_indices,
# MODIFIED
"medusa_topks": medusa_topks,
}
# Move the tensors in the dictionary to the specified device
medusa_buffers = {
k:
v.clone().to(device)
if isinstance(v, torch.Tensor) else torch.tensor(v, device=device)
for k, v in medusa_buffers.items()
}
return medusa_buffers
#############################################
## End of imported code ##
#############################################
def _vanilla_medusa_setup(medusa_choices: List[List[int]],
num_medusa_heads=None):
"""
Just a wrapper around the vanilla (modified) method, in case we write our own in the future.
One flexibility added: medusa choices can also be given as only the full paths e.g.
Path-only style:
[[0,0,0,0],[0,1,0],[1,0],[1,1]] is equivalent to
Vanilla-style:
[[0], [0,0], [0,0,0], [0,0,0,0], [0,1], [0,1,0], [1], [1,0], [1,1]]
"""
medusa_choices = copy.deepcopy(medusa_choices)
medusa_choices = expand_choices_if_needed(medusa_choices)
medusa_info = generate_medusa_info(medusa_choices, num_medusa_heads)
medusa_topks = medusa_info['medusa_topks']
medusa_mask = medusa_info['medusa_attn_mask']
medusa_paths = medusa_info['retrieve_indices']
medusa_tree_ids = medusa_info['tree_indices']
medusa_position_offsets = medusa_info['medusa_position_ids']
medusa_packed_mask = get_packed_mask(len(medusa_choices), medusa_mask[1:,
1:])
return Namespace(
medusa_mask=medusa_mask,
medusa_packed_mask=medusa_packed_mask,
medusa_topks=medusa_topks,
medusa_paths=medusa_paths,
medusa_tree_ids=medusa_tree_ids,
medusa_position_offsets=medusa_position_offsets,
)
class TestMedusaUtils(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level('error')
@parameterized.expand([
([[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 1], [0, 1, 0], [1], [1, 0],
[1, 1]], ),
([[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 1], [0, 1, 0], [1], [1, 0],
[1, 1], [2], [2, 0], [0, 2], [0, 2, 0]], ),
([[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3],
[4], [0, 4], [2, 0]], ),
(
[[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 1], [0, 1, 0], [1],
[1, 0], [1, 1]],
[[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]],
),
([[0], [0, 0], [0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1]], ),
# test the original choices
(
[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3],
[0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6],
[0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9],
[8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1],
[0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6],
[0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0],
[0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2],
[8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0],
[0, 0, 0, 1], [1, 6], [0, 7, 0]], )
])
def test_all_buffers(self, choices, paths=None):
num_medusa_heads = max([len(c) for c in choices])
info = _vanilla_medusa_setup(choices)
vanilla_paths = sorted(info.medusa_paths.tolist(), key=path_sorting_key)
cinfo = _medusa_setup(choices)
test_paths = sorted(cinfo.medusa_paths.tolist(), key=path_sorting_key)
assert info.medusa_tree_ids.equal(
cinfo.medusa_tree_ids
), f"{info.medusa_tree_ids} \n\n!=\n\n {cinfo.medusa_tree_ids}"
assert vanilla_paths == test_paths, f"\n{vanilla_paths} \n\n!=\n\n {test_paths}"
assert info.medusa_position_offsets.equal(
cinfo.medusa_position_offsets
), f"{info.medusa_position_offsets} != {cinfo.medusa_position_offsets}"
assert info.medusa_mask.equal(
cinfo.medusa_mask
), f"{info.medusa_mask} \n\n!=\n\n {cinfo.medusa_mask}"
assert info.medusa_packed_mask.cuda().equal(
cinfo.medusa_packed_mask
), f"{info.medusa_packed_mask} != {cinfo.medusa_packed_mask}"
if paths is not None:
deduced_choices = expand_choices_if_needed(paths)
assert sorted(deduced_choices) == sorted(
choices), f"{deduced_choices} != {choices}"
deduced_paths, _, _, _ = choices_2_paths(num_medusa_heads, choices)
assert sorted(deduced_paths, key=path_sorting_key) == sorted(
paths, key=path_sorting_key), f"{deduced_paths} != {paths}"
pinfo = _medusa_setup(paths)
assert info.medusa_paths.sort(dim=0).values.equal(
pinfo.medusa_paths.sort(dim=0).values
), f"{info.medusa_paths.sort(dim=0).values} != {pinfo.medusa_paths.sort(dim=0).values}"
assert info.medusa_tree_ids.equal(
pinfo.medusa_tree_ids
), f"{info.medusa_tree_ids} != {pinfo.medusa_tree_ids}"
assert info.medusa_position_offsets.equal(
pinfo.medusa_position_offsets
), f"{info.medusa_position_offsets} != {pinfo.medusa_position_offsets}"
assert info.medusa_mask.equal(
pinfo.medusa_mask
), f"{info.medusa_mask} \n\n!=\n\n {pinfo.medusa_mask}"
assert info.medusa_packed_mask.cuda().equal(
pinfo.medusa_packed_mask
), f"{info.medusa_packed_mask} != {pinfo.medusa_packed_mask}"
return