TensorRT-LLMs/tensorrt_llm/runtime/medusa_utils.py
Kaiyu Xie 0f041b7b57
Update TensorRT-LLM (#1098)
* Update TensorRT-LLM

* update submodule

* Remove unused binaries
2024-02-18 15:48:08 +08:00

179 lines
6.7 KiB
Python

import copy
from argparse import Namespace
from functools import cmp_to_key
from typing import List
import numpy as np
import torch
from tensorrt_llm.logger import logger
def path_sorter(a, b):
for i in range(min(len(a), len(b))):
if a[i] != b[i]:
return -1 if a[i] < b[i] else 1
return 0 # shouldn't reach
path_sorting_key = cmp_to_key(path_sorter)
def expand_choices_if_needed(medusa_choices: List[List[int]]):
"""
Do a simple check to see if the given choices are path-only or vanilla.
"""
assert len(medusa_choices) > 0
for c in medusa_choices:
if len(c) > 1:
try:
_ = medusa_choices.index(
[c[0]]) # find the first parent of current path
logger.debug(
"Detected vanilla-style of Medusa choices. No need to expand."
)
return medusa_choices # if found, just return assuming it is already expanded
except ValueError:
logger.debug(
"Detected path-only style of Medusa choices. Expanding ...")
break
expanded_choices = set()
for c in medusa_choices:
cur = ()
for n in c:
cur = (*cur, n)
expanded_choices.add(cur)
expanded_choices = [list(c) for c in expanded_choices]
return expanded_choices
def get_packed_mask(num_medusa_tokens, medusa_mask):
num_packed_masks = (num_medusa_tokens + 1 + 32 - 1) // 32
medusa_packed_mask = torch.zeros((num_medusa_tokens + 1, num_packed_masks),
dtype=torch.int32)
for token_idx in range(num_medusa_tokens + 1):
if token_idx == 0:
medusa_packed_mask[0, 0] = 1
else:
mask_list = medusa_mask[token_idx - 1, :].tolist()
# insert 1 as there is one extra new token from the original lm head.
mask_list.insert(0, True)
# convert binary bits into 4 int32_t
mask_str_list = [str(int(val)) for val in mask_list]
mask_str_list.reverse()
for mask_idx in range(num_packed_masks):
if mask_idx * 32 >= len(mask_str_list):
break
mask_32bits_str = ''.join(mask_str_list[-(mask_idx + 1) * 32:
(-mask_idx * 32 - 1)] +
[mask_str_list[(-mask_idx * 32 - 1)]])
valid_num_bits = len(mask_32bits_str)
first_bit1 = mask_32bits_str[0] == '1'
mask_31bits_str = mask_32bits_str[1:]
mask_31bits = int(mask_31bits_str, 2)
if valid_num_bits == 32:
mask_32bits = mask_31bits - first_bit1 * (2**(
valid_num_bits - 1))
else:
mask_32bits = mask_31bits + first_bit1 * (2**(
valid_num_bits - 1))
medusa_packed_mask[token_idx, mask_idx] = mask_32bits
return medusa_packed_mask
def choices_2_paths(num_medusa_heads, choices):
paths = {}
all_paths = {}
level_counts = [0] * num_medusa_heads
choices.sort(key=len, reverse=True)
for c in choices:
k = ":".join([str(ci) for ci in c])
if k not in all_paths:
paths[k] = c
for i in range(len(c)):
k = ":".join([str(ci) for ci in c[:i + 1]])
if k not in all_paths:
all_paths[k] = c[:i + 1]
level_counts[i] += 1
return list(paths.values()), level_counts, paths, all_paths
def get_medusa_topks(num_medusa_heads, paths):
medusa_topks = [0] * num_medusa_heads
for p in paths:
for i, k in enumerate(p):
medusa_topks[i] = max(medusa_topks[i], k + 1)
return medusa_topks
def get_medusa_tree(num_medusa_heads, medusa_topks, level_counts, paths):
cum_topks = np.cumsum([0] + medusa_topks)
cum_level_counts = np.cumsum([0] + level_counts)
tree_paths = copy.deepcopy(paths)
medusa_tree_ids = list(np.arange(medusa_topks[0]))
medusa_position_offsets = [0] * medusa_topks[0]
for i in range(1, num_medusa_heads):
last_prefix = "-1"
last = -1
c = -1
for pi, p in enumerate(paths):
if i < len(p):
prefix_str = ":".join([str(k) for k in p[:i]])
if last_prefix != prefix_str or last != p[i]:
# new path
medusa_position_offsets.append(i)
medusa_tree_ids.append(p[i] + cum_topks[i])
last_prefix = prefix_str
last = p[i]
c += 1
tree_paths[pi][i] = cum_level_counts[i] + c
return medusa_tree_ids, medusa_position_offsets, tree_paths
def get_medusa_mask(medusa_tree_ids, medusa_paths):
medusa_mask = torch.zeros((len(medusa_tree_ids), len(medusa_tree_ids)))
medusa_mask[:, 0] = 1
for p in medusa_paths:
for i, idx in enumerate(p):
if idx < 0:
continue
for j in range(i + 1):
medusa_mask[idx, p[j]] = 1
return medusa_mask
def _medusa_setup(choices_or_paths, num_medusa_heads=None):
choices = copy.deepcopy(choices_or_paths)
sorted_choices = sorted(choices, key=path_sorting_key)
if num_medusa_heads is None:
num_medusa_heads = max([len(c) for c in sorted_choices])
paths, level_counts, _, _ = choices_2_paths(num_medusa_heads,
sorted_choices)
paths = sorted(paths, key=path_sorting_key)
medusa_topks = get_medusa_topks(num_medusa_heads, paths)
medusa_tree_ids, medusa_position_offsets, tree_paths = get_medusa_tree(
num_medusa_heads, medusa_topks, level_counts, paths)
num_medusa_tokens = len(medusa_tree_ids)
# now do the padding before converting to torch.Tensor
medusa_paths = []
for p in tree_paths:
medusa_paths.append(
torch.tensor([-1] + p + ([-2] * (num_medusa_heads - len(p)))))
medusa_topks = torch.tensor(medusa_topks)
medusa_paths = torch.stack(medusa_paths) + 1
medusa_tree_ids = torch.tensor([-1] + medusa_tree_ids) + 1
medusa_position_offsets = torch.tensor([-1] + medusa_position_offsets) + 1
medusa_mask = get_medusa_mask(medusa_tree_ids, medusa_paths)
medusa_packed_mask = get_packed_mask(num_medusa_tokens, medusa_mask[1:, 1:])
return Namespace(
medusa_mask=medusa_mask.cuda(),
medusa_packed_mask=medusa_packed_mask.cuda(),
medusa_topks=medusa_topks.cuda(),
medusa_paths=medusa_paths.cuda(),
medusa_tree_ids=medusa_tree_ids.cuda(),
medusa_position_offsets=medusa_position_offsets.cuda(),
)